Plotting in HSSM¶
This tutorial demonstrates the plotting functionalities in HSSM.
While the ArviZ
package provides many plotting utilities, HSSM aims complement the ArviZ
package with additional types of plots specific for hierarchical sequential sampling models. In addition, HSSM also provides some plotting API directly from the HSSM
model object with some additional tweaks for convenience.
Most of the plotting and summary functionalities can be found in hssm.plotting
module, and additional convenience functions are exposed through the top-level HSSM
model as well.
# If running this on Colab, please uncomment the next line and
# !pip install hssm
from pathlib import Path
import arviz as az
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import hssm
import hssm.plotting
%matplotlib inline
%config InlineBackend.figure_format='retina'
WARNING (pytensor.tensor.blas): Using NumPy C-API based implementation for BLAS functions.
Model setup¶
fixtures_dir = Path("../../tests/fixtures")
cav_data_test = pd.read_csv(fixtures_dir / "cavanagh_theta_test.csv")
cav_data_traces = az.from_netcdf(fixtures_dir / "cavanagh_idata.nc")
# For demonstration purposes,
# `cav_data_test` is a subset of the `cavanagh_theta_nn` dataset
# with 5 subjects and 100 observation each
cav_data_test
participant_id | stim | rt | response | theta | dbs | conf | |
---|---|---|---|---|---|---|---|
0 | 0 | WL | 0.928 | -1.0 | -0.521933 | 0 | LC |
1 | 0 | WL | 0.661 | 1.0 | -0.219645 | 1 | LC |
2 | 0 | WW | 2.350 | -1.0 | -0.168728 | 1 | HC |
3 | 0 | LL | 1.250 | -1.0 | -0.104636 | 1 | HC |
4 | 0 | LL | 1.170 | -1.0 | 1.122720 | 1 | HC |
... | ... | ... | ... | ... | ... | ... | ... |
495 | 4 | WL | 0.606 | -1.0 | -0.635942 | 0 | LC |
496 | 4 | WL | 0.745 | -1.0 | -0.166833 | 0 | LC |
497 | 4 | WW | 1.320 | 1.0 | -0.283396 | 1 | HC |
498 | 4 | LL | 1.640 | 1.0 | 0.462584 | 1 | HC |
499 | 4 | WL | 0.822 | 1.0 | -0.019645 | 0 | LC |
500 rows × 7 columns
# Model parameter specification
cav_model = hssm.HSSM(
model="ddm",
data=cav_data_test,
include=[
{
"name": "v",
"prior": {
"Intercept": {"name": "Normal", "mu": 0.0, "sigma": 1.0},
"theta": {"name": "Normal", "mu": 0.0, "sigma": 1.0},
},
"formula": "v ~ theta + (1|participant_id)",
"link": "identity",
},
],
)
# Perform sampling
# cav_model.sample()
No common intercept. Bounds for parameter v is not applied due to a current limitation of Bambi. This will change in the future.
Model initialized successfully.
For demonstration purposes, we inject into the model an existing trace with posterior predictive sampling already performed. A posterior_predictive
attribute is added to the traces
object. A rt,response_mean
variable is also added to the posterior
attribute during predictive sampling.
# In practice, you would obtain this object by sampling from the model.
cav_model._inference_obj = cav_data_traces
cav_model.traces
-
<xarray.Dataset> Size: 4MB Dimensions: (chain: 2, draw: 500, v_1|participant_id__factor_dim: 5, __obs__: 500) Coordinates: * chain (chain) int64 16B 0 1 * draw (draw) int64 4kB 0 1 2 3 ... 496 497 498 499 * v_1|participant_id__factor_dim (v_1|participant_id__factor_dim) <U1 20B ... * __obs__ (__obs__) int64 4kB 0 1 2 3 ... 497 498 499 Data variables: v_Intercept (chain, draw) float64 8kB ... v_theta (chain, draw) float32 4kB ... a (chain, draw) float32 4kB ... z (chain, draw) float32 4kB ... t (chain, draw) float32 4kB ... v_1|participant_id_sigma (chain, draw) float32 4kB ... v_1|participant_id (chain, draw, v_1|participant_id__factor_dim) float32 20kB ... v (chain, draw, __obs__) float64 4MB ... Attributes: created_at: 2023-11-14T18:35:04.027433 arviz_version: 0.14.0 inference_library: pymc inference_library_version: 5.9.1 sampling_time: 22.505457878112793 tuning_steps: 1000 modeling_interface: bambi modeling_interface_version: 0.12.0
-
<xarray.Dataset> Size: 4MB Dimensions: (chain: 2, draw: 500, __obs__: 500, rt,response_dim: 2) Coordinates: * chain (chain) int64 16B 0 1 * draw (draw) int64 4kB 0 1 2 3 4 5 6 ... 494 495 496 497 498 499 * __obs__ (__obs__) int64 4kB 0 1 2 3 4 5 ... 494 495 496 497 498 499 * rt,response_dim (rt,response_dim) int64 16B 0 1 Data variables: rt,response (chain, draw, __obs__, rt,response_dim) float32 4MB ... Attributes: modeling_interface: bambi modeling_interface_version: 0.12.0
-
<xarray.Dataset> Size: 126kB Dimensions: (chain: 2, draw: 500) Coordinates: * chain (chain) int64 16B 0 1 * draw (draw) int64 4kB 0 1 2 3 4 5 ... 495 496 497 498 499 Data variables: (12/17) energy_error (chain, draw) float64 8kB ... perf_counter_diff (chain, draw) float64 8kB ... energy (chain, draw) float64 8kB ... max_energy_error (chain, draw) float64 8kB ... smallest_eigval (chain, draw) float64 8kB ... step_size_bar (chain, draw) float64 8kB ... ... ... perf_counter_start (chain, draw) float64 8kB ... process_time_diff (chain, draw) float64 8kB ... index_in_trajectory (chain, draw) int64 8kB ... reached_max_treedepth (chain, draw) bool 1kB ... lp (chain, draw) float64 8kB ... largest_eigval (chain, draw) float64 8kB ... Attributes: arviz_version: 0.14.0 created_at: 2023-11-14T18:35:04.032437 inference_library: pymc inference_library_version: 5.9.1 modeling_interface: bambi modeling_interface_version: 0.12.0 sampling_time: 22.505457878112793 tuning_steps: 1000
-
<xarray.Dataset> Size: 8kB Dimensions: (__obs__: 500, rt,response_extra_dim_0: 2) Coordinates: * __obs__ (__obs__) int64 4kB 0 1 2 3 4 ... 496 497 498 499 * rt,response_extra_dim_0 (rt,response_extra_dim_0) int64 16B 0 1 Data variables: rt,response (__obs__, rt,response_extra_dim_0) float32 4kB ... Attributes: arviz_version: 0.14.0 created_at: 2023-11-14T18:35:04.034261 inference_library: pymc inference_library_version: 5.9.1 modeling_interface: bambi modeling_interface_version: 0.12.0
Convenience functions from the top-level hssm.HSSM
model object¶
The ArviZ
package provides az.summary()
and az.plot_trace()
functions that are very frequently used. We have added these functions through the top-level hssm.HSSM
model object. The goal is to provide convenience to HSSM users. In addition, these functions provide some nice defaults. For example, in some cases, when some parameters are the targets of regressions, ArviZ
will also plot the computed values for each regressor on each observation, which is highly inefficient. When convenience functions are used, these computed values will be excluded by default.
1. model.summary()
¶
The [model.summary()
][hssm.HSSM.summary] convenience funciton automatically filters out undesirable outputs.
dir(cav_model.traces.posterior.data_vars)
['__abstractmethods__', '__class__', '__class_getitem__', '__contains__', '__delattr__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattribute__', '__getitem__', '__getstate__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__iter__', '__le__', '__len__', '__lt__', '__module__', '__ne__', '__new__', '__orig_bases__', '__reduce__', '__reduce_ex__', '__repr__', '__reversed__', '__setattr__', '__sizeof__', '__slots__', '__str__', '__subclasshook__', '_abc_impl', '_dataset', '_ipython_key_completions_', 'dtypes', 'get', 'items', 'keys', 'values', 'variables']
set(list(cav_model.traces.posterior.data_vars.keys()))
{'a', 't', 'v', 'v_1|participant_id', 'v_1|participant_id_sigma', 'v_Intercept', 'v_theta', 'z'}
cav_model.summary()
mean | sd | hdi_3% | hdi_97% | mcse_mean | mcse_sd | ess_bulk | ess_tail | r_hat | |
---|---|---|---|---|---|---|---|---|---|
v_1|participant_id_sigma | 0.621 | 0.281 | 0.237 | 1.200 | 0.022 | 0.016 | 183.0 | 221.0 | 1.01 |
t | 0.337 | 0.013 | 0.312 | 0.363 | 0.001 | 0.000 | 646.0 | 639.0 | 1.01 |
a | 1.044 | 0.024 | 0.995 | 1.085 | 0.001 | 0.001 | 831.0 | 737.0 | 1.00 |
v_theta | 0.062 | 0.048 | -0.027 | 0.150 | 0.002 | 0.001 | 620.0 | 471.0 | 1.01 |
v_Intercept | 0.425 | 0.268 | -0.086 | 0.936 | 0.017 | 0.012 | 262.0 | 359.0 | 1.00 |
z | 0.504 | 0.020 | 0.467 | 0.538 | 0.001 | 0.000 | 766.0 | 643.0 | 1.00 |
Compare this with the output of ArviZ
az.summary()
:
# Because of posterior predictive sampling, an `rt,response_mean` field was added to
# the traces object by default. ArviZ include these values by default.
# This is equivalent to calling
# cav_model.summary(include_computed_values=True)
az.summary(cav_model.traces)
mean | sd | hdi_3% | hdi_97% | mcse_mean | mcse_sd | ess_bulk | ess_tail | r_hat | |
---|---|---|---|---|---|---|---|---|---|
v_Intercept | 0.425 | 0.268 | -0.086 | 0.936 | 0.017 | 0.012 | 262.0 | 359.0 | 1.00 |
v_theta | 0.062 | 0.048 | -0.027 | 0.150 | 0.002 | 0.001 | 620.0 | 471.0 | 1.01 |
a | 1.044 | 0.024 | 0.995 | 1.085 | 0.001 | 0.001 | 831.0 | 737.0 | 1.00 |
z | 0.504 | 0.020 | 0.467 | 0.538 | 0.001 | 0.000 | 766.0 | 643.0 | 1.00 |
t | 0.337 | 0.013 | 0.312 | 0.363 | 0.001 | 0.000 | 646.0 | 639.0 | 1.01 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
v[495] | 1.054 | 0.172 | 0.754 | 1.394 | 0.006 | 0.004 | 747.0 | 742.0 | 1.00 |
v[496] | 1.083 | 0.169 | 0.740 | 1.369 | 0.006 | 0.004 | 790.0 | 735.0 | 1.00 |
v[497] | 1.076 | 0.170 | 0.727 | 1.363 | 0.006 | 0.004 | 780.0 | 760.0 | 1.00 |
v[498] | 1.122 | 0.170 | 0.805 | 1.432 | 0.006 | 0.004 | 861.0 | 714.0 | 1.00 |
v[499] | 1.092 | 0.169 | 0.752 | 1.382 | 0.006 | 0.004 | 805.0 | 760.0 | 1.00 |
511 rows × 9 columns
2. model.plot_trace()
¶
Likewise, [model.plot_trace()
][hssm.HSSM.plot_trace] is also equivalent to calling az.plot_trace()
on the model with computed values removed and also calling plt.tight_layout()
:
cav_model.plot_trace()
HSSM-specific plots¶
HSSM also offers various types of plots specific to hierarchical sequential sampling models.
1. Posterior predictive plots¶
Posterior predictive plots [hssm.plot_posterior_predictive()
][hssm.plotting.plot_posterior_predictive] plots the distribution of posterior predictive samples against the observed data.
# You can also call
# cav_model.plot_posterior_predictive()
hssm.plotting.plot_posterior_predictive(cav_model);
This API is designed to be a light wrapper around Seaborn's sns.histplot()
API. It accepts most arguments that sns.histplot()
accepts. It also returns an ax
object in matplotlib, so you can manipulate it further.
ax = hssm.plotting.plot_posterior_predictive(cav_model)
sns.despine()
ax.set_ylabel("")
plt.title("Posterior Predictive Plot")
Text(0.5, 1.0, 'Posterior Predictive Plot')
You can also plot subsets of data in subplots:
hssm.plotting.plot_posterior_predictive(
cav_model,
col="participant_id",
col_wrap=3, # limits to 3 columns per row
)
<seaborn.axisgrid.FacetGrid at 0x7ff1440e0d90>
2-dimensional grids are also possible:
hssm.plotting.plot_posterior_predictive(
cav_model,
col="participant_id",
row="conf",
)
<seaborn.axisgrid.FacetGrid at 0x7ff13eef51d0>
When grids are used, this function returns a sns.FacetGrid
object. You can also further customize your plot with this object as well. For example, you can set the titles each individual subplot according to a template or save the figure to disk (g.savefig()
).
g = hssm.plotting.plot_posterior_predictive(
cav_model,
col="participant_id",
row="conf",
)
g.set_titles(template="{row_name} | Participant {col_name}")
<seaborn.axisgrid.FacetGrid at 0x7ff144360650>
hssm.plotting.plot_quantile_probability()
works similarly to hssm.plotting.plot_posterior_predictive()
in that when only producing one plot (no grid), it returns an axis object, and when it returns multiple plots, it produces a FacetGrid
object
# Single plot, returns an axis object, which can be worked on further
ax = hssm.plotting.plot_quantile_probability(cav_model, cond="stim")
ax.set_ylim(0, 3);
# Multiple plots, returns a FacetGrid
g = hssm.plotting.plot_quantile_probability(
cav_model,
cond="stim",
col="participant_id",
col_wrap=3,
grid_kwargs=dict(
ylim=(0, 3)
), # additional kwargs to the grid can be passed through `grid_kwargs`
)