
This tutorial provides a comprehensive introduction to the HSSM package for Hierarchical Bayesian Estimation of Sequential Sampling Models.
To make the most of the tutorial, let us cover the functionality of the key supporting packages that we use along the way.
Colab Instructions¶
If you would like to run this tutorial on Google colab, please click this link.
Once you are in the colab, follow the installation instructions below and then restart your runtime.
Just uncomment the code in the next code cell and run it!
NOTE:
You may want to switch your runtime to have a GPU or TPU. To do so, go to Runtime > Change runtime type and select the desired hardware accelerator.
Note that if you switch your runtime you have to follow the installation instructions again.
# If running this on Colab, please uncomment the next line
# !pip install git+https://github.com/lnccbrown/HSSM@workshop_tutorial
Basic Imports¶
import warnings
warnings.filterwarnings("ignore")
# warnings.filterwarnings(action='once')
# Basics
import numpy as np
from matplotlib import pyplot as plt
random_seed_sim = 134
np.random.seed(random_seed_sim)
Data Simulation¶
We will rely on the ssms package for data simulation repeatedly. Let's look at a basic isolated use case below.
As an example, let's use ssms to simulate from the basic Drift Diffusion Model (a running example in this tutorial).

If you are not familiar with the DDM. For now just consider that it has four parameters.
v
the drift ratea
the boundary separationt
the non-decision timez
the a priori decision bias (starting point)
Using simulate_data()
¶
HSSM comes with a basic simulator function supplied the simulate_data()
function. We can use this function to create synthetic datasets.
Below we show the most basic usecase:
We wish to generate 500
datapoints (trials) from the standard Drift Diffusion Model with a fixed parameters, v = 0.5, a = 1.5, z = 0.5, t = 0.5
.
Note:
In the course of the tutorial, we will see multiple strategies for synthetic dataset generation, this being the most straightforward one.
# Single dataset
import arviz as az # Visualization
import bambi as bmb # Model construction
import hddm_wfpt
import jax
import pytensor # Graph-based tensor library
import hssm
pytensor.config.floatX = "float32"
jax.config.update("jax_enable_x64", False)
v_true = 0.5
a_true = 1.5
z_true = 0.5
t_true = 0.5
# Call the simulator function
dataset = hssm.simulate_data(
model="ddm", theta=dict(v=v_true, a=a_true, z=z_true, t=t_true), size=500
)
dataset
rt | response | |
---|---|---|
0 | 3.892863 | 1.0 |
1 | 1.561327 | -1.0 |
2 | 3.267397 | 1.0 |
3 | 3.347503 | -1.0 |
4 | 3.498229 | 1.0 |
... | ... | ... |
495 | 1.946372 | 1.0 |
496 | 1.224806 | -1.0 |
497 | 2.954831 | 1.0 |
498 | 1.238713 | 1.0 |
499 | 0.836208 | 1.0 |
500 rows × 2 columns
If instead you wish to supply a parameter that varies by trial (a lot more on this later), you can simply supply a vector of parameters to the theta
dictionary, when calling the simulator.
Note:
The size
argument conceptually functions as number of synthetic datasets. So if you supply a parameter as a (1000,)
vector, then the simulator assumes that one dataset consists of 1000
trials, hence if we set the size = 1
as below, we expect in return a dataset with 1000
trials.
# a changes trial wise
a_trialwise = np.random.normal(loc=2, scale=0.3, size=1000)
dataset_a_trialwise = hssm.simulate_data(
model="ddm",
theta=dict(
v=v_true,
a=a_trialwise,
z=z_true,
t=t_true,
),
size=1,
)
dataset_a_trialwise
rt | response | |
---|---|---|
0 | 2.357774 | 1.0 |
1 | 2.166630 | 1.0 |
2 | 2.417960 | -1.0 |
3 | 12.665596 | 1.0 |
4 | 3.624797 | 1.0 |
... | ... | ... |
995 | 5.246006 | 1.0 |
996 | 1.636335 | 1.0 |
997 | 5.401908 | 1.0 |
998 | 7.671658 | 1.0 |
999 | 2.394963 | 1.0 |
1000 rows × 2 columns
If we wish to simulate from another model, we can do so by changing the model
string.
The number of models we can simulate differs from the number of models for which we have likelihoods available (both will increase over time). To get the models for which likelihood functions are supplied out of the box, we should inspect hssm.HSSM.supported_models
.
hssm.HSSM.supported_models
('ddm', 'ddm_sdv', 'full_ddm', 'angle', 'levy', 'ornstein', 'weibull', 'race_no_bias_angle_4', 'ddm_seq2_no_bias', 'lba3', 'lba2')
If we wish to check more detailed information about a given supported model, we can use accessors get_<model_name>_config
under hssm.default
. For example, we inspect ddm
model metada below.
hssm.defaults.get_ddm_config()
{'response': ['rt', 'response'], 'list_params': ['v', 'a', 'z', 't'], 'choices': [-1, 1], 'description': 'The Drift Diffusion Model (DDM)', 'likelihoods': {'analytical': {'loglik': <function hssm.likelihoods.analytical.logp_ddm(data: numpy.ndarray, v: float, a: float, z: float, t: float, err: float = 1e-15, k_terms: int = 20, epsilon: float = 1e-15) -> numpy.ndarray>, 'backend': None, 'bounds': {'v': (-inf, inf), 'a': (0.0, inf), 'z': (0.0, 1.0), 't': (0.0, inf)}, 'default_priors': {'t': {'name': 'HalfNormal', 'sigma': 2.0}}, 'extra_fields': None}, 'approx_differentiable': {'loglik': 'ddm.onnx', 'backend': 'jax', 'default_priors': {'t': {'name': 'HalfNormal', 'sigma': 2.0}}, 'bounds': {'v': (-3.0, 3.0), 'a': (0.3, 2.5), 'z': (0.0, 1.0), 't': (0.0, 2.0)}, 'extra_fields': None}, 'blackbox': {'loglik': <function hssm.likelihoods.blackbox.hddm_to_hssm.<locals>.outer(data: numpy.ndarray, *args, **kwargs)>, 'backend': None, 'bounds': {'v': (-inf, inf), 'a': (0.0, inf), 'z': (0.0, 1.0), 't': (0.0, inf)}, 'default_priors': {'t': {'name': 'HalfNormal', 'sigma': 2.0}}, 'extra_fields': None}}}
This dictionary contains quite a bit of information. For purposes of simulating data from a given model, we will highlight two aspects:
- The key
list_of_params
provides us with the necessary information to define outtheta
dictionary - The
bounds
key inside thelikelihoods
sub-dictionaries, provides us with an indication of reasonable parameter values.
The likelihoods
dictionary inhabits three sub-directories for the ddm
model, since we have all three, an analytical
, an approx_differentiable
(LAN) and a blackbox
likelihood available. For many models, we will be able to access only one or two types of likelihoods.
Using ssm-simulators
¶
Internally, HSSM natively makes use of the ssm-simulators package for forward simulation of models.
hssm.simulate_data()
functions essentially as a convenience-wrapper.
Below we illustrate how to simulate data using the ssm-simulators
package directly, to generate an equivalent dataset as created above. We will use the third way of passing parameters to the simulator, which is as a parameter-matrix.
Notes:
If you pass parameters as a parameter matrix, make sure the column ordering is correct. You can follow the parameter ordering under
hssm.defaults.default_model_config['ddm']['list_params']
.This is a minimal example, for more information about the package, check the associated github-page.
import numpy as np
import pandas as pd
from ssms.basic_simulators.simulator import simulator
# a changes trial wise
theta_mat = np.zeros((1000, 4))
theta_mat[:, 0] = v_true # v
theta_mat[:, 1] = a_trialwise # a
theta_mat[:, 2] = z_true # z
theta_mat[:, 3] = t_true # t
# simulate data
sim_out_trialwise = simulator(
theta=theta_mat, # parameter_matrix
model="ddm", # specify model (many are included in ssms)
n_samples=1, # number of samples for each set of parameters
# (plays the role of `size` parameter in `hssm.simulate_data`)
)
# Turn into nice dataset
dataset_trialwise = pd.DataFrame(
np.column_stack(
[sim_out_trialwise["rts"][:, 0], sim_out_trialwise["choices"][:, 0]]
),
columns=["rt", "response"],
)
dataset_trialwise
rt | response | |
---|---|---|
0 | 2.126281 | -1.0 |
1 | 3.327770 | 1.0 |
2 | 3.811979 | -1.0 |
3 | 5.271413 | 1.0 |
4 | 3.837374 | 1.0 |
... | ... | ... |
995 | 3.061756 | -1.0 |
996 | 5.948287 | 1.0 |
997 | 1.967743 | 1.0 |
998 | 3.489338 | 1.0 |
999 | 2.352497 | 1.0 |
1000 rows × 2 columns
We will stick to hssm.simulate_data()
in this tutorial, to keep things simple.
ArviZ for Plotting¶

