00001
00009 #include "party.h"
00010
00011
00021 void C_splitnode(SEXP node, SEXP learnsample, SEXP control) {
00022
00023 SEXP weights, leftnode, rightnode, split;
00024 SEXP responses, inputs, whichNA;
00025 double cutpoint, *dx, *dweights, *leftweights, *rightweights;
00026 double sleft = 0.0, sright = 0.0;
00027 int *ix, *levelset, *iwhichNA;
00028 int nobs, i, nna;
00029
00030 weights = S3get_nodeweights(node);
00031 dweights = REAL(weights);
00032 responses = GET_SLOT(learnsample, PL2_responsesSym);
00033 inputs = GET_SLOT(learnsample, PL2_inputsSym);
00034 nobs = get_nobs(learnsample);
00035
00036
00037 SET_VECTOR_ELT(node, S3_LEFT, leftnode = allocVector(VECSXP, NODE_LENGTH));
00038 C_init_node(leftnode, nobs,
00039 get_ninputs(learnsample), get_maxsurrogate(get_splitctrl(control)),
00040 ncol(get_predict_trafo(GET_SLOT(learnsample, PL2_responsesSym))));
00041 leftweights = REAL(S3get_nodeweights(leftnode));
00042
00043
00044 SET_VECTOR_ELT(node, S3_RIGHT,
00045 rightnode = allocVector(VECSXP, NODE_LENGTH));
00046 C_init_node(rightnode, nobs,
00047 get_ninputs(learnsample), get_maxsurrogate(get_splitctrl(control)),
00048 ncol(get_predict_trafo(GET_SLOT(learnsample, PL2_responsesSym))));
00049 rightweights = REAL(S3get_nodeweights(rightnode));
00050
00051
00052 split = S3get_primarysplit(node);
00053 if (has_missings(inputs, S3get_variableID(split))) {
00054 whichNA = get_missings(inputs, S3get_variableID(split));
00055 iwhichNA = INTEGER(whichNA);
00056 nna = LENGTH(whichNA);
00057 } else {
00058 nna = 0;
00059 whichNA = R_NilValue;
00060 iwhichNA = NULL;
00061 }
00062
00063 if (S3is_ordered(split)) {
00064 cutpoint = REAL(S3get_splitpoint(split))[0];
00065 dx = REAL(get_variable(inputs, S3get_variableID(split)));
00066 for (i = 0; i < nobs; i++) {
00067 if (nna > 0) {
00068 if (i_in_set(i + 1, iwhichNA, nna)) continue;
00069 }
00070 if (dx[i] <= cutpoint)
00071 leftweights[i] = dweights[i];
00072 else
00073 leftweights[i] = 0.0;
00074 rightweights[i] = dweights[i] - leftweights[i];
00075 sleft += leftweights[i];
00076 sright += rightweights[i];
00077 }
00078 } else {
00079 levelset = INTEGER(S3get_splitpoint(split));
00080 ix = INTEGER(get_variable(inputs, S3get_variableID(split)));
00081
00082 for (i = 0; i < nobs; i++) {
00083 if (nna > 0) {
00084 if (i_in_set(i + 1, iwhichNA, nna)) continue;
00085 }
00086 if (levelset[ix[i] - 1])
00087 leftweights[i] = dweights[i];
00088 else
00089 leftweights[i] = 0.0;
00090 rightweights[i] = dweights[i] - leftweights[i];
00091 sleft += leftweights[i];
00092 sright += rightweights[i];
00093 }
00094 }
00095
00096
00097 if (nna > 0) {
00098 for (i = 0; i < nna; i++) {
00099 if (sleft > sright) {
00100 leftweights[iwhichNA[i] - 1] = dweights[iwhichNA[i] - 1];
00101 rightweights[iwhichNA[i] - 1] = 0.0;
00102 } else {
00103 rightweights[iwhichNA[i] - 1] = dweights[iwhichNA[i] - 1];
00104 leftweights[iwhichNA[i] - 1] = 0.0;
00105 }
00106 }
00107 }
00108 }
00109
00110
00120 SEXP C_get_node(SEXP subtree, SEXP newinputs,
00121 double mincriterion, int numobs) {
00122
00123 SEXP split, whichNA, weights, ssplit, surrsplit;
00124 double cutpoint, x, *dweights, swleft, swright;
00125 int level, *levelset, i, ns;
00126
00127 if (S3get_nodeterminal(subtree) ||
00128 REAL(S3get_maxcriterion(subtree))[0] < mincriterion)
00129 return(subtree);
00130
00131 split = S3get_primarysplit(subtree);
00132
00133
00134
00135 if (has_missings(newinputs, S3get_variableID(split))) {
00136 whichNA = get_missings(newinputs, S3get_variableID(split));
00137
00138
00139 if (C_i_in_set(numobs + 1, whichNA)) {
00140
00141 surrsplit = S3get_surrogatesplits(subtree);
00142 ns = 0;
00143 i = numobs;
00144
00145
00146 while(TRUE) {
00147
00148 if (ns >= LENGTH(surrsplit)) break;
00149
00150 ssplit = VECTOR_ELT(surrsplit, ns);
00151 if (has_missings(newinputs, S3get_variableID(ssplit))) {
00152 if (INTEGER(get_missings(newinputs,
00153 S3get_variableID(ssplit)))[i]) {
00154 ns++;
00155 continue;
00156 }
00157 }
00158
00159 cutpoint = REAL(S3get_splitpoint(ssplit))[0];
00160 x = REAL(get_variable(newinputs, S3get_variableID(ssplit)))[i];
00161
00162 if (S3get_toleft(ssplit)) {
00163 if (x <= cutpoint) {
00164 return(C_get_node(S3get_leftnode(subtree),
00165 newinputs, mincriterion, numobs));
00166 } else {
00167 return(C_get_node(S3get_rightnode(subtree),
00168 newinputs, mincriterion, numobs));
00169 }
00170 } else {
00171 if (x <= cutpoint) {
00172 return(C_get_node(S3get_rightnode(subtree),
00173 newinputs, mincriterion, numobs));
00174 } else {
00175 return(C_get_node(S3get_leftnode(subtree),
00176 newinputs, mincriterion, numobs));
00177 }
00178 }
00179 break;
00180 }
00181
00182
00183 swleft = S3get_sumweights(S3get_leftnode(subtree));
00184 swright = S3get_sumweights(S3get_rightnode(subtree));
00185 if (swleft > swright) {
00186 return(C_get_node(S3get_leftnode(subtree),
00187 newinputs, mincriterion, numobs));
00188 } else {
00189 return(C_get_node(S3get_rightnode(subtree),
00190 newinputs, mincriterion, numobs));
00191 }
00192 }
00193 }
00194
00195 if (S3is_ordered(split)) {
00196 cutpoint = REAL(S3get_splitpoint(split))[0];
00197 x = REAL(get_variable(newinputs,
00198 S3get_variableID(split)))[numobs];
00199 if (x <= cutpoint) {
00200 return(C_get_node(S3get_leftnode(subtree),
00201 newinputs, mincriterion, numobs));
00202 } else {
00203 return(C_get_node(S3get_rightnode(subtree),
00204 newinputs, mincriterion, numobs));
00205 }
00206 } else {
00207 levelset = INTEGER(S3get_splitpoint(split));
00208 level = INTEGER(get_variable(newinputs,
00209 S3get_variableID(split)))[numobs];
00210
00211 if (levelset[level - 1]) {
00212 return(C_get_node(S3get_leftnode(subtree), newinputs,
00213 mincriterion, numobs));
00214 } else {
00215 return(C_get_node(S3get_rightnode(subtree), newinputs,
00216 mincriterion, numobs));
00217 }
00218 }
00219 }
00220
00221
00230 SEXP R_get_node(SEXP subtree, SEXP newinputs, SEXP mincriterion,
00231 SEXP numobs) {
00232 return(C_get_node(subtree, newinputs, REAL(mincriterion)[0],
00233 INTEGER(numobs)[0] - 1));
00234 }
00235
00236
00243 SEXP C_get_nodebynum(SEXP subtree, int nodenum) {
00244
00245 if (nodenum == S3get_nodeID(subtree)) return(subtree);
00246
00247 if (S3get_nodeterminal(subtree))
00248 error("no node with number %d\n", nodenum);
00249
00250 if (nodenum < S3get_nodeID(S3get_rightnode(subtree))) {
00251 return(C_get_nodebynum(S3get_leftnode(subtree), nodenum));
00252 } else {
00253 return(C_get_nodebynum(S3get_rightnode(subtree), nodenum));
00254 }
00255 }
00256
00257
00264 SEXP R_get_nodebynum(SEXP subtree, SEXP nodenum) {
00265 return(C_get_nodebynum(subtree, INTEGER(nodenum)[0]));
00266 }
00267
00268
00277 SEXP C_get_prediction(SEXP subtree, SEXP newinputs,
00278 double mincriterion, int numobs) {
00279 return(S3get_prediction(C_get_node(subtree, newinputs,
00280 mincriterion, numobs)));
00281 }
00282
00283
00292 SEXP C_get_nodeweights(SEXP subtree, SEXP newinputs,
00293 double mincriterion, int numobs) {
00294 return(S3get_nodeweights(C_get_node(subtree, newinputs,
00295 mincriterion, numobs)));
00296 }
00297
00298
00307 int C_get_nodeID(SEXP subtree, SEXP newinputs,
00308 double mincriterion, int numobs) {
00309 return(S3get_nodeID(C_get_node(subtree, newinputs,
00310 mincriterion, numobs)));
00311 }
00312
00313
00321 SEXP R_get_nodeID(SEXP tree, SEXP newinputs, SEXP mincriterion) {
00322
00323 SEXP ans;
00324 int nobs, i, *dans;
00325
00326 nobs = get_nobs(newinputs);
00327 PROTECT(ans = allocVector(INTSXP, nobs));
00328 dans = INTEGER(ans);
00329 for (i = 0; i < nobs; i++)
00330 dans[i] = C_get_nodeID(tree, newinputs, REAL(mincriterion)[0], i);
00331 UNPROTECT(1);
00332 return(ans);
00333 }
00334
00335
00344 void C_predict(SEXP tree, SEXP newinputs, double mincriterion, SEXP ans) {
00345
00346 int nobs, i;
00347
00348 nobs = get_nobs(newinputs);
00349 if (LENGTH(ans) != nobs)
00350 error("ans is not of length %d\n", nobs);
00351
00352 for (i = 0; i < nobs; i++)
00353 SET_VECTOR_ELT(ans, i, C_get_prediction(tree, newinputs,
00354 mincriterion, i));
00355 }
00356
00357
00365 SEXP R_predict(SEXP tree, SEXP newinputs, SEXP mincriterion) {
00366
00367 SEXP ans;
00368 int nobs;
00369
00370 nobs = get_nobs(newinputs);
00371 PROTECT(ans = allocVector(VECSXP, nobs));
00372 C_predict(tree, newinputs, REAL(mincriterion)[0], ans);
00373 UNPROTECT(1);
00374 return(ans);
00375 }
00376
00377
00385 void C_getpredictions(SEXP tree, SEXP where, SEXP ans) {
00386
00387 int nobs, i, *iwhere;
00388
00389 nobs = LENGTH(where);
00390 iwhere = INTEGER(where);
00391 if (LENGTH(ans) != nobs)
00392 error("ans is not of length %d\n", nobs);
00393
00394 for (i = 0; i < nobs; i++)
00395 SET_VECTOR_ELT(ans, i, S3get_prediction(
00396 C_get_nodebynum(tree, iwhere[i])));
00397 }
00398
00399
00406 SEXP R_getpredictions(SEXP tree, SEXP where) {
00407
00408 SEXP ans;
00409 int nobs;
00410
00411 nobs = LENGTH(where);
00412 PROTECT(ans = allocVector(VECSXP, nobs));
00413 C_getpredictions(tree, where, ans);
00414 UNPROTECT(1);
00415 return(ans);
00416 }
00417
00428 SEXP R_predictRF_weights(SEXP forest, SEXP where, SEXP weights,
00429 SEXP newinputs, SEXP mincriterion, SEXP oobpred) {
00430
00431 SEXP ans, tree, bw;
00432 int ntrees, nobs, i, b, j, q, iwhere, oob = 0, count = 0, ntrain;
00433
00434 if (LOGICAL(oobpred)[0]) oob = 1;
00435
00436 nobs = get_nobs(newinputs);
00437 ntrees = LENGTH(forest);
00438 q = LENGTH(S3get_prediction(
00439 C_get_nodebynum(VECTOR_ELT(forest, 0), 1)));
00440
00441 if (oob) {
00442 if (LENGTH(VECTOR_ELT(weights, 0)) != nobs)
00443 error("number of observations don't match");
00444 }
00445
00446 tree = VECTOR_ELT(forest, 0);
00447 ntrain = LENGTH(VECTOR_ELT(weights, 0));
00448
00449 PROTECT(ans = allocVector(VECSXP, nobs));
00450
00451 for (i = 0; i < nobs; i++) {
00452 count = 0;
00453 SET_VECTOR_ELT(ans, i, bw = allocVector(REALSXP, ntrain));
00454 for (j = 0; j < ntrain; j++)
00455 REAL(bw)[j] = 0.0;
00456 for (b = 0; b < ntrees; b++) {
00457 tree = VECTOR_ELT(forest, b);
00458
00459 if (oob &&
00460 REAL(VECTOR_ELT(weights, b))[i] > 0.0)
00461 continue;
00462
00463 iwhere = C_get_nodeID(tree, newinputs, REAL(mincriterion)[0], i);
00464
00465 for (j = 0; j < ntrain; j++) {
00466 if (iwhere == INTEGER(VECTOR_ELT(where, b))[j])
00467 REAL(bw)[j] += REAL(VECTOR_ELT(weights, b))[j];
00468 }
00469 count++;
00470 }
00471 if (count == 0)
00472 error("cannot compute out-of-bag predictions for obs ", i + 1);
00473 }
00474 UNPROTECT(1);
00475 return(ans);
00476 }
00477
00478
00484 SEXP R_proximity(SEXP where) {
00485
00486 SEXP ans, bw, bin;
00487 int ntrees, nobs, i, b, j, iwhere;
00488
00489 ntrees = LENGTH(where);
00490 nobs = LENGTH(VECTOR_ELT(where, 0));
00491
00492 PROTECT(ans = allocVector(VECSXP, nobs));
00493 PROTECT(bin = allocVector(INTSXP, nobs));
00494
00495 for (i = 0; i < nobs; i++) {
00496 SET_VECTOR_ELT(ans, i, bw = allocVector(REALSXP, nobs));
00497 for (j = 0; j < nobs; j++) {
00498 REAL(bw)[j] = 0.0;
00499 INTEGER(bin)[j] = 0;
00500 }
00501 for (b = 0; b < ntrees; b++) {
00502
00503 if (INTEGER(VECTOR_ELT(where, b))[i] == 0)
00504 continue;
00505 iwhere = INTEGER(VECTOR_ELT(where, b))[i];
00506 for (j = 0; j < nobs; j++) {
00507 if (iwhere == INTEGER(VECTOR_ELT(where, b))[j])
00508
00509 REAL(bw)[j]++;
00510 if (INTEGER(VECTOR_ELT(where, b))[j] > 0)
00511
00512
00513 INTEGER(bin)[j]++;
00514 }
00515 }
00516 for (j = 0; j < nobs; j++)
00517 REAL(bw)[j] = REAL(bw)[j] / INTEGER(bin)[j];
00518 }
00519 UNPROTECT(2);
00520 return(ans);
00521 }