Construct Custom Models from simulators and JAX callables¶
import os
import pickle
import arviz as az
import jax.numpy as jnp
import lanfactory as lf
import matplotlib.pyplot as plt
import numpy as np
import pymc as pm
import hssm
from hssm.config import ModelConfig
Simulate Data¶
As a pre-amble, we will simulate a simple dataset from the DDM model to use through the example.
# simulate some data from the model
obs_ddm = hssm.simulate_data(
theta=dict(v=1.0, a=1.5, t=0.3, z=0.5), model="ddm", size=1000
)
Construct PyMC distribution from simulator and JAX callable¶
Create JAX Log-likelihood Callable¶
What we need is a callable that takes in a matrix that stacks model parameters and data trialwise, hence with input dimensions: $trials \times (dim-parameters + dim-data)$. The function should return a vector of length $trials$ that contains the log-likelihood for each trial.
If we have a JAX function with this signature, we will be able to proceed to create a valid PyMC distribution form helper functions provided by hssm
.
In this particular example, we will reinstantiate a pretrained LAN via utilities from the lanfactory
package. This is not necessary, but illustrates how to make use of the braoder ecosystem around HSSM.
# Loaded Net
jax_infer = lf.trainers.MLPJaxFactory(
network_config=pickle.load(
open(
os.path.join(
"data", "jax_models", "ddm", "ddm_jax_lan_network_config.pickle"
),
"rb",
)
),
train=False,
)
jax_logp, _ = jax_infer.make_forward_partial(
seed=42,
input_dim=4 + 2, # n-parameters (v,a,z,t) + n-data (rts and choices)
state=os.path.join("data", "jax_models", "ddm", "ddm_jax_lan_train_state.jax"),
add_jitted=True,
)
passing through identity
Checking the signature of the JAX callable¶
We can test the signature of the JAX callable by passing in a batch of inputs and checking the output shape.
# Testing the signature of the JAX function 1
n_trials = 10
jax_logp(np.tile(np.array([1.0, 1.5, 0.5, 0.3, 1.6, 1.0]), (n_trials, 1)))
passing through identity
Array([[-0.95924026], [-0.95924026], [-0.95924026], [-0.95924026], [-0.95924026], [-0.95924026], [-0.95924026], [-0.95924026], [-0.95924026], [-0.95924026]], dtype=float64)
# Testing the signature of the JAX function 2
n_dim_model_parameters = 4
n_dim_data = 2
in_ = jnp.zeros((n_trials, n_dim_model_parameters + n_dim_data))
out = jax_logp(in_)
print(out.shape)
passing through identity (10, 1)
Decorate and wrap a simulator¶
The simulator-signature expected by hssm
is the following:
A simulator is a callable that takes in a matrix that stacks model parameters and data trialwise, hence with input dimensions: $trials \times (dim-parameters)$. The function should return a matrix of shape $trials \times dim-data$.
We will use the decorate_atomic_simulator()
utility to annotate the simulator with necessary metadata.
Again, for illustration, as a starting point we will use a simulator from the ssm-simulators
package
You can start from any simulator you like.
from functools import partial
from ssms.basic_simulators.simulator import simulator
from hssm.utils import decorate_atomic_simulator
def sim_wrapper(simulator_fun, theta, model, n_samples, random_state, **kwargs):
"""Wrap a simulator function to match HSSM's expected interface.
Parameters
----------
simulator_fun : callable
The simulator function to wrap
theta : array-like
Model parameters, shape (n_trials, n_parameters)
model : str
Name of the model to simulate
n_samples : int
Number of samples to generate per trial
random_state : int or numpy.random.Generator
Random seed or random number generator
**kwargs
Additional keyword arguments passed to simulator_fun
Returns
-------
array-like
Array of shape (n_trials, 2) containing reaction times and choices
stacked column-wise
"""
out = simulator_fun(
theta=theta,
model=model,
n_samples=n_samples,
random_state=random_state,
**kwargs,
)
return np.column_stack([out["rts"], out["choices"]])
my_wrapped_simulator = partial(
sim_wrapper, simulator_fun=simulator, model="ddm", n_samples=1
)
decorated_simulator = decorate_atomic_simulator(
model_name="ddm", choices=[-1, 1], obs_dim=2
)(my_wrapped_simulator)
Checking the signature of the decorated simulator¶
We can check the signature of the decorated simulator by passing in a batch of parameters and checking the output shape.
decorated_simulator(
theta=np.tile(np.array([1.0, 1.5, 0.5, 0.3]), (10, 1)), random_state=42
)
array([[ 2.12210941, 1. ], [ 2.6890626 , 1. ], [ 1.28667963, 1. ], [ 1.16390681, 1. ], [ 1.42591691, 1. ], [ 1.14748347, 1. ], [ 1.47392309, 1. ], [ 0.88049531, -1. ], [ 1.78516912, 1. ], [ 1.20256472, 1. ]])
Create valid PyMC distribution¶
We have all ingredients to create a valid PyMC distribution, form a few helper functions provided by hssm
at this point.
A valid Distribution will need two ingredients:
- A
RandomVariable
(RV) that encodes the simulator and parameter names. - A likelihood function, which is a valid PyTensor
Op
that encodes the log-likelihood of the model.
We will use the make_hssm_rv()
helper function to create the RandomVariable
and the make_likelihood_callable()
helper function to create the likelihood Op
.
Finally, the make_distribution()
helper function will package everything into a valid PyMC distribution.
NOTE:
There are a few helpful settings which we can use to customize our distribution. We will not cover all details here, however it is worth highlighting the params_is_reg
argument in make_likelihood_callable()
. This argument is used to tell PyMC whether the parameter is a regression parameter or not. Specifically, if a parameter is treated as a regression, the likelihood function is built to assume that the parameter is passed trial-wise, i.e. as a vector of length n_trials
.
Note here, we set all parameters to be non-regression parameters, as we expect the parameters to be passed as single values in the simple PyMC
model below.
from hssm.distribution_utils.dist import (
make_distribution,
make_hssm_rv,
make_likelihood_callable,
)
# Step 1: Define a pytensor RandomVariable
CustomRV = make_hssm_rv(
simulator_fun=decorated_simulator, list_params=["v", "a", "z", "t"]
)
# Step 2: Define a likelihood function
logp_jax_op = make_likelihood_callable(
loglik=jax_logp,
loglik_kind="approx_differentiable",
backend="jax",
params_is_reg=[False, False, False, False],
params_only=False,
)
# Step 3: Define a distribution
CustomDistribution = make_distribution(
rv=CustomRV,
loglik=logp_jax_op,
list_params=["v", "a", "z", "t"],
bounds=dict(v=(-3, 3), a=(0.5, 3.0), z=(0.1, 0.9), t=(0, 2.0)),
)
We can now test the distribution by passing it to a simple PyMC model.
# Test via basic pymc model
with pm.Model() as model:
v = pm.Normal("v", mu=0, sigma=1)
a = pm.Uniform("a", lower=0.5, upper=3.0)
z = pm.Beta("z", alpha=10, beta=10)
t = pm.Weibull("t", alpha=0.5, beta=1.0)
CustomDistribution("custom", v=v, a=a, z=z, t=t, observed=obs_ddm.values)
with model:
idata = pm.sample(draws=1000, tune=200, chains=1, nuts_sampler="numpyro")
passing through identity passing through identity
0%| | 0/1200 [00:00<?, ?it/s]
passing through identity
sample: 100%|██████████| 1200/1200 [02:28<00:00, 8.07it/s, 23 steps of size 1.37e+00. acc. prob=0.92] Only one chain was sampled, this makes it impossible to run some convergence checks
az.plot_trace(idata)
array([[<Axes: title={'center': 'a'}>, <Axes: title={'center': 'a'}>], [<Axes: title={'center': 't'}>, <Axes: title={'center': 't'}>], [<Axes: title={'center': 'v'}>, <Axes: title={'center': 'v'}>], [<Axes: title={'center': 'z'}>, <Axes: title={'center': 'z'}>]], dtype=object)
Custom HSSM model from simulator and JAX callable¶
Next, we will create a custom HSSM model from the simulator and JAX callable. After the work we have done above, this is now very straightforward.
The only hssm
specific extra step is to define a ModelConfig
object,
which bundles all information about the model.
Then we pass our ModelConfig
object to the HSSM
class, along with the data and the log-likelihood function,
and hssm
will take care of the rest.
Importantly, hssm
will automatically understand how to construct the correct likelihood function
for the specified model configuration (parameter-wise regression settings, etc.). A step we have
to accomplish manually in the code above.
# Define model config
my_custom_model_config = ModelConfig(
response=["rt", "response"],
list_params=["v", "a", "z", "t"],
bounds={
"v": (-2.5, 2.5),
"a": (1.0, 3.0),
"z": (0.0, 0.9),
"t": (0.001, 2),
},
rv=decorated_simulator,
backend="jax",
choices=[-1, 1],
)
# Define the HSSM model
model = hssm.HSSM(
data=obs_ddm,
model="my_new_model", # some name for the model
model_config=my_custom_model_config,
loglik_kind="approx_differentiable", # use the blackbox loglik
loglik=jax_logp,
p_outlier=0,
)
model.graph()
You have specified the `lapse` argument to include a lapse distribution, but `p_outlier` is set to either 0 or None. Your lapse distribution will be ignored.
Model initialized successfully.
# Test sampling
model.sample(draws=500, tune=200, nuts_sampler="numpyro", discard_tuned_samples=False)
Using default initvals.
/Users/afengler/miniconda3/envs/hssm516/lib/python3.11/site-packages/pymc/sampling/jax.py:451: UserWarning: There are not enough devices to run parallel chains: expected 4 but got 1. Chains will be drawn sequentially. If you are running MCMC in CPU, consider using `numpyro.set_host_device_count(4)` at the beginning of your program. You can double-check how many devices are available in your system using `jax.local_device_count()`. pmap_numpyro = MCMC(
passing through identity
0%| | 0/700 [00:00<?, ?it/s]
passing through identity
sample: 100%|██████████| 700/700 [00:11<00:00, 59.57it/s, 15 steps of size 3.33e-01. acc. prob=0.90]
passing through identity
sample: 100%|██████████| 700/700 [00:11<00:00, 58.66it/s, 7 steps of size 2.87e-01. acc. prob=0.92]
passing through identity
sample: 100%|██████████| 700/700 [00:11<00:00, 63.43it/s, 3 steps of size 3.11e-01. acc. prob=0.94]
passing through identity
sample: 100%|██████████| 700/700 [00:08<00:00, 87.01it/s, 15 steps of size 3.84e-01. acc. prob=0.93] The rhat statistic is larger than 1.01 for some parameters. This indicates problems during sampling. See https://arxiv.org/abs/1903.08008 for details 10%|█ | 204/2000 [00:00<00:01, 1053.03it/s]
passing through identity
100%|██████████| 2000/2000 [00:01<00:00, 1166.14it/s]
-
<xarray.Dataset> Size: 68kB Dimensions: (chain: 4, draw: 500) Coordinates: * chain (chain) int64 32B 0 1 2 3 * draw (draw) int64 4kB 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499 Data variables: z (chain, draw) float64 16kB 0.398 0.3987 0.3995 ... 0.3791 0.3801 v (chain, draw) float64 16kB 1.395 1.433 1.378 ... 1.442 1.422 1.457 a (chain, draw) float64 16kB 1.753 1.768 1.755 ... 1.821 1.831 1.774 t (chain, draw) float64 16kB 0.1742 0.1857 0.1493 ... 0.1111 0.1386 Attributes: created_at: 2024-12-25T23:41:36.301669+00:00 arviz_version: 0.18.0 inference_library: numpyro inference_library_version: 0.15.2 sampling_time: 43.397043 tuning_steps: 200 modeling_interface: bambi modeling_interface_version: 0.14.0
-
<xarray.Dataset> Size: 16MB Dimensions: (chain: 4, draw: 500, __obs__: 1000) Coordinates: * chain (chain) int64 32B 0 1 2 3 * draw (draw) int64 4kB 0 1 2 3 4 5 6 ... 493 494 495 496 497 498 499 * __obs__ (__obs__) int64 8kB 0 1 2 3 4 5 6 ... 994 995 996 997 998 999 Data variables: rt,response (chain, draw, __obs__) float64 16MB -0.4214 -2.193 ... -1.67 Attributes: modeling_interface: bambi modeling_interface_version: 0.14.0
-
<xarray.Dataset> Size: 102kB Dimensions: (chain: 4, draw: 500) Coordinates: * chain (chain) int64 32B 0 1 2 3 * draw (draw) int64 4kB 0 1 2 3 4 5 6 ... 494 495 496 497 498 499 Data variables: acceptance_rate (chain, draw) float64 16kB 0.2345 0.8585 ... 0.9924 0.7746 diverging (chain, draw) bool 2kB False False False ... False False energy (chain, draw) float64 16kB 1.229e+03 ... 1.228e+03 lp (chain, draw) float64 16kB 1.224e+03 ... 1.225e+03 n_steps (chain, draw) int64 16kB 3 5 7 7 11 15 11 ... 7 3 15 7 7 15 step_size (chain, draw) float64 16kB 0.3328 0.3328 ... 0.3843 0.3843 tree_depth (chain, draw) int64 16kB 2 3 3 3 4 4 4 4 ... 3 3 2 4 3 3 4 Attributes: created_at: 2024-12-25T23:41:36.307714+00:00 arviz_version: 0.18.0 modeling_interface: bambi modeling_interface_version: 0.14.0
-
<xarray.Dataset> Size: 24kB Dimensions: (__obs__: 1000, rt,response_extra_dim_0: 2) Coordinates: * __obs__ (__obs__) int64 8kB 0 1 2 3 4 ... 996 997 998 999 * rt,response_extra_dim_0 (rt,response_extra_dim_0) int64 16B 0 1 Data variables: rt,response (__obs__, rt,response_extra_dim_0) float64 16kB ... Attributes: created_at: 2024-12-25T23:41:36.308543+00:00 arviz_version: 0.18.0 inference_library: numpyro inference_library_version: 0.15.2 sampling_time: 43.397043 tuning_steps: 200 modeling_interface: bambi modeling_interface_version: 0.14.0
az.plot_trace(model.traces)
plt.tight_layout()
az.plot_pair(model.traces)
plt.tight_layout()
We hope you find it easy to use the above example to leverage hssm
to fit your own custom models.