We use the ArviZ package for most of our plotting needs. ArviZ is a useful aid for plotting when doing anything Bayesian.
It works with HSSM out of the box, by virtue of HSSMs reliance on PyMC for model construction and sampling.
Checking out the ArviZ Documentation is a good idea to give you communication superpowers for not only your HSSM results, but also other libraries in the Bayesian Toolkit such as NumPyro or STAN.
We will see ArviZ plots throughout the notebook.
Main Tutorial¶
Initial Dataset¶
Let's proceed to simulate a simple dataset for our first example.
# Specify
param_dict_init = dict(
v=v_true,
a=a_true,
z=z_true,
t=t_true,
)
dataset = hssm.simulate_data(
model="ddm",
theta=param_dict_init,
size=500,
)
dataset
rt | response | |
---|---|---|
0 | 0.841617 | 1.0 |
1 | 3.490512 | 1.0 |
2 | 2.701918 | 1.0 |
3 | 2.934544 | 1.0 |
4 | 3.248816 | 1.0 |
... | ... | ... |
495 | 2.439974 | 1.0 |
496 | 7.056619 | 1.0 |
497 | 1.318811 | 1.0 |
498 | 1.439395 | 1.0 |
499 | 1.198503 | 1.0 |
500 rows × 2 columns
First HSSM Model¶
In this example we will use the analytical likelihood function computed as suggested in this paper.
Instantiate the model¶
To instantiate our HSSM
class, in the simplest version, we only need to provide an appropriate dataset.
The dataset is expected to be a pandas.DataFrame
with at least two columns, respectively called rt
(for reaction time) and response
.
Our data simulated above is already in the correct format, so let us try to construct the class.
NOTE:
If you are a user of the HDDM python package, this workflow should seem very familiar.
simple_ddm_model = hssm.HSSM(data=dataset)
Model initialized successfully.
simple_ddm_model
Hierarchical Sequential Sampling Model Model: ddm Response variable: rt,response Likelihood: analytical Observations: 500 Parameters: v: Prior: Normal(mu: 0.0, sigma: 2.0) Explicit bounds: (-inf, inf) a: Prior: HalfNormal(sigma: 2.0) Explicit bounds: (0.0, inf) z: Prior: Uniform(lower: 0.0, upper: 1.0) Explicit bounds: (0.0, 1.0) t: Prior: HalfNormal(sigma: 2.0) Explicit bounds: (0.0, inf) Lapse probability: 0.05 Lapse distribution: Uniform(lower: 0.0, upper: 20.0)
The print()
function gives us some basic information about our model including the number of observations the parameters in the model and their respective prior setting. We can also create a nice little graphical representation of our model...
Model Graph¶
Since HSSM
creates a PyMC Model
, we can can use the .graph()
function, to get a graphical representation of the the model we created.
simple_ddm_model.graph()
This is the simplest model we can build. The graph above follows plate notation, commonly used for probabilistic graphical models.
- We have our basic parameters (unobserved, white nodes), these are random variables in the model and we want to estimate them
- Our observed reaction times and choices (
SSMRandomVariable
, grey node), are fixed (or conditioned on). - Rounded rectangles provide us with information about dimensionality of objects
- Rectangles with sharp edges represent deterministic, but computed quantities (not shown here, but in later models)
This notation is helpful to get a quick overview of the structure of a given model we construct.
The graph()
function of course becomes a lot more interesting and useful for more complicated models!
Sample from the Model¶
We can now call the .sample()
function, to get posterior samples. The main arguments you may want to change are listed in the function call below.
Importantly, multiple backends are possible. We choose the nuts_numpyro
backend below,
which in turn compiles the model to a JAX
function.
infer_data_simple_ddm_model = simple_ddm_model.sample(
sampler="mcmc", # type of sampler to choose, 'nuts_numpyro',
# 'nuts_blackjax' of default pymc nuts sampler
cores=1, # how many cores to use
chains=2, # how many chains to run
draws=500, # number of draws from the markov chain
tune=1000, # number of burn-in samples
idata_kwargs=dict(log_likelihood=True), # return log likelihood
mp_ctx="spawn",
) # mp_ctx="forkserver")
Using default initvals.
Initializing NUTS using adapt_diag... Sequential sampling (2 chains in 1 job) NUTS: [t, a, z, v]
Output()
Output()
Sampling 2 chains for 1_000 tune and 500 draw iterations (2_000 + 1_000 draws total) took 12 seconds. There were 2 divergences after tuning. Increase `target_accept` or reparameterize. We recommend running at least 4 chains for robust computation of convergence diagnostics 100%|██████████| 1000/1000 [00:00<00:00, 4798.65it/s]
We sampled from the model, let's look at the output...
type(infer_data_simple_ddm_model)
arviz.data.inference_data.InferenceData
Errr... a closer look might be needed here!
Inference Data / What gets returned from the sampler?¶
The sampler returns an ArviZ InferenceData
object.
To understand all the logic behind these objects and how they mesh with the Bayesian Workflow, we refer you to the ArviZ Documentation.
InferenceData
is build on top of xarrays. The xarray documentation will help you understand in more detail how to manipulate these objects.
But let's take a quick high-level look to understand roughly what we are dealing with here!
infer_data_simple_ddm_model
-
<xarray.Dataset> Size: 20kB Dimensions: (chain: 2, draw: 500) Coordinates: * chain (chain) int64 16B 0 1 * draw (draw) int64 4kB 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499 Data variables: z (chain, draw) float32 4kB 0.5277 0.4719 0.5123 ... 0.5016 0.4757 a (chain, draw) float32 4kB 1.534 1.488 1.531 1.52 ... 1.55 1.47 1.51 t (chain, draw) float32 4kB 0.5117 0.5139 0.4826 ... 0.5348 0.4617 v (chain, draw) float32 4kB 0.4452 0.5427 0.459 ... 0.4747 0.4836 Attributes: created_at: 2024-12-27T22:26:00.767052+00:00 arviz_version: 0.18.0 inference_library: pymc inference_library_version: 5.19.1 sampling_time: 12.4322509765625 tuning_steps: 1000 modeling_interface: bambi modeling_interface_version: 0.15.0
-
<xarray.Dataset> Size: 4MB Dimensions: (chain: 2, draw: 500, __obs__: 500) Coordinates: * chain (chain) int64 16B 0 1 * draw (draw) int64 4kB 0 1 2 3 4 5 6 ... 493 494 495 496 497 498 499 * __obs__ (__obs__) int64 4kB 0 1 2 3 4 5 6 ... 494 495 496 497 498 499 Data variables: rt,response (chain, draw, __obs__) float64 4MB -1.494 -2.335 ... -1.068 Attributes: modeling_interface: bambi modeling_interface_version: 0.15.0
-
<xarray.Dataset> Size: 126kB Dimensions: (chain: 2, draw: 500) Coordinates: * chain (chain) int64 16B 0 1 * draw (draw) int64 4kB 0 1 2 3 4 5 ... 495 496 497 498 499 Data variables: (12/17) acceptance_rate (chain, draw) float64 8kB 0.9423 0.7136 ... 0.2158 diverging (chain, draw) bool 1kB False False ... False False energy (chain, draw) float64 8kB 1.038e+03 ... 1.044e+03 energy_error (chain, draw) float64 8kB 0.0855 0.22 ... 1.059 index_in_trajectory (chain, draw) int64 8kB -2 -5 -2 -1 -6 ... -2 -2 5 -4 largest_eigval (chain, draw) float64 8kB nan nan nan ... nan nan nan ... ... process_time_diff (chain, draw) float64 8kB 0.004019 ... 0.004314 reached_max_treedepth (chain, draw) bool 1kB False False ... False False smallest_eigval (chain, draw) float64 8kB nan nan nan ... nan nan nan step_size (chain, draw) float64 8kB 0.7159 0.7159 ... 0.7403 step_size_bar (chain, draw) float64 8kB 0.5958 0.5958 ... 0.6545 tree_depth (chain, draw) int64 8kB 3 3 3 2 3 3 3 ... 2 2 2 3 3 3 Attributes: created_at: 2024-12-27T22:26:00.784587+00:00 arviz_version: 0.18.0 inference_library: pymc inference_library_version: 5.19.1 sampling_time: 12.4322509765625 tuning_steps: 1000 modeling_interface: bambi modeling_interface_version: 0.15.0
-
<xarray.Dataset> Size: 8kB Dimensions: (__obs__: 500, rt,response_extra_dim_0: 2) Coordinates: * __obs__ (__obs__) int64 4kB 0 1 2 3 4 ... 496 497 498 499 * rt,response_extra_dim_0 (rt,response_extra_dim_0) int64 16B 0 1 Data variables: rt,response (__obs__, rt,response_extra_dim_0) float32 4kB 0... Attributes: created_at: 2024-12-27T22:26:00.787536+00:00 arviz_version: 0.18.0 inference_library: pymc inference_library_version: 5.19.1 modeling_interface: bambi modeling_interface_version: 0.15.0
We see that in our case, infer_data_simple_ddm_model
contains four basic types of data (note: this is extensible!)
posterior
log_likelihood
sample_stats
observed_data
The posterior
object contains our traces for each of the parameters in the model. The log_likelihood
field contains the trial wise log-likelihoods for each sample from the posterior. The sample_stats
field contains information about the sampler run. This can be important for chain diagnostics, but we will not dwell on this here. Finally we retreive our observed_data
.
Basic Manipulation¶
Accessing groups and variables¶
infer_data_simple_ddm_model.posterior
<xarray.Dataset> Size: 20kB Dimensions: (chain: 2, draw: 500) Coordinates: * chain (chain) int64 16B 0 1 * draw (draw) int64 4kB 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499 Data variables: z (chain, draw) float32 4kB 0.5277 0.4719 0.5123 ... 0.5016 0.4757 a (chain, draw) float32 4kB 1.534 1.488 1.531 1.52 ... 1.55 1.47 1.51 t (chain, draw) float32 4kB 0.5117 0.5139 0.4826 ... 0.5348 0.4617 v (chain, draw) float32 4kB 0.4452 0.5427 0.459 ... 0.4747 0.4836 Attributes: created_at: 2024-12-27T22:26:00.767052+00:00 arviz_version: 0.18.0 inference_library: pymc inference_library_version: 5.19.1 sampling_time: 12.4322509765625 tuning_steps: 1000 modeling_interface: bambi modeling_interface_version: 0.15.0
infer_data_simple_ddm_model.posterior.a.head()
<xarray.DataArray 'a' (chain: 2, draw: 5)> Size: 40B array([[1.5337849, 1.4877566, 1.5312983, 1.5203193, 1.4357468], [1.5204736, 1.5176351, 1.5176351, 1.458968 , 1.5212611]], dtype=float32) Coordinates: * chain (chain) int64 16B 0 1 * draw (draw) int64 40B 0 1 2 3 4
To simply access the underlying data as a numpy.ndarray
, we can use .values
(as e.g. when using pandas.DataFrame
objects).
type(infer_data_simple_ddm_model.posterior.a.values)
numpy.ndarray
# infer_data_simple_ddm_model.posterior.a.values
Combine chain
and draw
dimension¶
When operating directly on the xarray
, you will often find it useful to collapse the chain
and draw
coordinates into a single coordinate.
Arviz makes this easy via the extract
method.
idata_extracted = az.extract(infer_data_simple_ddm_model)
idata_extracted
<xarray.Dataset> Size: 40kB Dimensions: (sample: 1000) Coordinates: * sample (sample) object 8kB MultiIndex * chain (sample) int64 8kB 0 0 0 0 0 0 0 0 0 0 0 ... 1 1 1 1 1 1 1 1 1 1 1 * draw (sample) int64 8kB 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499 Data variables: z (sample) float32 4kB 0.5277 0.4719 0.5123 ... 0.4901 0.5016 0.4757 a (sample) float32 4kB 1.534 1.488 1.531 1.52 ... 1.55 1.47 1.51 t (sample) float32 4kB 0.5117 0.5139 0.4826 ... 0.4342 0.5348 0.4617 v (sample) float32 4kB 0.4452 0.5427 0.459 ... 0.5325 0.4747 0.4836 Attributes: created_at: 2024-12-27T22:26:00.767052+00:00 arviz_version: 0.18.0 inference_library: pymc inference_library_version: 5.19.1 sampling_time: 12.4322509765625 tuning_steps: 1000 modeling_interface: bambi modeling_interface_version: 0.15.0
Since Arviz really just calls the .stack()
method from xarray, here the corresponding example using the lower level xarray
interface.
infer_data_simple_ddm_model.posterior.stack(sample=("chain", "draw"))
<xarray.Dataset> Size: 40kB Dimensions: (sample: 1000) Coordinates: * sample (sample) object 8kB MultiIndex * chain (sample) int64 8kB 0 0 0 0 0 0 0 0 0 0 0 ... 1 1 1 1 1 1 1 1 1 1 1 * draw (sample) int64 8kB 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499 Data variables: z (sample) float32 4kB 0.5277 0.4719 0.5123 ... 0.4901 0.5016 0.4757 a (sample) float32 4kB 1.534 1.488 1.531 1.52 ... 1.55 1.47 1.51 t (sample) float32 4kB 0.5117 0.5139 0.4826 ... 0.4342 0.5348 0.4617 v (sample) float32 4kB 0.4452 0.5427 0.459 ... 0.5325 0.4747 0.4836 Attributes: created_at: 2024-12-27T22:26:00.767052+00:00 arviz_version: 0.18.0 inference_library: pymc inference_library_version: 5.19.1 sampling_time: 12.4322509765625 tuning_steps: 1000 modeling_interface: bambi modeling_interface_version: 0.15.0
Making use of ArviZ¶
Working with the InferenceData
directly, is very helpful if you want to include custom computations into your workflow.
For a basic Bayesian Workflow however, you will often find that standard functionality available through ArviZ
suffices.
Below we provide a few examples of useful Arviz outputs, which come handy for analyzing your traces (MCMC samples).
Summary table¶
Let's take a look at a summary table for our posterior.
az.summary(
infer_data_simple_ddm_model,
var_names=[var_name.name for var_name in simple_ddm_model.pymc_model.free_RVs],
)
mean | sd | hdi_3% | hdi_97% | mcse_mean | mcse_sd | ess_bulk | ess_tail | r_hat | |
---|---|---|---|---|---|---|---|---|---|
t | 0.500 | 0.034 | 0.432 | 0.560 | 0.001 | 0.001 | 716.0 | 643.0 | 1.0 |
a | 1.504 | 0.039 | 1.427 | 1.573 | 0.002 | 0.001 | 651.0 | 628.0 | 1.0 |
z | 0.503 | 0.018 | 0.471 | 0.539 | 0.001 | 0.001 | 655.0 | 739.0 | 1.0 |
v | 0.480 | 0.045 | 0.400 | 0.564 | 0.002 | 0.001 | 757.0 | 697.0 | 1.0 |
This table returns the parameter-wise mean of our posterior and a few extra statistics.
Of these extra statistics, the one-stop shop for flagging convergence issues is the r_hat
value, which
is reported in the right-most column.
To navigate this statistic, here is a rule of thumb widely used in applied Bayesian statistics.
If you find an r_hat
value > 1.01
, it warrants investigation.
Trace plot¶
az.plot_trace(
infer_data_simple_ddm_model, # we exclude the log_likelihood traces here
lines=[(key_, {}, param_dict_init[key_]) for key_ in param_dict_init],
)
plt.tight_layout()
The .sample()
function also sets a trace
attribute, on our hssm
class, so instead, we could call the plot like so:
az.plot_trace(
simple_ddm_model.traces,
lines=[(key_, {}, param_dict_init[key_]) for key_ in param_dict_init],
);
In this tutorial we are most often going to use the latter way of accessing the traces, but there is no preferred option.
Let's look at a few more plots.
Forest Plot¶
The forest plot is commonly used for a quick visual check of the marginal posteriors. It is very effective for intuitive communication of results.
az.plot_forest(simple_ddm_model.traces)
array([<Axes: title={'center': '94.0% HDI'}>], dtype=object)
Combining Chains¶
By default, chains are separated out into separate caterpillars, however
sometimes, especially if you are looking at a forest plot which includes many posterior parameters at once, you want to declutter and collapse the chains into single caterpillars.
In this case you can combine
chains instead.
az.plot_forest(simple_ddm_model.traces, combined=True)
array([<Axes: title={'center': '94.0% HDI'}>], dtype=object)
Basic Marginal Posterior Plot¶
Another way to view the marginal posteriors is provided by the plot_posterior()
function. It shows the mean and by default the $94\%$ HDIs.
az.plot_posterior(simple_ddm_model.traces)
array([<Axes: title={'center': 'z'}>, <Axes: title={'center': 'a'}>, <Axes: title={'center': 't'}>, <Axes: title={'center': 'v'}>], dtype=object)
Especially for parameter recovery studies, you may want to include reference values for the parameters of interest.
You can do so using the ref_val
argument. See the example below:
az.plot_posterior(
simple_ddm_model.traces,
ref_val=[
param_dict_init[var_name]
for var_name in simple_ddm_model.traces.posterior.data_vars
],
)
array([<Axes: title={'center': 'z'}>, <Axes: title={'center': 'a'}>, <Axes: title={'center': 't'}>, <Axes: title={'center': 'v'}>], dtype=object)
Since it is sometimes useful, especially for more complex cases, below an alternative approach in which we pass ref_val
as a dictionary.
az.plot_posterior(
simple_ddm_model.traces,
ref_val={
"v": [{"ref_val": param_dict_init["v"]}],
"a": [{"ref_val": param_dict_init["a"]}],
"z": [{"ref_val": param_dict_init["z"]}],
"t": [{"ref_val": param_dict_init["t"]}],
},
)
array([<Axes: title={'center': 'z'}>, <Axes: title={'center': 'a'}>, <Axes: title={'center': 't'}>, <Axes: title={'center': 'v'}>], dtype=object)
Posterior Pair Plot¶
The posterior pair plot show us bi-variate traceplots and is useful to check for simple parameter tradeoffs that may emerge. The simplest (linear) tradeoff may be a high correlation between two parameters. This can be very helpful in diagnosing sampler issues for example. If such tradeoffs exist, one often see extremely wide marginal distributions.
In our ddm
example, we see a little bit of a tradeoff between a
and t
, as well as between v
and z
, however nothing concerning.
az.plot_pair(
simple_ddm_model.traces,
kind="kde",
reference_values=param_dict_init,
marginals=True,
);
The few plot we showed here are just the beginning: ArviZ has a much broader spectrum of graphs and other convenience function available. Just check the documentation.
# Calculate the correlation matrix
posterior_correlation_matrix = np.corrcoef(
np.stack(
[idata_extracted[var_].values for var_ in idata_extracted.data_vars.variables]
)
)
num_vars = posterior_correlation_matrix.shape[0]
# Make heatmap
fig, ax = plt.subplots(1, 1)
cax = ax.imshow(posterior_correlation_matrix, cmap="coolwarm", vmin=-1, vmax=1)
fig.colorbar(cax, ax=ax)
ax.set_title("Posterior Correlation Matrix")
# Add ticks
ax.set_xticks(range(posterior_correlation_matrix.shape[0]))
ax.set_xticklabels([var_ for var_ in idata_extracted.data_vars.variables])
ax.set_yticks(range(posterior_correlation_matrix.shape[0]))
ax.set_yticklabels([var_ for var_ in idata_extracted.data_vars.variables])
# Annotate heatmap
for i in range(num_vars):
for j in range(num_vars):
ax.text(
j,
i,
f"{posterior_correlation_matrix[i, j]:.2f}",
ha="center",
va="center",
color="black",
)
plt.show()
HSSM Model based on LAN likelihood¶
With HSSM you can switch between pre-supplied models with a simple change of argument. The type of likelihood that will be accessed might change in the background for you.
Here we see an example in which the underlying likelihood is now a LAN.
We will talk more about different types of likelihood functions and backends later in the tutorial. For now just keep the following in mind:
There are three types of likelihoods
analytic
approx_differentiable
blackbox
To check which type is used in your HSSM model simple type:
simple_ddm_model.loglik_kind
'analytical'
Ah... we were using an analytical
likelihood with the DDM model in the last section.
Now let's see something different!
Simulating Angle Data¶
Again, let us simulate a simple dataset. This time we will use the angle
model (passed via the model
argument to the simulator()
function).
This model is distinguished from the basic ddm
model by an additional theta
parameter which specifies the angle with which the decision boundaries collapse over time.

