00001
00009 #include "party.h"
00010
00021 void C_surrogates(SEXP node, SEXP learnsample, SEXP weights, SEXP controls,
00022 SEXP fitmem) {
00023
00024 SEXP x, y, expcovinf;
00025 SEXP splitctrl, inputs;
00026 SEXP split, thiswhichNA;
00027 int nobs, ninputs, i, j, k, jselect, maxsurr, *order;
00028 double ms, cp, *thisweights, *cutpoint, *maxstat,
00029 *splitstat, *dweights, *tweights, *dx, *dy;
00030 double cut, *twotab;
00031
00032 nobs = get_nobs(learnsample);
00033 ninputs = get_ninputs(learnsample);
00034 splitctrl = get_splitctrl(controls);
00035 maxsurr = get_maxsurrogate(splitctrl);
00036
00037 if (maxsurr != LENGTH(S3get_surrogatesplits(node)))
00038 error("nodes does not have %s surrogate splits", maxsurr);
00039
00040 inputs = GET_SLOT(learnsample, PL2_inputsSym);
00041 jselect = S3get_variableID(S3get_primarysplit(node));
00042 y = S3get_nodeweights(VECTOR_ELT(node, 7));
00043
00044 tweights = Calloc(nobs, double);
00045 dweights = REAL(weights);
00046 for (i = 0; i < nobs; i++) tweights[i] = dweights[i];
00047 if (has_missings(inputs, jselect)) {
00048 thiswhichNA = get_missings(inputs, jselect);
00049 for (k = 0; k < LENGTH(thiswhichNA); k++)
00050 tweights[INTEGER(thiswhichNA)[k] - 1] = 0.0;
00051 }
00052
00053 expcovinf = GET_SLOT(fitmem, PL2_expcovinfssSym);
00054 C_ExpectCovarInfluence(REAL(y), 1, REAL(weights), nobs, expcovinf);
00055
00056 splitstat = REAL(get_splitstatistics(fitmem));
00057
00058 maxstat = Calloc(ninputs, double);
00059 cutpoint = Calloc(ninputs, double);
00060 order = Calloc(ninputs, int);
00061
00062
00063
00064
00065
00066
00067 for (j = 0; j < ninputs; j++) {
00068
00069 order[j] = j + 1;
00070 maxstat[j] = 0.0;
00071 cutpoint[j] = 0.0;
00072
00073
00074 if ((j + 1) == jselect || is_nominal(inputs, j + 1))
00075 continue;
00076
00077 x = get_variable(inputs, j + 1);
00078
00079 if (has_missings(inputs, j + 1)) {
00080
00081 thisweights = REAL(get_weights(fitmem, j + 1));
00082 for (i = 0; i < nobs; i++) thisweights[i] = tweights[i];
00083 thiswhichNA = get_missings(inputs, j + 1);
00084 for (k = 0; k < LENGTH(thiswhichNA); k++)
00085 thisweights[INTEGER(thiswhichNA)[k] - 1] = 0.0;
00086
00087 C_ExpectCovarInfluence(REAL(y), 1, thisweights, nobs, expcovinf);
00088
00089 C_split(REAL(x), 1, REAL(y), 1, thisweights, nobs,
00090 INTEGER(get_ordering(inputs, j + 1)), splitctrl,
00091 GET_SLOT(fitmem, PL2_linexpcov2sampleSym),
00092 expcovinf, &cp, &ms, splitstat);
00093 } else {
00094
00095 C_split(REAL(x), 1, REAL(y), 1, tweights, nobs,
00096 INTEGER(get_ordering(inputs, j + 1)), splitctrl,
00097 GET_SLOT(fitmem, PL2_linexpcov2sampleSym),
00098 expcovinf, &cp, &ms, splitstat);
00099 }
00100
00101 maxstat[j] = -ms;
00102 cutpoint[j] = cp;
00103 }
00104
00105
00106 rsort_with_index(maxstat, order, ninputs);
00107
00108 twotab = Calloc(4, double);
00109
00110
00111 for (j = 0; j < maxsurr; j++) {
00112
00113 for (i = 0; i < 4; i++) twotab[i] = 0.0;
00114 cut = cutpoint[order[j] - 1];
00115 SET_VECTOR_ELT(S3get_surrogatesplits(node), j,
00116 split = allocVector(VECSXP, SPLIT_LENGTH));
00117 C_init_orderedsplit(split, 0);
00118 S3set_variableID(split, order[j]);
00119 REAL(S3get_splitpoint(split))[0] = cut;
00120 dx = REAL(get_variable(inputs, order[j]));
00121 dy = REAL(y);
00122
00123
00124
00125
00126
00127 for (i = 0; i < nobs; i++) {
00128 twotab[0] += ((dy[i] == 1) && (dx[i] <= cut)) * tweights[i];
00129 twotab[1] += (dy[i] == 1) * tweights[i];
00130 twotab[2] += (dx[i] <= cut) * tweights[i];
00131 twotab[3] += tweights[i];
00132 }
00133 S3set_toleft(split, (int) (twotab[0] - twotab[1] * twotab[2] /
00134 twotab[3]) > 0);
00135 }
00136
00137 Free(maxstat);
00138 Free(cutpoint);
00139 Free(order);
00140 Free(tweights);
00141 Free(twotab);
00142 }
00143
00154 SEXP R_surrogates(SEXP node, SEXP learnsample, SEXP weights, SEXP controls,
00155 SEXP fitmem) {
00156
00157 C_surrogates(node, learnsample, weights, controls, fitmem);
00158 return(S3get_surrogatesplits(node));
00159
00160 }
00161
00169 void C_splitsurrogate(SEXP node, SEXP learnsample) {
00170
00171 SEXP weights, split, surrsplit;
00172 SEXP inputs, whichNA;
00173 double cutpoint, *dx, *dweights, *leftweights, *rightweights;
00174 int *iwhichNA, k;
00175 int nobs, i, nna, ns;
00176
00177 weights = S3get_nodeweights(node);
00178 dweights = REAL(weights);
00179 inputs = GET_SLOT(learnsample, PL2_inputsSym);
00180 nobs = get_nobs(learnsample);
00181
00182 leftweights = REAL(S3get_nodeweights(S3get_leftnode(node)));
00183 rightweights = REAL(S3get_nodeweights(S3get_rightnode(node)));
00184 surrsplit = S3get_surrogatesplits(node);
00185
00186
00187 split = S3get_primarysplit(node);
00188 if (has_missings(inputs, S3get_variableID(split))) {
00189
00190
00191 whichNA = get_missings(inputs, S3get_variableID(split));
00192 iwhichNA = INTEGER(whichNA);
00193 nna = LENGTH(whichNA);
00194
00195
00196 for (k = 0; k < nna; k++) {
00197 ns = 0;
00198 i = iwhichNA[k] - 1;
00199 if (dweights[i] == 0) continue;
00200
00201
00202 while(TRUE) {
00203
00204 if (ns >= LENGTH(surrsplit)) break;
00205
00206 split = VECTOR_ELT(surrsplit, ns);
00207 if (has_missings(inputs, S3get_variableID(split))) {
00208 if (INTEGER(get_missings(inputs,
00209 S3get_variableID(split)))[i]) {
00210 ns++;
00211 continue;
00212 }
00213 }
00214
00215 cutpoint = REAL(S3get_splitpoint(split))[0];
00216 dx = REAL(get_variable(inputs, S3get_variableID(split)));
00217
00218 if (S3get_toleft(split)) {
00219 if (dx[i] <= cutpoint) {
00220 leftweights[i] = dweights[i];
00221 rightweights[i] = 0.0;
00222 } else {
00223 rightweights[i] = dweights[i];
00224 leftweights[i] = 0.0;
00225 }
00226 } else {
00227 if (dx[i] <= cutpoint) {
00228 rightweights[i] = dweights[i];
00229 leftweights[i] = 0.0;
00230 } else {
00231 leftweights[i] = dweights[i];
00232 rightweights[i] = 0.0;
00233 }
00234 }
00235 break;
00236 }
00237 }
00238 }
00239 }