Python with marginaleffects and reticulate

Note: The code in this vignette requires version 0.7.1 of marginaleffects or the development version from Github.

The Python programming language offers several powerful libraries for (bayesian) statistical analysis, such as NumPyro and PyMC. This vignette shows how to use the the full power of marginaleffects to analyze and interpret the results of models estimated by Markov Chain Monte Carlo using the NumPyro Python library.

Fitting a NumPyro model

To begin, we load the reticulate package which allows us to interact with the Python interpreter from an R session. Then, we write a NumPyro model and we load it to memory using the source_python() function. The important functions to note in the Python code are:

library(reticulate)
library(marginaleffects)

model <- '
# Model code adapted from the NumPyro documtation under Apache License:
# https://num.pyro.ai/en/latest/tutorials/bayesian_hierarchical_linear_regression.html

import pandas as pd
import numpy as np
import numpyro
from numpyro.infer import SVI, Predictive, MCMC,NUTS, autoguide, TraceMeanField_ELBO
import numpyro.distributions as dist
from numpyro.infer.initialization import init_to_median, init_to_uniform,init_to_sample
from jax import random
from sklearn.preprocessing import LabelEncoder
import pickle

def load_df():
    train = pd.read_csv("https://raw.githubusercontent.com/vincentarelbundock/modelarchive/main/data-raw/osic_pulmonary_fibrosis.csv")
    return train


def model(data, predict = False):
    FVC_obs = data["FVC"].values  if predict == False else None
    patient_encoder = LabelEncoder()
    Age_obs = data["Age"].values
    patient_code = patient_encoder.fit_transform(data["Patient"].values)
    μ_α = numpyro.sample("μ_α", dist.Normal(0.0, 500.0))
    σ_α = numpyro.sample("σ_α", dist.HalfNormal(100.0))

    age = numpyro.sample("age", dist.Normal(0.0, 500.0))

    n_patients = len(np.unique(patient_code))

    with numpyro.plate("plate_i", n_patients):
        α = numpyro.sample("α", dist.Normal(μ_α, σ_α))

    σ = numpyro.sample("σ", dist.HalfNormal(100.0))
    FVC_est = α[patient_code] + age * Age_obs

    with numpyro.plate("data", len(patient_code)):
        numpyro.sample("obs", dist.Normal(FVC_est, σ), obs=FVC_obs)


def fit_mcmc_model(train_df, samples = 1000):
    numpyro.set_host_device_count(4)
    rng_key = random.PRNGKey(0)
    mcmc = MCMC(
        NUTS(model),
        num_samples=samples,
        num_warmup=1000,
        progress_bar=True,
        num_chains = 4
        )
    
    mcmc.run(rng_key, train_df)

    posterior_draws = mcmc.get_samples()

    with open("mcmc_posterior_draws.pickle", "wb") as handle:
        pickle.dump(posterior_draws, handle, protocol=pickle.HIGHEST_PROTOCOL)

def predict_mcmc(data):

    with open("mcmc_posterior_draws.pickle", "rb") as handle:
        posterior_draws = pickle.load(handle)

    predictive = Predictive(model = model,posterior_samples=posterior_draws)
    samples = predictive(random.PRNGKey(1), data, predict = True)
    y_pred = samples["obs"]
    # transpose so that each column is a draw and each row is an observation
    y_pred = np.transpose(np.array(y_pred))

    return y_pred 
'

# save python script to temp file
tmp <- tempfile()
cat(model, file = tmp)

# load functions
source_python(tmp)

# download data
df <- load_df()

# fit model
fit_mcmc_model(df)

Analyzing the results in marginaleffects

Each of the functions in the marginaleffects package requires that users supply a model object on which the function will operate. When estimating models outside R, we do not have such a model object. We thus begin by creating a “fake” model object: an empty data frame which we define to be of class “custom”. Then, we set a global option to tell marginaleffects that this “custom” class is supported.

mod <- data.frame()
class(mod) <- "custom"

options("marginaleffects_model_classes" = "custom")

Next, we define a get_predict method for our new custom class. This method must accept three arguments: model, newdata, and .... The get_predict method must return a data frame with one row for each of the rows in newdata, two columns (rowid and predicted), and an attribute called posterior_draws which hosts a matrix of posterior draws with the same number of rows as newdata.

