This packages allows you to to marginalize arbitrary prediction functions using Monte-Carlo integration. Since many prediction functions cannot be easily decomposed into a sum of low dimensional components marginalization can be helpful in making these functions interpretable.
marginalPrediction
does this computation and then evaluates the marginalized function at a set grid points, which can be uniformly created, subsampled from the training data, or explicitly specified as an argument.
The create of a uniform grid is handled by the variableGrid
method. If uniform = FALSE
and the points
argument isn’t used to specify what points to evaluate, a sample of size n[1]
is taken from the data without replacement.
library(mmpf)
library(randomForest)
## randomForest 4.6-12
## Type rfNews() to see new features/changes/bug fixes.
data(swiss)
fit = randomForest(Fertility ~ ., swiss)
marginalPrediction(swiss[, -1], "Education", c(10, 25), fit)
## $prediction
## [1] 73.16719 74.55362 73.63688 60.89975 65.07203 63.87359 66.01385
## [8] 65.16879 57.31980 64.23191
##
## $points
## Education
## 1 1
## 2 7
## 3 13
## 4 18
## 5 24
## 6 30
## 7 36
## 8 41
## 9 47
## 10 53
The output of marginalPrediction
is always a list with two elements, prediction
and points
.
By default the Monte-Carlo expectation is computed, which is set by the aggregate.fun
argument’s default value, the mean
function. Substituting, say, the median, would give a different output.
marginalPrediction(swiss[, -1], "Education", c(10, 25), fit, aggregate.fun = identity)
## $prediction
## [,1] [,2] [,3] [,4] [,5] [,6] [,7]
## [1,] 74.34694 62.63832 67.28735 82.51371 72.28726 74.34694 62.63832
## [2,] 61.82900 80.82552 64.74449 66.55270 67.69200 61.82900 80.82552
## [3,] 70.87189 76.16749 65.51442 70.08724 74.97268 70.87189 76.16749
## [4,] 61.70607 75.46225 68.62768 60.03769 60.78013 61.70607 75.46225
## [5,] 74.35020 69.80041 72.00466 71.91332 67.71496 74.35020 69.80041
## [6,] 74.98962 63.32124 65.76482 54.41933 57.34750 74.98962 63.32124
## [7,] 58.53077 60.83246 54.46824 73.12793 56.42593 58.53077 60.83246
## [8,] 62.50785 66.56272 64.13558 68.06267 57.21571 62.50785 66.56272
## [9,] 55.45441 56.45508 58.79520 72.72766 65.71551 55.45441 56.45508
## [10,] 71.51427 67.10101 73.85122 68.93235 71.59276 71.51427 67.10101
## [,8] [,9] [,10] [,11] [,12] [,13] [,14]
## [1,] 67.28735 82.51371 72.28726 74.34694 62.63832 67.28735 82.51371
## [2,] 64.74449 66.55270 67.69200 61.82900 80.82552 64.74449 66.55270
## [3,] 65.51442 70.08724 74.97268 70.87189 76.16749 65.51442 70.08724
## [4,] 68.62768 60.03769 60.78013 61.70607 75.46225 68.62768 60.03769
## [5,] 72.00466 71.91332 67.71496 74.35020 69.80041 72.00466 71.91332
## [6,] 65.76482 54.41933 57.34750 74.98962 63.32124 65.76482 54.41933
## [7,] 54.46824 73.12793 56.42593 58.53077 60.83246 54.46824 73.12793
## [8,] 64.13558 68.06267 57.21571 62.50785 66.56272 64.13558 68.06267
## [9,] 58.79520 72.72766 65.71551 55.45441 56.45508 58.79520 72.72766
## [10,] 73.85122 68.93235 71.59276 71.51427 67.10101 73.85122 68.93235
## [,15] [,16] [,17] [,18] [,19] [,20] [,21]
## [1,] 72.28726 74.34694 62.63832 67.28735 82.51371 72.28726 74.34694
## [2,] 67.69200 61.82900 80.82552 64.74449 66.55270 67.69200 61.82900
## [3,] 74.97268 70.87189 76.16749 65.51442 70.08724 74.97268 70.87189
## [4,] 60.78013 61.70607 75.46225 68.62768 60.03769 60.78013 61.70607
## [5,] 67.71496 74.35020 69.80041 72.00466 71.91332 67.71496 74.35020
## [6,] 57.34750 74.98962 63.32124 65.76482 54.41933 57.34750 74.98962
## [7,] 56.42593 58.53077 60.83246 54.46824 73.12793 56.42593 58.53077
## [8,] 57.21571 62.50785 66.56272 64.13558 68.06267 57.21571 62.50785
## [9,] 65.71551 55.45441 56.45508 58.79520 72.72766 65.71551 55.45441
## [10,] 71.59276 71.51427 67.10101 73.85122 68.93235 71.59276 71.51427
## [,22] [,23] [,24] [,25]
## [1,] 62.63832 67.28735 82.51371 72.28726
## [2,] 80.82552 64.74449 66.55270 67.69200
## [3,] 76.16749 65.51442 70.08724 74.97268
## [4,] 75.46225 68.62768 60.03769 60.78013
## [5,] 69.80041 72.00466 71.91332 67.71496
## [6,] 63.32124 65.76482 54.41933 57.34750
## [7,] 60.83246 54.46824 73.12793 56.42593
## [8,] 66.56272 64.13558 68.06267 57.21571
## [9,] 56.45508 58.79520 72.72766 65.71551
## [10,] 67.10101 73.85122 68.93235 71.59276
##
## $points
## Education
## 1 1
## 2 7
## 3 13
## 4 18
## 5 24
## 6 30
## 7 36
## 8 41
## 9 47
## 10 53
By passing the identity function to aggregate.fun
, which simply returns its input exactly, the integration points are returned directly so that the prediction
element of the return is a matrix of dimension n
.
marginalPrediction
can also handle cases in which predictions for a single data point are vector-valued. That is, classification tasks where probabilities are output, and multivariate regression and/or classification. In these cases aggregate.fun
is applied separately to each column of the prediction matrix.
data(iris)
fit = randomForest(Species ~ ., iris)
marginalPrediction(iris[, -ncol(iris)], "Sepal.Width", c(10, 25), fit,
predict.fun = function(object, newdata) predict(object, newdata = newdata, type = "prob"))
## $prediction
## [,1] [,2] [,3]
## [1,] 0.5784 0.4112 0.0104
## [2,] 0.1936 0.2580 0.5484
## [3,] 0.0000 0.3964 0.6036
## [4,] 0.1976 0.2604 0.5420
## [5,] 0.4016 0.3972 0.2012
## [6,] 0.6016 0.3980 0.0004
## [7,] 0.2108 0.2820 0.5072
## [8,] 0.0252 0.3968 0.5780
## [9,] 0.2092 0.2784 0.5124
## [10,] 0.4204 0.3780 0.2016
##
## $points
## Sepal.Width
## 1 2.000000
## 2 2.266667
## 3 2.533333
## 4 2.800000
## 5 3.066667
## 6 3.333333
## 7 3.600000
## 8 3.866667
## 9 4.133333
## 10 4.400000
In all of the aforementioned cases vars
can include multiple variables. In this case the Cartesian product of each variable’s grid is taken (however that is created), resulting in an at-most n[1]^length(vars)
unique points. This number can be reduced if n[1]
is less than the unique number of values for one of the variables.
marginalPrediction(iris[, -ncol(iris)], c("Sepal.Width", "Sepal.Length"), c(5, 25), fit,
predict.fun = function(object, newdata) predict(object, newdata = newdata, type = "prob"))
## $prediction
## [,1] [,2] [,3]
## [1,] 0.004 0.396 0.600
## [2,] 0.974 0.024 0.002
## [3,] 1.000 0.000 0.000
## [4,] 0.026 0.264 0.710
## [5,] 1.000 0.000 0.000
## [6,] 0.000 0.950 0.050
## [7,] 0.000 0.086 0.914
## [8,] 0.026 0.144 0.830
## [9,] 1.000 0.000 0.000
## [10,] 0.026 0.164 0.810
## [11,] 0.000 0.960 0.040
## [12,] 0.000 0.972 0.028
## [13,] 0.000 0.020 0.980
## [14,] 0.056 0.932 0.012
## [15,] 0.984 0.016 0.000
## [16,] 0.920 0.078 0.002
## [17,] 0.000 0.000 1.000
## [18,] 0.920 0.078 0.002
## [19,] 0.000 0.000 1.000
## [20,] 0.048 0.938 0.014
## [21,] 0.000 0.008 0.992
## [22,] 0.920 0.068 0.012
## [23,] 0.000 0.886 0.114
## [24,] 0.966 0.026 0.008
## [25,] 0.000 0.000 1.000
##
## $points
## Sepal.Width Sepal.Length
## 1 2.0 4.3
## 2 2.6 4.3
## 3 3.2 4.3
## 4 3.8 4.3
## 5 4.4 4.3
## 6 2.0 5.2
## 7 2.6 5.2
## 8 3.2 5.2
## 9 3.8 5.2
## 10 4.4 5.2
## 11 2.0 6.1
## 12 2.6 6.1
## 13 3.2 6.1
## 14 3.8 6.1
## 15 4.4 6.1
## 16 2.0 7.0
## 17 2.6 7.0
## 18 3.2 7.0
## 19 3.8 7.0
## 20 4.4 7.0
## 21 2.0 7.9
## 22 2.6 7.9
## 23 3.2 7.9
## 24 3.8 7.9
## 25 4.4 7.9
Permutation importance is a Monte-Carlo method which estimates the importance of variables in determining predictions by computing the change from repeatedly permuting the values of those variables, and comparing the prediction error using the permuted data to the error on the unpermuted training data.
permutationImportance
can compute this type of importance under arbitrary loss (with respect to the observed target) functions and contrast (between the loss with the unpermuted and permuted data).
permutationImportance(iris, "Sepal.Width", "Species", fit)
## [1] 0.01226667
For methods which generate predictions which are characters or unordered factors, the default loss function is the mean misclassification error. For all other types of predictions mean squared error is used.
It is, for example, possible to compute the expected change in the mean misclassification rate by class. The two arguments to loss.fun
are the permuted predictions and the target variable. In this case they are both vectors of factors.
permutationImportance(iris, "Sepal.Width", "Species", fit,
loss.fun = function(x, y) {
mat = table(x, y)
n = colSums(mat)
diag(mat) = 0
rowSums(mat) / n
},
contrast.fun = function(x, y) x - y)
## setosa versicolor virginica
## 0.0000 0.0256 0.0126