Getting Started¶
This tutorial demonstrates how to quickly get started with the HSSM package. We will cover the following steps:
- How to create a model
- How to create some simple simulated data
- How to specify parameters
- How to specify parameters with regressions
- How to use ArviZ to summarize and visualize the traces.
Colab Instructions¶
If you would like to run this tutorial on Google colab, please click this link.
Once you are in the colab, follow the installation instructions below and then restart your runtime.
Just uncomment the code in the next code cell and run it!
NOTE:
You may want to switch your runtime to have a GPU or TPU. To do so, go to Runtime > Change runtime type and select the desired hardware accelerator.
Note that if you switch your runtime you have to follow the installation instructions again.
# !pip install hssm
Import Modules¶
import numpy as np
import hssm
%matplotlib inline
%config InlineBackend.figure_format='retina'
Simulating a dataset¶
The hssm.simulate_data()
function generates data for most SSM types. Here we simulate some data from a Drift Diffusion Model (DDM) with known true parameter values.
v_true, a_true, z_true, t_true = [0.5, 1.5, 0.5, 0.5]
dataset = hssm.simulate_data(
model="ddm",
theta=[v_true, a_true, z_true, t_true],
size=1000,
)
dataset
rt | response | |
---|---|---|
0 | 2.138525 | 1.0 |
1 | 4.298813 | 1.0 |
2 | 2.722723 | 1.0 |
3 | 1.338940 | 1.0 |
4 | 1.915012 | -1.0 |
... | ... | ... |
995 | 2.367986 | 1.0 |
996 | 1.623948 | 1.0 |
997 | 1.893260 | 1.0 |
998 | 1.184971 | 1.0 |
999 | 4.209857 | 1.0 |
1000 rows × 2 columns
Model specification¶
1. DDM using defaults¶
We begin with a simple example. The only information required to create a model in HSSM
is a dataset.
A dataset in HSSM
is typically a pandas
DataFrame
with at least rt
and response
columns, which indicates response time and choices respectively. Right now, response
only accepts values of 1
and -1
.
If none of the optional parameters is provided, HSSM will assume that we are modeling a classical DDM model with v
, a
, z
, and t
as its parameters. HSSM also provides a default analytical likelihood function and some uninformative priors. These can all be overriden by user inputs.
Note
From HSSM v0.1.2 on, lapse distributions will be enabled by default, with `p_outlier` fixed to 0.05. You can set `p_outlier` to 0 or `None` to disable lapse distributions.
simple_ddm_model = hssm.HSSM(data=dataset)
simple_ddm_model
Model initialized successfully.
Hierarchical Sequential Sampling Model Model: ddm Response variable: rt,response Likelihood: analytical Observations: 1000 Parameters: v: Prior: Normal(mu: 0.0, sigma: 2.0) Explicit bounds: (-inf, inf) a: Prior: HalfNormal(sigma: 2.0) Explicit bounds: (0.0, inf) z: Prior: Uniform(lower: 0.0, upper: 1.0) Explicit bounds: (0.0, 1.0) t: Prior: HalfNormal(sigma: 2.0) Explicit bounds: (0.0, inf) Lapse probability: 0.05 Lapse distribution: Uniform(lower: 0.0, upper: 20.0)
Visualizing the model¶
If you have graphviz
installed on your machine, you will also be able to visualize the model. Please uncomment the code and run it if you have graphviz
installed.
# Uncomment if you have graphviz installed
simple_ddm_model.graph()
Performing MCMC sampling¶
Similar to PyMC
, HSSM provides the sample()
method once the model is created to perform MCMC sampling. By default, it uses PyMC
's NUTS
sampler. We can use other samplers, which we will cover soon.
simple_ddm_model.sample()
Using default initvals.
Initializing NUTS using adapt_diag... Multiprocess sampling (4 chains in 4 jobs) NUTS: [z, t, a, v]
Output()
Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 19 seconds. 100%|██████████| 4000/4000 [00:01<00:00, 2698.98it/s]
-
<xarray.Dataset> Size: 136kB Dimensions: (chain: 4, draw: 1000) Coordinates: * chain (chain) int64 32B 0 1 2 3 * draw (draw) int64 8kB 0 1 2 3 4 5 6 7 ... 993 994 995 996 997 998 999 Data variables: v (chain, draw) float64 32kB 0.5497 0.5603 0.5725 ... 0.6046 0.6327 t (chain, draw) float64 32kB 0.5142 0.5036 0.5201 ... 0.5427 0.5165 z (chain, draw) float64 32kB 0.4838 0.4935 0.4967 ... 0.5005 0.4871 a (chain, draw) float64 32kB 1.455 1.444 1.481 ... 1.523 1.454 1.467 Attributes: created_at: 2025-01-01T22:56:22.506497+00:00 arviz_version: 0.19.0 inference_library: pymc inference_library_version: 5.19.1 sampling_time: 18.824692964553833 tuning_steps: 1000 modeling_interface: bambi modeling_interface_version: 0.15.0
-
<xarray.Dataset> Size: 32MB Dimensions: (chain: 4, draw: 1000, __obs__: 1000) Coordinates: * chain (chain) int64 32B 0 1 2 3 * draw (draw) int64 8kB 0 1 2 3 4 5 6 ... 993 994 995 996 997 998 999 * __obs__ (__obs__) int64 8kB 0 1 2 3 4 5 6 ... 994 995 996 997 998 999 Data variables: rt,response (chain, draw, __obs__) float64 32MB -1.413 -2.947 ... -2.917 Attributes: modeling_interface: bambi modeling_interface_version: 0.15.0
-
<xarray.Dataset> Size: 496kB Dimensions: (chain: 4, draw: 1000) Coordinates: * chain (chain) int64 32B 0 1 2 3 * draw (draw) int64 8kB 0 1 2 3 4 5 ... 995 996 997 998 999 Data variables: (12/17) acceptance_rate (chain, draw) float64 32kB 0.9386 0.9401 ... 0.9572 diverging (chain, draw) bool 4kB False False ... False False energy (chain, draw) float64 32kB 1.896e+03 ... 1.891e+03 energy_error (chain, draw) float64 32kB 0.2122 -0.2901 ... 0.05422 index_in_trajectory (chain, draw) int64 32kB -3 -1 -2 -3 1 ... 2 2 -3 -3 largest_eigval (chain, draw) float64 32kB nan nan nan ... nan nan ... ... process_time_diff (chain, draw) float64 32kB 0.007661 ... 0.007419 reached_max_treedepth (chain, draw) bool 4kB False False ... False False smallest_eigval (chain, draw) float64 32kB nan nan nan ... nan nan step_size (chain, draw) float64 32kB 0.7494 0.7494 ... 0.8027 step_size_bar (chain, draw) float64 32kB 0.632 0.632 ... 0.6422 tree_depth (chain, draw) int64 32kB 3 3 2 4 2 3 ... 3 3 3 3 3 3 Attributes: created_at: 2025-01-01T22:56:22.519958+00:00 arviz_version: 0.19.0 inference_library: pymc inference_library_version: 5.19.1 sampling_time: 18.824692964553833 tuning_steps: 1000 modeling_interface: bambi modeling_interface_version: 0.15.0
-
<xarray.Dataset> Size: 24kB Dimensions: (__obs__: 1000, rt,response_extra_dim_0: 2) Coordinates: * __obs__ (__obs__) int64 8kB 0 1 2 3 4 ... 996 997 998 999 * rt,response_extra_dim_0 (rt,response_extra_dim_0) int64 16B 0 1 Data variables: rt,response (__obs__, rt,response_extra_dim_0) float64 16kB ... Attributes: created_at: 2025-01-01T22:56:22.522503+00:00 arviz_version: 0.19.0 inference_library: pymc inference_library_version: 5.19.1 modeling_interface: bambi modeling_interface_version: 0.15.0
Visualizing the traces with ArviZ
¶
Like that of pm.sample()
, the result of model.sample()
is also an az.InferenceData
object, which can be used with the ArviZ
package. The last sample the model has performed is stored in the model.traces
property for eazy access. Here we use the az.summary()
and az.plot_trace()
functions to generate a summary table and diagnostic plots for the samples.
simple_ddm_model.summary()
mean | sd | hdi_3% | hdi_97% | mcse_mean | mcse_sd | ess_bulk | ess_tail | r_hat | |
---|---|---|---|---|---|---|---|---|---|
v | 0.623 | 0.035 | 0.555 | 0.686 | 0.001 | 0.0 | 2889.0 | 2946.0 | 1.0 |
t | 0.521 | 0.020 | 0.484 | 0.561 | 0.000 | 0.0 | 2601.0 | 2822.0 | 1.0 |
z | 0.490 | 0.014 | 0.463 | 0.515 | 0.000 | 0.0 | 2760.0 | 2545.0 | 1.0 |
a | 1.476 | 0.027 | 1.425 | 1.528 | 0.001 | 0.0 | 2695.0 | 2712.0 | 1.0 |
simple_ddm_model.plot_trace();
Congratulations! You have just created and sampled from your first model in HSSM! Parameter recovery seems to be pretty successful.
2. Specifying different model types¶
The default model in HSSM is the classic DDM, but HSSM supports many other model types. Below is a full list of supported models at the moment. We will add new models, and HSSM users can also contribute models. Please click here for our contribution guidelines.
- ddm
- ddm_sdv
- full_ddm
- angle
- levy
- ornstein
- weibull
- race_no_bias_angle_4
- ddm_seq2_no_bias
For more information about these models, please see here.
To model a different model type, please use the model
parameter of the HSSM class constructor.
If model
provided is one of the supported models above, HSSM also has default specifications for these models. No other specifications is necessary.
If model
provided is not one of the above, then HSSM considers that you are using a custom model. We will cover custom models in more detail in a separate tutorial.
Below is an example of specifying the angle
model type.
# Simulate data for an angle model
angle_data = hssm.simulate_data(
model="angle",
theta=[0.5, 1.5, 0.5, 0.5, 0.3], # true values
size=1000,
)
angle_model = hssm.HSSM(data=angle_data, model="angle")
angle_model
Model initialized successfully.
Hierarchical Sequential Sampling Model Model: angle Response variable: rt,response Likelihood: approx_differentiable Observations: 1000 Parameters: v: Prior: Uniform(lower: -3.0, upper: 3.0) Explicit bounds: (-3.0, 3.0) a: Prior: Uniform(lower: 0.3, upper: 3.0) Explicit bounds: (0.3, 3.0) z: Prior: Uniform(lower: 0.1, upper: 0.9) Explicit bounds: (0.1, 0.9) t: Prior: Uniform(lower: 0.001, upper: 2.0) Explicit bounds: (0.001, 2.0) theta: Prior: Uniform(lower: -0.1, upper: 1.3) Explicit bounds: (-0.1, 1.3) Lapse probability: 0.05 Lapse distribution: Uniform(lower: 0.0, upper: 20.0)
angle
models, by default, use an approximate differentiable likelihood function that relies on JAX for likelihood computation. At the moment, due to the different ways JAX
and Jupyter
handle parallelism, parallel sampling is not available for this type of likelihood computations. There are ways to get around this, but for now, let's just perform sequential sampling with just one core.
Note: the sample()
method internally calls bambi
's fit()
fit method, which internally calls pymc
's sample()
function. This means that HSSM's sample()
method will accept most parameters accepted by the two other functions. You could pretty much use it just like the pm.sample()
function.
# Unless otherwise specified, we default to the `nuts_numpyro` sampler
# for "approx_differentiable" likelihoods.
angle_model.sample()
Using default initvals.
0%| | 0/2000 [00:00<?, ?it/s]
0%| | 0/2000 [00:00<?, ?it/s]
0%| | 0/2000 [00:00<?, ?it/s]
0%| | 0/2000 [00:00<?, ?it/s]
100%|██████████| 4000/4000 [00:04<00:00, 869.89it/s]
-
<xarray.Dataset> Size: 168kB Dimensions: (chain: 4, draw: 1000) Coordinates: * chain (chain) int64 32B 0 1 2 3 * draw (draw) int64 8kB 0 1 2 3 4 5 6 7 ... 993 994 995 996 997 998 999 Data variables: theta (chain, draw) float64 32kB 0.3024 0.2963 0.3347 ... 0.3651 0.373 v (chain, draw) float64 32kB 0.4196 0.4099 0.4478 ... 0.4249 0.4194 z (chain, draw) float64 32kB 0.5282 0.528 0.5147 ... 0.5295 0.5169 t (chain, draw) float64 32kB 0.5346 0.532 0.4893 ... 0.464 0.4651 a (chain, draw) float64 32kB 1.476 1.468 1.565 ... 1.56 1.633 1.64 Attributes: created_at: 2025-01-01T22:58:11.077741+00:00 arviz_version: 0.19.0 inference_library: numpyro inference_library_version: 0.16.1 sampling_time: 104.771827 tuning_steps: 1000 modeling_interface: bambi modeling_interface_version: 0.15.0
-
<xarray.Dataset> Size: 32MB Dimensions: (chain: 4, draw: 1000, __obs__: 1000) Coordinates: * chain (chain) int64 32B 0 1 2 3 * draw (draw) int64 8kB 0 1 2 3 4 5 6 ... 993 994 995 996 997 998 999 * __obs__ (__obs__) int64 8kB 0 1 2 3 4 5 6 ... 994 995 996 997 998 999 Data variables: rt,response (chain, draw, __obs__) float64 32MB -1.526 -0.8813 ... -0.7195 Attributes: modeling_interface: bambi modeling_interface_version: 0.15.0
-
<xarray.Dataset> Size: 204kB Dimensions: (chain: 4, draw: 1000) Coordinates: * chain (chain) int64 32B 0 1 2 3 * draw (draw) int64 8kB 0 1 2 3 4 5 6 ... 994 995 996 997 998 999 Data variables: acceptance_rate (chain, draw) float64 32kB 0.9751 0.9826 ... 0.9353 0.9994 diverging (chain, draw) bool 4kB False False False ... False False energy (chain, draw) float64 32kB 1.479e+03 1.48e+03 ... 1.478e+03 lp (chain, draw) float64 32kB 1.478e+03 ... 1.478e+03 n_steps (chain, draw) int64 32kB 15 7 15 15 31 15 ... 7 31 15 31 15 step_size (chain, draw) float64 32kB 0.1867 0.1867 ... 0.2378 0.2378 tree_depth (chain, draw) int64 32kB 4 3 4 4 5 4 3 4 ... 5 5 3 5 4 5 4 Attributes: created_at: 2025-01-01T22:58:11.100235+00:00 arviz_version: 0.19.0 modeling_interface: bambi modeling_interface_version: 0.15.0
-
<xarray.Dataset> Size: 24kB Dimensions: (__obs__: 1000, rt,response_extra_dim_0: 2) Coordinates: * __obs__ (__obs__) int64 8kB 0 1 2 3 4 ... 996 997 998 999 * rt,response_extra_dim_0 (rt,response_extra_dim_0) int64 16B 0 1 Data variables: rt,response (__obs__, rt,response_extra_dim_0) float64 16kB ... Attributes: created_at: 2025-01-01T22:58:11.101098+00:00 arviz_version: 0.19.0 inference_library: numpyro inference_library_version: 0.16.1 sampling_time: 104.771827 tuning_steps: 1000 modeling_interface: bambi modeling_interface_version: 0.15.0
angle_model.summary()
mean | sd | hdi_3% | hdi_97% | mcse_mean | mcse_sd | ess_bulk | ess_tail | r_hat | |
---|---|---|---|---|---|---|---|---|---|
v | 0.435 | 0.044 | 0.354 | 0.517 | 0.001 | 0.001 | 1903.0 | 2071.0 | 1.0 |
z | 0.521 | 0.013 | 0.497 | 0.547 | 0.000 | 0.000 | 1772.0 | 2092.0 | 1.0 |
a | 1.578 | 0.064 | 1.455 | 1.693 | 0.002 | 0.001 | 1336.0 | 1801.0 | 1.0 |
theta | 0.341 | 0.028 | 0.288 | 0.396 | 0.001 | 0.001 | 1461.0 | 2105.0 | 1.0 |
t | 0.487 | 0.029 | 0.435 | 0.541 | 0.001 | 0.001 | 1403.0 | 1668.0 | 1.0 |
angle_model.plot_trace();
3. Specifying priors: the non-regression case¶
Next, let's take a look at how to specify priors in the non-regression case. In HSSM, parameter specification can be done in two ways:
- through the
include
parameter, or - through a shortcut
3.1 Specifying priors through the include
parameter¶
The include
parameter accept a list of dictionaries or hssm.Param
objects. Both dictionaries and hssm.Param
objects are equivalent, since the content of the dictionary will be passed as parameters to hssm.Param
class during model creation, so it is more of a matter of preference. We recommend the hssm.Param
object because some IDEs will be able to provide prompts for possible options of parameters. In the non-regression case, each dictionary typically looks like this:
{
"name": "v",
"prior": {
"name": "Uniform",
"lower": -5.0,
"upper": 5.0,
},
"bounds": (-10.0, 10.0)
}
This is the equivalent of writing:
hssm.Param(
"v",
prior=dict(name="Uniform", upper=-5.0, lower=5.0),
bounds=(-10.0, 10.0)
)
The name
field corresponds with the name of the parameter being specified.
The prior
field specifies the distribution of the prior. There are two ways to achieve this:
- A dictionary with the
name
of the distribution (typically captalized) and the parameters of the distribution that you would typically set if you were specifying a distribution inPyMC
. For example, if you would like to specifypm.Normal(mu=0.0, sigma=1.0)
as the prior, then inHSSM
, this prior dictionary would be:
{
"name": "Normal", ## Note it is capitalized
"mu": 0.0,
"sigma": 1.0,
}
or, using the dict
function:
dict(name="Normal", mu=0.0, sigma=1.0)
- A
hssm.Prior
object. This is exactly how you would specify priors usingbambi
(In fact,hssm.Prior
is a subclass ofbmb.Prior
and for the most part can be used interchangeably withbmb.Prior
). Then to specify the same normal prior as above, you would write:
hssm.Prior("Normal", mu=0.0, sigma=1.0)
The bounds
field accepts a tuple of floats, indicating the lower and upper bounds for the parameter.
Fixing parameters: sometimes you might want to fix the values of a parameter. You can easily do so by specifying that value to the prior
field of the dictionary. In the following example, the paramter v
is fixed to 0.5
.
{
"name": "v",
"prior": 0.5,
}
Now let's make this concrete with an example:
# A Normal prior for `v` without explicit bounds
param_v = {
"name": "v",
"prior": {
"name": "Normal",
"mu": 0.0,
"sigma": 2.0,
},
}
# A Uniform prior for `a`. Using the `dict` function
param_a = hssm.Param(
"a",
prior=dict(
name="Uniform",
lower=0.01,
upper=5,
),
bounds=(0, np.inf),
)
# A Uniform prior for `z` over (0, 1) set using hssm.Prior.
# bounds are not set, existing default bounds will be used
param_z = {"name": "z", "prior": hssm.Prior("Uniform", lower=0.0, upper=1.0)}
# A fixed value for t
param_t = {"name": "t", "prior": 0.5}
example_ddm_model = hssm.HSSM(
data=dataset,
model="ddm",
include=[
param_v,
param_a,
param_z,
param_t,
],
)
example_ddm_model
Model initialized successfully.
Hierarchical Sequential Sampling Model Model: ddm Response variable: rt,response Likelihood: analytical Observations: 1000 Parameters: v: Prior: Normal(mu: 0.0, sigma: 2.0) Explicit bounds: (-inf, inf) a: Prior: Uniform(lower: 0.01, upper: 5.0) Explicit bounds: (0, inf) z: Prior: Uniform(lower: 0.0, upper: 1.0) Explicit bounds: (0.0, 1.0) t: Prior: 0.5 Explicit bounds: (0.0, inf) Lapse probability: 0.05 Lapse distribution: Uniform(lower: 0.0, upper: 20.0)
example_ddm_model.sample()
Using default initvals.
Initializing NUTS using adapt_diag... Multiprocess sampling (4 chains in 4 jobs) NUTS: [z, a, v] /Users/afengler/miniconda3/envs/hssm519/lib/python3.11/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock. self.pid = os.fork()
Output()
/Users/afengler/miniconda3/envs/hssm519/lib/python3.11/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock. self.pid = os.fork()
Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 12 seconds. 100%|██████████| 4000/4000 [00:01<00:00, 2229.85it/s]
-
<xarray.Dataset> Size: 104kB Dimensions: (chain: 4, draw: 1000) Coordinates: * chain (chain) int64 32B 0 1 2 3 * draw (draw) int64 8kB 0 1 2 3 4 5 6 7 ... 993 994 995 996 997 998 999 Data variables: z (chain, draw) float64 32kB 0.4848 0.4887 0.4816 ... 0.4839 0.4778 v (chain, draw) float64 32kB 0.6765 0.6443 0.6463 ... 0.6343 0.693 a (chain, draw) float64 32kB 1.478 1.5 1.517 ... 1.498 1.478 1.484 Attributes: created_at: 2025-01-01T22:58:30.612048+00:00 arviz_version: 0.19.0 inference_library: pymc inference_library_version: 5.19.1 sampling_time: 12.191963911056519 tuning_steps: 1000 modeling_interface: bambi modeling_interface_version: 0.15.0
-
<xarray.Dataset> Size: 32MB Dimensions: (chain: 4, draw: 1000, __obs__: 1000) Coordinates: * chain (chain) int64 32B 0 1 2 3 * draw (draw) int64 8kB 0 1 2 3 4 5 6 ... 993 994 995 996 997 998 999 * __obs__ (__obs__) int64 8kB 0 1 2 3 4 5 6 ... 994 995 996 997 998 999 Data variables: rt,response (chain, draw, __obs__) float64 32MB -1.349 -3.008 ... -2.932 Attributes: modeling_interface: bambi modeling_interface_version: 0.15.0
-
<xarray.Dataset> Size: 496kB Dimensions: (chain: 4, draw: 1000) Coordinates: * chain (chain) int64 32B 0 1 2 3 * draw (draw) int64 8kB 0 1 2 3 4 5 ... 995 996 997 998 999 Data variables: (12/17) acceptance_rate (chain, draw) float64 32kB 0.4948 0.8589 ... 0.6956 diverging (chain, draw) bool 4kB False False ... False False energy (chain, draw) float64 32kB 1.891e+03 ... 1.893e+03 energy_error (chain, draw) float64 32kB 0.7037 -0.6844 ... 0.5542 index_in_trajectory (chain, draw) int64 32kB 1 5 5 -2 2 1 ... 3 3 -2 -4 1 largest_eigval (chain, draw) float64 32kB nan nan nan ... nan nan ... ... process_time_diff (chain, draw) float64 32kB 0.001533 ... 0.005752 reached_max_treedepth (chain, draw) bool 4kB False False ... False False smallest_eigval (chain, draw) float64 32kB nan nan nan ... nan nan step_size (chain, draw) float64 32kB 0.9622 0.9622 ... 0.9208 step_size_bar (chain, draw) float64 32kB 0.7741 0.7741 ... 0.7759 tree_depth (chain, draw) int64 32kB 1 3 3 2 2 3 ... 2 3 3 2 3 3 Attributes: created_at: 2025-01-01T22:58:30.624651+00:00 arviz_version: 0.19.0 inference_library: pymc inference_library_version: 5.19.1 sampling_time: 12.191963911056519 tuning_steps: 1000 modeling_interface: bambi modeling_interface_version: 0.15.0
-
<xarray.Dataset> Size: 24kB Dimensions: (__obs__: 1000, rt,response_extra_dim_0: 2) Coordinates: * __obs__ (__obs__) int64 8kB 0 1 2 3 4 ... 996 997 998 999 * rt,response_extra_dim_0 (rt,response_extra_dim_0) int64 16B 0 1 Data variables: rt,response (__obs__, rt,response_extra_dim_0) float64 16kB ... Attributes: created_at: 2025-01-01T22:58:30.627251+00:00 arviz_version: 0.19.0 inference_library: pymc inference_library_version: 5.19.1 modeling_interface: bambi modeling_interface_version: 0.15.0
example_ddm_model.summary()
mean | sd | hdi_3% | hdi_97% | mcse_mean | mcse_sd | ess_bulk | ess_tail | r_hat | |
---|---|---|---|---|---|---|---|---|---|
z | 0.484 | 0.012 | 0.460 | 0.506 | 0.000 | 0.0 | 2194.0 | 2510.0 | 1.0 |
v | 0.633 | 0.033 | 0.567 | 0.694 | 0.001 | 0.0 | 2384.0 | 2783.0 | 1.0 |
a | 1.493 | 0.023 | 1.451 | 1.537 | 0.000 | 0.0 | 3163.0 | 3002.0 | 1.0 |
3.2 Specifying priors using the shortcut¶
HSSM also supports a syntax very similar to PyMC
: You can directly specify priors by passing the prior to the name of the parameter in hssm.HSSM
. This is convenient if the prior is simple. Below is an example almost equivalent to the above example:
# All ways to specify priors mentioned above are supported in the shortcut syntax
shortcut_ddm_model = hssm.HSSM(
data=dataset,
model="ddm",
v={"name": "Normal", "mu": 0.0, "sigma": 2.0},
a=dict(name="Uniform", lower=0.01, upper=5),
z=hssm.Prior("Uniform", lower=0.01, upper=1.0),
t=0.5,
)
shortcut_ddm_model
Model initialized successfully.
Hierarchical Sequential Sampling Model Model: ddm Response variable: rt,response Likelihood: analytical Observations: 1000 Parameters: v: Prior: Normal(mu: 0.0, sigma: 2.0) Explicit bounds: (-inf, inf) a: Prior: Uniform(lower: 0.01, upper: 5.0) Explicit bounds: (0.0, inf) z: Prior: Uniform(lower: 0.01, upper: 1.0) Explicit bounds: (0.0, 1.0) t: Prior: 0.5 Explicit bounds: (0.0, inf) Lapse probability: 0.05 Lapse distribution: Uniform(lower: 0.0, upper: 20.0)
Note that the shortcut syntax also supports specifying bounds. It will be more polished in a future update. We will skip this step for now.
4. Specifying priors: the regression case¶
Built on top of bambi
, HSSM uses an lmer
-like syntax that makes it extremely straight-forward to specify regressions.
Parameters that are targets of regressions are also specified using dictionaries in include
. Below is an example for such dictionaries.
{
"name": "v",
"formula": "v ~ 1 + x + y",
"prior": {
# All ways to specify priors in the non-regression case
# work the same way here.
"Intercept": {"name": "Uniform", "lower": -10.0, "upper": 10.0},
"x": dict(name="Normal", mu=0, sigma=1),
"y": hssm.Prior("HalfNormal", sigma=0.5),
"z": 1.0
}
"link": "identity",
"bounds": (-10.0, 10.0)
}
We see that in the regression case, name
and bounds
are specified the exact same way as in the non-regression case. The regression formula is specified in a way that's very similar to the lmer
package in R. Users that have experience with R formulas should be very familar with this syntax. In this case, the formula means that the parameter v
is regressed on variables x
and y
, which can be found in the dataframe passed to hssm.HSSM
. The 1
explicitly specifies an intercept for the regression.
In addition to the formula
, the users typically needs to specify priors for the regression coefficients. This is done in the prior
field of the dictionary. Instead of specifying priors for the parameter, the priors are now specified for the corresponding regression coefficients. If not specified, HSSM will use default priors generated in Bambi.
The users might also want to specify a link
function for generalized linear models. If left unspecified, the identity
link function will be used.
Now let's see an example of a regression:
# Generate some fake simulation data
intercept = 1.5
x = np.random.uniform(-5.0, 5.0, size=1000)
y = np.random.uniform(-5.0, 5.0, size=1000)
v = intercept + 0.8 * x + 0.3 * y
true_values = np.column_stack([v, np.repeat([[1.5, 0.5, 0.5]], axis=0, repeats=1000)])
true_values.shape
(1000, 4)
dataset_reg_v = hssm.simulate_data(
model="ddm",
theta=true_values,
size=1, # Generate one data point for each of the 1000 set of true values
)
dataset_reg_v["x"] = x
dataset_reg_v["y"] = y
dataset_reg_v
rt | response | x | y | |
---|---|---|---|---|
0 | 0.740289 | 1.0 | 1.819151 | -4.117494 |
1 | 0.850780 | 1.0 | 2.491206 | 4.171911 |
2 | 0.843184 | 1.0 | 2.147904 | -2.180605 |
3 | 0.833531 | 1.0 | 2.570652 | -0.476937 |
4 | 0.997305 | -1.0 | -2.800995 | -1.066227 |
... | ... | ... | ... | ... |
995 | 1.523809 | 1.0 | -0.781649 | 3.039424 |
996 | 1.733002 | 1.0 | -0.322806 | 2.131486 |
997 | 1.454640 | 1.0 | 1.452501 | -3.462405 |
998 | 2.021644 | 1.0 | 0.933907 | -4.275842 |
999 | 1.974865 | 1.0 | -2.507812 | 4.814151 |
1000 rows × 4 columns
model_reg_v = hssm.HSSM(
data=dataset_reg_v,
include=[
{
"name": "v",
"formula": "v ~ 1 + x + y",
"prior": {
"Intercept": {"name": "Uniform", "lower": -0.001, "upper": 0.5},
"x": dict(name="Uniform", lower=0.0, upper=1.0),
"y": hssm.Prior("Uniform", lower=0.0, upper=1.0),
},
"link": "identity",
}
],
)
model_reg_v
Model initialized successfully.
Hierarchical Sequential Sampling Model Model: ddm Response variable: rt,response Likelihood: analytical Observations: 1000 Parameters: v: Formula: v ~ 1 + x + y Priors: v_Intercept ~ Uniform(lower: -0.001, upper: 0.5) v_x ~ Uniform(lower: 0.0, upper: 1.0) v_y ~ Uniform(lower: 0.0, upper: 1.0) Link: identity Explicit bounds: (-inf, inf) a: Prior: HalfNormal(sigma: 2.0) Explicit bounds: (0.0, inf) z: Prior: Uniform(lower: 0.0, upper: 1.0) Explicit bounds: (0.0, 1.0) t: Prior: HalfNormal(sigma: 2.0) Explicit bounds: (0.0, inf) Lapse probability: 0.05 Lapse distribution: Uniform(lower: 0.0, upper: 20.0)
model_reg_v.initvals
{'z': array(0.5), 't': array(0.025), 'a': array(1.5), 'v_Intercept': array(0.), 'v_x': array(0.5), 'v_y': array(0.5)}
# Uncomment to see model graph if you have graphviz installed
model_reg_v.graph()
trace_reg_v = model_reg_v.sample()
Using default initvals.
Initializing NUTS using adapt_diag... Multiprocess sampling (4 chains in 4 jobs) NUTS: [z, t, a, v_Intercept, v_x, v_y] /Users/afengler/miniconda3/envs/hssm519/lib/python3.11/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock. self.pid = os.fork()
Output()
/Users/afengler/miniconda3/envs/hssm519/lib/python3.11/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock. self.pid = os.fork()
Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 25 seconds. 100%|██████████| 4000/4000 [00:01<00:00, 2687.22it/s]
# Looks like parameter recovery was successful
model_reg_v.summary()
mean | sd | hdi_3% | hdi_97% | mcse_mean | mcse_sd | ess_bulk | ess_tail | r_hat | |
---|---|---|---|---|---|---|---|---|---|
v_Intercept | 0.548 | 0.004 | 0.540 | 0.554 | 0.000 | 0.0 | 4540.0 | 3515.0 | 1.0 |
v_x | 0.661 | 0.020 | 0.626 | 0.699 | 0.000 | 0.0 | 3951.0 | 3240.0 | 1.0 |
z | 0.616 | 0.008 | 0.600 | 0.632 | 0.000 | 0.0 | 4129.0 | 3041.0 | 1.0 |
a | 1.315 | 0.032 | 1.259 | 1.379 | 0.001 | 0.0 | 3679.0 | 2938.0 | 1.0 |
v_y | 0.271 | 0.015 | 0.243 | 0.299 | 0.000 | 0.0 | 4134.0 | 3066.0 | 1.0 |
t | 0.554 | 0.008 | 0.538 | 0.568 | 0.000 | 0.0 | 3413.0 | 2848.0 | 1.0 |
model_reg_v.plot_trace();