00001
00009 #include "party.h"
00010
00011
00021 SEXP R_Ensemble(SEXP learnsample, SEXP weights, SEXP fitmem, SEXP controls) {
00022
00023 SEXP nweights, tree, where, ans;
00024 double *dnweights, *dweights, sw = 0.0, *prob, fraction;
00025 int nobs, i, b, B , nodenum = 1, *iweights, *iweightstmp,
00026 *iwhere, replace;
00027
00028 B = get_ntree(controls);
00029 nobs = get_nobs(learnsample);
00030
00031 PROTECT(ans = allocVector(VECSXP, B));
00032
00033 iweights = Calloc(nobs, int);
00034 iweightstmp = Calloc(nobs, int);
00035 prob = Calloc(nobs, double);
00036 dweights = REAL(weights);
00037
00038 for (i = 0; i < nobs; i++)
00039 sw += dweights[i];
00040 for (i = 0; i < nobs; i++)
00041 prob[i] = dweights[i]/sw;
00042
00043 replace = get_replace(controls);
00044 fraction = get_fraction(controls) * nobs;
00045
00046 if (!replace) {
00047 if (fraction < 10)
00048 error("fraction of %f is too small", fraction);
00049 }
00050
00051 for (b = 0; b < B; b++) {
00052 SET_VECTOR_ELT(ans, b, tree = allocVector(VECSXP, NODE_LENGTH + 1));
00053 SET_VECTOR_ELT(tree, NODE_LENGTH, where = allocVector(INTSXP, nobs));
00054 iwhere = INTEGER(where);
00055 for (i = 0; i < nobs; i++) iwhere[i] = 0;
00056
00057 C_init_node(tree, nobs, get_ninputs(learnsample),
00058 get_maxsurrogate(get_splitctrl(controls)),
00059 ncol(GET_SLOT(GET_SLOT(learnsample, PL2_responsesSym),
00060 PL2_jointtransfSym)));
00061
00062
00063
00064 GetRNGstate();
00065
00066
00067 if (replace) {
00068
00069 rmultinom((int) sw, prob, nobs, iweights);
00070 } else {
00071
00072 C_SampleNoReplace(iweightstmp, nobs, nobs, iweights);
00073 for (i = 0; i < nobs; i++) {
00074 if (iweights[i] < fraction) {
00075 iweights[i] = 1;
00076 } else {
00077 iweights[i] = 0;
00078 }
00079 }
00080 }
00081 PutRNGstate();
00082
00083 nweights = S3get_nodeweights(tree);
00084 dnweights = REAL(nweights);
00085 for (i = 0; i < nobs; i++) dnweights[i] = (double) iweights[i];
00086
00087 C_TreeGrow(tree, learnsample, fitmem, controls, iwhere, &nodenum, 1);
00088 nodenum = 1;
00089 }
00090
00091
00092 Free(prob); Free(iweights); Free(iweightstmp);
00093 UNPROTECT(1);
00094 return(ans);
00095 }