Construct Custom Models from simulators and JAX callables¶
import os
import pickle
import arviz as az
import jax.numpy as jnp
import lanfactory as lf
import matplotlib.pyplot as plt
import numpy as np
import pymc as pm
import hssm
from hssm.config import ModelConfig
wandb not available wandb not available
Simulate Data¶
As a pre-amble, we will simulate a simple dataset from the DDM model to use through the example.
# simulate some data from the model
SEED = 123
obs_ddm = hssm.simulate_data(
theta=dict(v=0.40,
a=1.25,
t=0.2,
z=0.5),
model="ddm",
size=500,
random_state = SEED,
)
Construct PyMC distribution from simulator and JAX callable¶
Create JAX Log-likelihood Callable¶
What we need is a callable that takes in a matrix that stacks model parameters and data trialwise, hence with input dimensions: $trials \times (dim-parameters + dim-data)$. The function should return a vector of length $trials$ that contains the log-likelihood for each trial.
If we have a JAX function with this signature, we will be able to proceed to create a valid PyMC distribution form helper functions provided by hssm
.
In this particular example, we will reinstantiate a pretrained LAN via utilities from the lanfactory
package. This is not necessary, but illustrates how to make use of the braoder ecosystem around HSSM.
Note:
lanfactory
is an optional dependency ofhssm
and can be installed using the commandpip install hssm[notebook]
. Alternatively, you can install it directly withpip install lanfactory
.
# Loaded Net
jax_infer = lf.trainers.MLPJaxFactory(
network_config=pickle.load(
open(
os.path.join(
"data", "jax_models", "ddm", "ddm_jax_lan_network_config.pickle"
),
"rb",
)
),
train=False,
)
jax_mlp_forward, _ = jax_infer.make_forward_partial(
seed=42,
input_dim=4 + 2, # n-parameters (v,a,z,t) + n-data (rts and choices)
state=os.path.join("data", "jax_models", "ddm", "ddm_jax_lan_train_state.jax"),
add_jitted=True,
)
Checking the signature of the JAX callable¶
We can test the signature of the JAX callable by passing in a batch of inputs and checking the output shape.
# Testing the signature of the JAX function 2
n_dim_model_parameters = 4
n_dim_data = 2
n_trials = 1
in_ = jnp.zeros((n_trials,
n_dim_model_parameters + n_dim_data))
out = jax_mlp_forward(in_)
print(out.shape)
(1, 1)
from hssm.distribution_utils.jax import make_jax_single_trial_logp_from_network_forward
import jax.numpy as jnp
jax_logp = make_jax_single_trial_logp_from_network_forward(jax_forward_fn = jax_mlp_forward,
params_only = False)
# Test call
jax_logp(jnp.zeros((2)),
jnp.array([1.0]),
jnp.array([1.5]),
jnp.array([0.5]),
jnp.array([0.3]))
Array(-20.01882519, dtype=float64)
Decorate and wrap a simulator¶
The simulator-signature expected by hssm
is the following:
A simulator is a callable that takes in a matrix that stacks model parameters and data trialwise, hence with input dimensions: $trials \times (|paramters))$. The function should return a matrix of shape $trials \times |data|$.
We will use the decorate_atomic_simulator()
utility to annotate the simulator with necessary metadata we use the hssm_sim_wrapper()
function from the ssm-simulators
package to make our simulator ready for usage inside a PyTensor
RandomVariable
later.
If you check the signature of the resulting rv_ready_simulator
, you should find no problems shoe-horning your own custon simulator into the corresponding behavior.
decorate_atomic_simulator()
is just a Python
decorator
that you can use around any function.
You can start from any simulator you like, we use the one from the ssm-simulators
package for convenience.
from ssms.hssm_support import hssm_sim_wrapper, decorate_atomic_simulator
from functools import partial
from ssms.basic_simulators.simulator import simulator
rv_ready_simulator = partial(hssm_sim_wrapper,
simulator_fun = simulator,
model = "ddm",
n_replicas = 1) # AF-TODO: n_replicas should default to 1 instead of being required
# We decorate the simulator to attach some metadata
# that HSSM can use
decorated_simulator = decorate_atomic_simulator(
model_name="ddm", choices=[-1, 1], obs_dim=2
)(rv_ready_simulator)
Checking the signature of the decorated simulator¶
We can check the signature of the decorated simulator by passing in a batch of parameters and checking the output shape.
decorated_simulator(
theta=np.tile(np.array([1.0, 1.5, 0.5, 0.2]), (10, 1)), random_state=42
)
array([[ 2.02210951, 1. ], [ 2.58906269, 1. ], [ 1.1866796 , 1. ], [ 1.06390691, 1. ], [ 1.32591701, 1. ], [ 1.04748344, 1. ], [ 1.37392318, 1. ], [ 0.78049529, -1. ], [ 1.6851691 , 1. ], [ 1.10256469, 1. ]])
Create valid PyMC distribution¶
We have all ingredients to create a valid PyMC distribution, form a few helper functions provided by hssm
at this point.
A valid Distribution will need two ingredients:
- A
RandomVariable
(RV) that encodes the simulator and parameter names. - A likelihood function, which is a valid PyTensor
Op
that encodes the log-likelihood of the model.
We will use the make_hssm_rv()
helper function to create the RandomVariable
and the make_likelihood_callable()
helper function to create the likelihood Op
.
Finally, the make_distribution()
helper function will package everything into a valid PyMC distribution.
NOTE:
There are a few helpful settings which we can use to customize our distribution. We will not cover all details here, however it is worth highlighting the params_is_reg
argument in make_likelihood_callable()
. This argument is used to tell PyMC whether the parameter is a regression parameter or not. Specifically, if a parameter is treated as a regression, the likelihood function is built to assume that the parameter is passed trial-wise, i.e. as a vector of length n_trials
.
Note here, we set all parameters to be non-regression parameters, as we expect the parameters to be passed as single values in the simple PyMC
model below.
from hssm.distribution_utils.dist import (
make_distribution,
make_hssm_rv,
make_likelihood_callable,
)
# Step 1: Define a pytensor RandomVariable
CustomRV = make_hssm_rv(
simulator_fun=decorated_simulator,
list_params=["v", "a", "z", "t"]
)
# Step 2: Define a likelihood function
logp_jax_op = make_likelihood_callable(
loglik=jax_logp,
loglik_kind="approx_differentiable",
backend="jax",
params_is_reg=[True, False, False, False],
params_only=False,
)
# Step 3: Define a distribution
CustomDistribution = make_distribution(
rv=CustomRV,
loglik=logp_jax_op,
list_params=["v", "a", "z", "t"],
bounds=dict(v=(-3, 3), a=(0.5, 3.0), z=(0.1, 0.9), t=(0, 2.0)),
)
passing here: params_only: False params_is_reg: [True, False, False, False]
We can now test the distribution by passing it to a simple PyMC model.
from pytensor import tensor as pt
# Test via basic pymc model
with pm.Model() as model_pymc:
v = pm.Normal("v", mu=0, sigma=1)
a = pm.Weibull("a", alpha=2.0, beta=1.2)
z = pm.Beta("z", alpha=10, beta=10)
t = pm.Weibull("t", alpha=1.5, beta=0.5)
# We define `v` as a vector of length `n_trials`
# To conform to the expected signature of the likelihood function
v_det = pm.Deterministic("v_det", v * pt.ones(obs_ddm.shape[0]))
CustomDistribution("custom", observed=obs_ddm.values, v=v_det, a=a, z=z, t=t)
pm.model_to_graphviz(model_pymc)
max_shape: (500,) size: (np.int64(500),)
with model_pymc:
idata = pm.sample(draws=500,
tune=500,
chains=2,
nuts_sampler="numpyro")
/Users/afengler/Library/CloudStorage/OneDrive-Personal/proj_hssm/HSSM/.venv/lib/python3.11/site-packages/pymc/sampling/jax.py:475: UserWarning: There are not enough devices to run parallel chains: expected 2 but got 1. Chains will be drawn sequentially. If you are running MCMC in CPU, consider using `numpyro.set_host_device_count(2)` at the beginning of your program. You can double-check how many devices are available in your system using `jax.local_device_count()`. pmap_numpyro = MCMC( sample: 100%|██████████| 1000/1000 [00:12<00:00, 82.15it/s, 3 steps of size 4.10e-01. acc. prob=0.84] sample: 100%|██████████| 1000/1000 [00:10<00:00, 93.81it/s, 3 steps of size 3.90e-01. acc. prob=0.89] There was 1 divergence after tuning. Increase `target_accept` or reparameterize. We recommend running at least 4 chains for robust computation of convergence diagnostics
az.summary(idata)
mean | sd | hdi_3% | hdi_97% | mcse_mean | mcse_sd | ess_bulk | ess_tail | r_hat | |
---|---|---|---|---|---|---|---|---|---|
v | 0.508 | 0.022 | 0.467 | 0.549 | 0.001 | 0.001 | 506.0 | 457.0 | 1.00 |
a | 1.275 | 0.028 | 1.221 | 1.325 | 0.001 | 0.001 | 668.0 | 466.0 | 1.00 |
z | 0.472 | 0.014 | 0.447 | 0.499 | 0.001 | 0.001 | 426.0 | 327.0 | 1.01 |
t | 0.184 | 0.021 | 0.145 | 0.222 | 0.001 | 0.001 | 544.0 | 508.0 | 1.00 |
v_det[0] | 0.508 | 0.022 | 0.467 | 0.549 | 0.001 | 0.001 | 506.0 | 457.0 | 1.00 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
v_det[495] | 0.508 | 0.022 | 0.467 | 0.549 | 0.001 | 0.001 | 506.0 | 457.0 | 1.00 |
v_det[496] | 0.508 | 0.022 | 0.467 | 0.549 | 0.001 | 0.001 | 506.0 | 457.0 | 1.00 |
v_det[497] | 0.508 | 0.022 | 0.467 | 0.549 | 0.001 | 0.001 | 506.0 | 457.0 | 1.00 |
v_det[498] | 0.508 | 0.022 | 0.467 | 0.549 | 0.001 | 0.001 | 506.0 | 457.0 | 1.00 |
v_det[499] | 0.508 | 0.022 | 0.467 | 0.549 | 0.001 | 0.001 | 506.0 | 457.0 | 1.00 |
504 rows × 9 columns
az.plot_trace(idata)
plt.tight_layout()
Custom HSSM model from simulator and JAX callable¶
Next, we will create a custom HSSM model from the simulator and JAX callable. After the work we have done above, this is now very straightforward.
The only hssm
specific extra step is to define a ModelConfig
object,
which bundles all information about the model.
Then we pass our ModelConfig
object to the HSSM
class, along with the data and the log-likelihood function,
and hssm
will take care of the rest.
Importantly, hssm
will automatically understand how to construct the correct likelihood function
for the specified model configuration (parameter-wise regression settings, etc.). A step we have
to accomplish manually in the code above.
# Define model config
my_custom_model_config = ModelConfig(
response=["rt", "response"],
list_params=["v", "a", "z", "t"],
bounds={
"v": (-2.5, 2.5),
"a": (1.0, 3.0),
"z": (0.0, 0.9),
"t": (0.001, 2),
},
rv=decorated_simulator,
backend="jax",
choices=[-1, 1],
)
# Define the HSSM model
model = hssm.HSSM(
data=obs_ddm,
model="my_new_model", # some name for the model
model_config=my_custom_model_config,
loglik_kind="approx_differentiable", # use the blackbox loglik
loglik=jax_logp,
p_outlier=0,
)
model.graph()
You have specified the `lapse` argument to include a lapse distribution, but `p_outlier` is set to either 0 or None. Your lapse distribution will be ignored. passing here: params_only: False params_is_reg: [True, False, False, False] Model initialized successfully. max_shape: (500,) size: (np.int64(500),)
# Test sampling
model.sample(draws=500,
tune=200,
nuts_sampler="numpyro",
chains = 2,
cores = 2,
discard_tuned_samples=False)
Using default initvals.
/Users/afengler/Library/CloudStorage/OneDrive-Personal/proj_hssm/HSSM/.venv/lib/python3.11/site-packages/pymc/sampling/jax.py:475: UserWarning: There are not enough devices to run parallel chains: expected 2 but got 1. Chains will be drawn sequentially. If you are running MCMC in CPU, consider using `numpyro.set_host_device_count(2)` at the beginning of your program. You can double-check how many devices are available in your system using `jax.local_device_count()`. pmap_numpyro = MCMC( sample: 100%|██████████| 700/700 [00:08<00:00, 84.96it/s, 7 steps of size 4.33e-01. acc. prob=0.88] sample: 100%|██████████| 700/700 [00:07<00:00, 97.64it/s, 3 steps of size 4.85e-01. acc. prob=0.90] There was 1 divergence after tuning. Increase `target_accept` or reparameterize. We recommend running at least 4 chains for robust computation of convergence diagnostics 100%|██████████| 1000/1000 [00:00<00:00, 1287.66it/s]
-
<xarray.Dataset> Size: 36kB Dimensions: (chain: 2, draw: 500) Coordinates: * chain (chain) int64 16B 0 1 * draw (draw) int64 4kB 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499 Data variables: v (chain, draw) float64 8kB 0.5065 0.5218 0.5144 ... 0.4876 0.5138 z (chain, draw) float64 8kB 0.4671 0.4671 0.4669 ... 0.4706 0.4799 t (chain, draw) float64 8kB 0.1846 0.1945 0.1819 ... 0.1861 0.1711 a (chain, draw) float64 8kB 1.203 1.215 1.282 ... 1.304 1.275 1.259 Attributes: created_at: 2025-09-29T20:06:54.606700+00:00 arviz_version: 0.22.0 inference_library: numpyro inference_library_version: 0.19.0 sampling_time: 15.733445 tuning_steps: 200 modeling_interface: bambi modeling_interface_version: 0.15.0
-
<xarray.Dataset> Size: 4MB Dimensions: (chain: 2, draw: 500, __obs__: 500) Coordinates: * chain (chain) int64 16B 0 1 * draw (draw) int64 4kB 0 1 2 3 4 5 6 ... 493 494 495 496 497 498 499 * __obs__ (__obs__) int64 4kB 0 1 2 3 4 5 6 ... 494 495 496 497 498 499 Data variables: rt,response (chain, draw, __obs__) float64 4MB -0.6224 -2.021 ... -1.468 Attributes: modeling_interface: bambi modeling_interface_version: 0.15.0
-
<xarray.Dataset> Size: 53kB Dimensions: (chain: 2, draw: 500) Coordinates: * chain (chain) int64 16B 0 1 * draw (draw) int64 4kB 0 1 2 3 4 5 6 ... 494 495 496 497 498 499 Data variables: acceptance_rate (chain, draw) float64 8kB 0.5239 0.7118 ... 1.0 0.7309 step_size (chain, draw) float64 8kB 0.4328 0.4328 ... 0.4848 0.4848 diverging (chain, draw) bool 1kB False False False ... False False energy (chain, draw) float64 8kB 908.3 907.6 907.0 ... 903.9 903.8 n_steps (chain, draw) int64 8kB 15 7 7 7 7 7 7 7 ... 7 7 7 7 7 7 3 tree_depth (chain, draw) int64 8kB 4 3 3 3 3 3 3 3 ... 3 3 3 3 3 3 3 2 lp (chain, draw) float64 8kB 906.8 904.2 900.9 ... 901.5 903.1 Attributes: created_at: 2025-09-29T20:06:54.612280+00:00 arviz_version: 0.22.0 modeling_interface: bambi modeling_interface_version: 0.15.0
-
<xarray.Dataset> Size: 12kB Dimensions: (__obs__: 500, rt,response_extra_dim_0: 2) Coordinates: * __obs__ (__obs__) int64 4kB 0 1 2 3 4 ... 496 497 498 499 * rt,response_extra_dim_0 (rt,response_extra_dim_0) int64 16B 0 1 Data variables: rt,response (__obs__, rt,response_extra_dim_0) float64 8kB 0... Attributes: created_at: 2025-09-29T20:06:54.613262+00:00 arviz_version: 0.22.0 inference_library: numpyro inference_library_version: 0.19.0 sampling_time: 15.733445 tuning_steps: 200 modeling_interface: bambi modeling_interface_version: 0.15.0
az.plot_trace(model.traces)
plt.tight_layout()
az.plot_pair(model.traces)
plt.tight_layout()
We hope you find it easy to use the above example to leverage hssm
to fit your own custom models.