DDMs with collapsing bounds have been of significant interest in the theoretical literature, but applications were rare due to a lack of analytical likelihoods. HSSM facilitates inference with such models via the our approx_differentiable
likelihoods. HSSM ships with a few predefined models based on LANs, but really we don't want to overemphasize those. They reflect the research interest of our and adjacent labs to a great extend.
Instead, we encourage the community to contribute to this model reservoir (more on this later).
# Simulate angle data
v_angle_true = 0.5
a_angle_true = 1.5
z_angle_true = 0.5
t_angle_true = 0.2
theta_angle_true = 0.2
param_dict_angle = dict(v=0.5, a=1.5, z=0.5, t=0.2, theta=0.2)
lines_list_angle = [(key_, {}, param_dict_angle[key_]) for key_ in param_dict_angle]
dataset_angle = hssm.simulate_data(model="angle", theta=param_dict_angle, size=1000)
We pass a single additional argument to our HSSM
class and set model='angle'
.
model_angle = hssm.HSSM(data=dataset_angle, model="angle")
model_angle
Model initialized successfully.
Hierarchical Sequential Sampling Model Model: angle Response variable: rt,response Likelihood: approx_differentiable Observations: 1000 Parameters: v: Prior: Uniform(lower: -3.0, upper: 3.0) Explicit bounds: (-3.0, 3.0) a: Prior: Uniform(lower: 0.30000001192092896, upper: 3.0) Explicit bounds: (0.3, 3.0) z: Prior: Uniform(lower: 0.10000000149011612, upper: 0.8999999761581421) Explicit bounds: (0.1, 0.9) t: Prior: Uniform(lower: 0.0010000000474974513, upper: 2.0) Explicit bounds: (0.001, 2.0) theta: Prior: Uniform(lower: -0.10000000149011612, upper: 1.2999999523162842) Explicit bounds: (-0.1, 1.3) Lapse probability: 0.05 Lapse distribution: Uniform(lower: 0.0, upper: 20.0)
The model graph now show us an additional parameter theta
!
model_angle.graph()
Let's check the type of likelihood that is used under the hood ...
model_angle.loglik_kind
'approx_differentiable'
Ok so here we rely on a likelihood of the approx_differentiable
kind.
As discussed, with the initial set of pre-supplied likelihoods, this implies that we are using a LAN in the background.
jax.config.update("jax_enable_x64", False)
infer_data_angle = model_angle.sample(
sampler="mcmc",
chains=2,
cores=2,
draws=500,
tune=500,
idata_kwargs=dict(log_likelihood=False), # no need to return likelihoods here
mp_ctx="spawn",
)
Using default initvals. Parallel sampling might not work with `jax` backend and the PyMC NUTS sampler on some platforms. Please consider using `nuts_numpyro` or `nuts_blackjax` sampler if that is a problem.
Initializing NUTS using adapt_diag... Multiprocess sampling (2 chains in 2 jobs) NUTS: [z, t, a, theta, v]
Output()
Sampling 2 chains for 500 tune and 500 draw iterations (1_000 + 1_000 draws total) took 74 seconds. We recommend running at least 4 chains for robust computation of convergence diagnostics
az.plot_trace(model_angle.traces, lines=lines_list_angle)
plt.tight_layout()
Choosing Priors¶
HSSM allows you to specify priors quite freely. If you used HDDM previously, you may feel relieved to read that your hands are now untied!

