This tutorial provides a comprehensive introduction to the HSSM package for Hierarchical Bayesian Estimation of Sequential Sampling Models.
To make the most of the tutorial, let us cover the functionality of the key supporting packages that we use along the way.
Colab Instructions¶
If you would like to run this tutorial on Google colab, please click this link.
Once you are in the colab, follow the installation instructions below and then restart your runtime.
Just uncomment the code in the next code cell and run it!
NOTE:
You may want to switch your runtime to have a GPU or TPU. To do so, go to Runtime > Change runtime type and select the desired hardware accelerator.
Note that if you switch your runtime you have to follow the installation instructions again.
# !pip install numpy==1.23.4
# !pip install git+https://github.com/lnccbrown/hssm@main
# !pip install git+https://github.com/brown-ccv/hddm-wfpt@main
# !pip install numpyro
SSMS for Data Simulation¶
We will rely on the ssms package for data simulation repeatedly. Let's look at a basic isolated use case below.
As an example, let's use ssms to simulate from the basic Drift Diffusion Model (a running example in this tutorial).
If you are not familiar with the DDM. For now just consider that it has four parameters.
v
the drift ratea
the boundary separationt
the non-decision timez
the a priori decision bias (starting point)
from ssms.basic_simulators import simulator
import numpy as np
import pandas as pd
# Specify parameters
v_true, a_true, z_true, t_true = [0.5, 1.5, 0.5, 0.2]
# Simulate data
sim_out = simulator(
theta=[v_true, a_true, z_true, t_true], # parameter list
model="ddm", # specify model (many are included in ssms)
n_samples=500, # number of samples for each set of parameters
)
# Turn into nice dataset
# Turn data into a pandas dataframe
dataset = pd.DataFrame(
np.column_stack([sim_out["rts"][:, 0], sim_out["choices"][:, 0]]),
columns=["rt", "response"],
)
dataset
rt | response | |
---|---|---|
0 | 2.563011 | 1.0 |
1 | 0.602998 | 1.0 |
2 | 1.565008 | -1.0 |
3 | 1.617010 | 1.0 |
4 | 2.158036 | 1.0 |
... | ... | ... |
495 | 0.915994 | -1.0 |
496 | 1.642011 | 1.0 |
497 | 0.727997 | 1.0 |
498 | 1.295995 | 1.0 |
499 | 2.807993 | 1.0 |
500 rows × 2 columns
We can instead supply a matrix
(or array
) of parameters and ssms
, will know how to handle that too.
This usage makes sense if you e.g. care about trial-wise parameterizations. Just supply a matrix of parameters, and set the n_samples = 1
.
# a changes trial wise
a_true_trialwise = np.random.normal(loc=2, scale=0.3, size=1000)
theta_mat = np.zeros((1000, 4))
theta_mat[:, 0] = v_true
theta_mat[:, 1] = a_true_trialwise
theta_mat[:, 2] = z_true
theta_mat[:, 3] = t_true
# simulate data
sim_out_trialwise = simulator(
theta=theta_mat, # parameter_matrix
model="ddm", # specify model (many are included in ssms)
n_samples=1, # number of samples for each set of parameters
)
# Turn into nice dataset
dataset_trialwise = pd.DataFrame(
np.column_stack(
[sim_out_trialwise["rts"][:, 0], sim_out_trialwise["choices"][:, 0]]
),
columns=["rt", "response"],
)
dataset_trialwise
rt | response | |
---|---|---|
0 | 2.607008 | 1.0 |
1 | 4.322884 | 1.0 |
2 | 2.611008 | 1.0 |
3 | 11.509940 | 1.0 |
4 | 1.749016 | 1.0 |
... | ... | ... |
995 | 1.715015 | 1.0 |
996 | 9.783242 | -1.0 |
997 | 5.137825 | 1.0 |
998 | 1.725015 | 1.0 |
999 | 1.994028 | 1.0 |
1000 rows × 2 columns
We will use ssms throughout the tutorial to generate data, relying on both, the trial-wise parameter matrices and the simple parameter list as natural for the respective example.
ArviZ for Plotting¶
We use the ArviZ package for most of our plotting needs. ArviZ is a useful aid for plotting when doing anything Bayesian.
It works with HSSM out of the box, by virtue of HSSMs reliance on PyMC for model construction and sampling.
Checking out the ArviZ Documentation is a good idea to give you communication superpowers for not only your HSSM results, but also other libraries in the Bayesian Toolkit such as NumPyro or STAN.
We will see ArviZ plots throughout the notebook.
Main Tutorial¶
# Basics
import os
import sys
import time
from matplotlib import pyplot as plt
import arviz as az # Visualization
import pytensor # Graph-based tensor library
import hssm
# import ssms.basic_simulators # Model simulators
import hddm_wfpt
import bambi as bmb
# Setting float precision in pytensor
pytensor.config.floatX = "float32"
from jax.config import config
config.update("jax_enable_x64", False)
Initial Dataset¶
Using our knowledge of ssms, we can proceed to simulate a simple dataset for our first example.
# Specify parameter values
v_true, a_true, z_true, t_true = [0.5, 1.5, 0.5, 0.2]
# Simulate data
sim_out = simulator([v_true, a_true, z_true, t_true], model="ddm", n_samples=500)
# Turn data into a pandas dataframe
dataset = pd.DataFrame(
np.column_stack([sim_out["rts"][:, 0], sim_out["choices"][:, 0]]),
columns=["rt", "response"],
)
dataset
rt | response | |
---|---|---|
0 | 8.003617 | 1.0 |
1 | 1.141991 | 1.0 |
2 | 6.076756 | 1.0 |
3 | 0.527999 | 1.0 |
4 | 1.060992 | 1.0 |
... | ... | ... |
495 | 1.001993 | 1.0 |
496 | 1.358998 | 1.0 |
497 | 1.423001 | 1.0 |
498 | 1.150991 | -1.0 |
499 | 1.638011 | 1.0 |
500 rows × 2 columns
First HSSM Model¶
In this example we will use the analytical likelihood function computed as suggested in this paper.
Instantiate the model¶
To instantiate our HSSM
class, in the simplest version, we only need to provide an appropriate dataset.
The dataset is expected to be a pandas.DataFrame
with at least two columns, respectively called rt
(for reaction time) and response
.
Our data simulated above is already in the correct format, so let us try to construct the class.
NOTE:
If you are a user of the HDDM python package, this workflow should seem very familiar.
simple_ddm_model = hssm.HSSM(data=dataset)
print(simple_ddm_model)
Hierarchical Sequential Sampling Model Model: ddm Response variable: rt,response Likelihood: analytical Observations: 500 Parameters: v: Prior: Normal(mu: 0.0, sigma: 2.0) Explicit bounds: (-inf, inf) a: Prior: HalfNormal(sigma: 2.0) Explicit bounds: (0.0, inf) z: Prior: Uniform(lower: 0.0, upper: 1.0) Explicit bounds: (0.0, 1.0) t: Prior: HalfNormal(sigma: 2.0, initval: 0.10000000149011612) Explicit bounds: (0.0, inf) Lapse probability: 0.05 Lapse distribution: Uniform(lower: 0.0, upper: 10.0)
The print()
function gives us some basic information about our model including the number of observations the parameters in the model and their respective prior setting. We can also create a nice little graphical represnetaion of our model...
Model Graph¶
Since HSSM
creates a PyMC Model
, we can can use the .graph()
function, to get a graphical representation of the the model we created.
simple_ddm_model.graph()
This is the simplest model we can build. We have our basic parameters (unobserved, white node), and our observed reaction times and choices (SSMRandomVariable
, grey node).
The graph()
function becomes a lot more interesting for more complicated models!
Sample from the Model¶
We can now call the .sample()
function, to get posterior samples. The main arguments you may want to change are listed in the function call below.
Importantly, multiple backends are possible. We choose the nuts_numpyro
backend below,
which in turn compiles the model to a JAX
function.
infer_data_simple_ddm_model = simple_ddm_model.sample(
sampler="nuts_numpyro", # type of sampler to choose, 'nuts_numpyro', 'nuts_blackjax' of default pymc nuts sampler
cores=1, # how many cores to use
chains=2, # how many chains to run
draws=500, # number of draws from the markov chain
tune=500, # number of burn-in samples
idata_kwargs=dict(log_likelihood=True), # return log likelihood
) # mp_ctx="forkserver")
Compiling... Compilation time = 0:00:01.690728 Sampling...
/Users/yxu150/HSSM/.venv/lib/python3.9/site-packages/jax/_src/numpy/array_methods.py:796: UserWarning: Explicitly requested dtype float64 requested in astype is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more. return getattr(self.aval, name).fun(self, *args, **kwargs) /Users/yxu150/HSSM/.venv/lib/python3.9/site-packages/pytensor/link/jax/dispatch/tensor_basic.py:177: UserWarning: Explicitly requested dtype float64 requested in array is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more. return jnp.array(x, dtype=op.dtype)
0%| | 0/1000 [00:00<?, ?it/s]
0%| | 0/1000 [00:00<?, ?it/s]
Sampling time = 0:00:06.244123 Transforming variables... Transformation time = 0:00:00.040493 Computing Log Likelihood...
/Users/yxu150/HSSM/.venv/lib/python3.9/site-packages/jax/_src/numpy/array_methods.py:796: UserWarning: Explicitly requested dtype float64 requested in astype is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more. return getattr(self.aval, name).fun(self, *args, **kwargs)
Log Likelihood time = 0:00:01.409715
type(infer_data_simple_ddm_model)
arviz.data.inference_data.InferenceData
Errr... let's look at this object in a bit more detail!
Inference Data / What gets returned from the sampler?¶
The sampler returns an ArviZ InferenceData
object.
To understand all the logic behind these objects and how they mesh with the Bayesian Workflow, we refer you to the ArviZ Documentation.
But let's take a quick high-level look to understand roughly what we are dealing with here!
infer_data_simple_ddm_model
-
<xarray.Dataset> Dimensions: (chain: 2, draw: 500) Coordinates: * chain (chain) int64 0 1 * draw (draw) int64 0 1 2 3 4 5 6 7 8 ... 492 493 494 495 496 497 498 499 Data variables: v (chain, draw) float32 0.6451 0.6595 0.6821 ... 0.7065 0.7184 0.6548 z (chain, draw) float32 0.4865 0.4897 0.4857 ... 0.4857 0.4847 0.4752 t (chain, draw) float32 0.23 0.1983 0.2448 ... 0.2455 0.2138 0.1803 a (chain, draw) float32 1.591 1.529 1.525 1.48 ... 1.531 1.59 1.615 Attributes: created_at: 2023-09-05T18:26:18.945748 arviz_version: 0.14.0 modeling_interface: bambi modeling_interface_version: 0.12.0
-
<xarray.Dataset> Dimensions: (chain: 2, draw: 500, rt,response_obs: 500) Coordinates: * chain (chain) int64 0 1 * draw (draw) int64 0 1 2 3 4 5 6 ... 493 494 495 496 497 498 499 * rt,response_obs (rt,response_obs) int64 0 1 2 3 4 5 ... 495 496 497 498 499 Data variables: rt,response (chain, draw, rt,response_obs) float32 -4.733 ... -1.183 Attributes: created_at: 2023-09-05T18:26:18.949085 arviz_version: 0.14.0 modeling_interface: bambi modeling_interface_version: 0.12.0
-
<xarray.Dataset> Dimensions: (chain: 2, draw: 500) Coordinates: * chain (chain) int64 0 1 * draw (draw) int64 0 1 2 3 4 5 6 ... 493 494 495 496 497 498 499 Data variables: acceptance_rate (chain, draw) float32 0.9642 0.8481 0.98 ... 0.8647 0.9917 step_size (chain, draw) float32 0.415 0.415 0.415 ... 0.3348 0.3348 diverging (chain, draw) bool False False False ... False False False energy (chain, draw) float32 978.6 975.4 975.4 ... 976.8 975.0 n_steps (chain, draw) int32 7 3 11 15 7 7 3 7 ... 11 7 15 7 7 7 7 tree_depth (chain, draw) int64 3 2 4 4 3 3 2 3 3 ... 4 4 4 3 4 3 3 3 3 lp (chain, draw) float32 973.4 974.7 972.7 ... 973.8 974.1 Attributes: created_at: 2023-09-05T18:26:18.947797 arviz_version: 0.14.0 modeling_interface: bambi modeling_interface_version: 0.12.0
-
<xarray.Dataset> Dimensions: (rt,response_obs: 500, rt,response_extra_dim_0: 2) Coordinates: * rt,response_obs (rt,response_obs) int64 0 1 2 3 ... 496 497 498 499 * rt,response_extra_dim_0 (rt,response_extra_dim_0) int64 0 1 Data variables: rt,response (rt,response_obs, rt,response_extra_dim_0) float32 ... Attributes: created_at: 2023-09-05T18:26:18.949383 arviz_version: 0.14.0 inference_library: numpyro inference_library_version: 0.12.1 sampling_time: 6.244123 modeling_interface: bambi modeling_interface_version: 0.12.0
We see that in our case, infer_data_simple_ddm_model
contains four basic types of data (note: this is extensible!)
posterior
log_likelihood
sample_stats
observed_data
The posterior
object contains our traces for each of the parameters in the model. The log_likelihood
field contains the trial wise log-likelihoods for each sample from the posterior. The sample_stats
field contains information about the sampler run. This can be important for chain diagnostics, but we will not dwell on this here. Finally we retreive our observed_data
.
az.summary(infer_data_simple_ddm_model)
mean | sd | hdi_3% | hdi_97% | mcse_mean | mcse_sd | ess_bulk | ess_tail | r_hat | |
---|---|---|---|---|---|---|---|---|---|
v | 0.654 | 0.050 | 0.565 | 0.753 | 0.002 | 0.002 | 478.0 | 635.0 | 1.0 |
z | 0.485 | 0.019 | 0.454 | 0.525 | 0.001 | 0.001 | 497.0 | 538.0 | 1.0 |
t | 0.230 | 0.034 | 0.168 | 0.292 | 0.002 | 0.001 | 449.0 | 373.0 | 1.0 |
a | 1.547 | 0.046 | 1.467 | 1.635 | 0.002 | 0.001 | 597.0 | 663.0 | 1.0 |
This table returns the parameter-wise mean of our posterior and a few extra statistics.
Next, we can plot our traces directly.
Trace plot¶
az.plot_trace(
infer_data_simple_ddm_model,
var_names="~log_likelihood", # we exclude the log_likelihood traces here
)
plt.tight_layout()
/Users/yxu150/HSSM/.venv/lib/python3.9/site-packages/arviz/utils.py:134: UserWarning: Items starting with ~: ['log_likelihood'] have not been found and will be ignored warnings.warn(
The .sample()
function also sets a trace
attribute, on our hssm
class, so instead, we could call the plot like so:
az.plot_trace(simple_ddm_model.traces);
In this tutorial we are most often going to use the latter way of accessing the traces, but there is no preferred option.
Let's look at a few more plots.
Forest Plot¶
The forest plot is commonly used for a quick visual check of the marginal posteriors. It is very effective for intuitive communication of results.
az.plot_forest(simple_ddm_model.traces)
array([<Axes: title={'center': '94.0% HDI'}>], dtype=object)
Basic Marginal Posterior Plot¶
Another way to view the marginal posteriors is provided by the plot_posterior()
function. It shows the mean and by default the $94\%$ HDIs.
az.plot_posterior(simple_ddm_model.traces)
array([<Axes: title={'center': 'v'}>, <Axes: title={'center': 'z'}>, <Axes: title={'center': 't'}>, <Axes: title={'center': 'a'}>], dtype=object)
Posterior Pair Plot¶
The posterior pair plot is useful to check for simple parameter tradeoffs thst may emerge.
az.plot_pair(simple_ddm_model.traces, kind="kde")
array([[<Axes: ylabel='z'>, <Axes: >, <Axes: >], [<Axes: ylabel='t'>, <Axes: >, <Axes: >], [<Axes: xlabel='v', ylabel='a'>, <Axes: xlabel='z'>, <Axes: xlabel='t'>]], dtype=object)
This is just the beginning: ArviZ has a much broader spectrum of graphs and other convenience function available. Just check the documentation.
HSSM Model based on LAN likelihood¶
With HSSM you can switch between pre-supplied models with a simple change of argument. The type of likelihood that will be accessed might change in the background for you.
Here we see an example in which the underlying likelihood is now a LAN.
We will talk more about different types of likelihood functions and backends later in the tutorial. For now just keep the following in mind:
There are three types of likelihoods
analytic
approx_differentiable
blackbox
To check which type is used in your HSSM model simple type:
simple_ddm_model.loglik_kind
'analytical'
Ah... we were using an analytical
likelihood with the DDM model in the last section.
Now let's see something different!
Simulating Angle Data¶
Again, let us simulate a simple dataset. This time we will use the angle
model (passed via the model
argument to the simulator()
function).
This model is distinguished from the basic ddm
model by an additional theta
parameter which specifies the angle with which the decision boundaries collapse over time.
DDMs with collapsing bounds have been of significant interest in the theoretical literature, but applications were rare due to a lack of analytical likelihoods. HSSM facilitates inference with such models via the our approx_differentiable
likelihoods. HSSM ships with a few predefined models based on LANs, but really we don't want to overemphasize those. They reflect the research interest of our and adjacent labs to a great extend.
Instead, we encourage the community to contribute to this model reservoir (more on this later).
# Simulate angle data
v_true, a_true, z_true, t_true, theta_true = [0.5, 1.5, 0.5, 0.5, 0.2]
obs_angle = simulator(
[v_true, a_true, z_true, t_true, theta_true], model="angle", n_samples=1000
)
dataset_angle = pd.DataFrame(
np.column_stack([obs_angle["rts"][:, 0], obs_angle["choices"][:, 0]]),
columns=["rt", "response"],
)
We pass a single additional argument to our HSSM
class and set model='angle'
.
model_angle = hssm.HSSM(data=dataset_angle, model="angle")
model_angle
Hierarchical Sequential Sampling Model Model: angle Response variable: rt,response Likelihood: approx_differentiable Observations: 1000 Parameters: v: Prior: Uniform(lower: -3.0, upper: 3.0) Explicit bounds: (-3.0, 3.0) a: Prior: Uniform(lower: 0.30000001192092896, upper: 3.0) Explicit bounds: (0.3, 3.0) z: Prior: Uniform(lower: 0.10000000149011612, upper: 0.8999999761581421) Explicit bounds: (0.1, 0.9) t: Prior: Uniform(lower: 0.0010000000474974513, upper: 2.0) Explicit bounds: (0.001, 2.0) theta: Prior: Uniform(lower: -0.10000000149011612, upper: 1.2999999523162842) Explicit bounds: (-0.1, 1.3) Lapse probability: 0.05 Lapse distribution: Uniform(lower: 0.0, upper: 10.0)
Let's check the type of likelihood that is used under the hood ...
model_angle.loglik_kind
'approx_differentiable'
Ok so here we rely on a likelihood of the approx_differentiable
kind.
As discussed, with the initial set of pre-supplied likelihoods, this implies that we are using a LAN in the background.
from jax.config import config
config.update("jax_enable_x64", False)
infer_data_angle = model_angle.sample(
sampler="nuts_numpyro",
chains=2,
cores=2,
draws=500,
tune=500,
idata_kwargs=dict(log_likelihood=False), # no need to return likelihoods here
)
Compiling... Compilation time = 0:00:00.656823 Sampling...
/Users/yxu150/HSSM/.venv/lib/python3.9/site-packages/jax/_src/numpy/array_methods.py:796: UserWarning: Explicitly requested dtype float64 requested in astype is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more. return getattr(self.aval, name).fun(self, *args, **kwargs) /Users/yxu150/HSSM/.venv/lib/python3.9/site-packages/pytensor/link/jax/dispatch/tensor_basic.py:177: UserWarning: Explicitly requested dtype float64 requested in array is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more. return jnp.array(x, dtype=op.dtype)
0%| | 0/1000 [00:00<?, ?it/s]
0%| | 0/1000 [00:00<?, ?it/s]
Sampling time = 0:00:16.655899 Transforming variables... Transformation time = 0:00:00.074300
az.plot_trace(model_angle.traces)
plt.tight_layout()
Choosing Priors¶
HSSM allows you to specify priors quite freely. If you used HDDM previously, you may feel relieved to read that your hands are now untied!
With HSSM we have multiple routes to priors. But let's first consider a special case:
Fixing a parameter to a given value¶
Assume that instead of fitting all parameters of the DDM,
we instead want to fit only the v
(drift) parameter, setting all other parameters to fixed scalar values.
HSSM makes this extremely easy!
ddm_model_only_v = hssm.HSSM(data=dataset, model="ddm", a=1.5, t=0.2, z=0.5)
ddm_model_only_v.graph()
ddm_model_only_v.sample(
sampler="nuts_numpyro",
chains=2,
cores=2,
draws=500,
tune=500,
idata_kwargs=dict(log_likelihood=False), # no need to return likelihoods here
)
Compiling... Compilation time = 0:00:04.302885 Sampling...
/Users/yxu150/HSSM/.venv/lib/python3.9/site-packages/jax/_src/numpy/array_methods.py:796: UserWarning: Explicitly requested dtype float64 requested in astype is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more. return getattr(self.aval, name).fun(self, *args, **kwargs) /Users/yxu150/HSSM/.venv/lib/python3.9/site-packages/pytensor/link/jax/dispatch/tensor_basic.py:177: UserWarning: Explicitly requested dtype float64 requested in array is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more. return jnp.array(x, dtype=op.dtype)
0%| | 0/1000 [00:00<?, ?it/s]
0%| | 0/1000 [00:00<?, ?it/s]
Sampling time = 0:00:01.143003 Transforming variables... Transformation time = 0:00:00.003312
-
<xarray.Dataset> Dimensions: (chain: 2, draw: 500) Coordinates: * chain (chain) int64 0 1 * draw (draw) int64 0 1 2 3 4 5 6 7 8 ... 492 493 494 495 496 497 498 499 Data variables: v (chain, draw) float32 0.5998 0.5594 0.5635 ... 0.5752 0.56 0.5864 Attributes: created_at: 2023-09-05T18:26:45.837697 arviz_version: 0.14.0 modeling_interface: bambi modeling_interface_version: 0.12.0
-
<xarray.Dataset> Dimensions: (chain: 2, draw: 500) Coordinates: * chain (chain) int64 0 1 * draw (draw) int64 0 1 2 3 4 5 6 ... 493 494 495 496 497 498 499 Data variables: acceptance_rate (chain, draw) float32 0.9902 0.8726 1.0 ... 0.9331 0.9917 step_size (chain, draw) float32 0.8014 0.8014 ... 0.8032 0.8032 diverging (chain, draw) bool False False False ... False False False energy (chain, draw) float32 972.3 973.3 972.6 ... 972.6 972.6 n_steps (chain, draw) int32 3 3 1 7 3 1 1 3 1 ... 3 1 3 7 1 7 1 1 3 tree_depth (chain, draw) int64 2 2 1 3 2 1 1 2 1 ... 2 1 2 3 1 3 1 1 2 lp (chain, draw) float32 972.1 972.6 972.5 ... 972.6 972.1 Attributes: created_at: 2023-09-05T18:26:45.839185 arviz_version: 0.14.0 modeling_interface: bambi modeling_interface_version: 0.12.0
-
<xarray.Dataset> Dimensions: (rt,response_obs: 500, rt,response_extra_dim_0: 2) Coordinates: * rt,response_obs (rt,response_obs) int64 0 1 2 3 ... 496 497 498 499 * rt,response_extra_dim_0 (rt,response_extra_dim_0) int64 0 1 Data variables: rt,response (rt,response_obs, rt,response_extra_dim_0) float32 ... Attributes: created_at: 2023-09-05T18:26:45.840376 arviz_version: 0.14.0 inference_library: numpyro inference_library_version: 0.12.1 sampling_time: 1.143003 modeling_interface: bambi modeling_interface_version: 0.12.0
az.plot_trace(ddm_model_only_v.traces)
array([[<Axes: title={'center': 'v'}>, <Axes: title={'center': 'v'}>]], dtype=object)
Named priors¶
We can choose any PyMC Distribution
to specify a prior for a given parameter.
Even better, if natural parameter bounds are provided, HSSM automatically truncates the prior distribution so that it respect these bounds.
Below is an example in which we specify a Normal prior on the v
parameter of the DDM.
We choose a ridiculously low $\sigma$ value, to illustrate it's regularizing effect on the parameter (just so we see a difference and you are convinced that something changed).
model_normal = hssm.HSSM(
data=dataset,
include=[
{
"name": "v",
"prior": {"name": "Normal", "mu": 0, "sigma": 0.01},
}
],
)
model_normal
Hierarchical Sequential Sampling Model Model: ddm Response variable: rt,response Likelihood: analytical Observations: 500 Parameters: v: Prior: Normal(mu: 0.0, sigma: 0.009999999776482582) Explicit bounds: (-inf, inf) a: Prior: HalfNormal(sigma: 2.0) Explicit bounds: (0.0, inf) z: Prior: Uniform(lower: 0.0, upper: 1.0) Explicit bounds: (0.0, 1.0) t: Prior: HalfNormal(sigma: 2.0, initval: 0.10000000149011612) Explicit bounds: (0.0, inf) Lapse probability: 0.05 Lapse distribution: Uniform(lower: 0.0, upper: 10.0)
infer_data_normal = model_normal.sample(
sampler="nuts_numpyro",
chains=2,
cores=2,
draws=500,
tune=500,
idata_kwargs=dict(log_likelihood=False), # no need to return likelihoods here
)
Compiling... Compilation time = 0:00:01.595330 Sampling...
/Users/yxu150/HSSM/.venv/lib/python3.9/site-packages/jax/_src/numpy/array_methods.py:796: UserWarning: Explicitly requested dtype float64 requested in astype is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more. return getattr(self.aval, name).fun(self, *args, **kwargs) /Users/yxu150/HSSM/.venv/lib/python3.9/site-packages/pytensor/link/jax/dispatch/tensor_basic.py:177: UserWarning: Explicitly requested dtype float64 requested in array is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more. return jnp.array(x, dtype=op.dtype)
0%| | 0/1000 [00:00<?, ?it/s]
0%| | 0/1000 [00:00<?, ?it/s]
Sampling time = 0:00:06.877557 Transforming variables... Transformation time = 0:00:00.009919
az.plot_trace(model_normal.traces)
array([[<Axes: title={'center': 'v'}>, <Axes: title={'center': 'v'}>], [<Axes: title={'center': 'z'}>, <Axes: title={'center': 'z'}>], [<Axes: title={'center': 't'}>, <Axes: title={'center': 't'}>], [<Axes: title={'center': 'a'}>, <Axes: title={'center': 'a'}>]], dtype=object)
Observe how we reused our previous dataset with underlying parameters
v = 0.5
a = 1.5
z = 0.5
t = 0.2
In contrast to our previous sampler round, in which we used Uniform priors, here the v
estimate is shrunk severley towared $0$ and the t
and z
parameter estimates are very biased to make up for this distortion.
HSSM Model with Regression¶
Crucial to the scope of HSSM is the ability to link parameters with trial-by-trial covariates via (hierarchical, but more on this later) general linear models.
In this section we explore how HSSM deals with these models. No big surprise here... it's simple!
Case 1: One parameter is a Regression Target¶
Simulating Data¶
Let's first simulate some data, where the trial-by-trial parameters of the v
parameter in our model are driven by a simple linear regression model.
The regression model is driven by two (random) covariates x
and y
, respectively with coefficients of $0.8$ and $0.3$ which are also simulated below.
We set the intercept to $0.3$.
The rest of the parameters are fixed to single values as before.
# Set up trial by trial parameters
intercept = 0.3
x = np.random.uniform(-1, 1, size=1000)
y = np.random.uniform(-1, 1, size=1000)
v = intercept + (0.8 * x) + (0.3 * y)
true_values = np.column_stack(
[v, np.repeat([[1.5, 0.5, 0.5, 0.0]], axis=0, repeats=1000)]
)
# Get mode simulations
obs_ddm_reg_v = simulator(true_values, model="ddm", n_samples=1)
dataset_reg_v = pd.DataFrame(
{
"rt": obs_ddm_reg_v["rts"].flatten(),
"response": obs_ddm_reg_v["choices"].flatten(),
"x": x,
"y": y,
}
)
dataset_reg_v
rt | response | x | y | |
---|---|---|---|---|
0 | 2.018015 | -1 | -0.929824 | -0.647178 |
1 | 1.049996 | -1 | -0.916070 | 0.730554 |
2 | 3.524963 | -1 | -0.621662 | -0.147269 |
3 | 2.976003 | 1 | -0.782370 | 0.524242 |
4 | 2.474036 | 1 | -0.351246 | 0.874034 |
... | ... | ... | ... | ... |
995 | 1.104996 | 1 | 0.877501 | -0.815849 |
996 | 1.363992 | -1 | -0.338860 | -0.166251 |
997 | 2.305028 | -1 | -0.887428 | 0.450602 |
998 | 1.064996 | -1 | -0.833160 | 0.710574 |
999 | 2.746020 | 1 | 0.398379 | -0.818633 |
1000 rows × 4 columns
We now create the HSSM
model.
Notice how we set the include
argument. The include argument expects a list of dictionaries, one dictionary for each parameter to be specified via a regression model.
Four keys
are expected to be set:
- The
name
of the parameter, - Potentially a
prior
for each of the regression level parameters ($\beta$'s), - The regression
formula
- A
link
function.
The regression formula follows the syntax in the formulae python package (as used by the Bambi package for building Bayesian Hierarchical Regression Models.
Bambi forms the main model-construction backend of HSSM.
model_reg_v = hssm.HSSM(
data=dataset_reg_v,
include=[
{
"name": "v",
"prior": {
"Intercept": {"name": "Uniform", "lower": -3.0, "upper": 3.0},
"x": {"name": "Uniform", "lower": -1.0, "upper": 1.0},
"y": {"name": "Uniform", "lower": -1.0, "upper": 1.0},
},
"formula": "v ~ 1 + x + y",
"link": "identity",
}
],
)
model_reg_v
Hierarchical Sequential Sampling Model Model: ddm Response variable: rt,response Likelihood: analytical Observations: 1000 Parameters: v: Formula: v ~ 1 + x + y Priors: v_Intercept ~ Uniform(lower: -3.0, upper: 3.0) v_x ~ Uniform(lower: -1.0, upper: 1.0) v_y ~ Uniform(lower: -1.0, upper: 1.0) Link: identity Explicit bounds: (-inf, inf) a: Prior: HalfNormal(sigma: 2.0) Explicit bounds: (0.0, inf) z: Prior: Uniform(lower: 0.0, upper: 1.0) Explicit bounds: (0.0, 1.0) t: Prior: HalfNormal(sigma: 2.0, initval: 0.10000000149011612) Explicit bounds: (0.0, inf) Lapse probability: 0.05 Lapse distribution: Uniform(lower: 0.0, upper: 10.0)
Notice how v
is now set as a regression.
model_reg_v.graph()
infer_data_reg_v = model_reg_v.sample(
sampler="nuts_numpyro", chains=1, cores=1, draws=500, tune=500
)
Compiling... Compilation time = 0:00:01.281230 Sampling...
/Users/yxu150/HSSM/.venv/lib/python3.9/site-packages/jax/_src/numpy/array_methods.py:796: UserWarning: Explicitly requested dtype float64 requested in astype is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more. return getattr(self.aval, name).fun(self, *args, **kwargs) /Users/yxu150/HSSM/.venv/lib/python3.9/site-packages/pytensor/link/jax/dispatch/tensor_basic.py:177: UserWarning: Explicitly requested dtype float64 requested in array is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more. return jnp.array(x, dtype=op.dtype) sample: 100%|█████████████████████████| 1000/1000 [00:12<00:00, 80.68it/s, 15 steps of size 2.91e-01. acc. prob=0.95]
Sampling time = 0:00:14.317063 Transforming variables... Transformation time = 0:00:00.103803
az.plot_trace(model_reg_v.traces)
plt.tight_layout()
# Looks like parameter recovery was successful
az.summary(model_reg_v.traces)
arviz - WARNING - Shape validation failed: input_shape: (1, 500), minimum_shape: (chains=2, draws=4)
mean | sd | hdi_3% | hdi_97% | mcse_mean | mcse_sd | ess_bulk | ess_tail | r_hat | |
---|---|---|---|---|---|---|---|---|---|
z | 0.505 | 0.012 | 0.481 | 0.524 | 0.001 | 0.000 | 348.0 | 286.0 | NaN |
t | 0.509 | 0.020 | 0.472 | 0.546 | 0.001 | 0.001 | 320.0 | 187.0 | NaN |
a | 1.490 | 0.030 | 1.436 | 1.551 | 0.002 | 0.001 | 231.0 | 236.0 | NaN |
v_Intercept | 0.355 | 0.034 | 0.289 | 0.409 | 0.002 | 0.001 | 275.0 | 424.0 | NaN |
v_x | 0.912 | 0.049 | 0.833 | 1.000 | 0.004 | 0.003 | 129.0 | 73.0 | NaN |
v_y | 0.334 | 0.047 | 0.250 | 0.420 | 0.003 | 0.002 | 362.0 | 272.0 | NaN |
Case 2: One parameter is a Regression (LAN)¶
We can do the same thing with the angle
model.
Note:
Our dataset was generated from the basic DDM here, so since the DDM assumes stable bounds, we expect the theta
(angle of linear collapse) parameter to be recovered as close to $0$.
model_reg_v_angle = hssm.HSSM(
data=dataset_reg_v,
model="angle",
include=[
{
"name": "v",
"prior": {
"Intercept": {
"name": "Uniform",
"lower": -3.0,
"upper": 3.0,
# "initval": 0 # optional --> set the initial value of the parameter (to e.g. avoid boundary violations at the intial sampling step)
},
"x": {
"name": "Uniform",
"lower": -1.0,
"upper": 1.0,
},
"y": {"name": "Uniform", "lower": -1.0, "upper": 1.0},
},
"formula": "v ~ 1 + x + y",
"link": "identity",
}
],
)
model_reg_v_angle.graph()
trace_reg_v_angle = model_reg_v_angle.sample(
sampler="nuts_numpyro", chains=1, cores=1, draws=500, tune=500
)
Compiling... Compilation time = 0:00:00.633356 Sampling...
/Users/yxu150/HSSM/.venv/lib/python3.9/site-packages/jax/_src/numpy/array_methods.py:796: UserWarning: Explicitly requested dtype float64 requested in astype is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more. return getattr(self.aval, name).fun(self, *args, **kwargs) /Users/yxu150/HSSM/.venv/lib/python3.9/site-packages/pytensor/link/jax/dispatch/tensor_basic.py:177: UserWarning: Explicitly requested dtype float64 requested in array is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more. return jnp.array(x, dtype=op.dtype) sample: 100%|██████████████████████████| 1000/1000 [00:16<00:00, 60.77it/s, 7 steps of size 3.89e-01. acc. prob=0.85]
Sampling time = 0:00:17.704611 Transforming variables... Transformation time = 0:00:00.064387
az.plot_trace(model_reg_v_angle.traces)
plt.tight_layout()
Great! theta
is recovered correctly, on top of that, we have good recovery for all other parameters!
Case 3: Multiple Parameters are Regression Targets (LAN)¶
Let's get a bit more ambitious. We may, for example, want to try a regression on a few of our basic model parameters at once. Below we show an example where we model both the a
and the v
parameters with a regression.
NOTE:
In our dataset of this section, only v
is actually driven by a trial-by-trial regression, so we expect the regression coefficients for a
to hover around $0$ in our posterior.
# Instantiate our hssm model
hssm_reg_v_a_angle = hssm.HSSM(
data=dataset_reg_v,
model="angle",
include=[
{
"name": "v",
"prior": {
"Intercept": {"name": "Uniform", "lower": -3.0, "upper": 3.0},
"x": {"name": "Uniform", "lower": -1.0, "upper": 1.0},
"y": {"name": "Uniform", "lower": -1.0, "upper": 1.0},
},
"formula": "v ~ 1 + x + y",
},
{
"name": "a",
"prior": {
"Intercept": {"name": "Uniform", "lower": 0.5, "upper": 3.0},
"x": {"name": "Uniform", "lower": -1.0, "upper": 1.0},
"y": {"name": "Uniform", "lower": -1.0, "upper": 1.0},
},
"formula": "a ~ 1 + x + y",
},
],
)
hssm_reg_v_a_angle
Hierarchical Sequential Sampling Model Model: angle Response variable: rt,response Likelihood: approx_differentiable Observations: 1000 Parameters: v: Formula: v ~ 1 + x + y Priors: v_Intercept ~ Uniform(lower: -3.0, upper: 3.0) v_x ~ Uniform(lower: -1.0, upper: 1.0) v_y ~ Uniform(lower: -1.0, upper: 1.0) Link: identity Explicit bounds: (-3.0, 3.0) a: Formula: a ~ 1 + x + y Priors: a_Intercept ~ Uniform(lower: 0.5, upper: 3.0) a_x ~ Uniform(lower: -1.0, upper: 1.0) a_y ~ Uniform(lower: -1.0, upper: 1.0) Link: identity Explicit bounds: (0.3, 3.0) z: Prior: Uniform(lower: 0.10000000149011612, upper: 0.8999999761581421) Explicit bounds: (0.1, 0.9) t: Prior: Uniform(lower: 0.0010000000474974513, upper: 2.0) Explicit bounds: (0.001, 2.0) theta: Prior: Uniform(lower: -0.10000000149011612, upper: 1.2999999523162842) Explicit bounds: (-0.1, 1.3) Lapse probability: 0.05 Lapse distribution: Uniform(lower: 0.0, upper: 10.0)
hssm_reg_v_a_angle.graph()
infer_data_reg_v_a = hssm_reg_v_a_angle.sample(
sampler="nuts_numpyro", chains=2, cores=1, draws=1000, tune=1000
)
Compiling... Compilation time = 0:00:01.846319 Sampling...
/Users/yxu150/HSSM/.venv/lib/python3.9/site-packages/jax/_src/numpy/array_methods.py:796: UserWarning: Explicitly requested dtype float64 requested in astype is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more. return getattr(self.aval, name).fun(self, *args, **kwargs) /Users/yxu150/HSSM/.venv/lib/python3.9/site-packages/pytensor/link/jax/dispatch/tensor_basic.py:177: UserWarning: Explicitly requested dtype float64 requested in array is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more. return jnp.array(x, dtype=op.dtype)
0%| | 0/2000 [00:00<?, ?it/s]
0%| | 0/2000 [00:00<?, ?it/s]
Sampling time = 0:00:33.882526 Transforming variables... Transformation time = 0:00:00.282471
az.summary(infer_data_reg_v_a, var_names=["~rt,response_a"])
/Users/yxu150/HSSM/.venv/lib/python3.9/site-packages/arviz/utils.py:134: UserWarning: Items starting with ~: ['rt,response_a'] have not been found and will be ignored warnings.warn(
mean | sd | hdi_3% | hdi_97% | mcse_mean | mcse_sd | ess_bulk | ess_tail | r_hat | |
---|---|---|---|---|---|---|---|---|---|
z | 0.514 | 0.013 | 0.490 | 0.538 | 0.000 | 0.000 | 1611.0 | 1316.0 | 1.0 |
t | 0.524 | 0.023 | 0.483 | 0.571 | 0.001 | 0.001 | 1069.0 | 1106.0 | 1.0 |
theta | -0.022 | 0.018 | -0.057 | 0.010 | 0.001 | 0.000 | 1128.0 | 972.0 | 1.0 |
v_Intercept | 0.355 | 0.033 | 0.291 | 0.413 | 0.001 | 0.001 | 1714.0 | 1479.0 | 1.0 |
v_x | 0.930 | 0.042 | 0.859 | 1.000 | 0.001 | 0.001 | 1040.0 | 668.0 | 1.0 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
a[995] | 1.452 | 0.065 | 1.331 | 1.572 | 0.002 | 0.001 | 1249.0 | 1133.0 | 1.0 |
a[996] | 1.411 | 0.043 | 1.336 | 1.499 | 0.001 | 0.001 | 949.0 | 993.0 | 1.0 |
a[997] | 1.415 | 0.055 | 1.311 | 1.518 | 0.002 | 0.001 | 1266.0 | 1075.0 | 1.0 |
a[998] | 1.437 | 0.058 | 1.319 | 1.544 | 0.002 | 0.001 | 1223.0 | 1042.0 | 1.0 |
a[999] | 1.418 | 0.054 | 1.319 | 1.519 | 0.002 | 0.001 | 1214.0 | 1232.0 | 1.0 |
1009 rows × 9 columns
We successfully recover our regression betas for a
! Moreover, no warning signs concerning our chains.
Hierarchical Inference¶
Let's try to fit a hierarchical model now. We will simulate a dataset with $15$ participants, with $200$ observations / trials for each participant.
We define a group mean mean_v
and a group standard deviation sd_v
for the intercept parameter of the regression on v
, which we sample from a corresponding normal distribution for each subject.
Simulate Data¶
# Make some hierarchical data
n_subjects = 15 # number of subjects
n_trials = 200 # number of trials per subject
sd_v = 0.5 # sd for v-intercept
mean_v = 0.5 # mean for v-intercept
data_list = []
for i in range(n_subjects):
# Make parameters for subject i
intercept = np.random.normal(mean_v, sd_v, size=1)
x = np.random.uniform(-1, 1, size=n_trials)
y = np.random.uniform(-1, 1, size=n_trials)
v = intercept + (0.8 * x) + (0.3 * y)
true_values = np.column_stack(
[v, np.repeat([[1.5, 0.5, 0.5, 0.0]], axis=0, repeats=n_trials)]
)
# Simulate data
obs_ddm_reg_v = simulator(true_values, model="ddm", n_samples=1)
# Append simulated data to list
data_list.append(
pd.DataFrame(
{
"rt": obs_ddm_reg_v["rts"].flatten(),
"response": obs_ddm_reg_v["choices"].flatten(),
"x": x,
"y": y,
"subject": i,
}
)
)
# Make single dataframe out of subject-wise datasets
dataset_reg_v_hier = pd.concat(data_list)
dataset_reg_v_hier
rt | response | x | y | subject | |
---|---|---|---|---|---|
0 | 3.245983 | -1 | -0.259151 | 0.463476 | 0 |
1 | 3.027999 | -1 | -0.424408 | -0.949052 | 0 |
2 | 1.856007 | -1 | -0.981422 | -0.816294 | 0 |
3 | 1.260994 | 1 | -0.009248 | -0.291613 | 0 |
4 | 1.498991 | 1 | 0.569295 | -0.034869 | 0 |
... | ... | ... | ... | ... | ... |
195 | 1.300993 | 1 | -0.285688 | 0.307017 | 14 |
196 | 1.330993 | 1 | -0.213189 | -0.790106 | 14 |
197 | 0.946998 | 1 | 0.546433 | -0.395383 | 14 |
198 | 1.682999 | 1 | 0.120770 | 0.394799 | 14 |
199 | 4.434897 | 1 | -0.022957 | 0.264478 | 14 |
3000 rows × 5 columns
We can now define our HSSM
model.
We specify the regression as v ~ 1 + (1|subject) + x + y
.
(1|subject)
tells the model to create a subject-wise offset for the intercept parameter. The rest of the regression $\beta$'s is fit globally.
As an R user you may recognize this syntax from the lmer package.
Our Bambi backend is essentially a Bayesian version of lmer, quite like the BRMS package in R, which operates on top of STAN.
As a previous HDDM user, you may recognize that now proper mixed-effect models are viable!
You should be able to handle between and within subject effects naturally now!
model_reg_v_angle_hier = hssm.HSSM(
data=dataset_reg_v_hier,
model="angle",
include=[
{
"name": "v",
"prior": {
"Intercept": {
"name": "Uniform",
"lower": -3.0,
"upper": 3.0,
"initval": 0.0,
},
"x": {"name": "Uniform", "lower": -1.0, "upper": 1.0, "initval": 0.0},
"y": {"name": "Uniform", "lower": -1.0, "upper": 1.0, "initval": 0.0},
},
"formula": "v ~ 1 + (1|subject) + x + y",
"link": "identity",
}
],
)
model_reg_v_angle_hier.graph()
from jax.config import config
config.update("jax_enable_x64", False)
model_reg_v_angle_hier.sample(
sampler="nuts_numpyro", chains=2, cores=1, draws=1000, tune=1000
)
Compiling... Compilation time = 0:00:04.343406 Sampling...
/Users/yxu150/HSSM/.venv/lib/python3.9/site-packages/jax/_src/numpy/array_methods.py:796: UserWarning: Explicitly requested dtype float64 requested in astype is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more. return getattr(self.aval, name).fun(self, *args, **kwargs) /Users/yxu150/HSSM/.venv/lib/python3.9/site-packages/pytensor/link/jax/dispatch/tensor_basic.py:177: UserWarning: Explicitly requested dtype float64 requested in array is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more. return jnp.array(x, dtype=op.dtype)
0%| | 0/2000 [00:00<?, ?it/s]
0%| | 0/2000 [00:00<?, ?it/s]
Sampling time = 0:03:57.486710 Transforming variables... Transformation time = 0:00:00.086975
-
<xarray.Dataset> Dimensions: (chain: 2, draw: 1000, v_1|subject__factor_dim: 15) Coordinates: * chain (chain) int64 0 1 * draw (draw) int64 0 1 2 3 4 5 ... 995 996 997 998 999 * v_1|subject__factor_dim (v_1|subject__factor_dim) <U2 '0' '1' ... '13' '14' Data variables: z (chain, draw) float32 0.4957 0.4993 ... 0.4967 a (chain, draw) float32 1.434 1.419 ... 1.428 1.412 t (chain, draw) float32 0.5577 0.5577 ... 0.5574 theta (chain, draw) float32 -0.006189 ... -0.01521 v_Intercept (chain, draw) float64 1.013 0.9594 ... 0.7313 v_x (chain, draw) float32 0.8168 0.816 ... 0.8637 v_y (chain, draw) float32 0.2879 0.2839 ... 0.2518 v_1|subject_sigma (chain, draw) float32 0.8029 0.7611 ... 0.836 0.833 v_1|subject (chain, draw, v_1|subject__factor_dim) float32 -... Attributes: created_at: 2023-09-05T18:32:12.130975 arviz_version: 0.14.0 modeling_interface: bambi modeling_interface_version: 0.12.0
-
<xarray.Dataset> Dimensions: (chain: 2, draw: 1000) Coordinates: * chain (chain) int64 0 1 * draw (draw) int64 0 1 2 3 4 5 6 ... 993 994 995 996 997 998 999 Data variables: acceptance_rate (chain, draw) float32 0.9273 0.9963 ... 0.8765 0.9925 step_size (chain, draw) float32 0.06281 0.06281 ... 0.06693 0.06693 diverging (chain, draw) bool False False False ... False False False energy (chain, draw) float32 4.589e+03 4.583e+03 ... 4.581e+03 n_steps (chain, draw) int32 143 15 95 31 31 31 ... 31 31 31 15 63 tree_depth (chain, draw) int64 8 4 7 5 5 5 7 6 7 ... 6 6 6 6 5 5 5 4 6 lp (chain, draw) float32 4.576e+03 4.578e+03 ... 4.575e+03 Attributes: created_at: 2023-09-05T18:32:12.133855 arviz_version: 0.14.0 modeling_interface: bambi modeling_interface_version: 0.12.0
-
<xarray.Dataset> Dimensions: (rt,response_obs: 3000, rt,response_extra_dim_0: 2) Coordinates: * rt,response_obs (rt,response_obs) int64 0 1 2 3 ... 2997 2998 2999 * rt,response_extra_dim_0 (rt,response_extra_dim_0) int64 0 1 Data variables: rt,response (rt,response_obs, rt,response_extra_dim_0) float32 ... Attributes: created_at: 2023-09-05T18:32:12.134925 arviz_version: 0.14.0 inference_library: numpyro inference_library_version: 0.12.1 sampling_time: 237.48671 modeling_interface: bambi modeling_interface_version: 0.12.0
Let's look at the posteriors!
az.plot_forest(model_reg_v_angle_hier.traces)
array([<Axes: title={'center': '94.0% HDI'}>], dtype=object)
Model Comparison¶
Fitting single models is all well and good. We are however, often interested in comparing how well a few different models account for the same data. Again, through ArviZ, we have all we need for modern Bayesian Model Comparison. We will keep it simple here, just to illustrate the basic idea. The following scenario is explored.
First we generate data from a ddm
model with fixed parameters, specifically we set the a
parameter to $1.5$.
We then define two HSSM
models:
- A model which allows fitting all but the
a
parameter, which is fixed to $1.0$ (wrong) - A model which allows fitting all but the
a
parameter, which is fixed to $1.5$ (correct)
We then use the ArviZ's compare()
function, to perform model comparison via elpd_loo
.
Data Simulation¶
# Specify parameter values
# Note 'a' is set to 1.5
v_true, a_true, z_true, t_true = [0.5, 1.5, 0.5, 0.2]
# Simulate data
sim_out = simulator([v_true, a_true, z_true, t_true], model="ddm", n_samples=500)
# Turn data into a pandas dataframe
dataset_model_comp = pd.DataFrame(
np.column_stack([sim_out["rts"][:, 0], sim_out["choices"][:, 0]]),
columns=["rt", "response"],
)
print(dataset_model_comp)
rt response 0 1.357998 1.0 1 4.972836 1.0 2 1.303996 1.0 3 6.726709 1.0 4 1.945026 1.0 .. ... ... 495 1.899023 1.0 496 3.701929 1.0 497 2.558012 1.0 498 4.930840 1.0 499 2.275032 1.0 [500 rows x 2 columns]
Defining the Models¶
# 'wrong' model
model_model_comp_1 = hssm.HSSM(
data=dataset_model_comp,
model="angle",
a=1.0,
)
# 'correct' model
model_model_comp_2 = hssm.HSSM(
data=dataset_model_comp,
model="angle",
a=1.5,
)
infer_data_model_comp_1 = model_model_comp_1.sample(
sampler="nuts_numpyro",
cores=1,
chains=2,
draws=1000,
tune=1000,
idata_kwargs=dict(
log_likelihood=True
), # model comparison metrics usually need this!
)
Compiling... Compilation time = 0:00:01.148847 Sampling...
/Users/yxu150/HSSM/.venv/lib/python3.9/site-packages/jax/_src/numpy/array_methods.py:796: UserWarning: Explicitly requested dtype float64 requested in astype is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more. return getattr(self.aval, name).fun(self, *args, **kwargs) /Users/yxu150/HSSM/.venv/lib/python3.9/site-packages/pytensor/link/jax/dispatch/tensor_basic.py:177: UserWarning: Explicitly requested dtype float64 requested in array is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more. return jnp.array(x, dtype=op.dtype)
0%| | 0/2000 [00:00<?, ?it/s]
0%| | 0/2000 [00:00<?, ?it/s]
Sampling time = 0:00:10.293478 Transforming variables... Transformation time = 0:00:00.037679 Computing Log Likelihood... Log Likelihood time = 0:00:00.866599
/Users/yxu150/HSSM/.venv/lib/python3.9/site-packages/jax/_src/numpy/array_methods.py:796: UserWarning: Explicitly requested dtype float64 requested in astype is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more. return getattr(self.aval, name).fun(self, *args, **kwargs)
infer_data_model_comp_2 = model_model_comp_2.sample(
sampler="nuts_numpyro",
cores=1,
chains=2,
draws=1000,
tune=1000,
idata_kwargs=dict(
log_likelihood=True
), # model comparison metrics usually need this!
)
Compiling... Compilation time = 0:00:00.479080 Sampling...
/Users/yxu150/HSSM/.venv/lib/python3.9/site-packages/jax/_src/numpy/array_methods.py:796: UserWarning: Explicitly requested dtype float64 requested in astype is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more. return getattr(self.aval, name).fun(self, *args, **kwargs) /Users/yxu150/HSSM/.venv/lib/python3.9/site-packages/pytensor/link/jax/dispatch/tensor_basic.py:177: UserWarning: Explicitly requested dtype float64 requested in array is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more. return jnp.array(x, dtype=op.dtype)
0%| | 0/2000 [00:00<?, ?it/s]
0%| | 0/2000 [00:00<?, ?it/s]
Sampling time = 0:00:11.062675 Transforming variables... Transformation time = 0:00:00.038244 Computing Log Likelihood... Log Likelihood time = 0:00:00.415868
/Users/yxu150/HSSM/.venv/lib/python3.9/site-packages/jax/_src/numpy/array_methods.py:796: UserWarning: Explicitly requested dtype float64 requested in astype is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more. return getattr(self.aval, name).fun(self, *args, **kwargs)
Compare¶
compare_data = az.compare(
{
"a_fixed_1(wrong)": model_model_comp_1.traces,
"a_fixed_1.5(correct)": model_model_comp_2.traces,
}
)
compare_data
rank | elpd_loo | p_loo | elpd_diff | weight | se | dse | warning | scale | |
---|---|---|---|---|---|---|---|---|---|
a_fixed_1.5(correct) | 0 | -1008.046587 | 3.286643 | 0.000000 | 1.000000e+00 | 23.470703 | 0.00000 | False | log |
a_fixed_1(wrong) | 1 | -1089.589708 | 2.944953 | 81.543121 | 1.692939e-10 | 28.681347 | 9.86142 | False | log |
Notice how the posterior weight on the correct
model is close to (or equal to ) $1$ here.
In other words model comparison points us to the correct model with
a very high degree of certainty here!
We can also use the .plot_compare()
function to illustrate the model comparison visually.
az.plot_compare(compare_data)
<Axes: title={'center': 'Model comparison\nhigher is better'}, xlabel='elpd_loo (log)', ylabel='ranked models'>
But what 'is' a Model in HSSM really?¶
Ok we have seen a few examples of HSSM models at this point. Add a model via a string, maybe toy a bit with with the priors and set regression functions for a given parameter. Turn it hierarchical... In this section we peel back the onion a bit more to understand better what is going on under the hood.
After all, we want to encourage you to contribute models to the package yourself.
Let's first take a little bit of a closer look into the model_config
dictionaries that define model properties for us. Again let's start with the DDM.
hssm.config.default_model_config["ddm"].keys()
dict_keys(['list_params', 'description', 'likelihoods'])
The dictionary has two high level keys.
analytical
approx_differentiable
These refer to two different types of likelihood that we have available for the ddm
model.
Our analytical
likelihood, which goes back to a standard algorithm designed by Navarro & Fuss. This is the likelihood which was used in the HDDM python toolbox.
Let's expand the dictionary contents:
hssm.config.default_model_config["ddm"]["likelihoods"]["analytical"]
{'loglik': <function hssm.likelihoods.analytical.logp_ddm(data: 'np.ndarray', v: 'float', a: 'float', z: 'float', t: 'float', err: 'float' = 1e-15, k_terms: 'int' = 20, epsilon: 'float' = 1e-15) -> 'np.ndarray'>, 'backend': None, 'bounds': {'v': (-inf, inf), 'a': (0.0, inf), 'z': (0.0, 1.0), 't': (0.0, inf)}, 'default_priors': {'t': {'name': 'HalfNormal', 'sigma': 2.0, 'initval': 0.1}}, 'extra_fields': None}
We see three properties (key) in this dictionary, of which two are essential:
- The
loglik
field, which points to the likelihood function - The
bounds
field, which specifies bounds on a subset of the model parameters - The
default_priors
field, which specifies parameter wise priors
If you provide bounds
for a parameter, but no default_priors
, a Uniform prior that respects the specified bounds will be applied.
Next, let's look at the approx_differentiable
part.
The likelihood in this part is based on a LAN which was available in HDDM through the LAN extension.
hssm.config.default_model_config["ddm"]["likelihoods"]["approx_differentiable"]
{'loglik': 'ddm.onnx', 'backend': 'jax', 'default_priors': {}, 'bounds': {'v': (-3.0, 3.0), 'a': (0.3, 2.5), 'z': (0.0, 1.0), 't': (0.0, 2.0)}, 'extra_fields': None}
We see that the loglik
field is now a string that points to a .onnx
file.
Onnx is a meta framework for Neural Network specification, that allows translation between deep learning Frameworks. This is the preferred format for the neural networks we store in our model reservoir on HuggingFace.
Moreover notice that we now have a backend
field. We allow for two primary backends in the approx_differentiable
field.
pytensor
jax
The jax
backend assumes that your likelihood is described as a jax function, the pytensor
backend assumes that your likelihood is described as a pytensor
function. Ok not that surprising...
We won't dwell on this here, however the key idea is to provide users with a large degree of flexibility in describing their likelihood functions and moreover to allow targeted optimization towards MCMC sampler types that PyMC allows us to access.
You can find a dedicated tutorial in the documentation, which describes the different likelihoods in much more detail.
Instead, let's take a quick look at how these newfound insights can be used for custom model definition.
hssm_alternative_model = hssm.HSSM(
data=dataset,
model="ddm",
loglik_kind="approx_differentiable",
)
hssm_alternative_model.loglik_kind
'approx_differentiable'
In this case we actually built the model class with an approx_differentiable
LAN likelihood, instead of the default analytical
likelihood we used in the beginning of the tutorial. The assumed generative model remains the ddm
however!
hssm_alternative_model.sample(
sampler="nuts_numpyro",
cores=1,
chains=2,
draws=1000,
tune=1000,
idata_kwargs=dict(
log_likelihood=False
), # model comparison metrics usually need this!
)
Compiling... Compilation time = 0:00:00.495657 Sampling...
/Users/yxu150/HSSM/.venv/lib/python3.9/site-packages/jax/_src/numpy/array_methods.py:796: UserWarning: Explicitly requested dtype float64 requested in astype is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more. return getattr(self.aval, name).fun(self, *args, **kwargs) /Users/yxu150/HSSM/.venv/lib/python3.9/site-packages/pytensor/link/jax/dispatch/tensor_basic.py:177: UserWarning: Explicitly requested dtype float64 requested in array is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more. return jnp.array(x, dtype=op.dtype)
0%| | 0/2000 [00:00<?, ?it/s]
0%| | 0/2000 [00:00<?, ?it/s]
Sampling time = 0:00:14.842033 Transforming variables... Transformation time = 0:00:00.032069
-
<xarray.Dataset> Dimensions: (chain: 2, draw: 1000) Coordinates: * chain (chain) int64 0 1 * draw (draw) int64 0 1 2 3 4 5 6 7 8 ... 992 993 994 995 996 997 998 999 Data variables: z (chain, draw) float32 0.5007 0.4872 0.4765 ... 0.4707 0.4896 0.4765 t (chain, draw) float32 0.2754 0.2388 0.1881 ... 0.2498 0.2388 0.2396 a (chain, draw) float32 1.48 1.497 1.525 1.462 ... 1.462 1.563 1.478 v (chain, draw) float32 0.6245 0.6462 0.6063 ... 0.6423 0.6408 0.6892 Attributes: created_at: 2023-09-05T18:33:51.457891 arviz_version: 0.14.0 modeling_interface: bambi modeling_interface_version: 0.12.0
-
<xarray.Dataset> Dimensions: (chain: 2, draw: 1000) Coordinates: * chain (chain) int64 0 1 * draw (draw) int64 0 1 2 3 4 5 6 ... 993 994 995 996 997 998 999 Data variables: acceptance_rate (chain, draw) float32 0.9945 0.9608 ... 0.9136 0.9959 step_size (chain, draw) float32 0.3576 0.3576 ... 0.3323 0.3323 diverging (chain, draw) bool False False False ... False False False energy (chain, draw) float32 973.4 972.3 976.0 ... 973.2 973.0 n_steps (chain, draw) int32 11 7 7 7 7 7 7 3 ... 15 15 7 3 15 7 7 tree_depth (chain, draw) int64 4 3 3 3 3 3 3 2 3 ... 4 4 4 4 3 2 4 3 3 lp (chain, draw) float32 971.3 970.8 973.0 ... 971.6 971.6 Attributes: created_at: 2023-09-05T18:33:51.459723 arviz_version: 0.14.0 modeling_interface: bambi modeling_interface_version: 0.12.0
-
<xarray.Dataset> Dimensions: (rt,response_obs: 500, rt,response_extra_dim_0: 2) Coordinates: * rt,response_obs (rt,response_obs) int64 0 1 2 3 ... 496 497 498 499 * rt,response_extra_dim_0 (rt,response_extra_dim_0) int64 0 1 Data variables: rt,response (rt,response_obs, rt,response_extra_dim_0) float32 ... Attributes: created_at: 2023-09-05T18:33:51.460779 arviz_version: 0.14.0 inference_library: numpyro inference_library_version: 0.12.1 sampling_time: 14.842033 modeling_interface: bambi modeling_interface_version: 0.12.0
az.plot_forest(hssm_alternative_model.traces)
array([<Axes: title={'center': '94.0% HDI'}>], dtype=object)
We can take this further and specify a completely custom likelihood. See the dedicated tutorial for more examples!
We will see one specific example below to illustrate another type of likelihood function we have available for model building in HSSM, the Blackbox likelihood.
'Blackbox' Likelihoods¶
What is a Blackbox Likelihood Function?¶
A Blackbox Likelihood Function is essentially any Python callable
(function) that provides trial by trial likelihoods for your model of interest. What kind of computations are performed in this Python function is completely arbitrary.
E.g. you could built a function that performs forward simulation from you model, constructs are kernel-density estimate for the resulting likelihood functions and evaluates your datapoints on this ad-hoc generated approximate likelihood.
What I just described is a once state-of-the-art method of performing simulation based inference on Sequential Sampling models, a precursor to LANs if you will.
We will do something simpler to keep it short and sweet, but really... the possibilities are endless!
Simulating simple dataset from the DDM¶
As always, let's begin by generating some simple dataset.
# Set parameters
v_true, a_true, z_true, t_true = [0.5, 1.5, 0.5, 0.5]
# Generate observations (rts, choices)
obs_ddm = simulator([v_true, a_true, z_true, t_true], model="ddm", n_samples=1000)
dataset = pd.DataFrame(
np.column_stack([obs_ddm["rts"][:, 0], obs_ddm["choices"][:, 0]]),
columns=["rt", "response"],
)
dataset
rt | response | |
---|---|---|
0 | 2.590031 | 1.0 |
1 | 2.400033 | 1.0 |
2 | 1.475991 | -1.0 |
3 | 2.188023 | 1.0 |
4 | 4.860867 | 1.0 |
... | ... | ... |
995 | 2.569032 | 1.0 |
996 | 1.379992 | 1.0 |
997 | 1.103996 | 1.0 |
998 | 3.428970 | 1.0 |
999 | 1.011997 | -1.0 |
1000 rows × 2 columns
Define the likelihood¶
Now the fun part... we simply define a Python function my_blackbox_loglik
which takes in our data
as well as a bunch of model parameters (in our case the familiar v
,a
, z
, t
from the DDM).
The function then does some arbitrary computation inside (in our case e.g. we pass the data and parameters to the DDM log-likelihood from our predecessor package HDDM).
def my_blackbox_loglik(data, v, a, z, t, err=1e-8):
data = data[:, 0] * data[:, 1]
# Our function expects inputs as float64, but they are not guaranteed to
# come in as such --> we type convert
return hddm_wfpt.wfpt.pdf_array(
np.float64(data),
np.float64(v),
0,
np.float64(2 * a),
np.float64(z),
0,
np.float64(t),
0,
err,
1,
)
Define HSSM class with our Blackbox Likelihood¶
We can now define our HSSM model class as usual, however passing our my_blackbox_loglik()
function to the loglik
argument, and passing as loglik_kind = blackbox
.
The rest of the model config is as usual. Here we can reuse our ddm
model config, and simply specify bounds on the parameters (e.g. your Blackbox Likelihood might be trustworthy only on a restricted parameters space).
model = hssm.HSSM(
data=dataset,
model="ddm",
loglik=my_blackbox_loglik,
loglik_kind="blackbox",
model_config={
"bounds": {
"v": (-10.0, 10.0),
"a": (0.5, 5.0),
"z": (0.0, 1.0),
}
},
t=bmb.Prior("Uniform", lower=0.0, upper=2.0, initval=0.1),
)
sample = model.sample()
Multiprocess sampling (4 chains in 4 jobs) CompoundStep >Slice: [z] >Slice: [t] >Slice: [a] >Slice: [v]
/var/folders/9x/cjrfyjd9443d4_0wt9qw8fhh0000gq/T/ipykernel_8618/1138958324.py:5: RuntimeWarning: divide by zero encountered in log return hddm_wfpt.wfpt.pdf_array( /var/folders/9x/cjrfyjd9443d4_0wt9qw8fhh0000gq/T/ipykernel_8618/1138958324.py:5: RuntimeWarning: divide by zero encountered in log return hddm_wfpt.wfpt.pdf_array( /var/folders/9x/cjrfyjd9443d4_0wt9qw8fhh0000gq/T/ipykernel_8618/1138958324.py:5: RuntimeWarning: divide by zero encountered in log return hddm_wfpt.wfpt.pdf_array( /var/folders/9x/cjrfyjd9443d4_0wt9qw8fhh0000gq/T/ipykernel_8618/1138958324.py:5: RuntimeWarning: divide by zero encountered in log return hddm_wfpt.wfpt.pdf_array( Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 7 seconds.
NOTE:
Since Blackbox likelihood functions are assumed to not be differentiable, our default sampler for such likelihood functions is a Slice
sampler. HSSM allows you to choose any other suitable sampler from the PyMC package instead. A bunch of options are available for grdient-free samplers.
Results¶
az.summary(sample)
mean | sd | hdi_3% | hdi_97% | mcse_mean | mcse_sd | ess_bulk | ess_tail | r_hat | |
---|---|---|---|---|---|---|---|---|---|
z | 0.479 | 0.014 | 0.453 | 0.505 | 0.000 | 0.000 | 1358.0 | 1594.0 | 1.0 |
t | 0.506 | 0.022 | 0.466 | 0.547 | 0.001 | 0.000 | 1025.0 | 1906.0 | 1.0 |
a | 1.489 | 0.029 | 1.433 | 1.542 | 0.001 | 0.001 | 1210.0 | 1768.0 | 1.0 |
v | 0.587 | 0.036 | 0.518 | 0.650 | 0.001 | 0.001 | 1397.0 | 2182.0 | 1.0 |
az.plot_trace(sample)
array([[<Axes: title={'center': 'z'}>, <Axes: title={'center': 'z'}>], [<Axes: title={'center': 't'}>, <Axes: title={'center': 't'}>], [<Axes: title={'center': 'a'}>, <Axes: title={'center': 'a'}>], [<Axes: title={'center': 'v'}>, <Axes: title={'center': 'v'}>]], dtype=object)
HSSM Random Variables in PyMC¶
We covered a lot of ground in this tutorial so far. You are now a sophisticated HSSM user.
It is therefore time to reveal a secret. We can actuallly peel back on more layer...
Instead of letting HSSM help you build the entire model, we can instead use HSSM to construct valid PyMC distributions and then proceed to build a custom PyMC model by ourselves...
We will illustrate the simplest example below. It sets a pattern that can be exploited for much more complicated modeling exercises, which importantly go far beyond what our basic HSSM class may facilitate for you!
See the dedicated tutorial in the documentation if you are interested.
Let's start by importing a few convenience functions:
from hssm.distribution_utils import (
make_distribution, # A general function for making Distribution classes
make_distribution_from_onnx, # Makes Distribution classes from onnx files
make_distribution_from_blackbox, # Makes Distribution classes from callables
)
# pm.Distributions that represents the top-level distribution for
# DDM models (the Wiener First-Passage Time distribution)
from hssm.likelihoods import logp_ddm_sdv, DDM
Simulate some data¶
v_true, a_true, z_true, t_true = [0.5, 1.5, 0.2, 0.5]
obs_ddm_pymc = simulator([v_true, a_true, z_true, t_true], model="ddm", n_samples=1000)
dataset_pymc = pd.DataFrame(
np.column_stack([obs_ddm_pymc["rts"][:, 0], obs_ddm_pymc["choices"][:, 0]]),
columns=["rt", "response"],
)
Build a custom PyMC Model¶
We can now use our custom random variable DDM
directly in a PyMC model.
import pymc as pm
with pm.Model() as ddm_pymc:
v = pm.Uniform("v", lower=-10.0, upper=10.0)
a = pm.HalfNormal("a", sigma=2.0)
z = pm.Uniform("z", lower=0.01, upper=0.99)
t = pm.Uniform("t", lower=0.0, upper=0.6, initval=0.1)
ddm = DDM("DDM", v=v, a=a, z=z, t=t, observed=dataset.values)
Let's check the model graph:
pm.model_to_graphviz(model=ddm_pymc)
Looks remarkably close to our HSSM version!
We can use PyMC directly to sample and finally return to ArviZ for some plotting!
with ddm_pymc:
ddm_pymc_trace = pm.sample()
Auto-assigning NUTS sampler... Initializing NUTS using jitter+adapt_diag... Multiprocess sampling (4 chains in 4 jobs) NUTS: [v, a, z, t]
Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 17 seconds.
az.plot_trace(ddm_pymc_trace)
array([[<Axes: title={'center': 'v'}>, <Axes: title={'center': 'v'}>], [<Axes: title={'center': 'a'}>, <Axes: title={'center': 'a'}>], [<Axes: title={'center': 'z'}>, <Axes: title={'center': 'z'}>], [<Axes: title={'center': 't'}>, <Axes: title={'center': 't'}>]], dtype=object)
az.plot_forest(ddm_pymc_trace)
array([<Axes: title={'center': '94.0% HDI'}>], dtype=object)
All layers peeled back, the only limit in your modeling endeavors becomes the limit of the PyMC universe!