marginaleffects
and reticulate
Note: The code in this vignette requires version 0.7.1 of marginaleffects
or the development version from Github.
The Python programming language offers several powerful libraries for (bayesian) statistical analysis, such as NumPyro
and PyMC
. This vignette shows how to use the the full power of marginaleffects
to analyze and interpret the results of models estimated by Markov Chain Monte Carlo using the NumPyro
Python library.
NumPyro
modelTo begin, we load the reticulate
package which allows us to interact with the Python interpreter from an R
session. Then, we write a NumPyro
model and we load it to memory using the source_python()
function. The important functions to note in the Python code are:
load_df()
downloads data on pulmonary fibrosis.model()
defines the NumPyro
model.fit_mcmc_model()
fits the model using Markov Chain Monte Carlo.predict_mcmc()
: accepts a data frame and returns a matrix of draws from the posterior distribution of adjusted predicions (fitted values).library(reticulate)
library(marginaleffects)
'
model <-# Model code adapted from the NumPyro documtation under Apache License:
# https://num.pyro.ai/en/latest/tutorials/bayesian_hierarchical_linear_regression.html
import pandas as pd
import numpy as np
import numpyro
from numpyro.infer import SVI, Predictive, MCMC,NUTS, autoguide, TraceMeanField_ELBO
import numpyro.distributions as dist
from numpyro.infer.initialization import init_to_median, init_to_uniform,init_to_sample
from jax import random
from sklearn.preprocessing import LabelEncoder
import pickle
def load_df():
train = pd.read_csv("https://raw.githubusercontent.com/vincentarelbundock/modelarchive/main/data-raw/osic_pulmonary_fibrosis.csv")
return train
def model(data, predict = False):
FVC_obs = data["FVC"].values if predict == False else None
patient_encoder = LabelEncoder()
Age_obs = data["Age"].values
patient_code = patient_encoder.fit_transform(data["Patient"].values)
μ_α = numpyro.sample("μ_α", dist.Normal(0.0, 500.0))
σ_α = numpyro.sample("σ_α", dist.HalfNormal(100.0))
age = numpyro.sample("age", dist.Normal(0.0, 500.0))
n_patients = len(np.unique(patient_code))
with numpyro.plate("plate_i", n_patients):
α = numpyro.sample("α", dist.Normal(μ_α, σ_α))
σ = numpyro.sample("σ", dist.HalfNormal(100.0))
FVC_est = α[patient_code] + age * Age_obs
with numpyro.plate("data", len(patient_code)):
numpyro.sample("obs", dist.Normal(FVC_est, σ), obs=FVC_obs)
def fit_mcmc_model(train_df, samples = 1000):
numpyro.set_host_device_count(4)
rng_key = random.PRNGKey(0)
mcmc = MCMC(
NUTS(model),
num_samples=samples,
num_warmup=1000,
progress_bar=True,
num_chains = 4
)
mcmc.run(rng_key, train_df)
posterior_draws = mcmc.get_samples()
with open("mcmc_posterior_draws.pickle", "wb") as handle:
pickle.dump(posterior_draws, handle, protocol=pickle.HIGHEST_PROTOCOL)
def predict_mcmc(data):
with open("mcmc_posterior_draws.pickle", "rb") as handle:
posterior_draws = pickle.load(handle)
predictive = Predictive(model = model,posterior_samples=posterior_draws)
samples = predictive(random.PRNGKey(1), data, predict = True)
y_pred = samples["obs"]
# transpose so that each column is a draw and each row is an observation
y_pred = np.transpose(np.array(y_pred))
return y_pred
'
# save python script to temp file
tempfile()
tmp <-cat(model, file = tmp)
# load functions
source_python(tmp)
# download data
load_df()
df <-
# fit model
fit_mcmc_model(df)
marginaleffects
Each of the functions in the marginaleffects
package requires that users supply a model
object on which the function will operate. When estimating models outside R
, we do not have such a model object. We thus begin by creating a “fake” model object: an empty data frame which we define to be of class “custom”. Then, we set a global option to tell marginaleffects
that this “custom” class is supported.
data.frame()
mod <-class(mod) <- "custom"
options("marginaleffects_model_classes" = "custom")
Next, we define a get_predict
method for our new custom class. This method must accept three arguments: model
, newdata
, and ...
. The get_predict
method must return a data frame with one row for each of the rows in newdata
, two columns (rowid
and predicted
), and an attribute called posterior_draws
which hosts a matrix of posterior draws with the same number of rows as newdata
.
The method below uses reticulate
to call the predict_mcmc()
function that we defined in the Python script above. The predict_mcmc()
function accepts a data frame and returns a matrix with the same number of rows.
function(model, newdata, ...) {
get_predict.custom <- predict_mcmc(newdata)
pred <- data.frame(
out <-rowid = seq_len(nrow(newdata)),
predicted = apply(pred, 1, stats::median)
)attr(out, "posterior_draws") <- pred
return(out)
}
Now we can use most of the marginaleffects
package functions to analyze our results. Since we use a “fake” model object, marginaleffects
cannot retrieve the original data from the model object, and we always need to supply a newdata
argument:
# predictions on the original dataset
predictions(mod, newdata = df) |> head()
#> rowid type predicted conf.low conf.high Patient Weeks
#> 1 1 response 2120.532 1763.176 2490.323 ID00007637202177411956430 -4
#> 2 2 response 2112.942 1740.647 2480.077 ID00007637202177411956430 5
#> 3 3 response 2116.086 1731.962 2480.866 ID00007637202177411956430 7
#> 4 4 response 2117.830 1740.041 2492.416 ID00007637202177411956430 9
#> 5 5 response 2114.272 1738.427 2490.165 ID00007637202177411956430 11
#> 6 6 response 2120.984 1741.109 2500.270 ID00007637202177411956430 17
#> FVC Percent Age Sex SmokingStatus
#> 1 2315 58.25365 79 Male Ex-smoker
#> 2 2214 55.71213 79 Male Ex-smoker
#> 3 2061 51.86210 79 Male Ex-smoker
#> 4 2144 53.95068 79 Male Ex-smoker
#> 5 2069 52.06341 79 Male Ex-smoker
#> 6 2101 52.86865 79 Male Ex-smoker
# predictions for user-defined predictor values
predictions(mod, newdata = datagrid(newdata = df, Age = c(60, 70)))
#> rowid type predicted conf.low conf.high Patient
#> 1 1 response 1798.592 1340.386 2318.963 ID00099637202206203080121
#> 2 2 response 1969.797 1583.616 2379.580 ID00099637202206203080121
#> Weeks FVC Percent Sex SmokingStatus Age
#> 1 31.86185 2690.479 77.67265 Male Ex-smoker 60
#> 2 31.86185 2690.479 77.67265 Male Ex-smoker 70
predictions(mod, newdata = datagrid(newdata = df, Age = range))
#> rowid type predicted conf.low conf.high Patient
#> 1 1 response 1605.094 1029.323 2301.611 ID00099637202206203080121
#> 2 2 response 2262.574 1860.123 2655.886 ID00099637202206203080121
#> Weeks FVC Percent Sex SmokingStatus Age
#> 1 31.86185 2690.479 77.67265 Male Ex-smoker 49
#> 2 31.86185 2690.479 77.67265 Male Ex-smoker 88
# average predictions by group
predictions(mod, newdata = df, by = "Sex")
#> type Sex predicted conf.low conf.high
#> 1 response Female 1883.243 1854.946 1910.619
#> 2 response Male 2904.648 2890.626 2918.217
# contrasts (average)
comparisons(mod, variables = "Age", newdata = df) |>
summary()
#> Term Contrast Effect 2.5 % 97.5 %
#> 1 Age +1 18.01 -0.6857 33.59
#>
#> Model type: custom
#> Prediction type: response
comparisons(mod, variables = list("Age" = "sd"), newdata = df) |>
summary()
#> Term Contrast Effect 2.5 % 97.5 %
#> 1 Age (x + sd/2) - (x - sd/2) 127.1 -4.84 237.1
#>
#> Model type: custom
#> Prediction type: response
# slope (elasticity)
marginaleffects(mod, variables = "Age", slope = "eyex", newdata = df) |>
summary()
#> Term Contrast Effect 2.5 % 97.5 %
#> 1 Age eY/eX 0.4997 -0.01924 0.9326
#>
#> Model type: custom
#> Prediction type: response