Main Page | Directories | File List | File Members | Related Pages

Predict.c

Go to the documentation of this file.
00001 
00009 #include "party.h"
00010 
00011 
00021 void C_splitnode(SEXP node, SEXP learnsample, SEXP control) {
00022 
00023     SEXP weights, leftnode, rightnode, split;
00024     SEXP responses, inputs, whichNA;
00025     double cutpoint, *dx, *dweights, *leftweights, *rightweights;
00026     double sleft = 0.0, sright = 0.0;
00027     int *ix, *levelset, *iwhichNA;
00028     int nobs, i, nna;
00029                     
00030     weights = S3get_nodeweights(node);
00031     dweights = REAL(weights);
00032     responses = GET_SLOT(learnsample, PL2_responsesSym);
00033     inputs = GET_SLOT(learnsample, PL2_inputsSym);
00034     nobs = get_nobs(learnsample);
00035             
00036     /* set up memory for the left daughter */
00037     SET_VECTOR_ELT(node, S3_LEFT, leftnode = allocVector(VECSXP, NODE_LENGTH));
00038     C_init_node(leftnode, nobs, 
00039         get_ninputs(learnsample), get_maxsurrogate(get_splitctrl(control)),
00040         ncol(GET_SLOT(GET_SLOT(learnsample, PL2_responsesSym),
00041                       PL2_jointtransfSym)));
00042     leftweights = REAL(S3get_nodeweights(leftnode));
00043 
00044     /* set up memory for the right daughter */
00045     SET_VECTOR_ELT(node, S3_RIGHT, 
00046                    rightnode = allocVector(VECSXP, NODE_LENGTH));
00047     C_init_node(rightnode, nobs, 
00048         get_ninputs(learnsample), get_maxsurrogate(get_splitctrl(control)),
00049         ncol(GET_SLOT(GET_SLOT(learnsample, PL2_responsesSym),
00050                       PL2_jointtransfSym)));
00051     rightweights = REAL(S3get_nodeweights(rightnode));
00052 
00053     /* split according to the primary split */
00054     split = S3get_primarysplit(node);
00055     if (has_missings(inputs, S3get_variableID(split))) {
00056         whichNA = get_missings(inputs, S3get_variableID(split));
00057         iwhichNA = INTEGER(whichNA);
00058         nna = LENGTH(whichNA);
00059     } else {
00060         nna = 0;
00061         whichNA = R_NilValue;
00062         iwhichNA = NULL;
00063     }
00064     
00065     if (S3is_ordered(split)) {
00066         cutpoint = REAL(S3get_splitpoint(split))[0];
00067         dx = REAL(get_variable(inputs, S3get_variableID(split)));
00068         for (i = 0; i < nobs; i++) {
00069             if (nna > 0) {
00070                 if (i_in_set(i + 1, iwhichNA, nna)) continue;
00071             }
00072             if (dx[i] <= cutpoint) 
00073                 leftweights[i] = dweights[i]; 
00074             else 
00075                 leftweights[i] = 0.0;
00076             rightweights[i] = dweights[i] - leftweights[i];
00077             sleft += leftweights[i];
00078             sright += rightweights[i];
00079         }
00080     } else {
00081         levelset = INTEGER(S3get_splitpoint(split));
00082         ix = INTEGER(get_variable(inputs, S3get_variableID(split)));
00083 
00084         for (i = 0; i < nobs; i++) {
00085             if (nna > 0) {
00086                 if (i_in_set(i + 1, iwhichNA, nna)) continue;
00087             }
00088             if (levelset[ix[i] - 1])
00089                 leftweights[i] = dweights[i];
00090             else 
00091                 leftweights[i] = 0.0;
00092             rightweights[i] = dweights[i] - leftweights[i];
00093             sleft += leftweights[i];
00094             sright += rightweights[i];
00095         }
00096     }
00097     
00098     /* for the moment: NA's go with majority */
00099     if (nna > 0) {
00100         for (i = 0; i < nna; i++) {
00101             if (sleft > sright) {
00102                 leftweights[iwhichNA[i] - 1] = dweights[iwhichNA[i] - 1];
00103                 rightweights[iwhichNA[i] - 1] = 0.0;
00104             } else {
00105                 rightweights[iwhichNA[i] - 1] = dweights[iwhichNA[i] - 1];
00106                 leftweights[iwhichNA[i] - 1] = 0.0;
00107             }
00108         }
00109     }
00110 }
00111 
00112 
00122 SEXP C_get_node(SEXP subtree, SEXP newinputs, 
00123                 double mincriterion, int numobs) {
00124 
00125     SEXP split, whichNA, weights, ssplit, surrsplit;
00126     double cutpoint, x, *dweights, swleft, swright;
00127     int level, *levelset, i, ns;
00128 
00129     if (S3get_nodeterminal(subtree) || 
00130         REAL(S3get_maxcriterion(subtree))[0] < mincriterion) 
00131         return(subtree);
00132     
00133     split = S3get_primarysplit(subtree);
00134 
00135     /* missing values. Maybe store the proportions left / 
00136        right in each node? */
00137     if (has_missings(newinputs, S3get_variableID(split))) {
00138         whichNA = get_missings(newinputs, S3get_variableID(split));
00139     
00140         if (C_i_in_set(numobs, whichNA)) {
00141         
00142             surrsplit = S3get_surrogatesplits(subtree);
00143             ns = 0;
00144             i = numobs;      
00145 
00146             /* try to find a surrogate split */
00147             while(TRUE) {
00148     
00149                 if (ns >= LENGTH(surrsplit)) break;
00150             
00151                 ssplit = VECTOR_ELT(surrsplit, ns);
00152                 if (has_missings(newinputs, S3get_variableID(ssplit))) {
00153                     if (INTEGER(get_missings(newinputs, 
00154                                              S3get_variableID(ssplit)))[i]) {
00155                         ns++;
00156                         continue;
00157                     }
00158                 }
00159 
00160                 cutpoint = REAL(S3get_splitpoint(ssplit))[0];
00161                 x = REAL(get_variable(newinputs, S3get_variableID(ssplit)))[i];
00162                      
00163                 if (S3get_toleft(ssplit)) {
00164                     if (x <= cutpoint) {
00165                         return(C_get_node(S3get_leftnode(subtree),
00166                                           newinputs, mincriterion, numobs));
00167                     } else {
00168                         return(C_get_node(S3get_rightnode(subtree),
00169                                newinputs, mincriterion, numobs));
00170                     }
00171                 } else {
00172                     if (x <= cutpoint) {
00173                         return(C_get_node(S3get_rightnode(subtree),
00174                                           newinputs, mincriterion, numobs));
00175                     } else {
00176                         return(C_get_node(S3get_leftnode(subtree),
00177                                newinputs, mincriterion, numobs));
00178                     }
00179                 }
00180                 break;
00181             }
00182 
00183             /* if this was not successful, we go with the majority */
00184             weights = S3get_nodeweights(S3get_leftnode(subtree));
00185             dweights = REAL(weights);
00186             swleft = 0.0;
00187             for (i = 0; i < LENGTH(weights); i++)
00188                 swleft += dweights[i];
00189             weights = S3get_nodeweights(S3get_rightnode(subtree));
00190             dweights = REAL(weights);
00191             swright = 0.0;
00192             for (i = 0; i < LENGTH(weights); i++)
00193                 swright += dweights[i];
00194             if (swleft > swright) {
00195                 return(C_get_node(S3get_leftnode(subtree), 
00196                                   newinputs, mincriterion, numobs));
00197             } else {
00198                 return(C_get_node(S3get_rightnode(subtree), 
00199                                   newinputs, mincriterion, numobs));
00200             }
00201         }
00202     }
00203     
00204     if (S3is_ordered(split)) {
00205         cutpoint = REAL(S3get_splitpoint(split))[0];
00206         x = REAL(get_variable(newinputs, 
00207                      S3get_variableID(split)))[numobs];
00208         if (x <= cutpoint) {
00209             return(C_get_node(S3get_leftnode(subtree), 
00210                               newinputs, mincriterion, numobs));
00211         } else {
00212             return(C_get_node(S3get_rightnode(subtree), 
00213                               newinputs, mincriterion, numobs));
00214         }
00215     } else {
00216         levelset = INTEGER(S3get_splitpoint(split));
00217         level = INTEGER(get_variable(newinputs, 
00218                             S3get_variableID(split)))[numobs];
00219         /* level is in 1, ..., K */
00220         if (levelset[level - 1]) {
00221             return(C_get_node(S3get_leftnode(subtree), newinputs, 
00222                               mincriterion, numobs));
00223         } else {
00224             return(C_get_node(S3get_rightnode(subtree), newinputs, 
00225                               mincriterion, numobs));
00226         }
00227     }
00228 }
00229 
00230 
00239 SEXP R_get_node(SEXP subtree, SEXP newinputs, SEXP mincriterion, 
00240                 SEXP numobs) {
00241     return(C_get_node(subtree, newinputs, REAL(mincriterion)[0],
00242                       INTEGER(numobs)[0] - 1));
00243 }
00244 
00245 
00252 SEXP C_get_nodebynum(SEXP subtree, int nodenum) {
00253     
00254     if (nodenum == S3get_nodeID(subtree)) return(subtree);
00255 
00256     if (S3get_nodeterminal(subtree)) 
00257         error("no node with number %d\n", nodenum);
00258 
00259     if (nodenum < S3get_nodeID(S3get_rightnode(subtree))) {
00260         return(C_get_nodebynum(S3get_leftnode(subtree), nodenum));
00261     } else {
00262         return(C_get_nodebynum(S3get_rightnode(subtree), nodenum));
00263     }
00264 }
00265 
00266 
00273 SEXP R_get_nodebynum(SEXP subtree, SEXP nodenum) {
00274     return(C_get_nodebynum(subtree, INTEGER(nodenum)[0]));
00275 }
00276 
00277 
00286 SEXP C_get_prediction(SEXP subtree, SEXP newinputs, 
00287                       double mincriterion, int numobs) {
00288     return(S3get_prediction(C_get_node(subtree, newinputs, 
00289                             mincriterion, numobs)));
00290 }
00291 
00292 
00301 SEXP C_get_nodeweights(SEXP subtree, SEXP newinputs, 
00302                        double mincriterion, int numobs) {
00303     return(S3get_nodeweights(C_get_node(subtree, newinputs, 
00304                              mincriterion, numobs)));
00305 }
00306 
00307 
00316 int C_get_nodeID(SEXP subtree, SEXP newinputs,
00317                   double mincriterion, int numobs) {
00318      return(S3get_nodeID(C_get_node(subtree, newinputs, 
00319             mincriterion, numobs)));
00320 }
00321 
00322 
00330 SEXP R_get_nodeID(SEXP tree, SEXP newinputs, SEXP mincriterion) {
00331 
00332     SEXP ans;
00333     int nobs, i, *dans;
00334             
00335     nobs = get_nobs(newinputs);
00336     PROTECT(ans = allocVector(INTSXP, nobs));
00337     dans = INTEGER(ans);
00338     for (i = 0; i < nobs; i++)
00339          dans[i] = C_get_nodeID(tree, newinputs, REAL(mincriterion)[0], i);
00340     UNPROTECT(1);
00341     return(ans);
00342 }
00343 
00344 
00353 void C_predict(SEXP tree, SEXP newinputs, double mincriterion, SEXP ans) {
00354     
00355     int nobs, i;
00356     
00357     nobs = get_nobs(newinputs);    
00358     if (LENGTH(ans) != nobs) 
00359         error("ans is not of length %d\n", nobs);
00360         
00361     for (i = 0; i < nobs; i++)
00362         SET_VECTOR_ELT(ans, i, C_get_prediction(tree, newinputs, 
00363                        mincriterion, i));
00364 }
00365 
00366 
00374 SEXP R_predict(SEXP tree, SEXP newinputs, SEXP mincriterion) {
00375 
00376     SEXP ans;
00377     int nobs;
00378     
00379     nobs = get_nobs(newinputs);
00380     PROTECT(ans = allocVector(VECSXP, nobs));
00381     C_predict(tree, newinputs, REAL(mincriterion)[0], ans);
00382     UNPROTECT(1);
00383     return(ans);
00384 }
00385 
00386 
00394 void C_getpredictions(SEXP tree, SEXP where, SEXP ans) {
00395 
00396     int nobs, i, *iwhere;
00397     
00398     nobs = LENGTH(where);
00399     iwhere = INTEGER(where);
00400     if (LENGTH(ans) != nobs)
00401         error("ans is not of length %d\n", nobs);
00402         
00403     for (i = 0; i < nobs; i++)
00404         SET_VECTOR_ELT(ans, i, S3get_prediction(
00405             C_get_nodebynum(tree, iwhere[i])));
00406 }
00407 
00408 
00415 SEXP R_getpredictions(SEXP tree, SEXP where) {
00416 
00417     SEXP ans;
00418     int nobs;
00419             
00420     nobs = LENGTH(where);
00421     PROTECT(ans = allocVector(VECSXP, nobs));
00422     C_getpredictions(tree, where, ans);
00423     UNPROTECT(1);
00424     return(ans);
00425 }                        
00426 
00427 
00435 void C_getweights(SEXP tree, SEXP where, SEXP ans) {
00436 
00437     int nobs, i, *iwhere;
00438     
00439     nobs = LENGTH(where);
00440     iwhere = INTEGER(where);
00441     if (LENGTH(ans) != nobs)
00442         error("ans is not of length %d\n", nobs);
00443         
00444     for (i = 0; i < nobs; i++)
00445         SET_VECTOR_ELT(ans, i, S3get_nodeweights(
00446             C_get_nodebynum(tree, iwhere[i])));
00447 }
00448 
00449 
00456 SEXP R_getweights(SEXP tree, SEXP where) {
00457 
00458     SEXP ans;
00459     int nobs;
00460             
00461     nobs = LENGTH(where);
00462     PROTECT(ans = allocVector(VECSXP, nobs));
00463     C_getweights(tree, where, ans);
00464     UNPROTECT(1);
00465     return(ans);
00466 }                        
00467 
00468 
00477 void C_weights(SEXP tree, SEXP newinputs, 
00478                double mincriterion, SEXP ans) {
00479     
00480     int nobs, i;
00481     
00482     nobs = get_nobs(newinputs);    
00483     if (LENGTH(ans) != nobs) 
00484         error("ans is not of length %d\n", nobs);
00485         
00486     for (i = 0; i < nobs; i++)
00487         SET_VECTOR_ELT(ans, i, C_get_nodeweights(tree, newinputs, 
00488                        mincriterion, i));
00489 }
00490 
00491 
00499 SEXP R_weights(SEXP tree, SEXP newinputs, SEXP mincriterion) {
00500 
00501     SEXP ans;
00502     int nobs;
00503     
00504     nobs = get_nobs(newinputs);
00505     PROTECT(ans = allocVector(VECSXP, nobs));
00506     C_weights(tree, newinputs, REAL(mincriterion)[0], ans);
00507     UNPROTECT(1);
00508     return(ans);
00509 }
00510 
00511 
00520 SEXP R_predictRF(SEXP forest, SEXP newinputs, SEXP mincriterion, SEXP oobpred) {
00521 
00522     SEXP ans, tmp, tree;
00523     int ntrees, nobs, i, b, j, q, iwhere, oob = 0, count = 0;
00524     
00525     if (LOGICAL(oobpred)[0]) oob = 1;
00526     
00527     nobs = get_nobs(newinputs);
00528     ntrees = LENGTH(forest);
00529     q = LENGTH(S3get_prediction(
00530                    C_get_nodebynum(VECTOR_ELT(forest, 0), 1)));
00531 
00532     if (oob) {
00533         if (LENGTH(S3get_nodeweights(
00534                        C_get_nodebynum(VECTOR_ELT(forest, 0), 1))) != nobs)
00535             error("number of observations don't match");
00536     }    
00537     
00538     PROTECT(ans = allocVector(VECSXP, nobs));
00539     
00540     for (i = 0; i < nobs; i++) {
00541         count = 0;
00542         SET_VECTOR_ELT(ans, i, allocVector(REALSXP, q));
00543         for (j = 0; j < q; j++)
00544                     REAL(VECTOR_ELT(ans, i))[j] = 0.0;
00545         for (b = 0; b < ntrees; b++) {
00546             tree = VECTOR_ELT(forest, b);
00547 
00548             if (oob && 
00549                 REAL(S3get_nodeweights(C_get_nodebynum(tree, 1)))[i] > 0.0) 
00550                 continue;
00551 
00552             iwhere = C_get_nodeID(tree, newinputs, REAL(mincriterion)[0], i);
00553             tmp = S3get_prediction(C_get_nodebynum(tree, iwhere));
00554             for (j = 0; j < q; j++)
00555                 REAL(VECTOR_ELT(ans, i))[j] += REAL(tmp)[j];
00556             count++;
00557         }
00558         if (count == 0) 
00559             error("cannot compute out-of-bag predictions for obs ", i + 1);
00560         for (j = 0; j < q; j++)
00561             REAL(VECTOR_ELT(ans, i))[j] = REAL(VECTOR_ELT(ans, i))[j] / count;
00562     }
00563     UNPROTECT(1);
00564     return(ans);
00565 }
00566 
00576 SEXP R_predictRF2(SEXP forest, SEXP response, SEXP newinputs, 
00577                   SEXP mincriterion, SEXP oobpred) {
00578 
00579     SEXP ans, tmp, tree, w;
00580     int ntrees, nobs, i, b, j, q, n, iwhere, oob = 0;
00581     double *dtmp, *dw, sumw = 0.0;
00582 
00583     if (LOGICAL(oobpred)[0]) oob = 1;
00584     
00585     nobs = get_nobs(newinputs);
00586     ntrees = LENGTH(forest);
00587     n = nrow(response);
00588     q = ncol(response);
00589 
00590     if (oob) {
00591         if (n != nobs)
00592             error("number of observations don't match");
00593     }    
00594     
00595     PROTECT(ans = allocVector(VECSXP, nobs));
00596     PROTECT(w = allocMatrix(REALSXP, 1, n));
00597     dw = REAL(w);
00598     
00599     for (i = 0; i < nobs; i++) {
00600 
00601         SET_VECTOR_ELT(ans, i, allocVector(REALSXP, q));
00602         for (j = 0; j < n; j++)
00603             dw[j] = 0.0;
00604 
00605         for (b = 0; b < ntrees; b++) {
00606             tree = VECTOR_ELT(forest, b);
00607 
00608             if (oob && 
00609                 REAL(S3get_nodeweights(C_get_nodebynum(tree, 1)))[i] > 0.0) 
00610                 continue;
00611 
00612             iwhere = C_get_nodeID(tree, newinputs, REAL(mincriterion)[0], i);
00613             tmp = S3get_nodeweights(C_get_nodebynum(tree, iwhere));
00614             dtmp = REAL(tmp);
00615             
00616             for (j = 0; j < n; j++)
00617                 dw[j] += dtmp[j];
00618         }
00619         
00620         C_matprod(dw, 1, n, REAL(response), n, q, REAL(VECTOR_ELT(ans, i)));
00621 
00622         sumw = 0.0;
00623         for (j = 0; j < n; j++)
00624             sumw += dw[j];
00625 
00626         for (j = 0; j < q; j++)
00627             REAL(VECTOR_ELT(ans, i))[j] = REAL(VECTOR_ELT(ans, i))[j] / sumw;
00628     }
00629     UNPROTECT(2);
00630     return(ans);
00631 }
00632 
00641 SEXP R_predictRF_weights(SEXP forest, SEXP newinputs, SEXP mincriterion, SEXP oobpred) {
00642 
00643     SEXP ans, tree, bw;
00644     int ntrees, nobs, i, b, j, q, iwhere, oob = 0, count = 0, ntrain;
00645     double *dtmp;
00646     
00647     if (LOGICAL(oobpred)[0]) oob = 1;
00648     
00649     nobs = get_nobs(newinputs);
00650     ntrees = LENGTH(forest);
00651     q = LENGTH(S3get_prediction(
00652                    C_get_nodebynum(VECTOR_ELT(forest, 0), 1)));
00653 
00654     if (oob) {
00655         if (LENGTH(S3get_nodeweights(
00656                        C_get_nodebynum(VECTOR_ELT(forest, 0), 1))) != nobs)
00657             error("number of observations don't match");
00658     }    
00659     
00660     tree = VECTOR_ELT(forest, 0);
00661     ntrain = LENGTH(S3get_nodeweights(C_get_nodebynum(tree, 1)));
00662     
00663     PROTECT(ans = allocVector(VECSXP, nobs));
00664     
00665     for (i = 0; i < nobs; i++) {
00666         count = 0;
00667         SET_VECTOR_ELT(ans, i, bw = allocVector(REALSXP, ntrain));
00668         for (j = 0; j < ntrain; j++)
00669             REAL(bw)[j] = 0.0;
00670         for (b = 0; b < ntrees; b++) {
00671             tree = VECTOR_ELT(forest, b);
00672 
00673             if (oob && 
00674                 REAL(S3get_nodeweights(C_get_nodebynum(tree, 1)))[i] > 0.0) 
00675                 continue;
00676 
00677             iwhere = C_get_nodeID(tree, newinputs, REAL(mincriterion)[0], i);
00678             dtmp = REAL(S3get_nodeweights(C_get_nodebynum(tree, iwhere)));
00679             for (j = 0; j < ntrain; j++)
00680                 REAL(bw)[j] += dtmp[j];
00681             count++;
00682         }
00683         if (count == 0) 
00684             error("cannot compute out-of-bag predictions for obs ", i + 1);
00685     }
00686     UNPROTECT(1);
00687     return(ans);
00688 }

Generated on Fri Aug 25 14:30:01 2006 for party by  doxygen 1.4.4