00001
00009 #include "party.h"
00010
00021 void C_surrogates(SEXP node, SEXP learnsample, SEXP weights, SEXP controls,
00022 SEXP fitmem) {
00023
00024 SEXP x, y, expcovinf;
00025 SEXP splitctrl, inputs;
00026 SEXP split, thiswhichNA;
00027 int nobs, ninputs, i, j, k, jselect, maxsurr, *order, nvar = 0;
00028 double ms, cp, *thisweights, *cutpoint, *maxstat,
00029 *splitstat, *dweights, *tweights, *dx, *dy;
00030 double cut, *twotab, *ytmp, sumw = 0.0;
00031
00032 nobs = get_nobs(learnsample);
00033 ninputs = get_ninputs(learnsample);
00034 splitctrl = get_splitctrl(controls);
00035 maxsurr = get_maxsurrogate(splitctrl);
00036 inputs = GET_SLOT(learnsample, PL2_inputsSym);
00037 jselect = S3get_variableID(S3get_primarysplit(node));
00038
00039
00040 y = S3get_nodeweights(VECTOR_ELT(node, S3_LEFT));
00041 ytmp = Calloc(nobs, double);
00042 for (i = 0; i < nobs; i++) {
00043 ytmp[i] = REAL(y)[i];
00044 if (ytmp[i] > 1.0) ytmp[i] = 1.0;
00045 }
00046
00047 for (j = 0; j < ninputs; j++) {
00048 if (is_nominal(inputs, j + 1)) continue;
00049 nvar++;
00050 }
00051 nvar--;
00052
00053 if (maxsurr != LENGTH(S3get_surrogatesplits(node)))
00054 error("nodes does not have %d surrogate splits", maxsurr);
00055 if (maxsurr > nvar)
00056 error("cannot set up %d surrogate splits with only %d ordered input variable(s)",
00057 maxsurr, nvar);
00058
00059 tweights = Calloc(nobs, double);
00060 dweights = REAL(weights);
00061 for (i = 0; i < nobs; i++) tweights[i] = dweights[i];
00062 if (has_missings(inputs, jselect)) {
00063 thiswhichNA = get_missings(inputs, jselect);
00064 for (k = 0; k < LENGTH(thiswhichNA); k++)
00065 tweights[INTEGER(thiswhichNA)[k] - 1] = 0.0;
00066 }
00067
00068
00069 sumw = 0.0;
00070 for (i = 0; i < nobs; i++) sumw += tweights[i];
00071 if (sumw < 2.0)
00072 error("can't implement surrogate splits, not enough observations available");
00073
00074 expcovinf = GET_SLOT(fitmem, PL2_expcovinfssSym);
00075 C_ExpectCovarInfluence(ytmp, 1, tweights, nobs, expcovinf);
00076
00077 splitstat = REAL(get_splitstatistics(fitmem));
00078
00079 maxstat = Calloc(ninputs, double);
00080 cutpoint = Calloc(ninputs, double);
00081 order = Calloc(ninputs, int);
00082
00083
00084
00085
00086
00087
00088 for (j = 0; j < ninputs; j++) {
00089
00090 order[j] = j + 1;
00091 maxstat[j] = 0.0;
00092 cutpoint[j] = 0.0;
00093
00094
00095 if ((j + 1) == jselect || is_nominal(inputs, j + 1))
00096 continue;
00097
00098 x = get_variable(inputs, j + 1);
00099
00100 if (has_missings(inputs, j + 1)) {
00101
00102 thisweights = C_tempweights(j + 1, weights, fitmem, inputs);
00103
00104
00105 sumw = 0.0;
00106 for (i = 0; i < nobs; i++) sumw += thisweights[i];
00107 if (sumw < 2.0) continue;
00108
00109 C_ExpectCovarInfluence(ytmp, 1, thisweights, nobs, expcovinf);
00110
00111 C_split(REAL(x), 1, ytmp, 1, thisweights, nobs,
00112 INTEGER(get_ordering(inputs, j + 1)), splitctrl,
00113 GET_SLOT(fitmem, PL2_linexpcov2sampleSym),
00114 expcovinf, &cp, &ms, splitstat);
00115 } else {
00116
00117 C_split(REAL(x), 1, ytmp, 1, tweights, nobs,
00118 INTEGER(get_ordering(inputs, j + 1)), splitctrl,
00119 GET_SLOT(fitmem, PL2_linexpcov2sampleSym),
00120 expcovinf, &cp, &ms, splitstat);
00121 }
00122
00123 maxstat[j] = -ms;
00124 cutpoint[j] = cp;
00125 }
00126
00127
00128
00129
00130
00131
00132
00133
00134 rsort_with_index(maxstat, order, ninputs);
00135
00136 twotab = Calloc(4, double);
00137
00138
00139 for (j = 0; j < maxsurr; j++) {
00140
00141 for (i = 0; i < 4; i++) twotab[i] = 0.0;
00142 cut = cutpoint[order[j] - 1];
00143 SET_VECTOR_ELT(S3get_surrogatesplits(node), j,
00144 split = allocVector(VECSXP, SPLIT_LENGTH));
00145 C_init_orderedsplit(split, 0);
00146 S3set_variableID(split, order[j]);
00147 REAL(S3get_splitpoint(split))[0] = cut;
00148 dx = REAL(get_variable(inputs, order[j]));
00149 dy = REAL(y);
00150
00151
00152
00153
00154
00155 for (i = 0; i < nobs; i++) {
00156 twotab[0] += ((dy[i] == 1) && (dx[i] <= cut)) * tweights[i];
00157 twotab[1] += (dy[i] == 1) * tweights[i];
00158 twotab[2] += (dx[i] <= cut) * tweights[i];
00159 twotab[3] += tweights[i];
00160 }
00161 S3set_toleft(split, (int) (twotab[0] - twotab[1] * twotab[2] /
00162 twotab[3]) > 0);
00163 }
00164
00165 Free(maxstat);
00166 Free(cutpoint);
00167 Free(order);
00168 Free(tweights);
00169 Free(twotab);
00170 Free(ytmp);
00171 }
00172
00183 SEXP R_surrogates(SEXP node, SEXP learnsample, SEXP weights, SEXP controls,
00184 SEXP fitmem) {
00185
00186 C_surrogates(node, learnsample, weights, controls, fitmem);
00187 return(S3get_surrogatesplits(node));
00188
00189 }
00190
00198 void C_splitsurrogate(SEXP node, SEXP learnsample) {
00199
00200 SEXP weights, split, surrsplit;
00201 SEXP inputs, whichNA;
00202 double cutpoint, *dx, *dweights, *leftweights, *rightweights;
00203 int *iwhichNA, k;
00204 int nobs, i, nna, ns;
00205
00206 weights = S3get_nodeweights(node);
00207 dweights = REAL(weights);
00208 inputs = GET_SLOT(learnsample, PL2_inputsSym);
00209 nobs = get_nobs(learnsample);
00210
00211 leftweights = REAL(S3get_nodeweights(S3get_leftnode(node)));
00212 rightweights = REAL(S3get_nodeweights(S3get_rightnode(node)));
00213 surrsplit = S3get_surrogatesplits(node);
00214
00215
00216 split = S3get_primarysplit(node);
00217 if (has_missings(inputs, S3get_variableID(split))) {
00218
00219
00220 whichNA = get_missings(inputs, S3get_variableID(split));
00221 iwhichNA = INTEGER(whichNA);
00222 nna = LENGTH(whichNA);
00223
00224
00225 for (k = 0; k < nna; k++) {
00226 ns = 0;
00227 i = iwhichNA[k] - 1;
00228 if (dweights[i] == 0) continue;
00229
00230
00231 while(TRUE) {
00232
00233 if (ns >= LENGTH(surrsplit)) break;
00234
00235 split = VECTOR_ELT(surrsplit, ns);
00236 if (has_missings(inputs, S3get_variableID(split))) {
00237 if (INTEGER(get_missings(inputs,
00238 S3get_variableID(split)))[i]) {
00239 ns++;
00240 continue;
00241 }
00242 }
00243
00244 cutpoint = REAL(S3get_splitpoint(split))[0];
00245 dx = REAL(get_variable(inputs, S3get_variableID(split)));
00246
00247 if (S3get_toleft(split)) {
00248 if (dx[i] <= cutpoint) {
00249 leftweights[i] = dweights[i];
00250 rightweights[i] = 0.0;
00251 } else {
00252 rightweights[i] = dweights[i];
00253 leftweights[i] = 0.0;
00254 }
00255 } else {
00256 if (dx[i] <= cutpoint) {
00257 rightweights[i] = dweights[i];
00258 leftweights[i] = 0.0;
00259 } else {
00260 leftweights[i] = dweights[i];
00261 rightweights[i] = 0.0;
00262 }
00263 }
00264 break;
00265 }
00266 }
00267 }
00268 }