import matplotlib.pyplot as plt
import numpy as np
import hssm
import arviz as az
import pymc as pm
Load Data and Specify model¶
cav_data = hssm.load_data("cavanagh_theta")
cav_model = hssm.HSSM(data=cav_data,
model="angle")
Model initialized successfully.
Inference¶
We will run MCMC and VI here to contrast results.
Run MCMC¶
# Slight adjustment to initial values
initvals_tmp = cav_model.initvals
initvals_tmp['theta'] = 0.1
mcmc_idata = cav_model.sample(chains = 2,
tune = 500,
draws = 500,
sampler = "nuts_numpyro",
initvals = initvals_tmp
)
Using MCMC starting point defaults.
Output()
Finished [100%]: Average Loss = 6,035.9
az.summary(mcmc_idata)
Run VI¶
vi_idata = cav_model.vi(niter=100000,
method="fullrank_advi")
/Users/afengler/Library/CloudStorage/OneDrive-Personal/proj_hssm/HSSM/.venv/lib/python3.12/site-packages/pymc/sampling/jax.py:475: UserWarning: There are not enough devices to run parallel chains: expected 2 but got 1. Chains will be drawn sequentially. If you are running MCMC in CPU, consider using `numpyro.set_host_device_count(2)` at the beginning of your program. You can double-check how many devices are available in your system using `jax.local_device_count()`. pmap_numpyro = MCMC( sample: 100%|██████████| 1000/1000 [01:07<00:00, 14.89it/s, 15 steps of size 2.43e-01. acc. prob=0.95] sample: 100%|██████████| 1000/1000 [02:53<00:00, 5.76it/s, 31 steps of size 2.60e-01. acc. prob=0.95] We recommend running at least 4 chains for robust computation of convergence diagnostics /Users/afengler/Library/CloudStorage/OneDrive-Personal/proj_hssm/HSSM/.venv/lib/python3.12/site-packages/pymc/pytensorf.py:958: FutureWarning: compile_pymc was renamed to compile. Old name will be removed in a future release of PyMC warnings.warn( 100%|██████████| 1000/1000 [00:27<00:00, 35.97it/s]
Inspect Outputs¶
From our variational inference runs, we extract two objects.
- An
az.InferenceData
object stored undercav_model.vi_idata
. This stores a slightly cleaned up posterior sample, constructed by sampling from the variational posterior. - An
pm.Approximator
object stored undercav_model.vi_approx
that holds the variational posterior object itself. This is a rich structure and it is beyond the purpose of this tutorial to illustrate all it's details. Amongst other things you will be able to inspect the loss history and take samples such as those stored undercam_model.vi_data
.
.vi_idata
¶
The approximate variational posterior InferenceData
.
cav_model.vi_idata
-
<xarray.Dataset> Size: 48kB Dimensions: (chain: 1, draw: 1000) Coordinates: * chain (chain) int64 8B 0 * draw (draw) int64 8kB 0 1 2 3 4 5 6 7 ... 993 994 995 996 997 998 999 Data variables: z (chain, draw) float64 8kB 0.5006 0.5039 0.4967 ... 0.5003 0.5091 theta (chain, draw) float64 8kB 0.2192 0.1891 0.2766 ... 0.2418 0.2402 a (chain, draw) float64 8kB 1.297 1.244 1.369 ... 1.332 1.329 1.298 t (chain, draw) float64 8kB 0.2789 0.3001 0.2639 ... 0.2763 0.2875 v (chain, draw) float64 8kB 0.3792 0.3556 0.3746 ... 0.3839 0.3445 Attributes: created_at: 2025-07-12T18:53:48.091598+00:00 arviz_version: 0.21.0 inference_library: pymc inference_library_version: 5.21.1
-
<xarray.Dataset> Size: 96kB Dimensions: (__obs__: 3988, rt,response_extra_dim_0: 2) Coordinates: * __obs__ (__obs__) int64 32kB 0 1 2 3 ... 3985 3986 3987 * rt,response_extra_dim_0 (rt,response_extra_dim_0) int64 16B 0 1 Data variables: rt,response (__obs__, rt,response_extra_dim_0) float64 64kB ... Attributes: created_at: 2025-07-12T18:53:48.113860+00:00 arviz_version: 0.21.0 inference_library: pymc inference_library_version: 5.21.1
.vi_approx
¶
The approximate variational posterior pm.Approximator
object.
We can take draws from the posterior with the .sample()
method.
cav_model.vi_approx.sample(draws=1000)
-
<xarray.Dataset> Size: 32MB Dimensions: (chain: 1, draw: 1000, __obs__: 3988) Coordinates: * chain (chain) int64 8B 0 * draw (draw) int64 8kB 0 1 2 3 4 5 6 7 ... 993 994 995 996 997 998 999 * __obs__ (__obs__) int64 32kB 0 1 2 3 4 5 ... 3982 3983 3984 3985 3986 3987 Data variables: z (chain, draw) float64 8kB 0.4993 0.5066 0.5059 ... 0.5043 0.5122 theta (chain, draw) float64 8kB 0.2276 0.2191 0.2268 ... 0.2458 0.225 a (chain, draw) float64 8kB 1.303 1.311 1.323 ... 1.289 1.327 1.302 t (chain, draw) float64 8kB 0.2861 0.2744 0.2773 ... 0.2742 0.2914 v (chain, draw) float64 8kB 0.4355 0.3922 0.3639 ... 0.3926 0.3537 v_mean (chain, draw, __obs__) float64 32MB 0.4355 0.4355 ... 0.3537 0.3537 Attributes: created_at: 2025-07-12T18:58:22.117059+00:00 arviz_version: 0.21.0 inference_library: pymc inference_library_version: 5.21.1
-
<xarray.Dataset> Size: 96kB Dimensions: (__obs__: 3988, rt,response_extra_dim_0: 2) Coordinates: * __obs__ (__obs__) int64 32kB 0 1 2 3 ... 3985 3986 3987 * rt,response_extra_dim_0 (rt,response_extra_dim_0) int64 16B 0 1 Data variables: rt,response (__obs__, rt,response_extra_dim_0) float64 64kB ... Attributes: created_at: 2025-07-12T18:58:22.128137+00:00 arviz_version: 0.21.0 inference_library: pymc inference_library_version: 5.21.1
The .hist
attribute stores the loss history. We can plot this to see how the loss function converged.
plt.plot(cav_model.vi_approx.hist)
plt.xlabel("Iteration")
plt.ylabel("Loss")
Text(0, 0.5, 'Loss')
Contrast outputs between MCMC and VI¶
__, axes = plt.subplots(4, 4, figsize=(10, 5))
az.plot_pair(cav_model.traces, ax=axes, scatter_kwargs=dict(alpha=0.01, color="blue"))
az.plot_pair(cav_model.vi_idata, ax=axes, scatter_kwargs=dict(alpha=0.04, color="red"))
array([[<Axes: ylabel='theta'>, <Axes: >, <Axes: >, <Axes: >], [<Axes: ylabel='a'>, <Axes: >, <Axes: >, <Axes: >], [<Axes: ylabel='t'>, <Axes: >, <Axes: >, <Axes: >], [<Axes: xlabel='z', ylabel='v'>, <Axes: xlabel='theta'>, <Axes: xlabel='a'>, <Axes: xlabel='t'>]], dtype=object)
cav_model.traces
-
<xarray.Dataset> Size: 44kB Dimensions: (chain: 2, draw: 500) Coordinates: * chain (chain) int64 16B 0 1 * draw (draw) int64 4kB 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499 Data variables: z (chain, draw) float64 8kB 0.5132 0.5113 0.5033 ... 0.4974 0.5105 theta (chain, draw) float64 8kB 0.2644 0.2334 0.2371 ... 0.2234 0.2221 a (chain, draw) float64 8kB 1.364 1.324 1.328 ... 1.293 1.309 1.301 t (chain, draw) float64 8kB 0.2687 0.2739 0.2702 ... 0.278 0.2891 v (chain, draw) float64 8kB 0.3367 0.3555 0.3566 ... 0.3631 0.3506 Attributes: created_at: 2025-07-12T18:57:53.102939+00:00 arviz_version: 0.21.0 inference_library: numpyro inference_library_version: 0.17.0 sampling_time: 243.820413 tuning_steps: 500 modeling_interface: bambi modeling_interface_version: 0.15.0
-
<xarray.Dataset> Size: 32MB Dimensions: (chain: 2, draw: 500, __obs__: 3988) Coordinates: * chain (chain) int64 16B 0 1 * draw (draw) int64 4kB 0 1 2 3 4 5 6 ... 493 494 495 496 497 498 499 * __obs__ (__obs__) int64 32kB 0 1 2 3 4 5 ... 3983 3984 3985 3986 3987 Data variables: rt,response (chain, draw, __obs__) float64 32MB -0.889 -1.258 ... -0.9309 Attributes: modeling_interface: bambi modeling_interface_version: 0.15.0
-
<xarray.Dataset> Size: 53kB Dimensions: (chain: 2, draw: 500) Coordinates: * chain (chain) int64 16B 0 1 * draw (draw) int64 4kB 0 1 2 3 4 5 6 ... 494 495 496 497 498 499 Data variables: acceptance_rate (chain, draw) float64 8kB 0.9639 0.9353 ... 0.8925 1.0 step_size (chain, draw) float64 8kB 0.2428 0.2428 ... 0.2598 0.2598 diverging (chain, draw) bool 1kB False False False ... False False energy (chain, draw) float64 8kB 6.031e+03 6.03e+03 ... 6.024e+03 n_steps (chain, draw) int64 8kB 7 15 15 31 31 15 ... 23 7 7 3 15 31 tree_depth (chain, draw) int64 8kB 3 4 4 5 5 4 3 3 ... 4 3 5 3 3 2 4 5 lp (chain, draw) float64 8kB 6.026e+03 6.023e+03 ... 6.022e+03 Attributes: created_at: 2025-07-12T18:57:53.114168+00:00 arviz_version: 0.21.0 modeling_interface: bambi modeling_interface_version: 0.15.0
-
<xarray.Dataset> Size: 96kB Dimensions: (__obs__: 3988, rt,response_extra_dim_0: 2) Coordinates: * __obs__ (__obs__) int64 32kB 0 1 2 3 ... 3985 3986 3987 * rt,response_extra_dim_0 (rt,response_extra_dim_0) int64 16B 0 1 Data variables: rt,response (__obs__, rt,response_extra_dim_0) float64 64kB ... Attributes: created_at: 2025-07-12T18:57:53.115798+00:00 arviz_version: 0.21.0 inference_library: numpyro inference_library_version: 0.17.0 sampling_time: 243.820413 tuning_steps: 500 modeling_interface: bambi modeling_interface_version: 0.15.0
Further Reading¶
We suggest to check out the documentation on the VI api in Pymc for the full glory details of the capabilities we have access to.
Working directly through PyMC¶
Here we illustrate how to use our attached pymc_model
to make use of the object oriented API for variational inference. This allows us a few extra affordances.
Let's define a few helper functions first.
import warnings
import xarray as xr
from matplotlib import gridspec
from pymc.blocking import DictToArrayBijection, RaveledVars
def tracker_to_idata(tracker, model):
"""Turn a tracker object into an InferenceData object."""
tracker_groups = list(tracker.whatchdict.keys())
# n_steps = len(tracker[tracker_groups[0]])
stacked_results = {
tracker_group: {
key: np.stack([d[key] for d in tracker[tracker_group]])
for key in tracker[tracker_group][0]
}
for tracker_group in tracker_groups
}
# coords = {"vi_step": np.arange(n_steps)} | {
# k: np.array(v) for k, v in model.coords.items()
# }
var_to_dims = {
var.name: ("vi_step", *(model.named_vars_to_dims.get(var.name, ())))
for var in model.continuous_value_vars
}
datasets = {
key: xr.Dataset(
{
var: (var_to_dims[var], stacked_results[key][var])
for var in stacked_results[key].keys()
}
)
for key in tracker_groups
}
with warnings.catch_warnings(action="ignore"):
return az.InferenceData(**datasets)
def untransform_params(idata, model):
"""Bring transformed parmater back to their original scale."""
suffixes = ["_interval__", "_log__"]
def remove_suffixes(word, suffixes):
for suffix in suffixes:
if word.endswith(suffix):
return word[: -len(suffix)]
return word
free_rv_names = [rv_.name for rv_ in model.free_RVs]
transformed_vars = list(idata.mean.data_vars.keys())
collect_untransformed_vars = []
collect_untransformed_xarray_datasets = []
for var_ in transformed_vars:
var_untrans = remove_suffixes(var_, suffixes=suffixes)
if var_untrans in free_rv_names:
rv = model.free_RVs[free_rv_names.index(var_untrans)]
if model.rvs_to_transforms[rv] is not None:
untransformed_var = (
model.rvs_to_transforms[rv]
.backward(idata.mean[var_].values, *rv.owner.inputs)
.eval()
)
collect_untransformed_vars.append(var_)
collect_untransformed_xarray_datasets.append(
xr.Dataset(
data_vars={var_untrans: (("vi_step"), untransformed_var)}
)
)
return xr.merge([idata.mean] + collect_untransformed_xarray_datasets).drop_vars(
collect_untransformed_vars
)
def plot_vi_traces(idata):
"""Plot parameter history of the optimization alogrithm."""
if not isinstance(idata, az.InferenceData):
raise ValueError("idata must be an InferenceData object")
if "loss" not in idata.groups():
raise ValueError("InferenceData object must contain a 'loss' group")
if "mean_untransformed" not in idata.groups():
print(
"Using transformed variables because 'mean_untransformed' group not found"
)
data_vars = list(idata["mean"].data_vars.keys())
else:
data_vars = list(idata["mean_untransformed"].data_vars.keys())
fig = plt.figure(figsize=(8, 1.5 * len(data_vars)))
gs = gridspec.GridSpec(
len(data_vars) // 2 + 2
if (len(data_vars) % 2) == 0
else (len(data_vars) // 2) + 3,
2,
)
for i, var_ in enumerate(data_vars):
ax_tmp = fig.add_subplot(gs[i // 2, i % 2])
idata["mean_untransformed"][var_].plot(ax=ax_tmp)
ax_tmp.set_title(var_)
last_ax = fig.add_subplot(gs[-2:, :])
idata["loss"].loss.plot(ax=last_ax)
gs.tight_layout(fig)
return fig
# Define the ADVI runner
with cav_model.pymc_model:
advi = pm.ADVI()
# Set up starting point
start = cav_model.pymc_model.initial_point()
vars_dict = {var.name: var for var in cav_model.pymc_model.continuous_value_vars}
x0 = DictToArrayBijection.map(
{var_name: value for var_name, value in start.items() if var_name in vars_dict}
)
# Define quantities to track
tracker = pm.callbacks.Tracker(
mean=lambda: DictToArrayBijection.rmap(
RaveledVars(advi.approx.mean.eval(), x0.point_map_info), start
), # callable that returns mean
std=lambda: DictToArrayBijection.rmap(
RaveledVars(advi.approx.std.eval(), x0.point_map_info), start
), # callable that returns std
)
# Run VI fit
approx = advi.fit(n=30000, callbacks=[tracker])
vi_posterior_samples = approx.sample(1000)
vi_posterior_samples.posterior = vi_posterior_samples.posterior.drop_vars("v_mean")
Output()
Finished [100%]: Average Loss = 6,037.4
from copy import deepcopy
# Convert tracked quantities to idata
result = tracker_to_idata(tracker, cav_model.pymc_model)
# Add untransformed parameters
result.add_groups(
{"mean_untransformed": untransform_params(deepcopy(result), cav_model.pymc_model)}
)
# Add loss group
result.add_groups(
{"loss": xr.Dataset(data_vars={"loss": ("vi_step", np.array(approx.hist))})}
)
A quick look at our result
InferenceData
object, to understand what happened here.
We now have two additional groups:
mean_untransformed
which holds parameter values in the orignal space (instead of the parameters over which the optimization operates, which will always live in an unconstrained space)loss
which holds our training history
result
-
<xarray.Dataset> Size: 1MB Dimensions: (vi_step: 30000) Dimensions without coordinates: vi_step Data variables: z_interval__ (vi_step) float64 240kB 0.001 0.001925 ... 0.02459 0.02467 theta_interval__ (vi_step) float64 240kB -0.001 -0.001948 ... -1.191 -1.191 a_interval__ (vi_step) float64 240kB -0.001 -0.001267 ... -0.5232 t_interval__ (vi_step) float64 240kB -0.001 -0.001748 ... -1.8 -1.8 v_interval__ (vi_step) float64 240kB -0.001 -0.0008689 ... 0.2399 0.24
-
<xarray.Dataset> Size: 1MB Dimensions: (vi_step: 30000) Dimensions without coordinates: vi_step Data variables: z_interval__ (vi_step) float64 240kB 0.6926 0.6922 ... 0.0242 0.0242 theta_interval__ (vi_step) float64 240kB 0.6926 0.6931 ... 0.02724 0.02724 a_interval__ (vi_step) float64 240kB 0.6936 0.6941 ... 0.01392 0.01392 t_interval__ (vi_step) float64 240kB 0.6926 0.6931 ... 0.02402 0.02402 v_interval__ (vi_step) float64 240kB 0.6926 0.6927 ... 0.01163 0.01163
-
<xarray.Dataset> Size: 1MB Dimensions: (vi_step: 30000) Dimensions without coordinates: vi_step Data variables: z (vi_step) float64 240kB 0.5002 0.5004 0.5004 ... 0.5049 0.5049 theta (vi_step) float64 240kB 0.5997 0.5993 0.5993 ... 0.2263 0.2263 a (vi_step) float64 240kB 1.649 1.649 1.649 ... 1.305 1.305 1.305 t (vi_step) float64 240kB 1.0 0.9996 0.9994 ... 0.2846 0.2846 0.2846 v (vi_step) float64 240kB -0.0015 -0.001303 ... 0.3582 0.3583
-
<xarray.Dataset> Size: 240kB Dimensions: (vi_step: 30000) Dimensions without coordinates: vi_step Data variables: loss (vi_step) float64 240kB 1.543e+04 1.035e+04 ... 6.038e+03 6.036e+03
Plot Results¶
We can plot the parameter trajectories (histories) over optimization steps, with our little helper function plot_vi_traces()
.
NOTE:
This is a random run, and we did not thoroughly check if the number of steps we allowed the optimizer were indeed enough to converge.
fig = plot_vi_traces(result)
__, axes = plt.subplots(4, 4, figsize=(10, 5))
# Plot MCMC [nuts]
az.plot_pair(cav_model.traces, ax=axes, scatter_kwargs=dict(alpha=0.01, color="blue"))
# Plot VI via .vi() [fullrank_advi]
az.plot_pair(cav_model.vi_idata, ax=axes, scatter_kwargs=dict(alpha=0.04, color="red"))
# Plot VI via pymc interface [advi]
# (We need to make sure the variables are in correct order)
az.plot_pair(
vi_posterior_samples.posterior[list(cav_model.traces.posterior.data_vars)],
ax=axes,
scatter_kwargs=dict(alpha=0.04, color="green"),
)
array([[<Axes: ylabel='theta'>, <Axes: >, <Axes: >, <Axes: >], [<Axes: ylabel='a'>, <Axes: >, <Axes: >, <Axes: >], [<Axes: ylabel='t'>, <Axes: >, <Axes: >, <Axes: >], [<Axes: xlabel='z', ylabel='v'>, <Axes: xlabel='theta'>, <Axes: xlabel='a'>, <Axes: xlabel='t'>]], dtype=object)
NOTE:
It is expected that the posterior of our last run looks a little worse. We chose to run the advi
algorithm, which implies only an isotropic Gaussian approximation to the posterior, so we expect to miss the posterior covariances which we pick up via fullrank_advi
as well as MCMC.
Caveats¶
Variational Inference is powerful, however it comes with it's own set of sharp edges.
- You will not always be in a position to compare VI with MCMC runs (after all if you can run full MCMC, there isn't much benefit to using VI at all) and it can be hard to a priori estimate how many steps you may need to run the algorithm for.
- The posteriors will be approximate, if the true posterior includes complex parameter trade-offs, VI might result in inaccurate posterior estimates.
- We recommend VI for
loglik_kind="approx_differentiable"
, since the gradients of theanalytical
log-likelihoods still prove somewhat brittle at this point in time.
Read¶
To learn a bit more about the VI-API in PyMC, we recommend you to read the excellent short tutorial in the main documentation.