With HSSM we have multiple routes to priors. But let's first consider a special case:
Fixing a parameter to a given value¶
Assume that instead of fitting all parameters of the DDM,

we instead want to fit only the v
(drift) parameter, setting all other parameters to fixed scalar values.

HSSM makes this extremely easy!
param_dict_init
{'v': 0.5, 'a': 1.5, 'z': 0.5, 't': 0.5}
ddm_model_only_v = hssm.HSSM(
data=dataset,
model="ddm",
a=param_dict_init["a"],
t=param_dict_init["t"],
z=param_dict_init["z"],
)
Model initialized successfully.
Since we fix all but one parameter, we therefore estimate only one parameter. This should be reflected in our model graph, where we expect only one free random variable v
:
ddm_model_only_v.graph()
ddm_model_only_v.sample(
sampler="mcmc",
chains=2,
cores=2,
draws=500,
tune=500,
idata_kwargs=dict(log_likelihood=False), # no need to return likelihoods here
mp_ctx="spawn",
)
Using default initvals.
Initializing NUTS using adapt_diag... Multiprocess sampling (2 chains in 2 jobs) NUTS: [v]
Output()
Sampling 2 chains for 500 tune and 500 draw iterations (1_000 + 1_000 draws total) took 53 seconds. We recommend running at least 4 chains for robust computation of convergence diagnostics
-
<xarray.Dataset> Size: 8kB Dimensions: (chain: 2, draw: 500) Coordinates: * chain (chain) int64 16B 0 1 * draw (draw) int64 4kB 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499 Data variables: v (chain, draw) float32 4kB 0.5139 0.5005 0.416 ... 0.4835 0.5047 Attributes: created_at: 2024-12-27T22:29:36.245832+00:00 arviz_version: 0.18.0 inference_library: pymc inference_library_version: 5.19.1 sampling_time: 53.41643691062927 tuning_steps: 500 modeling_interface: bambi modeling_interface_version: 0.15.0
-
<xarray.Dataset> Size: 126kB Dimensions: (chain: 2, draw: 500) Coordinates: * chain (chain) int64 16B 0 1 * draw (draw) int64 4kB 0 1 2 3 4 5 ... 495 496 497 498 499 Data variables: (12/17) acceptance_rate (chain, draw) float64 8kB 0.9206 0.9023 ... 0.8293 diverging (chain, draw) bool 1kB False False ... False False energy (chain, draw) float64 8kB 1.032e+03 ... 1.033e+03 energy_error (chain, draw) float64 8kB 0.08268 -0.1014 ... 0.0737 index_in_trajectory (chain, draw) int64 8kB 1 -3 -1 -1 0 ... 0 -1 -1 -1 2 largest_eigval (chain, draw) float64 8kB nan nan nan ... nan nan nan ... ... process_time_diff (chain, draw) float64 8kB 5.4e-05 ... 0.000116 reached_max_treedepth (chain, draw) bool 1kB False False ... False False smallest_eigval (chain, draw) float64 8kB nan nan nan ... nan nan nan step_size (chain, draw) float64 8kB 1.255 1.255 ... 0.986 0.986 step_size_bar (chain, draw) float64 8kB 1.24 1.24 ... 1.078 1.078 tree_depth (chain, draw) int64 8kB 1 2 2 2 2 2 2 ... 1 1 1 1 2 2 Attributes: created_at: 2024-12-27T22:29:36.260041+00:00 arviz_version: 0.18.0 inference_library: pymc inference_library_version: 5.19.1 sampling_time: 53.41643691062927 tuning_steps: 500 modeling_interface: bambi modeling_interface_version: 0.15.0
-
<xarray.Dataset> Size: 8kB Dimensions: (__obs__: 500, rt,response_extra_dim_0: 2) Coordinates: * __obs__ (__obs__) int64 4kB 0 1 2 3 4 ... 496 497 498 499 * rt,response_extra_dim_0 (rt,response_extra_dim_0) int64 16B 0 1 Data variables: rt,response (__obs__, rt,response_extra_dim_0) float32 4kB 0... Attributes: created_at: 2024-12-27T22:29:36.262883+00:00 arviz_version: 0.18.0 inference_library: pymc inference_library_version: 5.19.1 modeling_interface: bambi modeling_interface_version: 0.15.0
az.plot_trace(
ddm_model_only_v.traces.posterior, lines=[("v", {}, param_dict_init["v"])]
);
Instead of the trace on the right, a useful alternative / complement is the rank plot. As a rule of thumb, if the rank plots within chains look uniformly distributed, then our chains generally exhibit good mixing.
az.plot_trace(ddm_model_only_v.traces, kind="rank_bars")
array([[<Axes: title={'center': 'v'}>, <Axes: title={'center': 'v'}, xlabel='Rank (all chains)', ylabel='Chain'>]], dtype=object)
Named priors¶
We can choose any PyMC Distribution
to specify a prior for a given parameter.
Even better, if natural parameter bounds are provided, HSSM automatically truncates the prior distribution so that it respect these bounds.
Below is an example in which we specify a Normal prior on the v
parameter of the DDM.
We choose a ridiculously low $\sigma$ value, to illustrate it's regularizing effect on the parameter (just so we see a difference and you are convinced that something changed).
model_normal = hssm.HSSM(
data=dataset,
include=[
{
"name": "v",
"prior": {"name": "Normal", "mu": 0, "sigma": 0.01},
}
],
)
Model initialized successfully.
model_normal
Hierarchical Sequential Sampling Model Model: ddm Response variable: rt,response Likelihood: analytical Observations: 500 Parameters: v: Prior: Normal(mu: 0.0, sigma: 0.009999999776482582) Explicit bounds: (-inf, inf) a: Prior: HalfNormal(sigma: 2.0) Explicit bounds: (0.0, inf) z: Prior: Uniform(lower: 0.0, upper: 1.0) Explicit bounds: (0.0, 1.0) t: Prior: HalfNormal(sigma: 2.0) Explicit bounds: (0.0, inf) Lapse probability: 0.05 Lapse distribution: Uniform(lower: 0.0, upper: 20.0)
infer_data_normal = model_normal.sample(
sampler="mcmc",
chains=2,
cores=2,
draws=500,
tune=500,
idata_kwargs=dict(log_likelihood=False), # no need to return likelihoods here
mp_ctx="spawn",
)
Using default initvals.
Initializing NUTS using adapt_diag... Multiprocess sampling (2 chains in 2 jobs) NUTS: [t, a, z, v]
Output()
Sampling 2 chains for 500 tune and 500 draw iterations (1_000 + 1_000 draws total) took 57 seconds. There were 552 divergences after tuning. Increase `target_accept` or reparameterize. We recommend running at least 4 chains for robust computation of convergence diagnostics
az.plot_trace(
model_normal.traces,
lines=[(key_, {}, param_dict_init[key_]) for key_ in param_dict_init],
)
array([[<Axes: title={'center': 'a'}>, <Axes: title={'center': 'a'}>], [<Axes: title={'center': 't'}>, <Axes: title={'center': 't'}>], [<Axes: title={'center': 'v'}>, <Axes: title={'center': 'v'}>], [<Axes: title={'center': 'z'}>, <Axes: title={'center': 'z'}>]], dtype=object)
Observe how we reused our previous dataset with underlying parameters
v = 0.5
a = 1.5
z = 0.5
t = 0.2
In contrast to our previous sampler round, in which we used Uniform priors, here the v
estimate is shrunk severley towared $0$ and the t
and z
parameter estimates are very biased to make up for this distortion. Also, overall we see a lot of divergences now, which is a sign of poor sampler performance.
HSSM Model with Regression¶

