Main Page | Directories | File List | File Members | Related Pages

Node.c

Go to the documentation of this file.
00001 
00009 #include "party.h"
00010 
00011 
00022 void C_prediction(const double *y, int n, int q, const double *weights, 
00023                   const double sweights, double *ans) {
00024 
00025     int i, j, jn;
00026     
00027     for (j = 0; j < q; j++) {
00028         ans[j] = 0.0;
00029         jn = j * n;
00030         for (i = 0; i < n; i++) 
00031             ans[j] += weights[i] * y[jn + i];
00032         ans[j] = ans[j] / sweights;
00033     }
00034 }
00035 
00036 
00048 void C_Node(SEXP node, SEXP learnsample, SEXP weights, 
00049             SEXP fitmem, SEXP controls, int TERMINAL) {
00050     
00051     int nobs, ninputs, jselect, q, j, k, i;
00052     double mincriterion, sweights, *dprediction;
00053     double *teststat, *pvalue, smax, cutpoint = 0.0, maxstat = 0.0;
00054     double *standstat, *splitstat;
00055     SEXP responses, inputs, y, x, expcovinf, thisweights, linexpcov;
00056     SEXP varctrl, splitctrl, gtctrl, tgctrl, split, joint;
00057     double *dxtransf, *dweights;
00058     int *itable;
00059     
00060     nobs = get_nobs(learnsample);
00061     ninputs = get_ninputs(learnsample);
00062     varctrl = get_varctrl(controls);
00063     splitctrl = get_splitctrl(controls);
00064     gtctrl = get_gtctrl(controls);
00065     tgctrl = get_tgctrl(controls);
00066     mincriterion = get_mincriterion(gtctrl);
00067     responses = GET_SLOT(learnsample, PL2_responsesSym);
00068     inputs = GET_SLOT(learnsample, PL2_inputsSym);
00069     y = get_transformation(responses, 1);
00070     q = ncol(y);
00071     joint = GET_SLOT(responses, PL2_jointtransfSym);
00072 
00073     /* <FIXME> we compute C_GlobalTest even for TERMINAL nodes! </FIXME> */
00074 
00075     /* compute the test statistics and the node criteria for each input */        
00076     C_GlobalTest(learnsample, weights, fitmem, varctrl,
00077                  gtctrl, get_minsplit(splitctrl), 
00078                  REAL(S3get_teststat(node)), REAL(S3get_criterion(node)));
00079     
00080     /* sum of weights: C_GlobalTest did nothing if sweights < mincriterion */
00081     sweights = REAL(GET_SLOT(GET_SLOT(fitmem, PL2_expcovinfSym), 
00082                              PL2_sumweightsSym))[0];
00083 
00084     /* compute the prediction of this node */
00085     dprediction = REAL(S3get_prediction(node));
00086 
00087     /* <FIXME> feed raw numeric values OR dummy encoded factors as y 
00088        Problem: what happens for survival times ? */
00089     C_prediction(REAL(joint), nobs, ncol(joint), REAL(weights), 
00090                      sweights, dprediction);
00091     /* </FIXME> */
00092 
00093     teststat = REAL(S3get_teststat(node));
00094     pvalue = REAL(S3get_criterion(node));
00095 
00096     /* try the two out of ninputs best inputs variables */
00097     /* <FIXME> be more flexible and add a parameter controlling
00098                the number of inputs tried </FIXME> */
00099     for (j = 0; j < 2; j++) {
00100 
00101         smax = C_max(pvalue, ninputs);
00102         REAL(S3get_maxcriterion(node))[0] = smax;
00103     
00104         /* if the global null hypothesis was rejected */
00105         if (smax > mincriterion && !TERMINAL) {
00106 
00107             /* the input variable with largest association to the response */
00108             jselect = C_whichmax(pvalue, teststat, ninputs) + 1;
00109 
00110             /* get the raw numeric values or the codings of a factor */
00111             x = get_variable(inputs, jselect);
00112             if (has_missings(inputs, jselect)) {
00113                 expcovinf = GET_SLOT(get_varmemory(fitmem, jselect), 
00114                                     PL2_expcovinfSym);
00115                 thisweights = get_weights(fitmem, jselect);
00116             } else {
00117                 expcovinf = GET_SLOT(fitmem, PL2_expcovinfSym);
00118                 thisweights = weights;
00119             }
00120 
00121             /* <FIXME> handle ordered factors separatly??? </FIXME> */
00122             if (!is_nominal(inputs, jselect)) {
00123             
00124                 /* search for a split in a ordered variable x */
00125                 split = S3get_primarysplit(node);
00126                 
00127                 /* check if the n-vector of splitstatistics 
00128                    should be returned for each primary split */
00129                 if (get_savesplitstats(tgctrl)) {
00130                     C_init_orderedsplit(split, nobs);
00131                     splitstat = REAL(S3get_splitstatistics(split));
00132                 } else {
00133                     C_init_orderedsplit(split, 0);
00134                     splitstat = REAL(get_splitstatistics(fitmem));
00135                 }
00136 
00137                 C_split(REAL(x), 1, REAL(y), q, REAL(weights), nobs,
00138                         INTEGER(get_ordering(inputs, jselect)), splitctrl, 
00139                         GET_SLOT(fitmem, PL2_linexpcov2sampleSym),
00140                         expcovinf, REAL(S3get_splitpoint(split)), &maxstat,
00141                         splitstat);
00142                 S3set_variableID(split, jselect);
00143              } else {
00144            
00145                  /* search of a set of levels (split) in a numeric variable x */
00146                  split = S3get_primarysplit(node);
00147                  
00148                 /* check if the n-vector of splitstatistics 
00149                    should be returned for each primary split */
00150                 if (get_savesplitstats(tgctrl)) {
00151                     C_init_nominalsplit(split, 
00152                         LENGTH(get_levels(inputs, jselect)), 
00153                         nobs);
00154                     splitstat = REAL(S3get_splitstatistics(split));
00155                 } else {
00156                     C_init_nominalsplit(split, 
00157                         LENGTH(get_levels(inputs, jselect)), 
00158                         0);
00159                     splitstat = REAL(get_splitstatistics(fitmem));
00160                 }
00161           
00162                  linexpcov = get_varmemory(fitmem, jselect);
00163                  standstat = Calloc(get_dimension(linexpcov), double);
00164                  C_standardize(REAL(GET_SLOT(linexpcov, 
00165                                              PL2_linearstatisticSym)),
00166                                REAL(GET_SLOT(linexpcov, PL2_expectationSym)),
00167                                REAL(GET_SLOT(linexpcov, PL2_covarianceSym)),
00168                                get_dimension(linexpcov), get_tol(splitctrl), 
00169                                standstat);
00170  
00171                  C_splitcategorical(INTEGER(x), 
00172                                     LENGTH(get_levels(inputs, jselect)), 
00173                                     REAL(y), q, REAL(weights), 
00174                                     nobs, standstat, splitctrl, 
00175                                     GET_SLOT(fitmem, PL2_linexpcov2sampleSym),
00176                                     expcovinf, &cutpoint, 
00177                                     INTEGER(S3get_splitpoint(split)),
00178                                     &maxstat, splitstat);
00179 
00180                  /* compute which levels of a factor are available in this node 
00181                     (for printing) later on. A real `table' for this node would
00182                     induce too much overhead here. Maybe later. */
00183                     
00184                  itable = INTEGER(S3get_table(split));
00185                  dxtransf = REAL(get_transformation(inputs, jselect));
00186                  dweights = REAL(thisweights);
00187                  for (k = 0; k < LENGTH(get_levels(inputs, jselect)); k++) {
00188                      itable[k] = 0;
00189                      for (i = 0; i < nobs; i++) {
00190                          if (dxtransf[k * nobs + i] * dweights[i] > 0) {
00191                              itable[k] = 1;
00192                              continue;
00193                          }
00194                      }
00195                  }
00196 
00197                  Free(standstat);
00198             }
00199             if (maxstat == 0) {
00200                 warning("no admissible split found\n");
00201             
00202                 if (j == 1) {          
00203                     S3set_nodeterminal(node);
00204                 } else {
00205                     /* <FIXME> why? </FIXME> */
00206                     pvalue[jselect - 1] = 0.0;
00207                 }
00208             } else {
00209                 S3set_variableID(split, jselect);
00210                 break;
00211             }
00212         } else {
00213             S3set_nodeterminal(node);
00214             break;
00215         }
00216     }
00217 }       
00218 
00219 
00228 SEXP R_Node(SEXP learnsample, SEXP weights, SEXP fitmem, SEXP controls) {
00229             
00230      SEXP ans;
00231      
00232      PROTECT(ans = allocVector(VECSXP, NODE_LENGTH));
00233      C_init_node(ans, get_nobs(learnsample), get_ninputs(learnsample), 
00234                  get_maxsurrogate(get_splitctrl(controls)),
00235                  ncol(GET_SLOT(GET_SLOT(learnsample, PL2_responsesSym), 
00236                       PL2_jointtransfSym)));
00237 
00238      C_Node(ans, learnsample, weights, fitmem, controls, 0);
00239      UNPROTECT(1);
00240      return(ans);
00241 }

Generated on Fri Aug 25 14:30:01 2006 for party by  doxygen 1.4.4