In the context of this package, an “Adjusted Prediction” is defined as:
The outcome predicted by a fitted model on a specified scale for a given combination of values of the predictor variables, such as their observed values, their means, or factor levels (a.k.a. “reference grid”).
Here, the word “Adjusted” simply means “model-derived” or “model-based.”
Using the type
argument of the predictions()
function we can specify the “scale” on which to make predictions. This refers to either the scale used to estimate the model (i.e., link scale) or to a more interpretable scale (e.g., response scale). For example, when fitting a linear regression model using the lm()
function, the link scale and the response scale are identical. An “Adjusted Prediction” computed on either scale will be expressed as the mean value of the response variable at the given values of the predictor variables.
On the other hand, when fitting a binary logistic regression model using the glm()
function (which uses a binomial family and a logit link ), the link scale and the response scale will be different: an “Adjusted Prediction” computed on the link scale will be expressed as a log odds of a “successful” response at the given values of the predictor variables, whereas an “Adjusted Prediction” computed on the response scale will be expressed as a probability that the response variable equals 1.
The default value of the type
argument for most models is “response”, which means that the predictions()
function will compute predicted probabilities (binomial family), Poisson means (poisson family), etc.
To compute adjusted predictions we must first specify the values of the predictors to consider: a “reference grid.” For example, if our model is a linear model fitted with the lm() function which relates the response variable Happiness with the predictor variables Age, Gender and Income, the reference grid could be a data.frame
with values for Age, Gender and Income: Age = 40, Gender = Male, Income = 60000.
The “reference grid” may or may not correspond to actual observations in the dataset used to fit the model; the example values given above could match the mean values of each variable, or they could represent a specific observed (or hypothetical) individual. The reference grid can include many different rows if we want to make predictions for different combinations of predictors. By default, the predictions()
function uses the full original dataset as a reference grid, which means it will compute adjusted predictions for each of the individuals observed in the dataset that was used to fit the model.
predictions()
functionBy default, predictions
calculates the regression-adjusted predicted values for every observation in the original dataset:
library(marginaleffects)
lm(mpg ~ hp + factor(cyl), data = mtcars)
mod <-
predictions(mod)
pred <-
head(pred)
#> rowid type predicted std.error statistic p.value conf.low conf.high mpg hp cyl
#> 1 1 response 20.03819 1.2041405 16.64107 3.512623e-62 17.57162 22.50476 21.0 110 6
#> 2 2 response 20.03819 1.2041405 16.64107 3.512623e-62 17.57162 22.50476 21.0 110 6
#> 3 3 response 26.41451 0.9619738 27.45866 5.476301e-166 24.44399 28.38502 22.8 93 4
#> 4 4 response 20.03819 1.2041405 16.64107 3.512623e-62 17.57162 22.50476 21.4 110 6
#> 5 5 response 15.92247 0.9924560 16.04350 6.347069e-58 13.88952 17.95543 18.7 175 8
#> 6 6 response 20.15839 1.2186288 16.54186 1.832792e-61 17.66214 22.65463 18.1 105 6
In many cases, this is too limiting, and researchers will want to specify a grid of “typical” values over which to compute adjusted predictions.
There are two main ways to select the reference grid over which we want to compute adjusted predictions. The first is using the variables
argument. The second is with the newdata
argument and the datagrid()
function that we already introduced in the marginal effects vignette.
variables
: Counterfactual predictionsThe variables
argument is a handy way to create and make predictions on counterfactual datasets. For example, here the dataset that we used to fit the model has 32 rows. The counterfactual dataset with two distinct values of hp
has 64 rows: each of the original rows appears twice, that is, once with each of the values that we specified in the variables
argument:
predictions(mod, variables = list(hp = c(100, 120)))
p <-head(p)
#> rowid rowidcf type predicted std.error statistic p.value conf.low conf.high mpg cyl hp
#> 1 1 1 response 20.27858 1.2377512 16.383405 2.512745e-60 17.74316 22.81400 21.0 6 100
#> 2 2 2 response 20.27858 1.2377512 16.383405 2.512745e-60 17.74316 22.81400 21.0 6 100
#> 3 3 3 response 26.24623 0.9856325 26.628826 3.148430e-156 24.22726 28.26521 22.8 4 100
#> 4 4 4 response 20.27858 1.2377512 16.383405 2.512745e-60 17.74316 22.81400 21.4 6 100
#> 5 5 5 response 17.72538 1.8811567 9.422599 4.400597e-21 13.87201 21.57876 18.7 8 100
#> 6 6 6 response 20.27858 1.2377512 16.383405 2.512745e-60 17.74316 22.81400 18.1 6 100
nrow(p)
#> [1] 64
newdata
and datagrid
A second strategy to construct grids of predictors for adjusted predictions is to combine the newdata
argument and the datagrid
function. Recall that this function creates a “typical” dataset with all variables at their means or modes, except those we explicitly define:
datagrid(cyl = c(4, 6, 8), model = mod)
#> hp cyl
#> 1: 146.6875 4
#> 2: 146.6875 6
#> 3: 146.6875 8
We can also use this datagrid
function in a predictions
call (omitting the model
argument):
predictions(mod, newdata = datagrid())
#> rowid type predicted std.error statistic p.value conf.low conf.high hp cyl
#> 1 1 response 16.60307 1.278754 12.98379 1.512165e-38 13.98366 19.22248 146.6875 8
predictions(mod, newdata = datagrid(cyl = c(4, 6, 8)))
#> rowid type predicted std.error statistic p.value conf.low conf.high hp cyl
#> 1 1 response 25.12392 1.368888 18.35353 3.093502e-75 22.31988 27.92796 146.6875 4
#> 2 2 response 19.15627 1.247190 15.35955 3.057119e-53 16.60151 21.71102 146.6875 6
#> 3 3 response 16.60307 1.278754 12.98379 1.512165e-38 13.98366 19.22248 146.6875 8
Users can change the summary function used to summarize each type of variables using the FUN.numeric
, FUN.factor
, and related arguments, for example substituting the mean for the median.
The data.frame
produced by predictions
is “tidy”, which makes it easy to manipulate with other R
packages and functions:
library(kableExtra)
library(tidyverse)
predictions(
mod,newdata = datagrid(cyl = mtcars$cyl, hp = c(100, 110))) %>%
select(hp, cyl, predicted) %>%
pivot_wider(values_from = predicted, names_from = cyl) %>%
kbl(caption = "A table of Adjusted Predictions") %>%
kable_styling() %>%
add_header_above(header = c(" " = 1, "cyl" = 3))
hp | 6 | 4 | 8 |
---|---|---|---|
100 | 20.27858 | 26.24623 | 17.72538 |
110 | 20.03819 | 26.00585 | 17.48500 |
counterfactual
data gridAn alternative approach to construct grids of predictors is to use grid_type = "counterfactual"
argument value. This will duplicate the whole dataset, with the different values specified by the user.
For example, the mtcars
dataset has 32 rows. This command produces a new dataset with 64 rows, with each row of the original dataset duplicated with the two values of the am
variable supplied (0 and 1):
glm(vs ~ hp + am, data = mtcars, family = binomial)
mod <-
datagrid(model = mod, am = 0:1, grid_type = "counterfactual")
nd <-
dim(nd)
#> [1] 64 4
Then, we can use this dataset and the predictions
function to create interesting visualizations:
predictions(mod, newdata = datagrid(am = 0:1, grid_type = "counterfactual")) %>%
pred <- select(am, predicted, rowidcf) %>%
pivot_wider(id_cols = rowidcf,
names_from = am,
values_from = predicted)
ggplot(pred, aes(x = `0`, y = `1`)) +
geom_point() +
geom_abline(intercept = 0, slope = 1) +
labs(x = "Predicted Pr(vs=1), when am = 0",
y = "Predicted Pr(vs=1), when am = 1")
In this graph, each dot represents the predicted probability that vs=1
for one observation of the dataset, in the counterfactual worlds where am
is either 0 or 1.
Some analysts may want to calculate an “Adjusted Prediction at the Mean,” that is, the predicted outcome when all the regressors are held at their mean (or mode). To achieve this, we use the datagrid
function. By default, this function produces a grid of data with regressors at their means or modes, so all we need to do to get the APM is:
predictions(mod, newdata = "mean")
#> rowid type predicted std.error conf.low conf.high hp am
#> 1 1 response 0.06308965 0.08662801 0.003794253 0.543491 146.6875 0.40625
This is equivalent to calling:
predictions(mod, newdata = datagrid())
#> rowid type predicted std.error conf.low conf.high hp am
#> 1 1 response 0.06308965 0.08662801 0.003794253 0.543491 146.6875 0.40625
An “Average Adjusted Prediction” is the outcome of a two step process:
We can obtain AAPs by applying the tidy()
or summary()
functions to an object produced by the predictions()
function:
predictions(mod)
pred <-summary(pred)
#> Predicted Std. Error z value Pr(>|z|) CI low CI high
#> 1 0.4375 0.04288 10.2 < 2.22e-16 0.3535 0.5215
#>
#> Model type: glm
#> Prediction type: response
This is equivalent to:
%>% summarize(AAP = mean(predicted))
pred #> AAP
#> 1 0.4375
summary(pred)
#> Predicted Std. Error z value Pr(>|z|) CI low CI high
#> 1 0.4375 0.04288 10.2 < 2.22e-16 0.3535 0.5215
#>
#> Model type: glm
#> Prediction type: response
We can also compute the AAP for multiple values of the regressors. For example, here use create a “counterfactual” data grid where each observation of the dataset are repeated twice, with different values of the am
variable, and all other variables held at the observed values. Then, we use the by
argument or some dplyr
magic:
predictions(
p <-
mod,by = "am",
newdata = datagrid(am = 0:1, grid_type = "counterfactual"))
summary(p)
#> am Predicted Std. Error z value Pr(>|z|) CI low CI high
#> 1 0 0.5261 0.03303 15.93 < 2.22e-16 0.4614 0.5909
#> 2 1 0.3302 0.06462 5.11 3.2272e-07 0.2035 0.4568
#>
#> Model type: glm
#> Prediction type: response
%>% group_by(am) %>%
p summarize(AAP = mean(predicted))
#> # A tibble: 2 × 2
#> am AAP
#> <int> <dbl>
#> 1 0 0.526
#> 2 1 0.330
First, we download the ggplot2movies
dataset from the RDatasets archive. Then, we create a variable called certified_fresh
for movies with a rating of at least 8. Finally, we discard some outliers and fit a logistic regression model:
library(tidyverse)
read.csv("https://vincentarelbundock.github.io/Rdatasets/csv/ggplot2movies/movies.csv") %>%
dat <- mutate(style = case_when(Action == 1 ~ "Action",
== 1 ~ "Comedy",
Comedy == 1 ~ "Drama",
Drama TRUE ~ "Other"),
style = factor(style),
certified_fresh = rating >= 8) %>%
filter(length < 240)
glm(certified_fresh ~ length * style, data = dat, family = binomial) mod <-
We can plot adjusted predictions, conditional on the length
variable using the plot_cap
function:
glm(certified_fresh ~ length, data = dat, family = binomial)
mod <-
plot_cap(mod, condition = "length")
We can also introduce another condition which will display a categorical variable like style
in different colors. This can be useful in models with interactions:
glm(certified_fresh ~ length * style, data = dat, family = binomial)
mod <-
plot_cap(mod, condition = c("length", "style"))
Since the output of plot_cap()
is a ggplot2
object, it is very easy to customize. For example, we can add points for the actual observations of our dataset like so:
library(ggplot2)
library(ggrepel)
mtcars
mt <-$label <- row.names(mt)
mt
lm(mpg ~ hp, data = mt)
mod <-
plot_cap(mod, condition = "hp") +
geom_point(aes(x = hp, y = mpg), data = mt) +
geom_rug(aes(x = hp, y = mpg), data = mt) +
geom_text_repel(aes(x = hp, y = mpg, label = label),
data = subset(mt, hp > 250),
nudge_y = 2) +
theme_classic()
We can also use plot_cap()
in models with multinomial outcomes or grouped coefficients. For example, notice that when we call draw=FALSE
, the result includes a group
column:
library(MASS)
library(ggplot2)
nnet::multinom(factor(gear) ~ mpg, data = mtcars, trace = FALSE)
mod <-
plot_cap(
p <-
mod,type = "probs",
condition = "mpg",
draw = FALSE)
head(p)
#> rowid type group predicted std.error condition1
#> 1 1 probs 3 0.9714990 0.03871641 10.40000
#> 2 2 probs 3 0.9583559 0.04985914 11.37917
#> 3 3 probs 3 0.9393514 0.06291986 12.35833
#> 4 4 probs 3 0.9122105 0.07727155 13.33750
#> 5 5 probs 3 0.8741884 0.09157738 14.31667
#> 6 6 probs 3 0.8224163 0.10383644 15.29583
Now we use the group
column:
plot_cap(
mod,type = "probs",
condition = "mpg") +
facet_wrap(~group)
The predictions
function computes model-adjusted means on the scale of the output of the predict(model)
function. By default, predict
produces predictions on the "response"
scale, so the adjusted predictions should be interpreted on that scale. However, users can pass a string to the type
argument, and predictions
will consider different outcomes.
Typical values include "response"
and "link"
, but users should refer to the documentation of the predict
of the package they used to fit the model to know what values are allowable. documentation.
glm(am ~ mpg, family = binomial, data = mtcars)
mod <- predictions(mod, type = "response")
pred <-head(pred)
#> rowid type predicted std.error conf.low conf.high am mpg
#> 1 1 response 0.4610951 0.11584004 0.2554723 0.6808686 1 21.0
#> 2 2 response 0.4610951 0.11584004 0.2554723 0.6808686 1 21.0
#> 3 3 response 0.5978984 0.13239819 0.3356711 0.8139794 1 22.8
#> 4 4 response 0.4917199 0.11961263 0.2746560 0.7119512 0 21.4
#> 5 5 response 0.2969009 0.10051954 0.1411369 0.5204086 0 18.7
#> 6 6 response 0.2599331 0.09782666 0.1147580 0.4876032 0 18.1
predictions(mod, type = "link")
pred <-head(pred)
#> rowid type predicted std.error statistic p.value conf.low conf.high am mpg
#> 1 1 link -0.15593472 0.4661826 -0.33449281 0.73800772 -1.0696358 0.75776637 1 21.0
#> 2 2 link -0.15593472 0.4661826 -0.33449281 0.73800772 -1.0696358 0.75776637 1 21.0
#> 3 3 link 0.39671602 0.5507048 0.72037875 0.47129183 -0.6826455 1.47607755 1 22.8
#> 4 4 link -0.03312345 0.4785818 -0.06921168 0.94482113 -0.9711265 0.90487956 0 21.4
#> 5 5 link -0.86209956 0.4815290 -1.79033775 0.07339963 -1.8058791 0.08167995 0 18.7
#> 6 6 link -1.04631647 0.5085395 -2.05749308 0.03963882 -2.0430356 -0.04959739 0 18.1
We can also plot predictions on different outcome scales:
plot_cap(mod, condition = "mpg", type = "response")
plot_cap(mod, condition = "mpg", type = "link")