00001
00009 #include "party.h"
00010
00011
00021 void C_splitnode(SEXP node, SEXP learnsample, SEXP control) {
00022
00023 SEXP weights, leftnode, rightnode, split;
00024 SEXP responses, inputs, whichNA;
00025 double cutpoint, *dx, *dweights, *leftweights, *rightweights;
00026 double sleft = 0.0, sright = 0.0;
00027 int *ix, *levelset, *iwhichNA;
00028 int nobs, i, nna;
00029
00030 weights = S3get_nodeweights(node);
00031 dweights = REAL(weights);
00032 responses = GET_SLOT(learnsample, PL2_responsesSym);
00033 inputs = GET_SLOT(learnsample, PL2_inputsSym);
00034 nobs = get_nobs(learnsample);
00035
00036
00037 SET_VECTOR_ELT(node, S3_LEFT, leftnode = allocVector(VECSXP, NODE_LENGTH));
00038 C_init_node(leftnode, nobs,
00039 get_ninputs(learnsample), get_maxsurrogate(get_splitctrl(control)),
00040 ncol(GET_SLOT(GET_SLOT(learnsample, PL2_responsesSym),
00041 PL2_jointtransfSym)));
00042 leftweights = REAL(S3get_nodeweights(leftnode));
00043
00044
00045 SET_VECTOR_ELT(node, S3_RIGHT,
00046 rightnode = allocVector(VECSXP, NODE_LENGTH));
00047 C_init_node(rightnode, nobs,
00048 get_ninputs(learnsample), get_maxsurrogate(get_splitctrl(control)),
00049 ncol(GET_SLOT(GET_SLOT(learnsample, PL2_responsesSym),
00050 PL2_jointtransfSym)));
00051 rightweights = REAL(S3get_nodeweights(rightnode));
00052
00053
00054 split = S3get_primarysplit(node);
00055 if (has_missings(inputs, S3get_variableID(split))) {
00056 whichNA = get_missings(inputs, S3get_variableID(split));
00057 iwhichNA = INTEGER(whichNA);
00058 nna = LENGTH(whichNA);
00059 } else {
00060 nna = 0;
00061 whichNA = R_NilValue;
00062 iwhichNA = NULL;
00063 }
00064
00065 if (S3is_ordered(split)) {
00066 cutpoint = REAL(S3get_splitpoint(split))[0];
00067 dx = REAL(get_variable(inputs, S3get_variableID(split)));
00068 for (i = 0; i < nobs; i++) {
00069 if (nna > 0) {
00070 if (i_in_set(i + 1, iwhichNA, nna)) continue;
00071 }
00072 if (dx[i] <= cutpoint)
00073 leftweights[i] = dweights[i];
00074 else
00075 leftweights[i] = 0.0;
00076 rightweights[i] = dweights[i] - leftweights[i];
00077 sleft += leftweights[i];
00078 sright += rightweights[i];
00079 }
00080 } else {
00081 levelset = INTEGER(S3get_splitpoint(split));
00082 ix = INTEGER(get_variable(inputs, S3get_variableID(split)));
00083
00084 for (i = 0; i < nobs; i++) {
00085 if (nna > 0) {
00086 if (i_in_set(i + 1, iwhichNA, nna)) continue;
00087 }
00088 if (levelset[ix[i] - 1])
00089 leftweights[i] = dweights[i];
00090 else
00091 leftweights[i] = 0.0;
00092 rightweights[i] = dweights[i] - leftweights[i];
00093 sleft += leftweights[i];
00094 sright += rightweights[i];
00095 }
00096 }
00097
00098
00099 if (nna > 0) {
00100 for (i = 0; i < nna; i++) {
00101 if (sleft > sright) {
00102 leftweights[iwhichNA[i] - 1] = dweights[iwhichNA[i] - 1];
00103 rightweights[iwhichNA[i] - 1] = 0.0;
00104 } else {
00105 rightweights[iwhichNA[i] - 1] = dweights[iwhichNA[i] - 1];
00106 leftweights[iwhichNA[i] - 1] = 0.0;
00107 }
00108 }
00109 }
00110 }
00111
00112
00122 SEXP C_get_node(SEXP subtree, SEXP newinputs,
00123 double mincriterion, int numobs) {
00124
00125 SEXP split, whichNA, weights, ssplit, surrsplit;
00126 double cutpoint, x, *dweights, swleft, swright;
00127 int level, *levelset, i, ns;
00128
00129 if (S3get_nodeterminal(subtree) ||
00130 REAL(S3get_maxcriterion(subtree))[0] < mincriterion)
00131 return(subtree);
00132
00133 split = S3get_primarysplit(subtree);
00134
00135
00136
00137 if (has_missings(newinputs, S3get_variableID(split))) {
00138 whichNA = get_missings(newinputs, S3get_variableID(split));
00139
00140 if (C_i_in_set(numobs, whichNA)) {
00141
00142 surrsplit = S3get_surrogatesplits(subtree);
00143 ns = 0;
00144 i = numobs;
00145
00146
00147 while(TRUE) {
00148
00149 if (ns >= LENGTH(surrsplit)) break;
00150
00151 ssplit = VECTOR_ELT(surrsplit, ns);
00152 if (has_missings(newinputs, S3get_variableID(ssplit))) {
00153 if (INTEGER(get_missings(newinputs,
00154 S3get_variableID(ssplit)))[i]) {
00155 ns++;
00156 continue;
00157 }
00158 }
00159
00160 cutpoint = REAL(S3get_splitpoint(ssplit))[0];
00161 x = REAL(get_variable(newinputs, S3get_variableID(ssplit)))[i];
00162
00163 if (S3get_toleft(ssplit)) {
00164 if (x <= cutpoint) {
00165 return(C_get_node(S3get_leftnode(subtree),
00166 newinputs, mincriterion, numobs));
00167 } else {
00168 return(C_get_node(S3get_rightnode(subtree),
00169 newinputs, mincriterion, numobs));
00170 }
00171 } else {
00172 if (x <= cutpoint) {
00173 return(C_get_node(S3get_rightnode(subtree),
00174 newinputs, mincriterion, numobs));
00175 } else {
00176 return(C_get_node(S3get_leftnode(subtree),
00177 newinputs, mincriterion, numobs));
00178 }
00179 }
00180 break;
00181 }
00182
00183
00184 weights = S3get_nodeweights(S3get_leftnode(subtree));
00185 dweights = REAL(weights);
00186 swleft = 0.0;
00187 for (i = 0; i < LENGTH(weights); i++)
00188 swleft += dweights[i];
00189 weights = S3get_nodeweights(S3get_rightnode(subtree));
00190 dweights = REAL(weights);
00191 swright = 0.0;
00192 for (i = 0; i < LENGTH(weights); i++)
00193 swright += dweights[i];
00194 if (swleft > swright) {
00195 return(C_get_node(S3get_leftnode(subtree),
00196 newinputs, mincriterion, numobs));
00197 } else {
00198 return(C_get_node(S3get_rightnode(subtree),
00199 newinputs, mincriterion, numobs));
00200 }
00201 }
00202 }
00203
00204 if (S3is_ordered(split)) {
00205 cutpoint = REAL(S3get_splitpoint(split))[0];
00206 x = REAL(get_variable(newinputs,
00207 S3get_variableID(split)))[numobs];
00208 if (x <= cutpoint) {
00209 return(C_get_node(S3get_leftnode(subtree),
00210 newinputs, mincriterion, numobs));
00211 } else {
00212 return(C_get_node(S3get_rightnode(subtree),
00213 newinputs, mincriterion, numobs));
00214 }
00215 } else {
00216 levelset = INTEGER(S3get_splitpoint(split));
00217 level = INTEGER(get_variable(newinputs,
00218 S3get_variableID(split)))[numobs];
00219
00220 if (levelset[level - 1]) {
00221 return(C_get_node(S3get_leftnode(subtree), newinputs,
00222 mincriterion, numobs));
00223 } else {
00224 return(C_get_node(S3get_rightnode(subtree), newinputs,
00225 mincriterion, numobs));
00226 }
00227 }
00228 }
00229
00230
00239 SEXP R_get_node(SEXP subtree, SEXP newinputs, SEXP mincriterion,
00240 SEXP numobs) {
00241 return(C_get_node(subtree, newinputs, REAL(mincriterion)[0],
00242 INTEGER(numobs)[0] - 1));
00243 }
00244
00245
00252 SEXP C_get_nodebynum(SEXP subtree, int nodenum) {
00253
00254 if (nodenum == S3get_nodeID(subtree)) return(subtree);
00255
00256 if (S3get_nodeterminal(subtree))
00257 error("no node with number %d\n", nodenum);
00258
00259 if (nodenum < S3get_nodeID(S3get_rightnode(subtree))) {
00260 return(C_get_nodebynum(S3get_leftnode(subtree), nodenum));
00261 } else {
00262 return(C_get_nodebynum(S3get_rightnode(subtree), nodenum));
00263 }
00264 }
00265
00266
00273 SEXP R_get_nodebynum(SEXP subtree, SEXP nodenum) {
00274 return(C_get_nodebynum(subtree, INTEGER(nodenum)[0]));
00275 }
00276
00277
00286 SEXP C_get_prediction(SEXP subtree, SEXP newinputs,
00287 double mincriterion, int numobs) {
00288 return(S3get_prediction(C_get_node(subtree, newinputs,
00289 mincriterion, numobs)));
00290 }
00291
00292
00301 SEXP C_get_nodeweights(SEXP subtree, SEXP newinputs,
00302 double mincriterion, int numobs) {
00303 return(S3get_nodeweights(C_get_node(subtree, newinputs,
00304 mincriterion, numobs)));
00305 }
00306
00307
00316 int C_get_nodeID(SEXP subtree, SEXP newinputs,
00317 double mincriterion, int numobs) {
00318 return(S3get_nodeID(C_get_node(subtree, newinputs,
00319 mincriterion, numobs)));
00320 }
00321
00322
00330 SEXP R_get_nodeID(SEXP tree, SEXP newinputs, SEXP mincriterion) {
00331
00332 SEXP ans;
00333 int nobs, i, *dans;
00334
00335 nobs = get_nobs(newinputs);
00336 PROTECT(ans = allocVector(INTSXP, nobs));
00337 dans = INTEGER(ans);
00338 for (i = 0; i < nobs; i++)
00339 dans[i] = C_get_nodeID(tree, newinputs, REAL(mincriterion)[0], i);
00340 UNPROTECT(1);
00341 return(ans);
00342 }
00343
00344
00353 void C_predict(SEXP tree, SEXP newinputs, double mincriterion, SEXP ans) {
00354
00355 int nobs, i;
00356
00357 nobs = get_nobs(newinputs);
00358 if (LENGTH(ans) != nobs)
00359 error("ans is not of length %d\n", nobs);
00360
00361 for (i = 0; i < nobs; i++)
00362 SET_VECTOR_ELT(ans, i, C_get_prediction(tree, newinputs,
00363 mincriterion, i));
00364 }
00365
00366
00374 SEXP R_predict(SEXP tree, SEXP newinputs, SEXP mincriterion) {
00375
00376 SEXP ans;
00377 int nobs;
00378
00379 nobs = get_nobs(newinputs);
00380 PROTECT(ans = allocVector(VECSXP, nobs));
00381 C_predict(tree, newinputs, REAL(mincriterion)[0], ans);
00382 UNPROTECT(1);
00383 return(ans);
00384 }
00385
00386
00394 void C_getpredictions(SEXP tree, SEXP where, SEXP ans) {
00395
00396 int nobs, i, *iwhere;
00397
00398 nobs = LENGTH(where);
00399 iwhere = INTEGER(where);
00400 if (LENGTH(ans) != nobs)
00401 error("ans is not of length %d\n", nobs);
00402
00403 for (i = 0; i < nobs; i++)
00404 SET_VECTOR_ELT(ans, i, S3get_prediction(
00405 C_get_nodebynum(tree, iwhere[i])));
00406 }
00407
00408
00415 SEXP R_getpredictions(SEXP tree, SEXP where) {
00416
00417 SEXP ans;
00418 int nobs;
00419
00420 nobs = LENGTH(where);
00421 PROTECT(ans = allocVector(VECSXP, nobs));
00422 C_getpredictions(tree, where, ans);
00423 UNPROTECT(1);
00424 return(ans);
00425 }
00426
00427
00435 void C_getweights(SEXP tree, SEXP where, SEXP ans) {
00436
00437 int nobs, i, *iwhere;
00438
00439 nobs = LENGTH(where);
00440 iwhere = INTEGER(where);
00441 if (LENGTH(ans) != nobs)
00442 error("ans is not of length %d\n", nobs);
00443
00444 for (i = 0; i < nobs; i++)
00445 SET_VECTOR_ELT(ans, i, S3get_nodeweights(
00446 C_get_nodebynum(tree, iwhere[i])));
00447 }
00448
00449
00456 SEXP R_getweights(SEXP tree, SEXP where) {
00457
00458 SEXP ans;
00459 int nobs;
00460
00461 nobs = LENGTH(where);
00462 PROTECT(ans = allocVector(VECSXP, nobs));
00463 C_getweights(tree, where, ans);
00464 UNPROTECT(1);
00465 return(ans);
00466 }
00467
00468
00477 void C_weights(SEXP tree, SEXP newinputs,
00478 double mincriterion, SEXP ans) {
00479
00480 int nobs, i;
00481
00482 nobs = get_nobs(newinputs);
00483 if (LENGTH(ans) != nobs)
00484 error("ans is not of length %d\n", nobs);
00485
00486 for (i = 0; i < nobs; i++)
00487 SET_VECTOR_ELT(ans, i, C_get_nodeweights(tree, newinputs,
00488 mincriterion, i));
00489 }
00490
00491
00499 SEXP R_weights(SEXP tree, SEXP newinputs, SEXP mincriterion) {
00500
00501 SEXP ans;
00502 int nobs;
00503
00504 nobs = get_nobs(newinputs);
00505 PROTECT(ans = allocVector(VECSXP, nobs));
00506 C_weights(tree, newinputs, REAL(mincriterion)[0], ans);
00507 UNPROTECT(1);
00508 return(ans);
00509 }
00510
00511
00520 SEXP R_predictRF(SEXP forest, SEXP newinputs, SEXP mincriterion, SEXP oobpred) {
00521
00522 SEXP ans, tmp, tree;
00523 int ntrees, nobs, i, b, j, q, iwhere, oob = 0, count = 0;
00524
00525 if (LOGICAL(oobpred)[0]) oob = 1;
00526
00527 nobs = get_nobs(newinputs);
00528 ntrees = LENGTH(forest);
00529 q = LENGTH(S3get_prediction(
00530 C_get_nodebynum(VECTOR_ELT(forest, 0), 1)));
00531
00532 if (oob) {
00533 if (LENGTH(S3get_nodeweights(
00534 C_get_nodebynum(VECTOR_ELT(forest, 0), 1))) != nobs)
00535 error("number of observations don't match");
00536 }
00537
00538 PROTECT(ans = allocVector(VECSXP, nobs));
00539
00540 for (i = 0; i < nobs; i++) {
00541 count = 0;
00542 SET_VECTOR_ELT(ans, i, allocVector(REALSXP, q));
00543 for (j = 0; j < q; j++)
00544 REAL(VECTOR_ELT(ans, i))[j] = 0.0;
00545 for (b = 0; b < ntrees; b++) {
00546 tree = VECTOR_ELT(forest, b);
00547
00548 if (oob &&
00549 REAL(S3get_nodeweights(C_get_nodebynum(tree, 1)))[i] > 0.0)
00550 continue;
00551
00552 iwhere = C_get_nodeID(tree, newinputs, REAL(mincriterion)[0], i);
00553 tmp = S3get_prediction(C_get_nodebynum(tree, iwhere));
00554 for (j = 0; j < q; j++)
00555 REAL(VECTOR_ELT(ans, i))[j] += REAL(tmp)[j];
00556 count++;
00557 }
00558 if (count == 0)
00559 error("cannot compute out-of-bag predictions for obs ", i + 1);
00560 for (j = 0; j < q; j++)
00561 REAL(VECTOR_ELT(ans, i))[j] = REAL(VECTOR_ELT(ans, i))[j] / count;
00562 }
00563 UNPROTECT(1);
00564 return(ans);
00565 }
00566
00576 SEXP R_predictRF2(SEXP forest, SEXP response, SEXP newinputs,
00577 SEXP mincriterion, SEXP oobpred) {
00578
00579 SEXP ans, tmp, tree, w;
00580 int ntrees, nobs, i, b, j, q, n, iwhere, oob = 0;
00581 double *dtmp, *dw, sumw = 0.0;
00582
00583 if (LOGICAL(oobpred)[0]) oob = 1;
00584
00585 nobs = get_nobs(newinputs);
00586 ntrees = LENGTH(forest);
00587 n = nrow(response);
00588 q = ncol(response);
00589
00590 if (oob) {
00591 if (n != nobs)
00592 error("number of observations don't match");
00593 }
00594
00595 PROTECT(ans = allocVector(VECSXP, nobs));
00596 PROTECT(w = allocMatrix(REALSXP, 1, n));
00597 dw = REAL(w);
00598
00599 for (i = 0; i < nobs; i++) {
00600
00601 SET_VECTOR_ELT(ans, i, allocVector(REALSXP, q));
00602 for (j = 0; j < n; j++)
00603 dw[j] = 0.0;
00604
00605 for (b = 0; b < ntrees; b++) {
00606 tree = VECTOR_ELT(forest, b);
00607
00608 if (oob &&
00609 REAL(S3get_nodeweights(C_get_nodebynum(tree, 1)))[i] > 0.0)
00610 continue;
00611
00612 iwhere = C_get_nodeID(tree, newinputs, REAL(mincriterion)[0], i);
00613 tmp = S3get_nodeweights(C_get_nodebynum(tree, iwhere));
00614 dtmp = REAL(tmp);
00615
00616 for (j = 0; j < n; j++)
00617 dw[j] += dtmp[j];
00618 }
00619
00620 C_matprod(dw, 1, n, REAL(response), n, q, REAL(VECTOR_ELT(ans, i)));
00621
00622 sumw = 0.0;
00623 for (j = 0; j < n; j++)
00624 sumw += dw[j];
00625
00626 for (j = 0; j < q; j++)
00627 REAL(VECTOR_ELT(ans, i))[j] = REAL(VECTOR_ELT(ans, i))[j] / sumw;
00628 }
00629 UNPROTECT(2);
00630 return(ans);
00631 }
00632
00641 SEXP R_predictRF_weights(SEXP forest, SEXP newinputs, SEXP mincriterion, SEXP oobpred) {
00642
00643 SEXP ans, tree, bw;
00644 int ntrees, nobs, i, b, j, q, iwhere, oob = 0, count = 0, ntrain;
00645 double *dtmp;
00646
00647 if (LOGICAL(oobpred)[0]) oob = 1;
00648
00649 nobs = get_nobs(newinputs);
00650 ntrees = LENGTH(forest);
00651 q = LENGTH(S3get_prediction(
00652 C_get_nodebynum(VECTOR_ELT(forest, 0), 1)));
00653
00654 if (oob) {
00655 if (LENGTH(S3get_nodeweights(
00656 C_get_nodebynum(VECTOR_ELT(forest, 0), 1))) != nobs)
00657 error("number of observations don't match");
00658 }
00659
00660 tree = VECTOR_ELT(forest, 0);
00661 ntrain = LENGTH(S3get_nodeweights(C_get_nodebynum(tree, 1)));
00662
00663 PROTECT(ans = allocVector(VECSXP, nobs));
00664
00665 for (i = 0; i < nobs; i++) {
00666 count = 0;
00667 SET_VECTOR_ELT(ans, i, bw = allocVector(REALSXP, ntrain));
00668 for (j = 0; j < ntrain; j++)
00669 REAL(bw)[j] = 0.0;
00670 for (b = 0; b < ntrees; b++) {
00671 tree = VECTOR_ELT(forest, b);
00672
00673 if (oob &&
00674 REAL(S3get_nodeweights(C_get_nodebynum(tree, 1)))[i] > 0.0)
00675 continue;
00676
00677 iwhere = C_get_nodeID(tree, newinputs, REAL(mincriterion)[0], i);
00678 dtmp = REAL(S3get_nodeweights(C_get_nodebynum(tree, iwhere)));
00679 for (j = 0; j < ntrain; j++)
00680 REAL(bw)[j] += dtmp[j];
00681 count++;
00682 }
00683 if (count == 0)
00684 error("cannot compute out-of-bag predictions for obs ", i + 1);
00685 }
00686 UNPROTECT(1);
00687 return(ans);
00688 }