Likelihood functions in HSSM explained¶
One of the design goals of HSSM is its flexibility. It is built from ground up to support many types of likelihood functions out-of-the-box. For more tailored applications, HSSM provides a convenient toolbox. This allows users to create their own likelihood functions, which can seamlessly integrate with the HSSM class, facilitating a highly customizable analysis environment. This notebook focuses on explaining how to use different types of likelihoods with HSSM.
Colab Instructions¶
If you would like to run this tutorial on Google colab, please click this link.
Once you are in the colab, follow the installation instructions below and then restart your runtime.
Just uncomment the code in the next code cell and run it!
NOTE:
You may want to switch your runtime to have a GPU or TPU. To do so, go to Runtime > Change runtime type and select the desired hardware accelerator.
Note that if you switch your runtime you have to follow the installation instructions again.
# If running this on Colab, please uncomment the next line
# !pip install hssm
Load Modules¶
import numpy as np
import pytensor
import hssm
pytensor.config.floatX = "float32"
help(hssm.simulate_data)
Help on function simulate_data in module hssm.simulator: simulate_data(model: str, theta: Union[dict[str, Union[numpy._typing._array_like._SupportsArray[numpy.dtype[Any]], numpy._typing._nested_sequence._NestedSequence[numpy._typing._array_like._SupportsArray[numpy.dtype[Any]]], bool, int, float, complex, str, bytes, numpy._typing._nested_sequence._NestedSequence[Union[bool, int, float, complex, str, bytes]]]], list[float], numpy._typing._array_like._SupportsArray[numpy.dtype[Any]], numpy._typing._nested_sequence._NestedSequence[numpy._typing._array_like._SupportsArray[numpy.dtype[Any]]], bool, int, float, complex, str, bytes, numpy._typing._nested_sequence._NestedSequence[Union[bool, int, float, complex, str, bytes]]], size: int, random_state: int | None = None, output_df: bool = True, **kwargs) -> numpy.ndarray | pandas.core.frame.DataFrame Sample simulated data from specified distributions. Parameters ---------- model A model name that must be supported in `ssm_simulators`. For a detailed list of supported models, please see all fields in the `model_config` dict [here](https://github.com/AlexanderFengler/ssm-simulators/blob /e09eb2528d885c7b3340516597849fff4d9a5bf8/ssms/config/config.py#L6) theta Parameters of the process. Can be supplied as dictionary with parameter names as key and np.array or float as values. Can also be supplied as a list or 1D-array, however in this case the order of parameters is important and must match specifications [here](https://github.com/AlexanderFengler/ ssm-simulators/blob/e09eb2528d885c7b3340516597849fff4d9a5bf8/ssms/config/config.py#L6). Parameters can be specificed 'trial-wise', by supplying 1D arrays of shape `size` to the dictionary, or by supplying a 2D array of shape `(size, n_parameters)` dicrectly. size The size of the data to be simulated. If `theta` is a 2D ArrayLike, this parameter indicates the size of data to be simulated for each trial. random_state : optional A random seed for reproducibility. output_df : optional If True, outputs a DataFrame with column names "rt", "response". Otherwise a 2-column numpy array, by default True. kwargs : optional Other arguments passed to ssms.basic_simulators.simulator. Returns ------- np.ndarray | pd.DataFrame An array or DataFrame with simulated data.
Pre-simulate some data¶
# Simulate some data
data = hssm.simulate_data(
model="ddm", theta=dict(v=0.5, a=1.5, z=0.5, t=0.3), size=1000
)
data
rt | response | |
---|---|---|
0 | 1.157311 | 1.0 |
1 | 1.297562 | -1.0 |
2 | 1.033477 | -1.0 |
3 | 4.234457 | 1.0 |
4 | 1.737934 | 1.0 |
... | ... | ... |
995 | 1.558967 | 1.0 |
996 | 6.683352 | -1.0 |
997 | 3.063884 | 1.0 |
998 | 0.858070 | 1.0 |
999 | 2.328675 | -1.0 |
1000 rows × 2 columns
Three Kinds of Likelihoods¶
HSSM supports 3 kinds of likelihood functions supported via the loglik_kind
parameter to the HSSM
class:
"analytical"
: These likelihoods are usually closed-form solutions to the actual likelihoods. For example, Forddm
models, HSSM provides the analytical likelihoods in Navarro & Fuss (2009). HSSM expects these functions to be written withpytensor
, which can be compiled bypytensor
as part of a computational graph. As such, they are differentiable as well."approx_differentiable"
: These likelihoods are usually approximations of the actual likelihood functions with neural networks. These networks can be trained with any popular deep learning framework such asPyTorch
andTensorFlow
and saved asonnx
files. HSSM can load theonnx
files and translate the information of the neural network with either thejax
or thepytensor
backends. Please see below for detailed explanations for these backends. Thebackend
option can be supplied via the"backend"
field inmodel_config
. This field ofmodel_config
is not applicable to other kinds of likelihoods.- the
jax
backend: The basic computations in the likelihood are jax operations (validJAX
functions), which are wrapped in apytensor
Op
. When sampling using the default NUTS sampler inPyMC
, this option might be slightly faster but more prone to compatibility issues especially during parallel sampling due howJAX
handles paralellism.The preferred usage of this backend is together with thenuts_numpyro
andblack_jax
(experimental) samplers. Here JAX support is native and performance is optimized. - the
pytensor
backend: The basic computations in the likelihood are pytensor operations (validpytensor
functions). When sampling using the default NUTS sampler inPyMC
, this option allows for maximum compatibility. Not recommended when usingJAX
-based samplers.
- the
"blackbox"
: Use this option for "black box" likelihoods that are not differentiable. These likelihoods are typicallyCallable
s in Python that cannot be directly integrated to apytensor
computational graph.hssm
will wrap theseCallable
s in apytensor
Op
so it can be part of the graph.
Default vs. Custom Likelihoods¶
HSSM provides many default likelihood functions out-of-the-box. The supported likelihoods are:
- For
analytical
kind:ddm
andddm_sdv
models. - For
approx_differentiable
kind:ddm
,ddm_sdv
,angle
,levy
,ornstein
,weibull
,race_no_bias_angle_4
andddm_seq2_no_bias
. - For
blackbox
kind:ddm
,ddm_sdv
andfull_ddm
models.
For a model that has default likelihood functions, only the model
argument needs to be specified.
ddm_model_analytical = hssm.HSSM(data, model="ddm")
Model initialized successfully.
ddm_model_analytical
Hierarchical Sequential Sampling Model Model: ddm Response variable: rt,response Likelihood: analytical Observations: 1000 Parameters: v: Prior: Normal(mu: 0.0, sigma: 2.0) Explicit bounds: (-inf, inf) a: Prior: HalfNormal(sigma: 2.0) Explicit bounds: (0.0, inf) z: Prior: Uniform(lower: 0.0, upper: 1.0) Explicit bounds: (0.0, 1.0) t: Prior: HalfNormal(sigma: 2.0) Explicit bounds: (0.0, inf) Lapse probability: 0.05 Lapse distribution: Uniform(lower: 0.0, upper: 20.0)
ddm_model_analytical.graph()
The ddm
and ddm_sdv
models have analytical
and approx_differentiable
likelihoods. If loglik_kind
is not specified, the analytical
likelihood will be used. We can however directly specify the loglik_kind
argument for a given model, and if available, the likelihood backend will be switched automatically.
ddm_model_approx_diff = hssm.HSSM(
data, model="ddm", loglik_kind="approx_differentiable"
)
Model initialized successfully.
While the model graph looks the same:
ddm_model_approx_diff.graph()
We can check that the likelihood is now coming from a different backend by printing the model string:
ddm_model_approx_diff
Hierarchical Sequential Sampling Model Model: ddm Response variable: rt,response Likelihood: approx_differentiable Observations: 1000 Parameters: v: Prior: Uniform(lower: -3.0, upper: 3.0) Explicit bounds: (-3.0, 3.0) a: Prior: Uniform(lower: 0.30000001192092896, upper: 2.5) Explicit bounds: (0.3, 2.5) z: Prior: Uniform(lower: 0.0, upper: 1.0) Explicit bounds: (0.0, 1.0) t: Prior: HalfNormal(sigma: 2.0) Explicit bounds: (0.0, 2.0) Lapse probability: 0.05 Lapse distribution: Uniform(lower: 0.0, upper: 20.0)
Note how under the Likelihood rubric, it now says "approx_differentiable". Another simple way to check this is to access the loglik_kind
attribute of our HSSM model.
ddm_model_approx_diff.loglik_kind
'approx_differentiable'
Overriding default likelihoods¶
Sometimes a likelihood other than the default version is preferred. In that case, you can supply a likelihood function directly to the loglik
parameter. We will discuss acceptable likelihood function types in a moment.
For illustration we load the basic analytical DDM likelihood, which is shipped with HSSM and supply it manually our HSSM model class.
from hssm.likelihoods.analytical import logp_ddm
ddm_model_analytical_override = hssm.HSSM(
data, model="ddm", loglik_kind="analytical", loglik=logp_ddm
)
Model initialized successfully.
HSSM automatically constructed our model with the likelihood function we provided. We can now take posterior samples as usual.
idata = ddm_model_analytical_override.sample(draws=500, tune=500, chains=2)
Using default initvals.
Auto-assigning NUTS sampler... Initializing NUTS using adapt_diag... Multiprocess sampling (2 chains in 2 jobs) NUTS: [t, z, a, v]
Output()
Sampling 2 chains for 500 tune and 500 draw iterations (1_000 + 1_000 draws total) took 7 seconds. We recommend running at least 4 chains for robust computation of convergence diagnostics 100%|██████████| 1000/1000 [00:00<00:00, 2612.86it/s]
ddm_model_analytical_override._inference_obj
-
<xarray.Dataset> Size: 20kB Dimensions: (chain: 2, draw: 500) Coordinates: * chain (chain) int64 16B 0 1 * draw (draw) int64 4kB 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499 Data variables: t (chain, draw) float32 4kB 0.3409 0.3232 0.3109 ... 0.3639 0.3557 v (chain, draw) float32 4kB 0.61 0.5658 0.5585 ... 0.5517 0.5479 0.54 z (chain, draw) float32 4kB 0.4808 0.4753 0.471 ... 0.5016 0.5026 a (chain, draw) float32 4kB 1.483 1.493 1.486 ... 1.476 1.405 1.459 Attributes: created_at: 2024-12-25T23:42:48.582666+00:00 arviz_version: 0.18.0 inference_library: pymc inference_library_version: 5.16.2 sampling_time: 7.497994899749756 tuning_steps: 500 modeling_interface: bambi modeling_interface_version: 0.14.0
-
<xarray.Dataset> Size: 8MB Dimensions: (chain: 2, draw: 500, __obs__: 1000) Coordinates: * chain (chain) int64 16B 0 1 * draw (draw) int64 4kB 0 1 2 3 4 5 6 ... 493 494 495 496 497 498 499 * __obs__ (__obs__) int64 8kB 0 1 2 3 4 5 6 ... 994 995 996 997 998 999 Data variables: rt,response (chain, draw, __obs__) float64 8MB -0.893 -2.625 ... -3.209 Attributes: modeling_interface: bambi modeling_interface_version: 0.14.0
-
<xarray.Dataset> Size: 126kB Dimensions: (chain: 2, draw: 500) Coordinates: * chain (chain) int64 16B 0 1 * draw (draw) int64 4kB 0 1 2 3 4 5 ... 495 496 497 498 499 Data variables: (12/17) acceptance_rate (chain, draw) float64 8kB 0.9552 0.9626 ... 0.9021 diverging (chain, draw) bool 1kB False False ... False False energy (chain, draw) float64 8kB 1.995e+03 ... 1.995e+03 energy_error (chain, draw) float64 8kB 0.08009 ... -0.8272 index_in_trajectory (chain, draw) int64 8kB 4 -2 2 5 -1 ... -4 -3 4 -3 4 largest_eigval (chain, draw) float64 8kB nan nan nan ... nan nan nan ... ... process_time_diff (chain, draw) float64 8kB 0.007094 ... 0.007121 reached_max_treedepth (chain, draw) bool 1kB False False ... False False smallest_eigval (chain, draw) float64 8kB nan nan nan ... nan nan nan step_size (chain, draw) float64 8kB 0.6009 0.6009 ... 0.7046 step_size_bar (chain, draw) float64 8kB 0.6293 0.6293 ... 0.6216 tree_depth (chain, draw) int64 8kB 3 3 2 4 2 3 3 ... 2 3 3 3 3 3 Attributes: created_at: 2024-12-25T23:42:48.590093+00:00 arviz_version: 0.18.0 inference_library: pymc inference_library_version: 5.16.2 sampling_time: 7.497994899749756 tuning_steps: 500 modeling_interface: bambi modeling_interface_version: 0.14.0
-
<xarray.Dataset> Size: 16kB Dimensions: (__obs__: 1000, rt,response_extra_dim_0: 2) Coordinates: * __obs__ (__obs__) int64 8kB 0 1 2 3 4 ... 996 997 998 999 * rt,response_extra_dim_0 (rt,response_extra_dim_0) int64 16B 0 1 Data variables: rt,response (__obs__, rt,response_extra_dim_0) float32 8kB 1... Attributes: created_at: 2024-12-25T23:42:48.592276+00:00 arviz_version: 0.18.0 inference_library: pymc inference_library_version: 5.16.2 modeling_interface: bambi modeling_interface_version: 0.14.0
Using Custom Likelihoods¶
If you are specifying a model with a kind of likelihood that's not included in the list above, then HSSM considers that you are using a custom model with custom likelihoods. In this case, you will need to specify your entire model. Below is the procedure to specify a custom model:
Specify a
model
string. It can be any string that helps identify the model, but if it is not one of the model strings supported in thessm_simulators
package see full list here, you will need to supply aRandomVariable
class tomodel_config
detailed below. Otherwise, you can still perform MCMC sampling, but sampling from the posterior predictive distribution will raise a ValueError.Specify a
model_config
. It typically contains the following fields:"list_params"
: Required if yourmodel
string is not one ofddm
,ddm_sdv
,full_ddm
,angle
,levy
,ornstein
,weibull
,race_no_bias_angle_4
andddm_seq2_no_bias
. A list ofstr
indicating the parameters of the model. The order in which the parameters are specified in this list is important. Values for each parameter will be passed to the likelihood function in this order."backend"
: Optional. Only used whenloglik_kind
isapprox_differentiable
and an onnx file is supplied for the likelihood approximation network (LAN). Valid values are"jax"
or"pytensor"
. It determines whether the LAN in ONNX should be converted to"jax"
or"pytensor"
. If not provided,jax
will be used for maximum performance."default_priors"
: Optional. Adict
indicating the default priors for each parameter."bounds"
: Optional. Adict
of(lower, upper)
tuples indicating the acceptable boundaries for each parameter. In the case of LAN, these bounds are training boundaries."rv"
: Optional. Can be aRandomVariable
class containing the user's ownrng_fn
function for sampling from the distribution that the user is supplying. If not supplied, HSSM will automatically generate aRandomVariable
using the simulator identified bymodel
from thessm_simulators
package. Ifmodel
is not supported inssm_simulators
, a warning will be raised letting the user know that sampling from theRandomVariable
will result in errors.
Specify
loglik
andloglik_kind
.Specify parameter priors in
include
.
NOTE:
default_priors
and bounds
in model_config
specifies default priors and bounds for the model. Actual priors and defaults should be provided via the include
list and will override these defaults.
Below are a few examples:
# An angle model with an analytical likelihood function.
# Because `model` is known, no `list_params` needs to be provided.
custom_angle_model = hssm.HSSM(
data,
model="angle",
model_config={
"bounds": {
"v": (-3.0, 3.0),
"a": (0.3, 3.0),
"z": (0.1, 0.9),
"t": (0.001, 2.0),
"theta": (-0.1, 1.3),
} # bounds will be used to create Uniform (uninformative) priors by default
# if priors are not supplied in `include`.
},
loglik=custom_angle_logp,
loglik_kind="analytical",
)
# A fully customized model with a custom likelihood function.
# Because `model` is not known, a `list_params` needs to be provided.
my_custom_model = hssm.HSSM(
data,
model="my_model",
model_config={
"list_params": ["v", "a", "z", "t", "theta"],
"bounds": {
"v": (-3.0, 3.0),
"a": (0.3, 3.0),
"z": (0.1, 0.9),
"t": (0.001, 2.0),
"theta": (-0.1, 1.3),
} # bounds will be used to create Uniform (uninformative) priors by default
# if priors are not supplied in `include`.
"default_priors": ... # usually no need to supply this.
"rv": MyRV # provide a RandomVariable class if pps is needed.
},
loglik="my_model.onnx", # Can be a path to an onnx model.
loglik_kind="approx_differentiable",
include=[...]
)
Supported types of likelihoods¶
When default likelihoods are not used, custom likelihoods are supplied via loglik
argument to HSSM
. Depending on what loglik_kind
is used, loglik
supports different types of Python objects:
Type[pm.Distribution]
: Supports allloglik_kind
s.You can pass any subclass of
pm.Distribution
tologlik
representing the underlying top-level distribution of the model. It has to be a class instead of an instance of the class.Op
: Supports allloglik_kind
kinds.You can pass a
pytensor
Op
(an instance instead of the class itself), in which case HSSM will create a top-levelpm.Distribution
, which calls thisOp
in itslogp
function to compute the log-likelihood.Callable
: Supports allloglik_kind
s.You can use any Python Callable as well. When
loglik_kind
isblackbox
, HSSM will wrap it in apytensor
Op
and create a top-levelpm.Distribution
with it. Otherwise, HSSM will assume that this Python callable is created withpytensor
and is thus differentiable.str
orPathlike
: Only supported whenloglik_kind
isapprox_differentiable
.The
str
orPathlike
indicates the path to anonnx
file which represents the neural network for likelihood approximation. In the case ofstr
, if the path indicated bystr
is not found locally, HSSM will also look for theonnx
file in the official HuggingFace repo. An error is thrown when theonnx
file is not found.
Note
When using Op
and Callable
types of likelihoods, they need to have the this signature:
def logp_fn(data, *):
...
where data
is a 2-column numpy array and *
represents named arguments in the order of the parameters in list_params
. For example, if a model's list_params
is ["v", "a", "z", "t"]
, then the Op
or Callable
should at least look like this:
def logp_fn(data, v, a, z, t):
...
Using blackbox
likelihoods¶
HSSM also supports "black box" likelihood functions, which are assumed to not be differentiable. When loglik_kind
is blackbox
, by default, HSSM will switch to a MCMC sampler that does not use differentiation. Below is an example showing how to use a blackbox
likelihood function. We use a log-likelihood function for ddm
written in Cython to show that you can use any function or computation inside this function as long as the function itself has the signature defined above. See here for the function definition.
import bambi as bmb
import hddm_wfpt
# Define a function with fun(data, *) signature
def my_blackbox_loglik(data, v, a, z, t, err=1e-8):
"""Create a blackbox log-likelihood function for the DDM model.
Note the function signature: the first argument must be the data, and the
remaining arguments are the parameters to be estimated. The function must
return the log-likelihood of the data given the parameters.
Parameters
----------
data : np.ndarray
A 2D array with columns for the RT and choice of each trial.
"""
data = data[:, 0] * data[:, 1]
data_nrows = data.shape[0]
# Our function expects inputs as float64, but they are not guaranteed to
# come in as such --> we type convert
return hddm_wfpt.wfpt.wiener_logp_array(
np.float64(data),
(np.ones(data_nrows) * v).astype(np.float64),
np.ones(data_nrows) * 0,
(np.ones(data_nrows) * 2 * a).astype(np.float64),
(np.ones(data_nrows) * z).astype(np.float64),
np.ones(data_nrows) * 0,
(np.ones(data_nrows) * t).astype(np.float64),
np.ones(data_nrows) * 0,
err,
1,
)
# Create the model with pdf_ddm_blackbox
model = hssm.HSSM(
data=data,
model="ddm",
loglik=my_blackbox_loglik,
loglik_kind="blackbox",
model_config={
"bounds": {
"v": (-10.0, 10.0),
"a": (0.0, 4.0),
"z": (0.0, 1.0),
"t": (0.0, 2.0),
}
},
t=bmb.Prior("Uniform", lower=0.0, upper=2.0),
)
Model initialized successfully.
model.graph()
sample = model.sample()
Using default initvals.
Multiprocess sampling (4 chains in 4 jobs) CompoundStep >Slice: [t] >Slice: [z] >Slice: [a] >Slice: [v]
Output()
Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 24 seconds. 100%|██████████| 4000/4000 [00:01<00:00, 2112.50it/s]
model.summary()
mean | sd | hdi_3% | hdi_97% | mcse_mean | mcse_sd | ess_bulk | ess_tail | r_hat | |
---|---|---|---|---|---|---|---|---|---|
t | 0.360 | 0.020 | 0.323 | 0.397 | 0.001 | 0.000 | 1055.0 | 1789.0 | 1.0 |
v | 0.527 | 0.034 | 0.460 | 0.589 | 0.001 | 0.001 | 1285.0 | 2216.0 | 1.0 |
z | 0.498 | 0.013 | 0.473 | 0.523 | 0.000 | 0.000 | 1058.0 | 1626.0 | 1.0 |
a | 1.445 | 0.026 | 1.394 | 1.495 | 0.001 | 0.000 | 1577.0 | 1922.0 | 1.0 |
model.plot_trace()
Using the blackbox interface provides maximum flexibility on the user side. We hope you will find it useful!