Crucial to the scope of HSSM is the ability to link parameters with trial-by-trial covariates via (hierarchical, but more on this later) general linear models.
In this section we explore how HSSM deals with these models. No big surprise here... it's simple!
Case 1: One parameter is a Regression Target¶
Simulating Data¶
Let's first simulate some data, where the trial-by-trial parameters of the v
parameter in our model are driven by a simple linear regression model.
The regression model is driven by two (random) covariates x
and y
, respectively with coefficients of $0.8$ and $0.3$ which are also simulated below.
We set the intercept to $0.3$.
The rest of the parameters are fixed to single values as before.
# Set up trial by trial parameters
v_intercept = 0.3
x = np.random.uniform(-1, 1, size=1000)
v_x = 0.8
y = np.random.uniform(-1, 1, size=1000)
v_y = 0.3
v_reg_v = v_intercept + (v_x * x) + (v_y * y)
# rest
a_reg_v = 1.5
z_reg_v = 0.5
t_reg_v = 0.1
param_dict_reg_v = dict(
a=1.5, z=0.5, t=0.1, v=v_reg_v, v_x=v_x, v_y=v_y, v_Intercept=v_intercept, theta=0.0
)
# base dataset
dataset_reg_v = hssm.simulate_data(model="ddm", theta=param_dict_reg_v, size=1)
# Adding covariates into the datsaframe
dataset_reg_v["x"] = x
dataset_reg_v["y"] = y
Basic Model¶
We now create the HSSM
model.
Notice how we set the include
argument. The include argument expects a list of dictionaries, one dictionary for each parameter to be specified via a regression model.
Four keys
are expected to be set:
- The
name
of the parameter, - Potentially a
prior
for each of the regression level parameters ($\beta$'s), - The regression
formula
- A
link
function.
The regression formula follows the syntax in the formulae python package (as used by the Bambi package for building Bayesian Hierarchical Regression Models.
Bambi forms the main model-construction backend of HSSM.
model_reg_v_simple = hssm.HSSM(
data=dataset_reg_v, include=[{"name": "v", "formula": "v ~ 1 + x + y"}]
)
Model initialized successfully.
model_reg_v_simple
Hierarchical Sequential Sampling Model Model: ddm Response variable: rt,response Likelihood: analytical Observations: 1000 Parameters: v: Formula: v ~ 1 + x + y Priors: v_Intercept ~ Normal(mu: 2.0, sigma: 3.0) v_x ~ Normal(mu: 0.0, sigma: 0.25) v_y ~ Normal(mu: 0.0, sigma: 0.25) Link: identity Explicit bounds: (-inf, inf) a: Prior: HalfNormal(sigma: 2.0) Explicit bounds: (0.0, inf) z: Prior: Uniform(lower: 0.0, upper: 1.0) Explicit bounds: (0.0, 1.0) t: Prior: HalfNormal(sigma: 2.0) Explicit bounds: (0.0, inf) Lapse probability: 0.05 Lapse distribution: Uniform(lower: 0.0, upper: 20.0)
Param
class¶
As illustrated below, there is an alternative way of specifying the parameter specific data via the Param
class.
model_reg_v_simple_new = hssm.HSSM(
data=dataset_reg_v, include=[hssm.Param(name="v", formula="v ~ 1 + x + y")]
)
Model initialized successfully.
model_reg_v_simple_new
Hierarchical Sequential Sampling Model Model: ddm Response variable: rt,response Likelihood: analytical Observations: 1000 Parameters: v: Formula: v ~ 1 + x + y Priors: v_Intercept ~ Normal(mu: 2.0, sigma: 3.0) v_x ~ Normal(mu: 0.0, sigma: 0.25) v_y ~ Normal(mu: 0.0, sigma: 0.25) Link: identity Explicit bounds: (-inf, inf) a: Prior: HalfNormal(sigma: 2.0) Explicit bounds: (0.0, inf) z: Prior: Uniform(lower: 0.0, upper: 1.0) Explicit bounds: (0.0, 1.0) t: Prior: HalfNormal(sigma: 2.0) Explicit bounds: (0.0, inf) Lapse probability: 0.05 Lapse distribution: Uniform(lower: 0.0, upper: 20.0)
model_reg_v_simple.graph()
print(model_reg_v_simple)
Hierarchical Sequential Sampling Model Model: ddm Response variable: rt,response Likelihood: analytical Observations: 1000 Parameters: v: Formula: v ~ 1 + x + y Priors: v_Intercept ~ Normal(mu: 2.0, sigma: 3.0) v_x ~ Normal(mu: 0.0, sigma: 0.25) v_y ~ Normal(mu: 0.0, sigma: 0.25) Link: identity Explicit bounds: (-inf, inf) a: Prior: HalfNormal(sigma: 2.0) Explicit bounds: (0.0, inf) z: Prior: Uniform(lower: 0.0, upper: 1.0) Explicit bounds: (0.0, 1.0) t: Prior: HalfNormal(sigma: 2.0) Explicit bounds: (0.0, inf) Lapse probability: 0.05 Lapse distribution: Uniform(lower: 0.0, upper: 20.0)
Custom Model¶
These were the defaults, with a little extra labor, we can e.g. customize the choice of priors for each parameter in the model.
model_reg_v = hssm.HSSM(
data=dataset_reg_v,
include=[
{
"name": "v",
"prior": {
"Intercept": {"name": "Uniform", "lower": -3.0, "upper": 3.0},
"x": {"name": "Uniform", "lower": -1.0, "upper": 1.0},
"y": {"name": "Uniform", "lower": -1.0, "upper": 1.0},
},
"formula": "v ~ 1 + x + y",
"link": "identity",
}
],
)
Model initialized successfully.
model_reg_v
Hierarchical Sequential Sampling Model Model: ddm Response variable: rt,response Likelihood: analytical Observations: 1000 Parameters: v: Formula: v ~ 1 + x + y Priors: v_Intercept ~ Uniform(lower: -3.0, upper: 3.0) v_x ~ Uniform(lower: -1.0, upper: 1.0) v_y ~ Uniform(lower: -1.0, upper: 1.0) Link: identity Explicit bounds: (-inf, inf) a: Prior: HalfNormal(sigma: 2.0) Explicit bounds: (0.0, inf) z: Prior: Uniform(lower: 0.0, upper: 1.0) Explicit bounds: (0.0, 1.0) t: Prior: HalfNormal(sigma: 2.0) Explicit bounds: (0.0, inf) Lapse probability: 0.05 Lapse distribution: Uniform(lower: 0.0, upper: 20.0)
Notice how v
is now set as a regression.
infer_data_reg_v = model_reg_v.sample(
sampler="mcmc",
chains=2,
cores=2,
draws=500,
tune=500,
mp_ctx="spawn",
)
Using default initvals.
Initializing NUTS using adapt_diag... Multiprocess sampling (2 chains in 2 jobs) NUTS: [t, a, z, v_Intercept, v_x, v_y]
Output()
Sampling 2 chains for 500 tune and 500 draw iterations (1_000 + 1_000 draws total) took 78 seconds. We recommend running at least 4 chains for robust computation of convergence diagnostics 100%|██████████| 1000/1000 [00:00<00:00, 2526.84it/s]
infer_data_reg_v
-
<xarray.Dataset> Size: 32kB Dimensions: (chain: 2, draw: 500) Coordinates: * chain (chain) int64 16B 0 1 * draw (draw) int64 4kB 0 1 2 3 4 5 6 ... 493 494 495 496 497 498 499 Data variables: z (chain, draw) float32 4kB 0.5122 0.5095 ... 0.5202 0.5141 v_x (chain, draw) float32 4kB 0.7631 0.8532 ... 0.8481 0.7889 t (chain, draw) float32 4kB 0.1354 0.1353 0.1228 ... 0.141 0.095 a (chain, draw) float32 4kB 1.476 1.432 1.473 ... 1.422 1.475 v_y (chain, draw) float32 4kB 0.3061 0.2164 ... 0.3503 0.2618 v_Intercept (chain, draw) float64 8kB 0.3044 0.3164 ... 0.3146 0.3094 Attributes: created_at: 2024-12-27T22:33:42.049881+00:00 arviz_version: 0.18.0 inference_library: pymc inference_library_version: 5.19.1 sampling_time: 78.35613012313843 tuning_steps: 500 modeling_interface: bambi modeling_interface_version: 0.15.0
-
<xarray.Dataset> Size: 8MB Dimensions: (chain: 2, draw: 500, __obs__: 1000) Coordinates: * chain (chain) int64 16B 0 1 * draw (draw) int64 4kB 0 1 2 3 4 5 6 ... 493 494 495 496 497 498 499 * __obs__ (__obs__) int64 8kB 0 1 2 3 4 5 6 ... 994 995 996 997 998 999 Data variables: rt,response (chain, draw, __obs__) float64 8MB -1.444 -2.44 ... -1.465 Attributes: modeling_interface: bambi modeling_interface_version: 0.15.0
-
<xarray.Dataset> Size: 126kB Dimensions: (chain: 2, draw: 500) Coordinates: * chain (chain) int64 16B 0 1 * draw (draw) int64 4kB 0 1 2 3 4 5 ... 495 496 497 498 499 Data variables: (12/17) acceptance_rate (chain, draw) float64 8kB 0.878 0.8049 ... 0.838 diverging (chain, draw) bool 1kB False False ... False False energy (chain, draw) float64 8kB 1.985e+03 ... 1.988e+03 energy_error (chain, draw) float64 8kB 0.1968 0.2409 ... -0.04186 index_in_trajectory (chain, draw) int64 8kB -1 -2 -4 4 -3 ... -2 5 6 -2 5 largest_eigval (chain, draw) float64 8kB nan nan nan ... nan nan nan ... ... process_time_diff (chain, draw) float64 8kB 0.008167 ... 0.005709 reached_max_treedepth (chain, draw) bool 1kB False False ... False False smallest_eigval (chain, draw) float64 8kB nan nan nan ... nan nan nan step_size (chain, draw) float64 8kB 0.5894 0.5894 ... 0.6907 step_size_bar (chain, draw) float64 8kB 0.5946 0.5946 ... 0.5695 tree_depth (chain, draw) int64 8kB 3 2 3 3 3 3 3 ... 3 3 3 3 2 3 Attributes: created_at: 2024-12-27T22:33:42.064810+00:00 arviz_version: 0.18.0 inference_library: pymc inference_library_version: 5.19.1 sampling_time: 78.35613012313843 tuning_steps: 500 modeling_interface: bambi modeling_interface_version: 0.15.0
-
<xarray.Dataset> Size: 16kB Dimensions: (__obs__: 1000, rt,response_extra_dim_0: 2) Coordinates: * __obs__ (__obs__) int64 8kB 0 1 2 3 4 ... 996 997 998 999 * rt,response_extra_dim_0 (rt,response_extra_dim_0) int64 16B 0 1 Data variables: rt,response (__obs__, rt,response_extra_dim_0) float32 8kB 0... Attributes: created_at: 2024-12-27T22:33:42.068486+00:00 arviz_version: 0.18.0 inference_library: pymc inference_library_version: 5.19.1 modeling_interface: bambi modeling_interface_version: 0.15.0
# az.plot_forest(model_reg_v.traces)
az.plot_trace(
model_reg_v.traces,
var_names=["~v"],
lines=[(key_, {}, param_dict_reg_v[key_]) for key_ in param_dict_reg_v],
)
array([[<Axes: title={'center': 'z'}>, <Axes: title={'center': 'z'}>], [<Axes: title={'center': 'v_x'}>, <Axes: title={'center': 'v_x'}>], [<Axes: title={'center': 't'}>, <Axes: title={'center': 't'}>], [<Axes: title={'center': 'a'}>, <Axes: title={'center': 'a'}>], [<Axes: title={'center': 'v_y'}>, <Axes: title={'center': 'v_y'}>], [<Axes: title={'center': 'v_Intercept'}>, <Axes: title={'center': 'v_Intercept'}>]], dtype=object)
az.plot_trace(
model_reg_v.traces,
var_names=["~v"],
lines=[(key_, {}, param_dict_reg_v[key_]) for key_ in param_dict_reg_v],
)
plt.tight_layout()
# Looks like parameter recovery was successful
az.summary(model_reg_v.traces, var_names=["~v"])
mean | sd | hdi_3% | hdi_97% | mcse_mean | mcse_sd | ess_bulk | ess_tail | r_hat | |
---|---|---|---|---|---|---|---|---|---|
z | 0.508 | 0.013 | 0.485 | 0.533 | 0.000 | 0.000 | 842.0 | 686.0 | 1.00 |
v_x | 0.810 | 0.048 | 0.718 | 0.899 | 0.002 | 0.001 | 855.0 | 503.0 | 1.00 |
t | 0.120 | 0.019 | 0.084 | 0.156 | 0.001 | 0.001 | 708.0 | 682.0 | 1.00 |
a | 1.470 | 0.026 | 1.426 | 1.524 | 0.001 | 0.001 | 880.0 | 711.0 | 1.00 |
v_y | 0.305 | 0.043 | 0.221 | 0.385 | 0.001 | 0.001 | 1085.0 | 768.0 | 1.01 |
v_Intercept | 0.317 | 0.033 | 0.258 | 0.384 | 0.001 | 0.001 | 919.0 | 548.0 | 1.01 |
Case 2: One parameter is a Regression (LAN)¶
We can do the same thing with the angle
model.
Note:
Our dataset was generated from the basic DDM here, so since the DDM assumes stable bounds, we expect the theta
(angle of linear collapse) parameter to be recovered as close to $0$.
model_reg_v_angle = hssm.HSSM(
data=dataset_reg_v,
model="angle",
include=[
{
"name": "v",
"prior": {
"Intercept": {
"name": "Uniform",
"lower": -3.0,
"upper": 3.0,
},
"x": {
"name": "Uniform",
"lower": -1.0,
"upper": 1.0,
},
"y": {"name": "Uniform", "lower": -1.0, "upper": 1.0},
},
"formula": "v ~ 1 + x + y",
"link": "identity",
}
],
)
Model initialized successfully.
model_reg_v_angle.graph()