Tutorial for hierarchical Bayesian inference for Reinforcement Learning - Sequential Sampling Models.¶
This is a (preview) tutorial for using the HSSM Python package to simultaneously estimate reinforcement learning parameters and decision parameters within a fully hierarchical bayesian estimation framework, including steps for constructing HSSM-compatible likelihoods/distributions and sampling from the posterior. Further, the plots to assess the recovery of model parameters are also shown.
The module uses the reinforcement learning sequential sampling model (RLSSM), a reinforcement learning model that replaces the standard “softmax” choice function with a drift diffusion process with collapsing bounds (referred to as the 'angle' decision model hereon). The softmax and sequential sampling process is equivalent for capturing choice proportions, but the angle model also takes RT distributions into account. The RLSSM estimates trial-by-trial drift rate as a scaled difference in expected rewards (expected reward for upper bound alternative minus expected reward for lower bound alternative). Expected rewards are updated with a delta learning rule using a single learning rate. The model also includes the standard angle parameters such as the rate of collapse of the decision boundaries. The broader RLSSM framework is described in detail in Pedersen, Frank & Biele (2017) and Fengler, Bera, Pedersen & Frank (2022).
# Import necessary libraries
import numpy as np
import arviz as az
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
from functools import partial
from scipy.stats import spearmanr
# Import HSSM and simulator package
import hssm
from hssm.utils import decorate_atomic_simulator
from hssm.likelihoods.rldm import make_rldm_logp_op
from hssm.distribution_utils.dist import make_hssm_rv
from ssms.basic_simulators.simulator import simulator
# Set the style for the plots
plt.style.use('seaborn-v0_8-dark-palette')
Load and prepare the demo dataset¶
This data file contains (synthetic) data from a simulated 2-armed bandit task. We examine the dataset -- it contains the typical columns that are expected from a canonical instrumental learning task. participant_id
identifies the subject id, trial
identifies the sequence of trials within the subject data, response
and rt
are the data columns recorded for each trial, feedback
column shows the reward obtained on a given trial and correct
records whether the response was correct.
# Load synthetic RLSSM dataset containing both behavioral data and ground truth parameters
savefile = np.load("../../tests/fixtures/rldm_data.npy", allow_pickle=True).item()
dataset = savefile['data']
# Rename trial column to match HSSM conventions
dataset.rename(columns={'trial': 'trial_id'}, inplace=True)
# Examine the dataset structure
dataset.head()
participant_id | trial_id | response | rt | feedback | correct | |
---|---|---|---|---|---|---|
0 | 0 | 0 | 0.0 | 0.935602 | 0.126686 | 0.0 |
1 | 0 | 1 | 0.0 | 1.114379 | 0.173100 | 0.0 |
2 | 0 | 2 | 0.0 | 0.564311 | 0.444935 | 0.0 |
3 | 0 | 3 | 0.0 | 2.885860 | 0.307207 | 0.0 |
4 | 0 | 4 | 0.0 | 0.532113 | 0.177911 | 0.0 |
# Validate data structure and extract dataset configuration
dataset, n_participants, n_trials = hssm.check_data_for_rl(dataset)
print(f"Number of participants: {n_participants}")
print(f"Number of trials: {n_trials}")
Number of participants: 20 Number of trials: 200
Construct HSSM-compatible PyMC distribution from a simulator and JAX likelihood callable¶
We now construct a custom model that is compatible with HSSM and PyMC. Note that HSSM internally constructs a PyMC object (which is used for sampling) based on the user-specified HSSM model. In other words, we are peeling the abstration layers conveniently afforded by HSSM to directly use the core machinery of HSSM. This advanced HSSM tutorial explains how to use HSSM when starting from the very basics of a model -- a simulator and a JAX likelihood callable.
The simulator function is used for generating samples from the model (for posterior predictives, etc.) and the likelihood callable is employed for sampling/inference. This preview tutorial exposes the key flexibility of the HSSM for use in fitting RLSSM models. Therefore, the subsequent tutorial will focus only on the sampling/inference aspect. We create a dummy simulator function to bypass the need for defining the actual simulator.
Step 1: Define a pytensor RandomVariable¶
# Define parameters for the RLSSM model (RL + decision model parameters)
list_params = ['rl.alpha', 'scaler', 'a', 'z', 't', 'theta']
# Create a dummy simulator for generating synthetic data (used for posterior predictives)
# This bypasses the need for a full RLSSM simulator implementation
def create_dummy_simulator():
"""Create a dummy simulator function for RLSSM model."""
def sim_wrapper(simulator_fun, theta, model, n_samples, random_state, **kwargs):
# Generate random RT and choice data as placeholders
sim_rt = np.random.uniform(0.2, 0.6, n_samples)
sim_ch = np.random.randint(0, 2, n_samples)
return np.column_stack([sim_rt, sim_ch])
# Wrap the simulator function with required metadata
wrapped_simulator = partial(sim_wrapper, simulator_fun=simulator, model="custom", n_samples=1)
# Decorate the simulator to make it compatible with HSSM
return decorate_atomic_simulator(model_name="custom", choices=[0, 1], obs_dim=2)(wrapped_simulator)
# Create the simulator and RandomVariable
decorated_simulator = create_dummy_simulator()
# Create a PyTensor RandomVariable using `make_hssm_rv` for use in the PyMC model
CustomRV = make_hssm_rv(
simulator_fun=decorated_simulator, list_params=list_params
)
Step 2: Define a likelihood function¶
# Create a Pytensor Op for the likelihood function.
# The `make_rldm_logp_op` function is a utility that wraps the base JAX likelihood function into a HSSM/PyMC-compatible callable.
logp_jax_op = make_rldm_logp_op(
n_participants=n_participants,
n_trials=n_trials,
n_params=6 # the number of parameters in the RLSSM model. Should match the length of `list_params`
)
# Test the likelihood function by evaluating it with dummy parameters and data extracted from the dataset.
def extract_data_columns(dataset):
"""Extract required data columns from dataset."""
return {
'participant_id': dataset["participant_id"].values,
'trial': dataset["trial_id"].values,
'response': dataset["response"].values,
'feedback': dataset["feedback"].values,
'rt': dataset["rt"].values
}
def create_test_parameters(n_trials):
"""Create dummy parameters for testing the likelihood function."""
return {
'rl_alpha': np.ones(n_trials) * 0.60,
'scaler': np.ones(n_trials) * 3.2,
'a': np.ones(n_trials) * 1.2,
'z': np.ones(n_trials) * 0.1,
't': np.ones(n_trials) * 0.1,
'theta': np.ones(n_trials) * 0.1
}
# Extract data and create test parameters
data_columns = extract_data_columns(dataset)
num_subj = len(np.unique(data_columns['participant_id']))
n_trials_total = num_subj * 200
test_params = create_test_parameters(n_trials_total)
# Evaluate the likelihood function
test_logp_out = logp_jax_op(
np.column_stack((data_columns['rt'], data_columns['response'])),
test_params['rl_alpha'],
test_params['scaler'],
test_params['a'],
test_params['z'],
test_params['t'],
test_params['theta'],
data_columns['participant_id'],
data_columns['trial'],
data_columns['feedback'],
)
LL = test_logp_out.eval()
print(f"Log likelihood: {np.sum(LL):.4f}")
Log likelihood: -6879.1526
Step 3: Define a model config and HSSM model¶
# Step 3: Define the model config
# Configure the HSSM model
model_config = hssm.ModelConfig(
response=["rt", "response"], # Dependent variables (RT and choice)
list_params= # List of model parameters
['rl.alpha', 'scaler', 'a', 'z', 't', 'theta'],
choices=[0, 1], # Possible choice options
default_priors={}, # Use custom priors (defined below)
bounds=dict( # Parameter bounds for optimization
rl_alpha=(0.01, 1), # Learning rate bounds
scaler=(1, 4), # Scaler bounds
a=(0.3, 2.5), # Boundary separation bounds
z=(0.1, 0.9), # Bias bounds
t=(0.1, 2.0), # Non-decision time bounds
theta=(0.0, 1.2) # Collapse rate bounds
),
rv=CustomRV, # Custom RandomVariable that we created earlier
extra_fields=[ # Additional data columns to be passed to the likelihood function as extra_fields
"participant_id",
"trial_id",
"feedback"],
backend="jax" # Use JAX for computation
)
# Create a hierarchical HSSM model with custom likelihood function
hssm_model = hssm.HSSM(
data=dataset, # Input dataset
model_config=model_config, # Model configuration
p_outlier=0, # No outlier modeling
lapse=None, # No lapse rate modeling
loglik=logp_jax_op, # Custom RLDM likelihood function
loglik_kind="approx_differentiable", # Use approximate gradients
noncentered=True, # Use non-centered parameterization
process_initvals=False, # Skip initial value processing in HSSM
include=[
# Define hierarchical priors: group-level intercepts + subject-level random effects
hssm.Param("rl.alpha",
formula="rl_alpha ~ 1 + (1|participant_id)",
prior={"Intercept": hssm.Prior("TruncatedNormal", lower=0.01, upper=1, mu=0.3)}),
hssm.Param("scaler",
formula="scaler ~ 1 + (1|participant_id)",
prior={"Intercept": hssm.Prior("TruncatedNormal", lower=1, upper=4, mu=1.5)}),
hssm.Param("a",
formula="a ~ 1 + (1|participant_id)",
prior={"Intercept": hssm.Prior("TruncatedNormal", lower=0.3, upper=2.5, mu=1.0)}),
hssm.Param("z",
formula="z ~ 1 + (1|participant_id)",
prior={"Intercept": hssm.Prior("TruncatedNormal", lower=0.1, upper=0.9, mu=0.2)}),
hssm.Param("t",
formula="t ~ 1 + (1|participant_id)",
prior={"Intercept": hssm.Prior("TruncatedNormal", lower=0.01, upper=2, mu=0.2, initval=0.1)}),
hssm.Param("theta",
formula="theta ~ 1 + (1|participant_id)",
prior={"Intercept": hssm.Prior("TruncatedNormal", lower=0.00, upper=1.2, mu=0.3)}),
]
)
No common intercept. Bounds for parameter rl.alpha is not applied due to a current limitation of Bambi. This will change in the future. No common intercept. Bounds for parameter scaler is not applied due to a current limitation of Bambi. This will change in the future. No common intercept. Bounds for parameter a is not applied due to a current limitation of Bambi. This will change in the future. No common intercept. Bounds for parameter z is not applied due to a current limitation of Bambi. This will change in the future. No common intercept. Bounds for parameter t is not applied due to a current limitation of Bambi. This will change in the future. No common intercept. Bounds for parameter theta is not applied due to a current limitation of Bambi. This will change in the future. Model initialized successfully.
Sample using NUTS MCMC¶
# Run MCMC sampling using NUTS sampler with JAX backend
# Note: Using small number of samples for demonstration (increase for real analysis)
idata_mcmc = hssm_model.sample(
sampler='nuts_numpyro', # JAX-based NUTS sampler for efficiency
chains=1, # Number of parallel chains
draws=1000, # Number of posterior samples
tune=1000, # Number of tuning/warmup samples
)
Using default initvals.
sample: 100%|██████████| 2000/2000 [08:03<00:00, 4.14it/s, 31 steps of size 1.50e-01. acc. prob=0.91] Only one chain was sampled, this makes it impossible to run some convergence checks /Users/krishnbera/Documents/revert_rldm/HSSM/.venv/lib/python3.12/site-packages/pymc/pytensorf.py:958: FutureWarning: compile_pymc was renamed to compile. Old name will be removed in a future release of PyMC warnings.warn( 100%|██████████| 1000/1000 [00:02<00:00, 374.33it/s]
idata_mcmc
-
<xarray.Dataset> Size: 2MB Dimensions: (chain: 1, draw: 1000, participant_id__factor_dim: 20, rl.alpha_1|participant_id__factor_dim: 20) Coordinates: * chain (chain) int64 8B 0 * draw (draw) int64 8kB 0 1 2 ... 998 999 * rl.alpha_1|participant_id__factor_dim (rl.alpha_1|participant_id__factor_dim) <U2 160B ... * participant_id__factor_dim (participant_id__factor_dim) <U2 160B ... Data variables: (12/30) a_1|participant_id (chain, draw, participant_id__factor_dim) float64 160kB ... rl.alpha_1|participant_id_offset (chain, draw, rl.alpha_1|participant_id__factor_dim) float64 160kB ... a_1|participant_id_offset (chain, draw, participant_id__factor_dim) float64 160kB ... rl.alpha_Intercept (chain, draw) float64 8kB 0.6899 .... t_1|participant_id (chain, draw, participant_id__factor_dim) float64 160kB ... t_1|participant_id_mu (chain, draw) float64 8kB 0.196 ..... ... ... z_1|participant_id_mu (chain, draw) float64 8kB 0.1362 .... scaler_1|participant_id (chain, draw, participant_id__factor_dim) float64 160kB ... theta_1|participant_id_mu (chain, draw) float64 8kB -0.173 .... a_1|participant_id_mu (chain, draw) float64 8kB -0.5176 ... t_1|participant_id_offset (chain, draw, participant_id__factor_dim) float64 160kB ... rl.alpha_1|participant_id (chain, draw, rl.alpha_1|participant_id__factor_dim) float64 160kB ... Attributes: created_at: 2025-07-20T18:47:18.671108+00:00 arviz_version: 0.22.0 inference_library: numpyro inference_library_version: 0.18.0 sampling_time: 487.226472 tuning_steps: 1000 modeling_interface: bambi modeling_interface_version: 0.15.0
-
<xarray.Dataset> Size: 32MB Dimensions: (chain: 1, draw: 1000, __obs__: 4000) Coordinates: * chain (chain) int64 8B 0 * draw (draw) int64 8kB 0 1 2 3 4 5 6 ... 993 994 995 996 997 998 999 * __obs__ (__obs__) int64 32kB 0 1 2 3 4 5 ... 3995 3996 3997 3998 3999 Data variables: rt,response (chain, draw, __obs__) float64 32MB -2.682 -3.113 ... -0.286 Attributes: modeling_interface: bambi modeling_interface_version: 0.15.0
-
<xarray.Dataset> Size: 57kB Dimensions: (chain: 1, draw: 1000) Coordinates: * chain (chain) int64 8B 0 * draw (draw) int64 8kB 0 1 2 3 4 5 6 ... 994 995 996 997 998 999 Data variables: acceptance_rate (chain, draw) float64 8kB 0.5476 0.876 ... 0.9808 0.9407 step_size (chain, draw) float64 8kB 0.1505 0.1505 ... 0.1505 0.1505 diverging (chain, draw) bool 1kB False False False ... False False energy (chain, draw) float64 8kB 4.088e+03 4.108e+03 ... 4.118e+03 n_steps (chain, draw) int64 8kB 31 31 31 31 31 ... 31 31 31 31 31 tree_depth (chain, draw) int64 8kB 5 5 5 5 5 5 5 5 ... 5 5 5 5 5 5 5 5 lp (chain, draw) float64 8kB 4.022e+03 4.055e+03 ... 4.051e+03 Attributes: created_at: 2025-07-20T18:47:18.683967+00:00 arviz_version: 0.22.0 modeling_interface: bambi modeling_interface_version: 0.15.0
-
<xarray.Dataset> Size: 96kB Dimensions: (__obs__: 4000, rt,response_extra_dim_0: 2) Coordinates: * __obs__ (__obs__) int64 32kB 0 1 2 3 ... 3997 3998 3999 * rt,response_extra_dim_0 (rt,response_extra_dim_0) int64 16B 0 1 Data variables: rt,response (__obs__, rt,response_extra_dim_0) float64 64kB ... Attributes: created_at: 2025-07-20T18:47:18.684574+00:00 arviz_version: 0.22.0 inference_library: numpyro inference_library_version: 0.18.0 sampling_time: 487.226472 tuning_steps: 1000 modeling_interface: bambi modeling_interface_version: 0.15.0
Assess the model fits¶
We examine the quality of fits by comparing the recovered parameters with the ground-truth data generating parameters of the simulated dataset. We examine the quality of fits both at group-level as well as subject-level.
Examining group-level posteriors¶
# Define parameter names for analysis
list_group_mean_params = [
"rl.alpha_Intercept",
"scaler_Intercept",
"a_Intercept",
"z_Intercept",
"t_Intercept",
"theta_Intercept",
]
list_group_sd_params = [
"rl.alpha_1|participant_id_sigma",
"scaler_1|participant_id_sigma",
"a_1|participant_id_sigma",
"z_1|participant_id_sigma",
"t_1|participant_id_sigma",
"theta_1|participant_id_sigma",
]
# Create mapping from HSSM model parameter names to ground truth values.
def create_ground_truth_mapping(savefile):
return {
"rl.alpha_Intercept": savefile['params_true_group']['rl_alpha_mean'],
"scaler_Intercept": savefile['params_true_group']['scaler_mean'],
"a_Intercept": savefile['params_true_group']['a_mean'],
"z_Intercept": savefile['params_true_group']['z_mean'],
"t_Intercept": savefile['params_true_group']['t_mean'],
"theta_Intercept": savefile['params_true_group']['theta_mean'],
}
ground_truth_params = create_ground_truth_mapping(savefile)
print("Ground truth group means:\n")
for param, value in ground_truth_params.items():
print(f"{param}: {value:.3f}")
Ground truth group means: rl.alpha_Intercept: 0.660 scaler_Intercept: 2.785 a_Intercept: 1.481 z_Intercept: 0.267 t_Intercept: 0.443 theta_Intercept: 0.312
# Plot posterior distributions and MCMC traces for group-level parameters
# Vertical lines show ground truth values for parameter recovery assessment
az.plot_trace(idata_mcmc, var_names=list_group_mean_params,
lines=[(key_, {}, ground_truth_params[key_]) for key_ in ground_truth_params])
plt.tight_layout()
# plot the posterior pair plots of the group-level parameters
# this will show the joint distributions and correlations between the group-level parameters.
az.plot_pair(idata_mcmc, var_names=list_group_mean_params, kind='kde', point_estimate='mean')
plt.tight_layout()
Examining participant-level posteriors¶
# Extract ground truth subject-level parameters from the synthetic dataset
# Reshape from dictionary format to matrix (subjects x parameters)
sim_param_list = np.zeros((20, 6))
for ind, p in enumerate(savefile['params_true_subj'].keys()):
sim_param_list[:, ind] = savefile['params_true_subj'][p]
# Function to extract subject-level parameters from inference data.
def extract_subject_parameters(idata, param_names):
n_subjects = idata.posterior[f'{param_names[0]}_1|participant_id'].shape[-1]
n_params = len(param_names)
subject_params = np.zeros((n_subjects, n_params))
for i, param in enumerate(param_names):
intercept = np.mean(idata.posterior[f'{param}_Intercept'].values[0])
random_effects = np.mean(idata.posterior[f'{param}_1|participant_id'].values[0], axis=0)
subject_params[:, i] = intercept + random_effects
return subject_params
# Extract recovered parameters
recov_param_list = extract_subject_parameters(idata_mcmc, model_config.list_params)
plot_param_ranges = [[0.5, 0.9], [2.2, 3.5], [1.2, 2], [0, 0.4], [0.3, 0.7], [0.2, 0.5]]
plot_param_names = ["rl.alpha", "scaler", "a", "z", "t", "theta"]
# Function to create parameter recovery plots comparing simulated vs recovered values
def plot_parameter_recovery(sim_params, recov_params, param_names, axes_limits,
additional_data=None, additional_label=None, show_correlation=False):
fig, axes = plt.subplots(2, 3, figsize=(12, 8))
axes = axes.flatten()
for i, ax in enumerate(axes):
if i >= len(param_names):
ax.set_visible(False)
continue
x, y = sim_params[:, i], recov_params[:, i]
# Scatter plot showing parameter recovery
ax.scatter(x, y, alpha=0.6, label='MCMC' if additional_data is None else 'MCMC')
# Additional data if provided
if additional_data is not None:
z = additional_data[:, i]
ax.scatter(x, z, alpha=0.6, marker='x', label=additional_label or 'Additional')
ax.legend(loc='lower right')
# Calculate and display correlation between true and recovered parameters
if show_correlation:
spearman_r, _ = spearmanr(x, y)
ax.text(0.05, 0.93, f"R: {spearman_r:.2f}", transform=ax.transAxes,
fontsize=12, verticalalignment='bottom')
# Formatting subplot
ax.set_title(param_names[i], fontsize=16)
ax.set_xlim(axes_limits[i])
ax.set_ylim(axes_limits[i])
ax.xaxis.set_major_locator(ticker.MaxNLocator(nbins=5))
ax.yaxis.set_major_locator(ticker.MaxNLocator(nbins=5))
ax.grid(True, linestyle='--', alpha=0.6)
ax.axline((0, 0), linestyle='--', slope=1, c='k', alpha=0.8) # Perfect recovery line
# Add axis labels
fig.text(0.5, 0.02, 'Simulated', ha='center', fontsize=20)
fig.text(0.02, 0.5, 'Recovered', va='center', rotation='vertical', fontsize=20)
plt.tight_layout(rect=[0.05, 0.05, 1, 1])
return fig
# Plot parameter recovery
plot_parameter_recovery(sim_param_list, recov_param_list, plot_param_names, plot_param_ranges, show_correlation=True)
plt.show()
Estimating the posterior using Variation Inference (VI)¶
# Run variational inference (VI) as a faster alternative to MCMC
# VI approximates the posterior with a simpler distribution family
idata_vi = hssm_model.vi(
niter=30000, # Number of optimization iterations
draws=1000, # Number of samples from approximate posterior
method="advi" # Automatic Differentiation Variational Inference
)
Using MCMC starting point defaults.
Output()
Finished [100%]: Average Loss = 3,948
We now examine the VI loss over iterations. In general, this looks good with the caveat that there are oscillations during the initial iterations. In principle, this could arise from the model geometry, priors or simply because of an aggressive learning rate. We recommend users experiment with different settings to figure out what works best for their model.
plt.plot(hssm_model.vi_approx.hist)
plt.xlabel("Iteration")
plt.ylabel("Loss")
Text(0, 0.5, 'Loss')
idata_vi
-
<xarray.Dataset> Size: 2MB Dimensions: (chain: 1, draw: 1000, participant_id__factor_dim: 20, rl.alpha_1|participant_id__factor_dim: 20) Coordinates: * chain (chain) int64 8B 0 * draw (draw) int64 8kB 0 1 2 ... 998 999 * rl.alpha_1|participant_id__factor_dim (rl.alpha_1|participant_id__factor_dim) <U2 160B ... * participant_id__factor_dim (participant_id__factor_dim) <U2 160B ... Data variables: (12/30) a_1|participant_id (chain, draw, participant_id__factor_dim) float64 160kB ... t_1|participant_id (chain, draw, participant_id__factor_dim) float64 160kB ... t_1|participant_id_mu (chain, draw) float64 8kB -0.4449 ... scaler_Intercept (chain, draw) float64 8kB 2.467 ..... scaler_1|participant_id_offset (chain, draw, participant_id__factor_dim) float64 160kB ... z_1|participant_id (chain, draw, participant_id__factor_dim) float64 160kB ... ... ... t_Intercept (chain, draw) float64 8kB 0.2921 .... theta_1|participant_id_sigma (chain, draw) float64 8kB 0.02817 ... z_1|participant_id_mu (chain, draw) float64 8kB -0.2323 ... theta_1|participant_id_mu (chain, draw) float64 8kB -0.5682 ... a_1|participant_id_mu (chain, draw) float64 8kB -0.4297 ... t_1|participant_id_offset (chain, draw, participant_id__factor_dim) float64 160kB ... Attributes: created_at: 2025-07-20T18:52:02.063675+00:00 arviz_version: 0.22.0 inference_library: pymc inference_library_version: 5.23.0
-
<xarray.Dataset> Size: 96kB Dimensions: (__obs__: 4000, rt,response_extra_dim_0: 2) Coordinates: * __obs__ (__obs__) int64 32kB 0 1 2 3 ... 3997 3998 3999 * rt,response_extra_dim_0 (rt,response_extra_dim_0) int64 16B 0 1 Data variables: rt,response (__obs__, rt,response_extra_dim_0) float64 64kB ... Attributes: created_at: 2025-07-20T18:52:02.072293+00:00 arviz_version: 0.22.0 inference_library: pymc inference_library_version: 5.23.0
# Extract VI recovered parameters
recov_param_list_vi = extract_subject_parameters(idata_vi, model_config.list_params)
# Plot comparison between MCMC and VI recovery
plot_parameter_recovery(
sim_param_list,
recov_param_list,
plot_param_names,
plot_param_ranges,
additional_data=recov_param_list_vi,
additional_label='VI',
show_correlation=False
)
plt.show()
Compare MCMC and VI parameter recoveries¶
import arviz as az
import matplotlib.pyplot as plt
# Define constant for group mean parameters
list_group_mean_params = [
"rl.alpha_Intercept",
"scaler_Intercept",
"a_Intercept",
"z_Intercept",
"t_Intercept",
"theta_Intercept",
]
def plot_posterior_comparison(idata_list, labels, var_names, title_suffix=""):
"""Plot posterior density comparisons between different inference methods."""
axes = az.plot_density(
idata_list,
data_labels=labels,
var_names=var_names,
hdi_prob=1.0,
shade=0.2,
)
fig = axes.flatten()[0].get_figure()
if title_suffix:
fig.suptitle(f"Posterior Comparison - {title_suffix}", y=1.02)
plt.tight_layout()
return fig
# Compare group-level means
plot_posterior_comparison(
[idata_mcmc, idata_vi],
["NUTS", "VI"],
list_group_mean_params,
"Group Means"
)
plt.show()
# Compare group-level standard deviations
plot_posterior_comparison(
[idata_mcmc, idata_vi],
["NUTS", "VI"],
list_group_sd_params,
"Group Standard Deviations"
)
plt.show()
While the VI and MCMC posteriors largely agree with each other, there are major discrepancies in parameters such as 't'. This is partly because of the fact that we are using mean field approximation for VI which relaxes the structure of covariance present in the model (eg. we also see that the VI posteriors are more peaked). While this tutorial is meant to serve as a proof-of-concept for RLSSM paradigm, we encourage users to leverage the HSSM functionality by tweaking and fine-tuning various settings including the choice of Bayesian inference method (say, VI vs MCMC) for robust results.