Main Page | Directories | File List | File Members | Related Pages

SurrogateSplits.c

Go to the documentation of this file.
00001 
00009 #include "party.h"
00010 
00021 void C_surrogates(SEXP node, SEXP learnsample, SEXP weights, SEXP controls, 
00022                   SEXP fitmem) {
00023 
00024     SEXP x, y, expcovinf; 
00025     SEXP splitctrl, inputs; 
00026     SEXP split, thiswhichNA;
00027     int nobs, ninputs, i, j, k, jselect, maxsurr, *order;
00028     double ms, cp, *thisweights, *cutpoint, *maxstat, 
00029            *splitstat, *dweights, *tweights, *dx, *dy;
00030     double cut, *twotab;
00031     
00032     nobs = get_nobs(learnsample);
00033     ninputs = get_ninputs(learnsample);
00034     splitctrl = get_splitctrl(controls);
00035     maxsurr = get_maxsurrogate(splitctrl);
00036 
00037     if (maxsurr != LENGTH(S3get_surrogatesplits(node)))
00038         error("nodes does not have %s surrogate splits", maxsurr);
00039 
00040     inputs = GET_SLOT(learnsample, PL2_inputsSym);
00041     jselect = S3get_variableID(S3get_primarysplit(node));
00042     y = S3get_nodeweights(VECTOR_ELT(node, 7));
00043 
00044     tweights = Calloc(nobs, double);
00045     dweights = REAL(weights);
00046     for (i = 0; i < nobs; i++) tweights[i] = dweights[i];
00047     if (has_missings(inputs, jselect)) {
00048         thiswhichNA = get_missings(inputs, jselect);
00049         for (k = 0; k < LENGTH(thiswhichNA); k++)
00050             tweights[INTEGER(thiswhichNA)[k] - 1] = 0.0;
00051     }
00052 
00053     expcovinf = GET_SLOT(fitmem, PL2_expcovinfssSym);
00054     C_ExpectCovarInfluence(REAL(y), 1, REAL(weights), nobs, expcovinf);
00055     
00056     splitstat = REAL(get_splitstatistics(fitmem));
00057     /* <FIXME> extend `TreeFitMemory' to those as well ... */
00058     maxstat = Calloc(ninputs, double);
00059     cutpoint = Calloc(ninputs, double);
00060     order = Calloc(ninputs, int);
00061     /* <FIXME> */
00062     
00063     /* this is essentially an exhaustive search */
00064     /* <FIXME>: we don't want to do this for random forest like trees 
00065        </FIXME>
00066      */
00067     for (j = 0; j < ninputs; j++) {
00068     
00069          order[j] = j + 1;
00070          maxstat[j] = 0.0;
00071          cutpoint[j] = 0.0;
00072 
00073          /* ordered input variables only (for the moment) */
00074          if ((j + 1) == jselect || is_nominal(inputs, j + 1))
00075              continue;
00076 
00077          x = get_variable(inputs, j + 1);
00078 
00079          if (has_missings(inputs, j + 1)) {
00080 
00081              thisweights = REAL(get_weights(fitmem, j + 1));
00082              for (i = 0; i < nobs; i++) thisweights[i] = tweights[i];
00083              thiswhichNA = get_missings(inputs, j + 1);
00084              for (k = 0; k < LENGTH(thiswhichNA); k++)
00085                  thisweights[INTEGER(thiswhichNA)[k] - 1] = 0.0;
00086                  
00087              C_ExpectCovarInfluence(REAL(y), 1, thisweights, nobs, expcovinf);
00088              
00089              C_split(REAL(x), 1, REAL(y), 1, thisweights, nobs,
00090                      INTEGER(get_ordering(inputs, j + 1)), splitctrl,
00091                      GET_SLOT(fitmem, PL2_linexpcov2sampleSym),
00092                      expcovinf, &cp, &ms, splitstat);
00093          } else {
00094          
00095              C_split(REAL(x), 1, REAL(y), 1, tweights, nobs,
00096              INTEGER(get_ordering(inputs, j + 1)), splitctrl,
00097              GET_SLOT(fitmem, PL2_linexpcov2sampleSym),
00098              expcovinf, &cp, &ms, splitstat);
00099          }
00100 
00101          maxstat[j] = -ms;
00102          cutpoint[j] = cp;
00103     }
00104 
00105     /* order with respect to maximal statistic */
00106     rsort_with_index(maxstat, order, ninputs);
00107     
00108     twotab = Calloc(4, double);
00109     
00110     /* the best `maxsurr' ones are implemented */
00111     for (j = 0; j < maxsurr; j++) {
00112 
00113         for (i = 0; i < 4; i++) twotab[i] = 0.0;
00114         cut = cutpoint[order[j] - 1];
00115         SET_VECTOR_ELT(S3get_surrogatesplits(node), j, 
00116                        split = allocVector(VECSXP, SPLIT_LENGTH));
00117         C_init_orderedsplit(split, 0);
00118         S3set_variableID(split, order[j]);
00119         REAL(S3get_splitpoint(split))[0] = cut;
00120         dx = REAL(get_variable(inputs, order[j]));
00121         dy = REAL(y);
00122 
00123         /* OK, this is a dirty hack: determine if the split 
00124            goes left or right by the Pearson residual of a 2x2 table.
00125            I don't want to use the big caliber here 
00126         */
00127         for (i = 0; i < nobs; i++) {
00128             twotab[0] += ((dy[i] == 1) && (dx[i] <= cut)) * tweights[i];
00129             twotab[1] += (dy[i] == 1) * tweights[i];
00130             twotab[2] += (dx[i] <= cut) * tweights[i];
00131             twotab[3] += tweights[i];
00132         }
00133         S3set_toleft(split, (int) (twotab[0] - twotab[1] * twotab[2] / 
00134                      twotab[3]) > 0);
00135     }
00136     
00137     Free(maxstat);
00138     Free(cutpoint);
00139     Free(order);
00140     Free(tweights);
00141     Free(twotab);
00142 }
00143 
00154 SEXP R_surrogates(SEXP node, SEXP learnsample, SEXP weights, SEXP controls, 
00155                   SEXP fitmem) {
00156 
00157     C_surrogates(node, learnsample, weights, controls, fitmem);
00158     return(S3get_surrogatesplits(node));
00159     
00160 }
00161 
00169 void C_splitsurrogate(SEXP node, SEXP learnsample) {
00170 
00171     SEXP weights, split, surrsplit;
00172     SEXP inputs, whichNA;
00173     double cutpoint, *dx, *dweights, *leftweights, *rightweights;
00174     int *iwhichNA, k;
00175     int nobs, i, nna, ns;
00176                     
00177     weights = S3get_nodeweights(node);
00178     dweights = REAL(weights);
00179     inputs = GET_SLOT(learnsample, PL2_inputsSym);
00180     nobs = get_nobs(learnsample);
00181             
00182     leftweights = REAL(S3get_nodeweights(S3get_leftnode(node)));
00183     rightweights = REAL(S3get_nodeweights(S3get_rightnode(node)));
00184     surrsplit = S3get_surrogatesplits(node);
00185 
00186     /* if the primary split has any missings */
00187     split = S3get_primarysplit(node);
00188     if (has_missings(inputs, S3get_variableID(split))) {
00189 
00190         /* where are the missings? */
00191         whichNA = get_missings(inputs, S3get_variableID(split));
00192         iwhichNA = INTEGER(whichNA);
00193         nna = LENGTH(whichNA);
00194 
00195         /* for all missing values ... */
00196         for (k = 0; k < nna; k++) {
00197             ns = 0;
00198             i = iwhichNA[k] - 1;
00199             if (dweights[i] == 0) continue;
00200             
00201             /* loop over surrogate splits until an appropriate one is found */
00202             while(TRUE) {
00203             
00204                 if (ns >= LENGTH(surrsplit)) break;
00205             
00206                 split = VECTOR_ELT(surrsplit, ns);
00207                 if (has_missings(inputs, S3get_variableID(split))) {
00208                     if (INTEGER(get_missings(inputs, 
00209                             S3get_variableID(split)))[i]) {
00210                         ns++;
00211                         continue;
00212                     }
00213                 }
00214 
00215                 cutpoint = REAL(S3get_splitpoint(split))[0];
00216                 dx = REAL(get_variable(inputs, S3get_variableID(split)));
00217 
00218                 if (S3get_toleft(split)) {
00219                     if (dx[i] <= cutpoint) {
00220                         leftweights[i] = dweights[i];
00221                         rightweights[i] = 0.0;
00222                     } else {
00223                         rightweights[i] = dweights[i];
00224                         leftweights[i] = 0.0;
00225                     }
00226                 } else {
00227                     if (dx[i] <= cutpoint) {
00228                         rightweights[i] = dweights[i];
00229                         leftweights[i] = 0.0;
00230                     } else {
00231                         leftweights[i] = dweights[i];
00232                         rightweights[i] = 0.0;
00233                     }
00234                 }
00235                 break;
00236             }
00237         }
00238     }
00239 }

Generated on Fri Aug 25 14:30:01 2006 for party by  doxygen 1.4.4