Hierarchical Variational Inference
Hierarchical VI Example¶
In this example we will fit a hierarchical model to a simulated dataset using both MCMC and VI, and perform a simple comparison of the results.
Thanks for Guillaume Pagnier PhD for the initial version of this example.
import matplotlib
import numpy as np
import pandas as pd
matplotlib.use("Agg")
import warnings
import arviz as az
import matplotlib.pyplot as plt
import pymc as pm
import pytensor
import hssm
warnings.filterwarnings("ignore")
pytensor.config.floatX = "float64"
%matplotlib inline
Utilities¶
def process_idata_for_plotting(
idata: az.InferenceData, parameter_matrix: pd.DataFrame, model: str
) -> pd.DataFrame:
"""Process inference data and parameter matrix into a dataframe for plotting.
Parameters
----------
idata : az.InferenceData
Inference data containing posterior samples
parameter_matrix : pd.DataFrame
DataFrame containing true parameter values
model : str
Name of model to get parameters for from defaults
Returns
-------
pd.DataFrame
DataFrame containing processed posterior means, HDIs and true parameters
ready for plotting
"""
# Get Posterior Means VI
params_df_mean = pd.DataFrame(
{
param: idata.posterior[param].mean(dim=["chain", "draw"]).values
for param in hssm.defaults.default_model_config[model]["list_params"]
}
)
params_df_mean.columns = [f"{param}_mean" for param in params_df_mean.columns]
# Get Posterior HDIs VI
params_df_hdi = (
az.hdi(
idata.posterior[hssm.defaults.default_model_config[model]["list_params"]],
hdi_prob=0.95,
)
.to_dataframe()
.reset_index()
.pivot(
index="__obs__",
columns="hdi",
values=hssm.defaults.default_model_config[model]["list_params"],
)
)
# Get rid of multiindex
params_df_hdi.columns = ["_".join(col + ("hdi",)) for col in params_df_hdi.columns]
params_df_hdi = params_df_hdi.reset_index(drop=True)
# Combine data
data_processed = pd.concat([simDataDDM, params_df_mean, params_df_hdi], axis=1)
# Make plotting ready
plot_df = pd.concat(
[
data_processed.drop(columns=["rt", "response"])
.drop_duplicates()
.reset_index(drop=True),
parameter_matrix.drop(columns=["participant_id"]),
],
axis=1,
)
return plot_df
Simulate Dataset¶
We simulate a dataset of 20 participants, each with 120 trials.
# Group level parameters
v_mu = 0.2
v_sigma = 0.3
a_mu = 1
a_sigma = 0.2
z_mu = 0.5
z_sigma = 0.01
t_mu = 0.4
t_sigma = 0.00
# Make hierarchical dataset
parameter_matrix = pd.DataFrame(
{
"participant_id": [f"subj_{str(i).zfill(2)}" for i in range(1, 21)],
"v_true": np.sort(np.random.normal(loc=v_mu, scale=v_sigma, size=20)).round(1),
"a_true": np.sort(np.random.normal(loc=a_mu, scale=a_sigma, size=20)).round(1),
"z_true": np.sort(np.random.normal(loc=z_mu, scale=z_sigma, size=20)).round(1),
"t_true": np.sort(np.random.normal(loc=t_mu, scale=t_sigma, size=20)).round(1),
"nTrials": [120] * 20,
}
)
dfs = []
for _, row in parameter_matrix.iterrows():
df = hssm.simulate_data(
model="ddm",
theta=dict(v=row["v_true"], a=row["a_true"], z=row["z_true"], t=row["t_true"]),
size=row["nTrials"],
)
df["participant_id"] = row["participant_id"]
dfs.append(df)
simDataDDM = pd.concat(dfs, ignore_index=True)
Hierarchical Model: Centered Parameterization¶
We first instantiate the generative model using HSSM syntax. We will fit this model to the synthetic data using two approaches 1. Variational inference 2. MCMC
# Generative model
mSimCentered = hssm.HSSM(
data=simDataDDM,
p_outlier=0.01,
prior_settings="safe",
noncentered=False,
model="ddm",
loglik_kind="approx_differentiable",
include=[
{
"name": "v",
"formula": "v ~ 0 +(1|participant_id)",
"prior": {
"1|id": {
"name": "Normal",
"mu": {
"name": "Normal",
"mu": 1.3,
"sigma": 0.3,
},
"sigma": {"name": "HalfNormal", "sigma": 0.2},
},
},
},
{
"name": "a",
"formula": "a ~ 0 + (1|participant_id)",
"prior": {
"1|id": {
"name": "Normal",
"mu": {"name": "Gamma", "mu": 1, "sigma": 0.2},
"sigma": {"name": "HalfNormal", "sigma": 0.2},
},
},
},
{
"name": "z",
"formula": "z ~ 0 + (1|participant_id)",
"prior": {
"1|id": {
"name": "Normal",
"mu": {"name": "Beta", "alpha": 10, "beta": 10},
"sigma": {"name": "HalfNormal", "sigma": 0.01},
},
},
},
{
"name": "t",
"formula": "t ~ 0 + (1|participant_id)",
"prior": {
"1|id": {
"name": "Normal",
"mu": {
"name": "Normal",
"mu": 0.4,
"sigma": 0.1,
},
"sigma": {"name": "HalfNormal", "sigma": 0.01},
},
},
},
],
)
No common intercept. Bounds for parameter v is not applied due to a current limitation of Bambi. This will change in the future. No common intercept. Bounds for parameter a is not applied due to a current limitation of Bambi. This will change in the future. No common intercept. Bounds for parameter z is not applied due to a current limitation of Bambi. This will change in the future. No common intercept. Bounds for parameter t is not applied due to a current limitation of Bambi. This will change in the future. Model initialized successfully.
mSimCentered.graph()
Fit VI¶
Approach 1: Using Variational Inference (VI) to estimate posteriors. VI is deterministic, it treats inference as an optimization problem. While typically faster then MCMC, it can underestimate true posterior variance and may not correctly consider the tradeoffs that may exist between parameters. Be sure to always use method "fullrank_advi".
# VI
# obj_optimizer=pm.adamax(learning_rate=0.1)
vi_idata = mSimCentered.vi(
niter=20000,
method="advi", # mention full_rank_advi
obj_optimizer=pm.adamax(learning_rate=0.01),
)
mSimCenteredVIObject = mSimCentered.vi_approx.sample(draws=1000)
Using MCMC starting point defaults.
Output()
Finished [100%]: Average Loss = 3,382.7
Before looking at the posteriors we must ensure that the model successfully "converged".
# Loss plot
plt.plot(mSimCentered.vi_approx.hist)
plt.xlabel("Iteration")
plt.ylabel("Loss")
plt.title("VI iteration loss")
plt.show()
Once we are satisfied the loss is acceptable, we can take a look at the mean of the group parameters. These should match our generative parameters above.
summary_vi = az.summary(
mSimCenteredVIObject.posterior,
var_names=[r".*_mu", r".*_sigma"],
filter_vars="regex",
).sort_index()
summary_vi
arviz - WARNING - Shape validation failed: input_shape: (1, 1000), minimum_shape: (chains=2, draws=4)
mean | sd | hdi_3% | hdi_97% | mcse_mean | mcse_sd | ess_bulk | ess_tail | r_hat | |
---|---|---|---|---|---|---|---|---|---|
a_1|participant_id_mu | 0.932 | 0.049 | 0.850 | 1.028 | 0.002 | 0.001 | 913.0 | 1023.0 | NaN |
a_1|participant_id_sigma | 0.236 | 0.037 | 0.168 | 0.306 | 0.001 | 0.001 | 988.0 | 816.0 | NaN |
t_1|participant_id_mu | 0.415 | 0.004 | 0.408 | 0.423 | 0.000 | 0.000 | 784.0 | 625.0 | NaN |
t_1|participant_id_sigma | 0.018 | 0.003 | 0.014 | 0.025 | 0.000 | 0.000 | 1096.0 | 936.0 | NaN |
v_1|participant_id_mu | 0.279 | 0.058 | 0.161 | 0.376 | 0.002 | 0.001 | 866.0 | 1026.0 | NaN |
v_1|participant_id_sigma | 0.263 | 0.040 | 0.189 | 0.330 | 0.001 | 0.001 | 807.0 | 905.0 | NaN |
z_1|participant_id_mu | 0.489 | 0.004 | 0.481 | 0.496 | 0.000 | 0.000 | 1018.0 | 942.0 | NaN |
z_1|participant_id_sigma | 0.021 | 0.003 | 0.014 | 0.027 | 0.000 | 0.000 | 802.0 | 972.0 | NaN |
Fit MCMC¶
Approach 2: Using MCMC to estimate posteriors. MCMC is typically slower than VI, but can more accurately capture the true posterior variance in some cases.
# MCMC
mSimCenteredSampled = mSimCentered.sample(
sampler="nuts_numpyro", cores=4, chains=4, draws=250, tune=750, mp_ctx="forkserver"
)
mSimCentered.sample_posterior_predictive(idata=mSimCenteredSampled)
Using default initvals.
sample: 100%|██████████| 1000/1000 [04:05<00:00, 4.08it/s, 63 steps of size 5.60e-02. acc. prob=0.97] sample: 100%|██████████| 1000/1000 [04:42<00:00, 3.55it/s, 31 steps of size 1.23e-01. acc. prob=0.85] sample: 100%|██████████| 1000/1000 [02:28<00:00, 6.74it/s, 31 steps of size 1.47e-01. acc. prob=0.88] sample: 100%|██████████| 1000/1000 [03:17<00:00, 5.07it/s, 31 steps of size 1.42e-01. acc. prob=0.88] There were 9 divergences after tuning. Increase `target_accept` or reparameterize. The rhat statistic is larger than 1.01 for some parameters. This indicates problems during sampling. See https://arxiv.org/abs/1903.08008 for details The effective sample size per chain is smaller than 100 for some parameters. A higher number is needed for reliable rhat and ess computation. See https://arxiv.org/abs/1903.08008 for details 100%|██████████| 1000/1000 [00:02<00:00, 496.57it/s]
Let's make sure MCMC converged and can successfully recapitulate the raw data
At an individual level, our model can capture the real data relatively well. How do the summary statistics of the posteriors compare to when we fit using VI?
summary_mcmc = az.summary(
mSimCenteredSampled, var_names=[r".*_mu", r".*_sigma"], filter_vars="regex"
).sort_index()
summary_mcmc
mean | sd | hdi_3% | hdi_97% | mcse_mean | mcse_sd | ess_bulk | ess_tail | r_hat | |
---|---|---|---|---|---|---|---|---|---|
a_1|participant_id_mu | 0.919 | 0.055 | 0.820 | 1.023 | 0.002 | 0.001 | 1081.0 | 836.0 | 1.01 |
a_1|participant_id_sigma | 0.240 | 0.048 | 0.163 | 0.330 | 0.002 | 0.001 | 826.0 | 560.0 | 1.01 |
t_1|participant_id_mu | 0.414 | 0.006 | 0.403 | 0.425 | 0.000 | 0.000 | 400.0 | 569.0 | 1.01 |
t_1|participant_id_sigma | 0.012 | 0.007 | 0.002 | 0.025 | 0.001 | 0.001 | 48.0 | 204.0 | 1.09 |
v_1|participant_id_mu | 0.280 | 0.068 | 0.142 | 0.398 | 0.002 | 0.002 | 927.0 | 722.0 | 1.00 |
v_1|participant_id_sigma | 0.272 | 0.056 | 0.179 | 0.379 | 0.002 | 0.001 | 884.0 | 778.0 | 1.00 |
z_1|participant_id_mu | 0.489 | 0.009 | 0.472 | 0.505 | 0.001 | 0.000 | 289.0 | 329.0 | 1.01 |
z_1|participant_id_sigma | 0.017 | 0.008 | 0.004 | 0.032 | 0.001 | 0.001 | 92.0 | 219.0 | 1.01 |
Process results for plotting¶
# Add trialwise parameters to idata
mSimCenteredSampled = mSimCentered.add_likelihood_parameters_to_idata(
mSimCenteredSampled
)
plot_df_mcmc = process_idata_for_plotting(
idata=mSimCenteredSampled, parameter_matrix=parameter_matrix, model="ddm"
)
plot_df_vi = process_idata_for_plotting(
idata=mSimCenteredVIObject, parameter_matrix=parameter_matrix, model="ddm"
)
Plotting¶
hssm.plotting.plot_posterior_predictive(mSimCentered, col="participant_id", col_wrap=5)
<seaborn.axisgrid.FacetGrid at 0x2f9ef7d90>
# Suppose your dataframe is called `df`
params = hssm.defaults.default_model_config["ddm"]["list_params"]
fig, axes = plt.subplots(nrows=1, ncols=len(params), figsize=(14, 6), sharey=True)
for ax, par in zip(axes, params):
# Identify the relevant columns for this parameter
mean_col = f"{par}_mean"
lower_col = f"{par}_lower_hdi"
upper_col = f"{par}_higher_hdi"
true_col = f"{par}_true" # if you also want to plot 'true' values
# Sort if you want the participants in order on the y‐axis
df_sorted_vi = plot_df_vi.sort_values("participant_id")
df_sorted_mcmc = plot_df_mcmc.sort_values("participant_id")
# We'll use the row index (0..N-1) for plotting against x=mean
yvals = np.array(range(len(df_sorted_vi)))
yvals_vi = yvals + 0.1
yvals_mcmc = yvals - 0.1
# Draw horizontal lines from lower_HDI to upper_HDI
ax.hlines(
yvals_vi,
df_sorted_vi[lower_col],
df_sorted_vi[upper_col],
color="blue",
alpha=0.5,
)
ax.hlines(
yvals_mcmc,
df_sorted_mcmc[lower_col],
df_sorted_mcmc[upper_col],
color="green",
alpha=0.5,
)
# Mark the posterior mean in blue
ax.plot(
df_sorted_vi[mean_col],
yvals_vi,
"|",
color="blue",
label="Mean_vi" if par == "v" else None,
)
ax.plot(
df_sorted_mcmc[mean_col],
yvals_mcmc,
"|",
color="green",
label="Mean_mcmc" if par == "v" else None,
)
# Optionally, if you have true values, plot them as red 'x'
if true_col in df_sorted_vi.columns:
ax.plot(
df_sorted_vi[true_col],
yvals,
"x",
color="red",
label="True_vi" if par == "v" else None,
)
ax.set_title(par)
ax.set_yticks(yvals)
ax.set_yticklabels(df_sorted_vi["participant_id"])
ax.invert_yaxis() # optional, if you prefer subject_01 at the top
if par == "v":
ax.legend()
plt.tight_layout()
plt.show()
Here, results match up quite nicely between the two approaches. Overall the posteriors are very similar, and we don't see a strong tendency for VI posteriors to be more peaked than MCMC posteriors. However this is just an example, and we can't deduce a general rule from this observation.
Hierarchical Model: Non-Centered Parameterization¶
To highlight that once we have the trial wise parameters, we can easily generate subject (and/or condition) wise plots of the posteriors, we also show an example using the non-centered parameterization.
The non-centered parameterization is slightly more complex, so manually recomposing parameters can sometimes be a little bit more confusing.
We can let HSSM handle this via Bambi under the hood, and not worry about this complexity.
# Generative model
mSimNonCentered = hssm.HSSM(
data=simDataDDM,
p_outlier=0.01,
prior_settings="safe",
noncentered=True,
model="ddm",
loglik_kind="approx_differentiable",
include=[
{
"name": "v",
"formula": "v ~ 1 + (1|participant_id)",
},
{
"name": "a",
"formula": "a ~ 1 + (1|participant_id)",
},
{
"name": "z",
"formula": "z ~ 1 + (1|participant_id)",
},
{
"name": "t",
"formula": "t ~ 1 + (1|participant_id)",
},
],
)
Model initialized successfully.
Fit VI¶
# VI
# obj_optimizer=pm.adamax(learning_rate=0.1)
vi_idata_NC = mSimNonCentered.vi(
niter=30000,
method="advi", # mention full_rank_advi
obj_optimizer=pm.adamax(learning_rate=0.01),
)
mSimNonCenteredVIObject = mSimNonCentered.vi_approx.sample(draws=1000)
Using MCMC starting point defaults.
Output()
Finished [100%]: Average Loss = 3,374.1
# Loss plot
plt.plot(mSimNonCentered.vi_approx.hist)
plt.xlabel("Iteration")
plt.ylabel("Loss")
plt.title("VI iteration loss")
plt.show()
Fit MCMC¶
# MCMC
mSimNonCenteredSampled = mSimNonCentered.sample(
sampler="nuts_numpyro",
cores=4,
chains=4,
draws=250,
tune=750,
mp_ctx="forkserver",
nuts_kwargs={"max_tree_depth": 5},
)
mSimNonCentered.sample_posterior_predictive(idata=mSimNonCenteredSampled)
Using default initvals.
sample: 100%|██████████| 1000/1000 [03:15<00:00, 5.12it/s, 63 steps of size 9.30e-02. acc. prob=0.93] sample: 100%|██████████| 1000/1000 [03:17<00:00, 5.07it/s, 31 steps of size 9.89e-02. acc. prob=0.94] sample: 100%|██████████| 1000/1000 [03:12<00:00, 5.19it/s, 31 steps of size 9.46e-02. acc. prob=0.95] sample: 100%|██████████| 1000/1000 [03:30<00:00, 4.76it/s, 63 steps of size 8.93e-02. acc. prob=0.94] The rhat statistic is larger than 1.01 for some parameters. This indicates problems during sampling. See https://arxiv.org/abs/1903.08008 for details The effective sample size per chain is smaller than 100 for some parameters. A higher number is needed for reliable rhat and ess computation. See https://arxiv.org/abs/1903.08008 for details 100%|██████████| 1000/1000 [00:02<00:00, 455.30it/s]
Post-process VI / MCMC results¶
mSimNonCenteredSampled = mSimNonCentered.add_likelihood_parameters_to_idata(
mSimNonCenteredSampled
)
plot_df_vi_nc = process_idata_for_plotting(
idata=mSimNonCenteredVIObject, parameter_matrix=parameter_matrix, model="ddm"
)
plot_df_mcmc_nc = process_idata_for_plotting(
idata=mSimNonCenteredSampled, parameter_matrix=parameter_matrix, model="ddm"
)
Plotting¶
hssm.plotting.plot_posterior_predictive(
mSimNonCentered, col="participant_id", col_wrap=5
)
<seaborn.axisgrid.FacetGrid at 0x2dc4f3f10>
# Suppose your dataframe is called `df`
params = hssm.defaults.default_model_config["ddm"]["list_params"]
fig, axes = plt.subplots(nrows=1, ncols=len(params), figsize=(14, 6), sharey=True)
for ax, par in zip(axes, params):
# Identify the relevant columns for this parameter
mean_col = f"{par}_mean"
lower_col = f"{par}_lower_hdi"
upper_col = f"{par}_higher_hdi"
true_col = f"{par}_true" # if you also want to plot 'true' values
# Sort if you want the participants in order on the y‐axis
df_sorted_vi_nc = plot_df_vi_nc.sort_values("participant_id")
df_sorted_mcmc_nc = plot_df_mcmc_nc.sort_values("participant_id")
# We'll use the row index (0..N-1) for plotting against x=mean
yvals = np.array(range(len(df_sorted_vi_nc)))
yvals_vi = yvals + 0.1
yvals_mcmc = yvals - 0.1
# Draw horizontal lines from lower_HDI to upper_HDI
ax.hlines(
yvals_vi,
df_sorted_vi_nc[lower_col],
df_sorted_vi_nc[upper_col],
color="blue",
alpha=0.5,
)
ax.hlines(
yvals_mcmc,
df_sorted_mcmc_nc[lower_col],
df_sorted_mcmc_nc[upper_col],
color="green",
alpha=0.5,
)
# Mark the posterior mean in blue
ax.plot(
df_sorted_vi_nc[mean_col],
yvals,
"|",
color="blue",
label="Mean_vi" if par == "v" else None,
)
ax.plot(
df_sorted_mcmc_nc[mean_col],
yvals_mcmc,
"|",
color="green",
label="Mean_mcmc" if par == "v" else None,
)
# Optionally, if you have true values, plot them as red 'x'
if true_col in df_sorted_vi_nc.columns:
ax.plot(
df_sorted_vi_nc[true_col],
yvals,
"x",
color="red",
label="True_vi" if par == "v" else None,
)
ax.set_title(par)
ax.set_yticks(yvals)
ax.set_yticklabels(df_sorted_vi_nc["participant_id"])
ax.invert_yaxis() # optional, if you prefer subject_01 at the top
if par == "v":
ax.legend()
plt.tight_layout()
plt.show()
Overall posteriors look very similar between the two approaches again. However, we can see that the MCMC posteriors are more variable than the VI posteriors in this case. This illustrates a case where VI posteriors are too confident, and bayesian t-tests on parameter differences may yield different results when using the MCMC vs. the VI posteriors. From a practical perspective, it is sometimes simply infeasible to run MCMC, while VI is still computationally tractable. In such cases we have to take what we can get...