Plotting in HSSM¶
This tutorial demonstrates the plotting functionalities in HSSM.
While the ArviZ package provides many plotting utilities, HSSM aims complement the ArviZ package with additional types of plots specific for hierarchical sequential sampling models. In addition, HSSM also provides some plotting API directly from the HSSM model object with some additional tweaks for convenience.
Most of the plotting and summary functionalities can be found in hssm.plotting module, and additional convenience functions are exposed through the top-level HSSM model as well.
# If running this on Colab, please uncomment the next line and
# !pip install hssm
from pathlib import Path
import arviz as az
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import seaborn as sns
import hssm
import hssm.plotting
%matplotlib inline
%config InlineBackend.figure_format='retina'
Model setup¶
fixtures_dir = Path("../../tests/fixtures")
cav_data_test = pd.read_csv(fixtures_dir / "cavanagh_theta_test.csv")
cav_data_traces = az.from_netcdf(fixtures_dir / "cavanagh_idata.nc")
# For demonstration purposes,
# `cav_data_test` is a subset of the `cavanagh_theta_nn` dataset
# with 5 subjects and 100 observation each
cav_data_test
| participant_id | stim | rt | response | theta | dbs | conf | |
|---|---|---|---|---|---|---|---|
| 0 | 0 | WL | 0.928 | -1.0 | -0.521933 | 0 | LC |
| 1 | 0 | WL | 0.661 | 1.0 | -0.219645 | 1 | LC |
| 2 | 0 | WW | 2.350 | -1.0 | -0.168728 | 1 | HC |
| 3 | 0 | LL | 1.250 | -1.0 | -0.104636 | 1 | HC |
| 4 | 0 | LL | 1.170 | -1.0 | 1.122720 | 1 | HC |
| ... | ... | ... | ... | ... | ... | ... | ... |
| 495 | 4 | WL | 0.606 | -1.0 | -0.635942 | 0 | LC |
| 496 | 4 | WL | 0.745 | -1.0 | -0.166833 | 0 | LC |
| 497 | 4 | WW | 1.320 | 1.0 | -0.283396 | 1 | HC |
| 498 | 4 | LL | 1.640 | 1.0 | 0.462584 | 1 | HC |
| 499 | 4 | WL | 0.822 | 1.0 | -0.019645 | 0 | LC |
500 rows × 7 columns
# Model parameter specification
cav_model = hssm.HSSM(
model="ddm",
data=cav_data_test,
include=[
{
"name": "v",
"prior": {
"Intercept": {"name": "Normal", "mu": 0.0, "sigma": 1.0},
"theta": {"name": "Normal", "mu": 0.0, "sigma": 1.0},
},
"formula": "v ~ theta + (1|participant_id)",
"link": "identity",
},
],
)
# Perform sampling
# cav_model.sample()
Model initialized successfully.
For demonstration purposes, we inject into the model an existing trace with posterior predictive sampling already performed. A posterior_predictive attribute is added to the traces object. A rt,response_mean variable is also added to the posterior attribute during predictive sampling.
# In practice, you would obtain this object by sampling from the model.
cav_model._inference_obj = cav_data_traces
# Sample prior predictive
# cav_model.sample_prior_predictive(draws = 1000)
# Print
cav_model.traces
-
<xarray.Dataset> Size: 4MB Dimensions: (chain: 2, draw: 500, v_1|participant_id__factor_dim: 5, __obs__: 500) Coordinates: * chain (chain) int64 16B 0 1 * draw (draw) int64 4kB 0 1 2 3 ... 496 497 498 499 * v_1|participant_id__factor_dim (v_1|participant_id__factor_dim) <U1 20B ... * __obs__ (__obs__) int64 4kB 0 1 2 3 ... 497 498 499 Data variables: v_Intercept (chain, draw) float64 8kB ... v_theta (chain, draw) float32 4kB ... a (chain, draw) float32 4kB ... z (chain, draw) float32 4kB ... t (chain, draw) float32 4kB ... v_1|participant_id_sigma (chain, draw) float32 4kB ... v_1|participant_id (chain, draw, v_1|participant_id__factor_dim) float32 20kB ... v (chain, draw, __obs__) float64 4MB ... Attributes: created_at: 2023-11-14T18:35:04.027433 arviz_version: 0.14.0 inference_library: pymc inference_library_version: 5.9.1 sampling_time: 22.505457878112793 tuning_steps: 1000 modeling_interface: bambi modeling_interface_version: 0.12.0 -
<xarray.Dataset> Size: 4MB Dimensions: (chain: 2, draw: 500, __obs__: 500, rt,response_dim: 2) Coordinates: * chain (chain) int64 16B 0 1 * draw (draw) int64 4kB 0 1 2 3 4 5 6 ... 494 495 496 497 498 499 * __obs__ (__obs__) int64 4kB 0 1 2 3 4 5 ... 494 495 496 497 498 499 * rt,response_dim (rt,response_dim) int64 16B 0 1 Data variables: rt,response (chain, draw, __obs__, rt,response_dim) float32 4MB ... Attributes: modeling_interface: bambi modeling_interface_version: 0.12.0 -
<xarray.Dataset> Size: 126kB Dimensions: (chain: 2, draw: 500) Coordinates: * chain (chain) int64 16B 0 1 * draw (draw) int64 4kB 0 1 2 3 4 5 ... 495 496 497 498 499 Data variables: (12/17) energy_error (chain, draw) float64 8kB ... perf_counter_diff (chain, draw) float64 8kB ... energy (chain, draw) float64 8kB ... max_energy_error (chain, draw) float64 8kB ... smallest_eigval (chain, draw) float64 8kB ... step_size_bar (chain, draw) float64 8kB ... ... ... perf_counter_start (chain, draw) float64 8kB ... process_time_diff (chain, draw) float64 8kB ... index_in_trajectory (chain, draw) int64 8kB ... reached_max_treedepth (chain, draw) bool 1kB ... lp (chain, draw) float64 8kB ... largest_eigval (chain, draw) float64 8kB ... Attributes: arviz_version: 0.14.0 created_at: 2023-11-14T18:35:04.032437 inference_library: pymc inference_library_version: 5.9.1 modeling_interface: bambi modeling_interface_version: 0.12.0 sampling_time: 22.505457878112793 tuning_steps: 1000 -
<xarray.Dataset> Size: 8kB Dimensions: (__obs__: 500, rt,response_extra_dim_0: 2) Coordinates: * __obs__ (__obs__) int64 4kB 0 1 2 3 4 ... 496 497 498 499 * rt,response_extra_dim_0 (rt,response_extra_dim_0) int64 16B 0 1 Data variables: rt,response (__obs__, rt,response_extra_dim_0) float32 4kB ... Attributes: arviz_version: 0.14.0 created_at: 2023-11-14T18:35:04.034261 inference_library: pymc inference_library_version: 5.9.1 modeling_interface: bambi modeling_interface_version: 0.12.0
Convenience functions from the top-level hssm.HSSM model object¶
The ArviZ package provides az.summary() and az.plot_trace() functions that are very frequently used. We have added these functions through the top-level hssm.HSSM model object. The goal is to provide convenience to HSSM users. In addition, these functions provide some nice defaults. For example, in some cases, when some parameters are the targets of regressions, ArviZ will also plot the computed values for each regressor on each observation, which is highly inefficient. When convenience functions are used, these computed values will be excluded by default.
1. model.summary()¶
The [model.summary()][hssm.HSSM.summary] convenience funciton automatically filters out undesirable outputs.
set(list(cav_model.traces.posterior.data_vars.keys()))
{'a',
't',
'v',
'v_1|participant_id',
'v_1|participant_id_sigma',
'v_Intercept',
'v_theta',
'z'}
cav_model.summary()
| mean | sd | hdi_3% | hdi_97% | mcse_mean | mcse_sd | ess_bulk | ess_tail | r_hat | |
|---|---|---|---|---|---|---|---|---|---|
| v_theta | 0.062 | 0.048 | -0.027 | 0.150 | 0.002 | 0.002 | 620.0 | 471.0 | 1.01 |
| z | 0.504 | 0.020 | 0.467 | 0.538 | 0.001 | 0.001 | 766.0 | 643.0 | 1.00 |
| v_Intercept | 0.425 | 0.268 | -0.086 | 0.936 | 0.017 | 0.012 | 262.0 | 359.0 | 1.00 |
| t | 0.337 | 0.013 | 0.312 | 0.363 | 0.001 | 0.000 | 646.0 | 639.0 | 1.01 |
| a | 1.044 | 0.024 | 0.995 | 1.085 | 0.001 | 0.001 | 831.0 | 737.0 | 1.00 |
| v_1|participant_id_sigma | 0.621 | 0.281 | 0.237 | 1.200 | 0.022 | 0.019 | 183.0 | 221.0 | 1.01 |
Compare this with the output of ArviZ az.summary():
# Because of posterior predictive sampling, an `rt,response_mean` field was added to
# the traces object by default. ArviZ include these values by default.
# This is equivalent to calling
# cav_model.summary(include_computed_values=True)
az.summary(cav_model.traces)
| mean | sd | hdi_3% | hdi_97% | mcse_mean | mcse_sd | ess_bulk | ess_tail | r_hat | |
|---|---|---|---|---|---|---|---|---|---|
| v_Intercept | 0.425 | 0.268 | -0.086 | 0.936 | 0.017 | 0.012 | 262.0 | 359.0 | 1.00 |
| v_theta | 0.062 | 0.048 | -0.027 | 0.150 | 0.002 | 0.002 | 620.0 | 471.0 | 1.01 |
| a | 1.044 | 0.024 | 0.995 | 1.085 | 0.001 | 0.001 | 831.0 | 737.0 | 1.00 |
| z | 0.504 | 0.020 | 0.467 | 0.538 | 0.001 | 0.001 | 766.0 | 643.0 | 1.00 |
| t | 0.337 | 0.013 | 0.312 | 0.363 | 0.001 | 0.000 | 646.0 | 639.0 | 1.01 |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| v[495] | 1.054 | 0.172 | 0.754 | 1.394 | 0.006 | 0.004 | 747.0 | 742.0 | 1.00 |
| v[496] | 1.083 | 0.169 | 0.740 | 1.369 | 0.006 | 0.004 | 790.0 | 735.0 | 1.00 |
| v[497] | 1.076 | 0.170 | 0.727 | 1.363 | 0.006 | 0.004 | 780.0 | 760.0 | 1.00 |
| v[498] | 1.122 | 0.170 | 0.805 | 1.432 | 0.006 | 0.004 | 861.0 | 714.0 | 1.00 |
| v[499] | 1.092 | 0.169 | 0.752 | 1.382 | 0.006 | 0.004 | 805.0 | 760.0 | 1.00 |
511 rows × 9 columns
2. model.plot_trace()¶
Likewise, [model.plot_trace()][hssm.HSSM.plot_trace] is also equivalent to calling az.plot_trace() on the model with computed values removed and also calling plt.tight_layout():
cav_model.plot_trace()
hssm.plotting.plot_predictive(cav_model,
predictive_group="posterior_predictive");
The predictive_group argument lets you choose between the posterior_predictive and the prior_predictive groups.
hssm.plotting.plot_predictive(cav_model,
predictive_group="prior_predictive",
x_range = (-10, 10));
No prior_predictive samples found. Generating prior_predictive samples using the provided InferenceData object and the original data. This will modify the provided InferenceData object, or if not provided, the traces object stored inside the model.
Sampling: [a, rt,response, t, v_1|participant_id_offset, v_1|participant_id_sigma, v_Intercept, v_theta, z] /Users/afengler/Library/CloudStorage/OneDrive-Personal/proj_hssm/HSSM/.venv/lib/python3.11/site-packages/bambi/models.py:851: FutureWarning: 'mean' has been replaced by 'response_params' and is not going to work in the future warnings.warn(
This API is designed to be a light wrapper around Seaborn's sns.histplot() API. It accepts most arguments that sns.histplot() accepts. It also returns an ax object in matplotlib, so you can manipulate it further.
ax = hssm.plotting.plot_predictive(cav_model)
sns.despine()
ax.set_ylabel("")
plt.title("Posterior Predictive Plot")
Text(0.5, 1.0, 'Posterior Predictive Plot')
You can also plot subsets of data in subplots:
hssm.plotting.plot_predictive(
cav_model,
col="participant_id",
col_wrap=3, # limits to 3 columns per row
)
<seaborn.axisgrid.FacetGrid at 0x349ce7e10>
hssm.plotting.plot_predictive(
cav_model,
col="participant_id",
col_wrap=3, # limits to 3 columns per row
predictive_group="prior_predictive",
)
<seaborn.axisgrid.FacetGrid at 0x349e78c10>
2-dimensional grids are also possible:
hssm.plotting.plot_predictive(
cav_model,
col="participant_id",
row="conf",
)
<seaborn.axisgrid.FacetGrid at 0x349bcd250>
When grids are used, this function returns a sns.FacetGrid object. You can also further customize your plot with this object as well. For example, you can set the titles each individual subplot according to a template or save the figure to disk (g.savefig()).
g = hssm.plotting.plot_predictive(
cav_model,
col="participant_id",
row="conf",
)
g.set_titles(template="{row_name} | Participant {col_name}")
<seaborn.axisgrid.FacetGrid at 0x34e45b990>
hssm.plotting.plot_quantile_probability() works similarly to hssm.plotting.plot_predictive() in that when only producing one plot (no grid), it returns an axis object, and when it returns multiple plots, it produces a FacetGrid object
# Single plot, returns an axis object, which can be worked on further
ax = hssm.plotting.plot_quantile_probability(cav_model, cond="stim")
ax.set_ylim(0, 3);
# Multiple plots, returns a FacetGrid
g = hssm.plotting.plot_quantile_probability(
cav_model,
cond="stim",
col="participant_id",
col_wrap=3,
grid_kwargs=dict(
ylim=(0, 3)
), # additional kwargs to the grid can be passed through `grid_kwargs`
)
The predictive_group argument works here too:
# Multiple plots, returns a FacetGrid
g = hssm.plotting.plot_quantile_probability(
cav_model,
cond="stim",
col="participant_id",
predictive_group="prior_predictive",
col_wrap=3,
grid_kwargs=dict(
ylim=(0, 3)
), # additional kwargs to the grid can be passed through `grid_kwargs`
)
Model Cartoon Plots¶
This part showcases the plot_model_cartoon() function, which is part of the HSSM plotting submodule.
The idea with these plots is to provide a pictorial representation of the underlying process that we recover with our model fits.
You can explore various options to include posterior uncertainty graphically in these plots.
ax_1 = hssm.plotting.plot_model_cartoon(
cav_model,
n_samples=10,
bins=20,
col="stim",
row="participant_id",
groups=["dbs"],
plot_predictive_mean=True,
plot_predictive_samples=True,
n_trajectories=2, # extra arguments for the underlying plot_model_cartoon() function
);
No posterior_predictive samples found. Generating posterior_predictive samples using the provided InferenceData object and the original data. This will modify the provided InferenceData object, or if not provided, the traces object stored inside the model.
Prior Predictive¶
Note, the model cartoon plots look particularly unwieldy when using the prior predictive. If you want, take a look at the prior specifications and you can reason why and how it translates into such an unruly plot.
ax_2 = hssm.plotting.plot_model_cartoon(
cav_model,
n_samples=10,
bins=20,
col="stim",
row="participant_id",
groups=["dbs"],
plot_predictive_mean=True,
plot_predictive_samples=True,
predictive_group="prior_predictive",
n_trajectories=2, # extra arguments for the underlying plot_model_cartoon() function
);
No split by group | Inclue posterior uncertainty¶
ax = hssm.plotting.plot_model_cartoon(
cav_model,
n_samples=100,
bins=10,
col="stim",
row="participant_id",
plot_predictive_mean=True,
plot_predictive_samples=True,
alpha_mean=1.0,
alpha_predictive=0.01,
alpha_trajectories=0.5,
n_trajectories=2, # extra arguments for the underlying plot_model_cartoon() function
);
No posterior_predictive samples found. Generating posterior_predictive samples using the provided InferenceData object and the original data. This will modify the provided InferenceData object, or if not provided, the traces object stored inside the model.
ax = hssm.plotting.plot_model_cartoon(
cav_model,
n_samples=10,
bins=20,
col="stim",
row="participant_id",
groups=["dbs"],
plot_predictive_mean=True,
plot_predictive_samples=True,
predictive_group="prior_predictive",
alpha_mean=1.0,
alpha_predictive=0.01,
alpha_trajectories=0.5,
n_trajectories=2, # extra arguments for the underlying plot_model_cartoon() function
);
N Choices¶
Data simulation¶
stim_v = [0.0, 0.75, 1.0]
stim_names = ["low", "medium", "high"]
datasets = []
a_vec = np.random.normal(loc=1.25, scale=0.3, size=5)
for v_tmp in stim_v:
for participant_id in range(5):
a_tmp = a_vec[participant_id]
data_tmp = hssm.simulate_data(
model="race_no_bias_angle_4",
theta=dict(
a=a_tmp,
v0=v_tmp,
v1=v_tmp + 0.25,
v2=v_tmp + 0.5,
v3=v_tmp + 0.75,
z=0.5,
t=0.2,
theta=0.1,
),
size=200,
)
data_tmp["stim"] = stim_names[stim_v.index(v_tmp)]
data_tmp["participant_id"] = str(participant_id)
datasets.append(data_tmp)
dataset = pd.concat(datasets).reset_index(drop=True)
param_dict = {
"v": {
"low": np.repeat(stim_v[0], 5),
"medium": np.repeat(stim_v[1], 5),
"high": np.repeat(stim_v[2], 5),
},
"a": {"participant_id"},
}
HSSM Model¶
race_model = hssm.HSSM(
model="race_no_bias_angle_4",
data=dataset,
include=[
{
"name": "v0",
"prior": {
"Intercept": {"name": "Normal", "mu": 0.0, "sigma": 1.5},
},
"formula": "v0 ~ 1 + stim",
"link": "identity",
},
{
"name": "a",
"prior": {
"Intercept": {"name": "Normal", "mu": 1.5, "sigma": 0.5},
},
"formula": "a ~ 1 + (1|participant_id)",
"link": "identity",
},
],
p_outlier=0.00,
)
You have specified the `lapse` argument to include a lapse distribution, but `p_outlier` is set to either 0 or None. Your lapse distribution will be ignored. Model initialized successfully.
idata_race = race_model.sample(
sampler="nuts_numpyro",
chains=2,
cores=2,
chain_method="vectorized",
draws=500,
tune=500,
idata_kwargs=dict(log_likelihood=False), # no need to return likelihoods here
)
Using default initvals.
/Users/afengler/Library/CloudStorage/OneDrive-Personal/proj_hssm/HSSM/.venv/lib/python3.11/site-packages/pymc/sampling/jax.py:475: UserWarning: There are not enough devices to run parallel chains: expected 2 but got 1. Chains will be drawn sequentially. If you are running MCMC in CPU, consider using `numpyro.set_host_device_count(2)` at the beginning of your program. You can double-check how many devices are available in your system using `jax.local_device_count()`. pmap_numpyro = MCMC( sample: 100%|██████████| 1000/1000 [01:14<00:00, 13.34it/s, 18 steps of size 3.26e-02. acc. prob=0.84] sample: 100%|██████████| 1000/1000 [02:47<00:00, 5.95it/s, 57 steps of size 6.43e-03. acc. prob=0.91] There were 1000 divergences after tuning. Increase `target_accept` or reparameterize. We recommend running at least 4 chains for robust computation of convergence diagnostics The rhat statistic is larger than 1.01 for some parameters. This indicates problems during sampling. See https://arxiv.org/abs/1903.08008 for details The effective sample size per chain is smaller than 100 for some parameters. A higher number is needed for reliable rhat and ess computation. See https://arxiv.org/abs/1903.08008 for details
Plot¶
ax = hssm.plotting.plot_model_cartoon(
race_model,
n_samples=10,
col="stim",
row="participant_id",
plot_pp_mean=True,
plot_pp_samples=False,
n_trajectories=1,
ylims=(0, 5),
alpha_pp=0.2,
xlims=(0.0, 2),
);
No posterior_predictive samples found. Generating posterior_predictive samples using the provided InferenceData object and the original data. This will modify the provided InferenceData object, or if not provided, the traces object stored inside the model.
Output()
Output()