00001
00009 #include "party.h"
00010
00011
00022 void C_prediction(const double *y, int n, int q, const double *weights,
00023 const double sweights, double *ans) {
00024
00025 int i, j, jn;
00026
00027 for (j = 0; j < q; j++) {
00028 ans[j] = 0.0;
00029 jn = j * n;
00030 for (i = 0; i < n; i++)
00031 ans[j] += weights[i] * y[jn + i];
00032 ans[j] = ans[j] / sweights;
00033 }
00034 }
00035
00036
00048 void C_Node(SEXP node, SEXP learnsample, SEXP weights,
00049 SEXP fitmem, SEXP controls, int TERMINAL) {
00050
00051 int nobs, ninputs, jselect, q, j, k, i;
00052 double mincriterion, sweights, *dprediction;
00053 double *teststat, *pvalue, smax, cutpoint = 0.0, maxstat = 0.0;
00054 double *standstat, *splitstat;
00055 SEXP responses, inputs, y, x, expcovinf, thisweights, linexpcov;
00056 SEXP varctrl, splitctrl, gtctrl, tgctrl, split, joint;
00057 double *dxtransf, *dweights;
00058 int *itable;
00059
00060 nobs = get_nobs(learnsample);
00061 ninputs = get_ninputs(learnsample);
00062 varctrl = get_varctrl(controls);
00063 splitctrl = get_splitctrl(controls);
00064 gtctrl = get_gtctrl(controls);
00065 tgctrl = get_tgctrl(controls);
00066 mincriterion = get_mincriterion(gtctrl);
00067 responses = GET_SLOT(learnsample, PL2_responsesSym);
00068 inputs = GET_SLOT(learnsample, PL2_inputsSym);
00069 y = get_transformation(responses, 1);
00070 q = ncol(y);
00071 joint = GET_SLOT(responses, PL2_jointtransfSym);
00072
00073
00074
00075
00076 C_GlobalTest(learnsample, weights, fitmem, varctrl,
00077 gtctrl, get_minsplit(splitctrl),
00078 REAL(S3get_teststat(node)), REAL(S3get_criterion(node)));
00079
00080
00081 sweights = REAL(GET_SLOT(GET_SLOT(fitmem, PL2_expcovinfSym),
00082 PL2_sumweightsSym))[0];
00083
00084
00085 dprediction = REAL(S3get_prediction(node));
00086
00087
00088
00089 C_prediction(REAL(joint), nobs, ncol(joint), REAL(weights),
00090 sweights, dprediction);
00091
00092
00093 teststat = REAL(S3get_teststat(node));
00094 pvalue = REAL(S3get_criterion(node));
00095
00096
00097
00098
00099 for (j = 0; j < 2; j++) {
00100
00101 smax = C_max(pvalue, ninputs);
00102 REAL(S3get_maxcriterion(node))[0] = smax;
00103
00104
00105 if (smax > mincriterion && !TERMINAL) {
00106
00107
00108 jselect = C_whichmax(pvalue, teststat, ninputs) + 1;
00109
00110
00111 x = get_variable(inputs, jselect);
00112 if (has_missings(inputs, jselect)) {
00113 expcovinf = GET_SLOT(get_varmemory(fitmem, jselect),
00114 PL2_expcovinfSym);
00115 thisweights = get_weights(fitmem, jselect);
00116 } else {
00117 expcovinf = GET_SLOT(fitmem, PL2_expcovinfSym);
00118 thisweights = weights;
00119 }
00120
00121
00122 if (!is_nominal(inputs, jselect)) {
00123
00124
00125 split = S3get_primarysplit(node);
00126
00127
00128
00129 if (get_savesplitstats(tgctrl)) {
00130 C_init_orderedsplit(split, nobs);
00131 splitstat = REAL(S3get_splitstatistics(split));
00132 } else {
00133 C_init_orderedsplit(split, 0);
00134 splitstat = REAL(get_splitstatistics(fitmem));
00135 }
00136
00137 C_split(REAL(x), 1, REAL(y), q, REAL(weights), nobs,
00138 INTEGER(get_ordering(inputs, jselect)), splitctrl,
00139 GET_SLOT(fitmem, PL2_linexpcov2sampleSym),
00140 expcovinf, REAL(S3get_splitpoint(split)), &maxstat,
00141 splitstat);
00142 S3set_variableID(split, jselect);
00143 } else {
00144
00145
00146 split = S3get_primarysplit(node);
00147
00148
00149
00150 if (get_savesplitstats(tgctrl)) {
00151 C_init_nominalsplit(split,
00152 LENGTH(get_levels(inputs, jselect)),
00153 nobs);
00154 splitstat = REAL(S3get_splitstatistics(split));
00155 } else {
00156 C_init_nominalsplit(split,
00157 LENGTH(get_levels(inputs, jselect)),
00158 0);
00159 splitstat = REAL(get_splitstatistics(fitmem));
00160 }
00161
00162 linexpcov = get_varmemory(fitmem, jselect);
00163 standstat = Calloc(get_dimension(linexpcov), double);
00164 C_standardize(REAL(GET_SLOT(linexpcov,
00165 PL2_linearstatisticSym)),
00166 REAL(GET_SLOT(linexpcov, PL2_expectationSym)),
00167 REAL(GET_SLOT(linexpcov, PL2_covarianceSym)),
00168 get_dimension(linexpcov), get_tol(splitctrl),
00169 standstat);
00170
00171 C_splitcategorical(INTEGER(x),
00172 LENGTH(get_levels(inputs, jselect)),
00173 REAL(y), q, REAL(weights),
00174 nobs, standstat, splitctrl,
00175 GET_SLOT(fitmem, PL2_linexpcov2sampleSym),
00176 expcovinf, &cutpoint,
00177 INTEGER(S3get_splitpoint(split)),
00178 &maxstat, splitstat);
00179
00180
00181
00182
00183
00184 itable = INTEGER(S3get_table(split));
00185 dxtransf = REAL(get_transformation(inputs, jselect));
00186 dweights = REAL(thisweights);
00187 for (k = 0; k < LENGTH(get_levels(inputs, jselect)); k++) {
00188 itable[k] = 0;
00189 for (i = 0; i < nobs; i++) {
00190 if (dxtransf[k * nobs + i] * dweights[i] > 0) {
00191 itable[k] = 1;
00192 continue;
00193 }
00194 }
00195 }
00196
00197 Free(standstat);
00198 }
00199 if (maxstat == 0) {
00200 warning("no admissible split found\n");
00201
00202 if (j == 1) {
00203 S3set_nodeterminal(node);
00204 } else {
00205
00206 pvalue[jselect - 1] = 0.0;
00207 }
00208 } else {
00209 S3set_variableID(split, jselect);
00210 break;
00211 }
00212 } else {
00213 S3set_nodeterminal(node);
00214 break;
00215 }
00216 }
00217 }
00218
00219
00228 SEXP R_Node(SEXP learnsample, SEXP weights, SEXP fitmem, SEXP controls) {
00229
00230 SEXP ans;
00231
00232 PROTECT(ans = allocVector(VECSXP, NODE_LENGTH));
00233 C_init_node(ans, get_nobs(learnsample), get_ninputs(learnsample),
00234 get_maxsurrogate(get_splitctrl(controls)),
00235 ncol(GET_SLOT(GET_SLOT(learnsample, PL2_responsesSym),
00236 PL2_jointtransfSym)));
00237
00238 C_Node(ans, learnsample, weights, fitmem, controls, 0);
00239 UNPROTECT(1);
00240 return(ans);
00241 }