The method below uses reticulate to call the predict_mcmc() function that we defined in the Python script above. The predict_mcmc() function accepts a data frame and returns a matrix with the same number of rows.

get_predict.custom <- function(model, newdata, ...) {
    pred <- predict_mcmc(newdata)
    out <- data.frame(
        rowid = seq_len(nrow(newdata)),
        predicted = apply(pred, 1, stats::median)
    )
    attr(out, "posterior_draws") <- pred
    return(out)
}

Now we can use most of the marginaleffects package functions to analyze our results. Since we use a “fake” model object, marginaleffects cannot retrieve the original data from the model object, and we always need to supply a newdata argument:

# predictions on the original dataset
predictions(mod, newdata = df) |> head()
#>   rowid     type predicted conf.low conf.high                   Patient Weeks
#> 1     1 response  2120.532 1763.176  2490.323 ID00007637202177411956430    -4
#> 2     2 response  2112.942 1740.647  2480.077 ID00007637202177411956430     5
#> 3     3 response  2116.086 1731.962  2480.866 ID00007637202177411956430     7
#> 4     4 response  2117.830 1740.041  2492.416 ID00007637202177411956430     9
#> 5     5 response  2114.272 1738.427  2490.165 ID00007637202177411956430    11
#> 6     6 response  2120.984 1741.109  2500.270 ID00007637202177411956430    17
#>    FVC  Percent Age  Sex SmokingStatus
#> 1 2315 58.25365  79 Male     Ex-smoker
#> 2 2214 55.71213  79 Male     Ex-smoker
#> 3 2061 51.86210  79 Male     Ex-smoker
#> 4 2144 53.95068  79 Male     Ex-smoker
#> 5 2069 52.06341  79 Male     Ex-smoker
#> 6 2101 52.86865  79 Male     Ex-smoker

# predictions for user-defined predictor values
predictions(mod, newdata = datagrid(newdata = df, Age = c(60, 70)))
#>   rowid     type predicted conf.low conf.high                   Patient
#> 1     1 response  1798.592 1340.386  2318.963 ID00099637202206203080121
#> 2     2 response  1969.797 1583.616  2379.580 ID00099637202206203080121
#>      Weeks      FVC  Percent  Sex SmokingStatus Age
#> 1 31.86185 2690.479 77.67265 Male     Ex-smoker  60
#> 2 31.86185 2690.479 77.67265 Male     Ex-smoker  70

predictions(mod, newdata = datagrid(newdata = df, Age = range))
#>   rowid     type predicted conf.low conf.high                   Patient
#> 1     1 response  1605.094 1029.323  2301.611 ID00099637202206203080121
#> 2     2 response  2262.574 1860.123  2655.886 ID00099637202206203080121
#>      Weeks      FVC  Percent  Sex SmokingStatus Age
#> 1 31.86185 2690.479 77.67265 Male     Ex-smoker  49
#> 2 31.86185 2690.479 77.67265 Male     Ex-smoker  88

# average predictions by group
predictions(mod, newdata = df, by = "Sex")
#>       type    Sex predicted conf.low conf.high
#> 1 response Female  1883.243 1854.946  1910.619
#> 2 response   Male  2904.648 2890.626  2918.217

# contrasts (average)
comparisons(mod, variables = "Age", newdata = df)  |>
    summary()
#>   Term Contrast Effect   2.5 % 97.5 %
#> 1  Age       +1  18.01 -0.6857  33.59
#> 
#> Model type:  custom 
#> Prediction type:  response

comparisons(mod, variables = list("Age" = "sd"), newdata = df)  |>
    summary()
#>   Term                Contrast Effect 2.5 % 97.5 %
#> 1  Age (x + sd/2) - (x - sd/2)  127.1 -4.84  237.1
#> 
#> Model type:  custom 
#> Prediction type:  response

# slope (elasticity)
marginaleffects(mod, variables = "Age", slope = "eyex", newdata = df) |>
    summary()
#>   Term Contrast Effect    2.5 % 97.5 %
#> 1  Age    eY/eX 0.4997 -0.01924 0.9326
#> 
#> Model type:  custom 
#> Prediction type:  response