Likelihood functions in HSSM explained¶
One of the design goals of HSSM is its flexibility. It is built from ground up to support many types of likelihood functions out-of-the-box. For more tailored applications, HSSM provides a convenient toolbox. This allows users to create their own likelihood functions, which can seamlessly integrate with the HSSM class, facilitating a highly customizable analysis environment. This notebook focuses on explaining how to use different types of likelihoods with HSSM.
Colab Instructions¶
If you would like to run this tutorial on Google colab, please click this link.
Once you are in the colab, follow the installation instructions below and then restart your runtime.
Just uncomment the code in the next code cell and run it!
NOTE:
You may want to switch your runtime to have a GPU or TPU. To do so, go to Runtime > Change runtime type and select the desired hardware accelerator.
Note that if you switch your runtime you have to follow the installation instructions again.
# If running this on Colab, please uncomment the next line
# !pip install hssm
Load Modules¶
import numpy as np
import pytensor
import hssm
help(hssm.simulate_data)
Help on function simulate_data in module hssm.simulator:
simulate_data(model: str, theta: Union[dict[str, Union[numpy._typing._array_like._Buffer, numpy._typing._array_like._SupportsArray[numpy.dtype[Any]], numpy._typing._nested_sequence._NestedSequence[numpy._typing._array_like._SupportsArray[numpy.dtype[Any]]], complex, bytes, str, numpy._typing._nested_sequence._NestedSequence[complex | bytes | str]]], list[float], numpy._typing._array_like._Buffer, numpy._typing._array_like._SupportsArray[numpy.dtype[Any]], numpy._typing._nested_sequence._NestedSequence[numpy._typing._array_like._SupportsArray[numpy.dtype[Any]]], complex, bytes, str, numpy._typing._nested_sequence._NestedSequence[complex | bytes | str]], size: int, random_state: int | None = None, output_df: bool = True, **kwargs) -> numpy.ndarray | pandas.core.frame.DataFrame
Sample simulated data from specified distributions.
Parameters
----------
model
A model name that must be supported in `ssm_simulators`. For a detailed list of
supported models, please see all fields in the `model_config` dict
[here](https://github.com/AlexanderFengler/ssm-simulators/blob
/e09eb2528d885c7b3340516597849fff4d9a5bf8/ssms/config/config.py#L6)
theta
Parameters of the process. Can be supplied as dictionary with parameter names as
key and np.array or float as values. Can also be supplied as a list or 1D-array,
however in this case the order of parameters is important and must match
specifications [here](https://github.com/AlexanderFengler/
ssm-simulators/blob/e09eb2528d885c7b3340516597849fff4d9a5bf8/ssms/config/config.py#L6).
Parameters can be specificed 'trial-wise', by supplying 1D arrays of shape
`size` to the dictionary, or by supplying a 2D array
of shape `(size, n_parameters)` dicrectly.
size
The size of the data to be simulated. If `theta` is a 2D ArrayLike, this
parameter indicates the size of data to be simulated for each trial.
random_state : optional
A random seed for reproducibility.
output_df : optional
If True, outputs a DataFrame with column names "rt", "response". Otherwise a
2-column numpy array, by default True.
kwargs : optional
Other arguments passed to ssms.basic_simulators.simulator.
Returns
-------
np.ndarray | pd.DataFrame
An array or DataFrame with simulated data.
Pre-simulate some data¶
# Simulate some data
data = hssm.simulate_data(
model="ddm", theta=dict(v=0.5, a=1.5, z=0.5, t=0.3), size=1000
)
data
| rt | response | |
|---|---|---|
| 0 | 1.045665 | 1.0 |
| 1 | 2.170658 | 1.0 |
| 2 | 1.184247 | 1.0 |
| 3 | 4.467735 | 1.0 |
| 4 | 1.289009 | 1.0 |
| ... | ... | ... |
| 995 | 1.803433 | 1.0 |
| 996 | 2.017609 | 1.0 |
| 997 | 3.286063 | 1.0 |
| 998 | 1.840125 | 1.0 |
| 999 | 2.363211 | 1.0 |
1000 rows × 2 columns
Three Kinds of Likelihoods¶
HSSM supports 3 kinds of likelihood functions supported via the loglik_kind parameter to the HSSM class:
"analytical": These likelihoods are usually closed-form solutions to the actual likelihoods. For example, Forddmmodels, HSSM provides the analytical likelihoods in Navarro & Fuss (2009). HSSM expects these functions to be written withpytensor, which can be compiled bypytensoras part of a computational graph. As such, they are differentiable as well."approx_differentiable": These likelihoods are usually approximations of the actual likelihood functions with neural networks. These networks can be trained with any popular deep learning framework such asPyTorchandTensorFlowand saved asonnxfiles. HSSM can load theonnxfiles and translate the information of the neural network with either thejaxor thepytensorbackends. Please see below for detailed explanations for these backends. Thebackendoption can be supplied via the"backend"field inmodel_config. This field ofmodel_configis not applicable to other kinds of likelihoods.- the
jaxbackend: The basic computations in the likelihood are jax operations (validJAXfunctions), which are wrapped in apytensorOp. When sampling using the default NUTS sampler inPyMC, this option might be slightly faster but more prone to compatibility issues especially during parallel sampling due howJAXhandles paralellism.The preferred usage of this backend is together with thenuts_numpyroandblack_jax(experimental) samplers. Here JAX support is native and performance is optimized. - the
pytensorbackend: The basic computations in the likelihood are pytensor operations (validpytensorfunctions). When sampling using the default NUTS sampler inPyMC, this option allows for maximum compatibility. Not recommended when usingJAX-based samplers.
- the
"blackbox": Use this option for "black box" likelihoods that are not differentiable. These likelihoods are typicallyCallables in Python that cannot be directly integrated to apytensorcomputational graph.hssmwill wrap theseCallables in apytensorOpso it can be part of the graph.
Default vs. Custom Likelihoods¶
HSSM provides many default likelihood functions out-of-the-box. The supported likelihoods are:
- For
analyticalkind:ddmandddm_sdvmodels. - For
approx_differentiablekind:ddm,ddm_sdv,angle,levy,ornstein,weibull,race_no_bias_angle_4andddm_seq2_no_bias. - For
blackboxkind:ddm,ddm_sdvandfull_ddmmodels.
For a model that has default likelihood functions, only the model argument needs to be specified.
ddm_model_analytical = hssm.HSSM(data, model="ddm")
Model initialized successfully.
ddm_model_analytical
Hierarchical Sequential Sampling Model
Model: ddm
Response variable: rt,response
Likelihood: analytical
Observations: 1000
Parameters:
v:
Prior: Normal(mu: 0.0, sigma: 2.0)
Explicit bounds: (-inf, inf)
a:
Prior: HalfNormal(sigma: 2.0)
Explicit bounds: (0.0, inf)
z:
Prior: Uniform(lower: 0.0, upper: 1.0)
Explicit bounds: (0.0, 1.0)
t:
Prior: HalfNormal(sigma: 2.0)
Explicit bounds: (0.0, inf)
Lapse probability: 0.05
Lapse distribution: Uniform(lower: 0.0, upper: 20.0)
ddm_model_analytical.graph()
The ddm and ddm_sdv models have analytical and approx_differentiable likelihoods. If loglik_kind is not specified, the analytical likelihood will be used. We can however directly specify the loglik_kind argument for a given model, and if available, the likelihood backend will be switched automatically.
ddm_model_approx_diff = hssm.HSSM(
data, model="ddm", loglik_kind="approx_differentiable"
)
Model initialized successfully.
While the model graph looks the same:
ddm_model_approx_diff.graph()
We can check that the likelihood is now coming from a different backend by printing the model string:
ddm_model_approx_diff
Hierarchical Sequential Sampling Model
Model: ddm
Response variable: rt,response
Likelihood: approx_differentiable
Observations: 1000
Parameters:
v:
Prior: Uniform(lower: -3.0, upper: 3.0)
Explicit bounds: (-3.0, 3.0)
a:
Prior: Uniform(lower: 0.3, upper: 2.5)
Explicit bounds: (0.3, 2.5)
z:
Prior: Uniform(lower: 0.0, upper: 1.0)
Explicit bounds: (0.0, 1.0)
t:
Prior: HalfNormal(sigma: 2.0)
Explicit bounds: (0.0, 2.0)
Lapse probability: 0.05
Lapse distribution: Uniform(lower: 0.0, upper: 20.0)
Note how under the Likelihood rubric, it now says "approx_differentiable". Another simple way to check this is to access the loglik_kind attribute of our HSSM model.
ddm_model_approx_diff.loglik_kind
'approx_differentiable'
Overriding default likelihoods¶
Sometimes a likelihood other than the default version is preferred. In that case, you can supply a likelihood function directly to the loglik parameter. We will discuss acceptable likelihood function types in a moment.
For illustration we load the basic analytical DDM likelihood, which is shipped with HSSM and supply it manually our HSSM model class.
from hssm.likelihoods.analytical import logp_ddm
ddm_model_analytical_override = hssm.HSSM(
data, model="ddm", loglik_kind="analytical", loglik=logp_ddm
)
Model initialized successfully.
HSSM automatically constructed our model with the likelihood function we provided. We can now take posterior samples as usual.
idata = ddm_model_analytical_override.sample(draws=500, tune=500, chains=2)
Using default initvals.
Initializing NUTS using adapt_diag... Multiprocess sampling (2 chains in 2 jobs) NUTS: [z, a, t, v] /opt/homebrew/Cellar/python@3.11/3.11.12/Frameworks/Python.framework/Versions/3.11/lib/python3.11/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock. self.pid = os.fork()
Output()
/opt/homebrew/Cellar/python@3.11/3.11.12/Frameworks/Python.framework/Versions/3.11/lib/python3.11/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock. self.pid = os.fork()
Sampling 2 chains for 500 tune and 500 draw iterations (1_000 + 1_000 draws total) took 9 seconds. We recommend running at least 4 chains for robust computation of convergence diagnostics 100%|██████████| 1000/1000 [00:00<00:00, 2803.20it/s]
ddm_model_analytical_override._inference_obj
-
<xarray.Dataset> Size: 36kB Dimensions: (chain: 2, draw: 500) Coordinates: * chain (chain) int64 16B 0 1 * draw (draw) int64 4kB 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499 Data variables: z (chain, draw) float64 8kB 0.4924 0.5124 0.5089 ... 0.5106 0.5249 a (chain, draw) float64 8kB 1.449 1.497 1.541 ... 1.475 1.461 1.431 v (chain, draw) float64 8kB 0.5545 0.5242 0.5237 ... 0.566 0.436 t (chain, draw) float64 8kB 0.3371 0.2973 0.273 ... 0.3674 0.3316 Attributes: created_at: 2025-09-27T00:09:25.419589+00:00 arviz_version: 0.22.0 inference_library: pymc inference_library_version: 5.25.1 sampling_time: 8.755553007125854 tuning_steps: 500 modeling_interface: bambi modeling_interface_version: 0.15.0 -
<xarray.Dataset> Size: 8MB Dimensions: (chain: 2, draw: 500, __obs__: 1000) Coordinates: * chain (chain) int64 16B 0 1 * draw (draw) int64 4kB 0 1 2 3 4 5 6 ... 493 494 495 496 497 498 499 * __obs__ (__obs__) int64 8kB 0 1 2 3 4 5 6 ... 994 995 996 997 998 999 Data variables: rt,response (chain, draw, __obs__) float64 8MB -0.8813 -1.571 ... -1.814 Attributes: modeling_interface: bambi modeling_interface_version: 0.15.0 -
<xarray.Dataset> Size: 134kB Dimensions: (chain: 2, draw: 500) Coordinates: * chain (chain) int64 16B 0 1 * draw (draw) int64 4kB 0 1 2 3 4 5 ... 495 496 497 498 499 Data variables: (12/18) step_size (chain, draw) float64 8kB 0.5059 0.5059 ... 0.5671 n_steps (chain, draw) float64 8kB 7.0 7.0 3.0 ... 7.0 3.0 3.0 energy (chain, draw) float64 8kB 1.983e+03 ... 1.987e+03 reached_max_treedepth (chain, draw) bool 1kB False False ... False False max_energy_error (chain, draw) float64 8kB 0.5403 0.4229 ... -0.5541 tree_depth (chain, draw) int64 8kB 3 3 2 2 2 3 3 ... 2 3 3 3 2 2 ... ... lp (chain, draw) float64 8kB -1.981e+03 ... -1.984e+03 perf_counter_start (chain, draw) float64 8kB 1.435e+06 ... 1.435e+06 step_size_bar (chain, draw) float64 8kB 0.625 0.625 ... 0.6339 perf_counter_diff (chain, draw) float64 8kB 0.006602 ... 0.003229 divergences (chain, draw) int64 8kB 0 0 0 0 0 0 0 ... 0 0 0 0 0 0 diverging (chain, draw) bool 1kB False False ... False False Attributes: created_at: 2025-09-27T00:09:25.454929+00:00 arviz_version: 0.22.0 inference_library: pymc inference_library_version: 5.25.1 sampling_time: 8.755553007125854 tuning_steps: 500 modeling_interface: bambi modeling_interface_version: 0.15.0 -
<xarray.Dataset> Size: 24kB Dimensions: (__obs__: 1000, rt,response_extra_dim_0: 2) Coordinates: * __obs__ (__obs__) int64 8kB 0 1 2 3 4 ... 996 997 998 999 * rt,response_extra_dim_0 (rt,response_extra_dim_0) int64 16B 0 1 Data variables: rt,response (__obs__, rt,response_extra_dim_0) float64 16kB ... Attributes: created_at: 2025-09-27T00:09:25.458147+00:00 arviz_version: 0.22.0 inference_library: pymc inference_library_version: 5.25.1 modeling_interface: bambi modeling_interface_version: 0.15.0
Using Custom Likelihoods¶
If you are specifying a model with a kind of likelihood that's not included in the list above, then HSSM considers that you are using a custom model with custom likelihoods. In this case, you will need to specify your entire model. Below is the procedure to specify a custom model:
Specify a
modelstring. It can be any string that helps identify the model, but if it is not one of the model strings supported in thessm_simulatorspackage see full list here, you will need to supply aRandomVariableclass tomodel_configdetailed below. Otherwise, you can still perform MCMC sampling, but sampling from the posterior predictive distribution will raise a ValueError.Specify a
model_config. It typically contains the following fields:"list_params": Required if yourmodelstring is not one ofddm,ddm_sdv,full_ddm,angle,levy,ornstein,weibull,race_no_bias_angle_4andddm_seq2_no_bias. A list ofstrindicating the parameters of the model. The order in which the parameters are specified in this list is important. Values for each parameter will be passed to the likelihood function in this order."backend": Optional. Only used whenloglik_kindisapprox_differentiableand an onnx file is supplied for the likelihood approximation network (LAN). Valid values are"jax"or"pytensor". It determines whether the LAN in ONNX should be converted to"jax"or"pytensor". If not provided,jaxwill be used for maximum performance."default_priors": Optional. Adictindicating the default priors for each parameter."bounds": Optional. Adictof(lower, upper)tuples indicating the acceptable boundaries for each parameter. In the case of LAN, these bounds are training boundaries."rv": Optional. Can be aRandomVariableclass containing the user's ownrng_fnfunction for sampling from the distribution that the user is supplying. If not supplied, HSSM will automatically generate aRandomVariableusing the simulator identified bymodelfrom thessm_simulatorspackage. Ifmodelis not supported inssm_simulators, a warning will be raised letting the user know that sampling from theRandomVariablewill result in errors.
Specify
loglikandloglik_kind.Specify parameter priors in
include.
NOTE:
default_priors and bounds in model_config specifies default priors and bounds for the model. Actual priors and defaults should be provided via the include list and will override these defaults.
Below are a few examples:
# An angle model with an analytical likelihood function.
# Because `model` is known, no `list_params` needs to be provided.
custom_angle_model = hssm.HSSM(
data,
model="angle",
model_config={
"bounds": {
"v": (-3.0, 3.0),
"a": (0.3, 3.0),
"z": (0.1, 0.9),
"t": (0.001, 2.0),
"theta": (-0.1, 1.3),
} # bounds will be used to create Uniform (uninformative) priors by default
# if priors are not supplied in `include`.
},
loglik=custom_angle_logp,
loglik_kind="analytical",
)
# A fully customized model with a custom likelihood function.
# Because `model` is not known, a `list_params` needs to be provided.
my_custom_model = hssm.HSSM(
data,
model="my_model",
model_config={
"list_params": ["v", "a", "z", "t", "theta"],
"bounds": {
"v": (-3.0, 3.0),
"a": (0.3, 3.0),
"z": (0.1, 0.9),
"t": (0.001, 2.0),
"theta": (-0.1, 1.3),
} # bounds will be used to create Uniform (uninformative) priors by default
# if priors are not supplied in `include`.
"default_priors": ... # usually no need to supply this.
"rv": MyRV # provide a RandomVariable class if pps is needed.
},
loglik="my_model.onnx", # Can be a path to an onnx model.
loglik_kind="approx_differentiable",
include=[...]
)
Supported types of likelihoods¶
When default likelihoods are not used, custom likelihoods are supplied via loglik argument to HSSM. Depending on what loglik_kind is used, loglik supports different types of Python objects:
Type[pm.Distribution]: Supports allloglik_kinds.You can pass any subclass of
pm.Distributiontologlikrepresenting the underlying top-level distribution of the model. It has to be a class instead of an instance of the class.Op: Supports allloglik_kindkinds.You can pass a
pytensorOp(an instance instead of the class itself), in which case HSSM will create a top-levelpm.Distribution, which calls thisOpin itslogpfunction to compute the log-likelihood.Callable: Supports allloglik_kinds.You can use any Python Callable as well. When
loglik_kindisblackbox, HSSM will wrap it in apytensorOpand create a top-levelpm.Distributionwith it. Otherwise, HSSM will assume that this Python callable is created withpytensorand is thus differentiable.strorPathlike: Only supported whenloglik_kindisapprox_differentiable.The
strorPathlikeindicates the path to anonnxfile which represents the neural network for likelihood approximation. In the case ofstr, if the path indicated bystris not found locally, HSSM will also look for theonnxfile in the official HuggingFace repo. An error is thrown when theonnxfile is not found.
Note
When using Op and Callable types of likelihoods, they need to have the this signature:
def logp_fn(data, *):
...
where data is a 2-column numpy array and * represents named arguments in the order of the parameters in list_params. For example, if a model's list_params is ["v", "a", "z", "t"], then the Op or Callable should at least look like this:
def logp_fn(data, v, a, z, t):
...
Using blackbox likelihoods¶
HSSM also supports "black box" likelihood functions, which are assumed to not be differentiable. When loglik_kind is blackbox, by default, HSSM will switch to a MCMC sampler that does not use differentiation. Below is an example showing how to use a blackbox likelihood function. We use a log-likelihood function for ddm written in Cython to show that you can use any function or computation inside this function as long as the function itself has the signature defined above. See here for the function definition.
import bambi as bmb
import hddm_wfpt
# Define a function with fun(data, *) signature
def my_blackbox_loglik(data, v, a, z, t, err=1e-8):
"""Create a blackbox log-likelihood function for the DDM model.
Note the function signature: the first argument must be the data, and the
remaining arguments are the parameters to be estimated. The function must
return the log-likelihood of the data given the parameters.
Parameters
----------
data : np.ndarray
A 2D array with columns for the RT and choice of each trial.
"""
data = data[:, 0] * data[:, 1]
data_nrows = data.shape[0]
# Our function expects inputs as float64, but they are not guaranteed to
# come in as such --> we type convert
return hddm_wfpt.wfpt.wiener_logp_array(
np.float64(data),
(np.ones(data_nrows) * v).astype(np.float64),
np.ones(data_nrows) * 0,
(np.ones(data_nrows) * 2 * a).astype(np.float64),
(np.ones(data_nrows) * z).astype(np.float64),
np.ones(data_nrows) * 0,
(np.ones(data_nrows) * t).astype(np.float64),
np.ones(data_nrows) * 0,
err,
1,
)
# Create the model with pdf_ddm_blackbox
model = hssm.HSSM(
data=data,
model="ddm",
loglik=my_blackbox_loglik,
loglik_kind="blackbox",
model_config={
"bounds": {
"v": (-10.0, 10.0),
"a": (0.0, 4.0),
"z": (0.0, 1.0),
"t": (0.0, 2.0),
}
},
t=bmb.Prior("Uniform", lower=0.0, upper=2.0),
)
Model initialized successfully.
model.graph()
sample = model.sample()
Using default initvals.
Multiprocess sampling (4 chains in 4 jobs) CompoundStep >Slice: [z] >Slice: [a] >Slice: [t] >Slice: [v] /opt/homebrew/Cellar/python@3.11/3.11.12/Frameworks/Python.framework/Versions/3.11/lib/python3.11/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock. self.pid = os.fork()
Output()
/opt/homebrew/Cellar/python@3.11/3.11.12/Frameworks/Python.framework/Versions/3.11/lib/python3.11/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock. self.pid = os.fork()
Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 25 seconds. 100%|██████████| 4000/4000 [00:01<00:00, 2103.47it/s]
model.summary()
| mean | sd | hdi_3% | hdi_97% | mcse_mean | mcse_sd | ess_bulk | ess_tail | r_hat | |
|---|---|---|---|---|---|---|---|---|---|
| z | 0.510 | 0.013 | 0.485 | 0.534 | 0.000 | 0.0 | 1182.0 | 2018.0 | 1.0 |
| a | 1.460 | 0.027 | 1.410 | 1.510 | 0.001 | 0.0 | 1812.0 | 2348.0 | 1.0 |
| t | 0.334 | 0.021 | 0.297 | 0.375 | 0.001 | 0.0 | 1363.0 | 2095.0 | 1.0 |
| v | 0.515 | 0.033 | 0.452 | 0.574 | 0.001 | 0.0 | 1368.0 | 2141.0 | 1.0 |
model.plot_trace()
Using the blackbox interface provides maximum flexibility on the user side. We hope you will find it useful!