Custom models from onnx files
Build HSSM models starting from ONNX files¶
In this tutorial we build a HSSM
model directly from an onnx
file. For our purposes, the onnx
file-format provides nice translation layer from deep learning frameworks into a common layer from which we can then reconstruct computation graph to use through PyMC.
import os
import matplotlib.pyplot as plt
import numpy as np
import hssm
Loading the network¶
# Networks
network_path = os.path.join("data", "race_3_no_bias_lan_no_batch.onnx")
The network we load here does not have dynamic input dimensions, which prevents us from batching computations.
Instead of fixing things behind the scenes and loading a fixed network, we provide a useful snippet below that shows how to rectify this situation.
import onnx
import onnxruntime as ort
# Load model from path
onnx_model = onnx.load(network_path)
# Change input and output dimensions to be dynamic to allow for batching
# (in case this is not already done)
for input_tensor in onnx_model.graph.input:
dim_proto = input_tensor.type.tensor_type.shape.dim[0]
if not dim_proto.dim_param == "None":
dim_proto.dim_param = "None"
for output_tensor in onnx_model.graph.output:
dim_proto = output_tensor.type.tensor_type.shape.dim[0]
if not dim_proto.dim_param == "None":
dim_proto.dim_param = "None"
input_name = onnx_model.graph.input[0].name
# Please uncomment the below line to save the adjusted model
# onnx.save(onnx_model, "test_files/race_3_no_bias_lan_batch.onnx")
Armed with the corrected network, let's test inference speed on a data-batch of $1000$ trials.
# Load model batch ready model
ort_session = ort.InferenceSession("data/race_3_no_bias_lan_batch.onnx")
# Test inference speed
import time
start = time.time()
for i in range(100):
ort_session.run(
None, {input_name: np.random.uniform(size=(1000, 8)).astype(np.float32)}
)
end = time.time()
print(f"Time taken: {(end - start) / 100} seconds")
Time taken: 0.00040030956268310546 seconds
Defining the Likelihood¶
The network we loaded corresponds to a LAN
, for a Race model with three choice alternatives.
This model has three drift parameters v0, v1, v2
, a boundary parameter a
, a starting point bias z
and a non-decision-time t
.
Data from this model has the usual rt, choice
format.
We use this to construct a simple blackbox likelihood function below. This likelihood function takes the respective data and model parameters as arguments.
The function body shapes these input arguments into a matrix and performs a batched forward pass through the loaded network via the onnx.runtime
.
def my_blackbox_race_model(data, v0, v1, v2, a, t, z):
"""Calculate log-likelihood for a 3-choice race model.
Parameters
----------
data : np.ndarray
Array of shape (n_trials, 2) containing response times in first column
and choices (0, 1, or 2) in second column
v0 : float
Drift rate for accumulator 0
v1 : float
Drift rate for accumulator 1
v2 : float
Drift rate for accumulator 2
a : float
Decision threshold/boundary
t : float
Non-decision time
z : float
Starting point bias
Returns
-------
np.ndarray
Array of log-likelihood values for each trial
"""
data_nrows = data.shape[0]
data = np.vstack(
[np.full(data_nrows, param_) for param_ in [v0, v1, v2, a, t, z]]
+ [data[:, 0], data[:, 1]]
).T.astype(np.float32)
return ort_session.run(None, {input_name: data})[0].squeeze()
Simulate example data¶
# Set parameters
v0 = 1.0
v1 = 0.5
v2 = 0.25
a = 1.5
t = 0.3
z = 0.5
# simulate some data from the model
obs_race3 = hssm.simulate_data(
theta=dict(v0=v0, v1=v1, v2=v2, a=a, t=t, z=z), model="race_no_bias_3", size=1000
)
Test Likelihood Outputs¶
# Test that outputs are reasonable
for choice in [0, 1, 2]:
rts = np.linspace(0, 20, 1000)
choices = np.repeat(choice, 1000)
data = np.vstack([rts, choices]).T
out = my_blackbox_race_model(data, v0, v1, v2, a, t, z)
plt.plot(rts, np.exp(out), label=f"choice: {choice}")
plt.legend()
plt.show()
Build HSSM Model¶
We can now build a simple HSSM
model that takes in our new blackbox likelihood.
model = hssm.HSSM(
data=obs_race3,
model="race_no_bias_3", # some name for the model
model_config={
"list_params": ["v0", "v1", "v2", "a", "z", "t"],
"bounds": {
"v0": (0.0, 2.5),
"v1": (0.0, 2.5),
"v2": (0.0, 2.5),
"a": (1.0, 3.0),
"z": (0.0, 0.9),
"t": (0.001, 2),
},
}, # minimal specification of model parameters and parameter bounds
loglik_kind="blackbox", # use the blackbox loglik
loglik=my_blackbox_race_model,
choices=[0, 1, 2], # list the legal choice options
z=0.5,
p_outlier=0,
)
You have specified the `lapse` argument to include a lapse distribution, but `p_outlier` is set to either 0 or None. Your lapse distribution will be ignored. Model initialized successfully.
model.graph()
model.sample(draws=500, tune=200, discard_tuned_samples=False)
Using default initvals.
Multiprocess sampling (4 chains in 4 jobs) CompoundStep >Slice: [v1] >Slice: [t] >Slice: [a] >Slice: [v2] >Slice: [v0]
Output()
Sampling 4 chains for 200 tune and 500 draw iterations (800 + 2_000 draws total) took 33 seconds. The rhat statistic is larger than 1.01 for some parameters. This indicates problems during sampling. See https://arxiv.org/abs/1903.08008 for details The effective sample size per chain is smaller than 100 for some parameters. A higher number is needed for reliable rhat and ess computation. See https://arxiv.org/abs/1903.08008 for details /Users/afengler/Library/CloudStorage/OneDrive-Personal/proj_hssm/HSSM/.venv/lib/python3.12/site-packages/pymc/pytensorf.py:958: FutureWarning: compile_pymc was renamed to compile. Old name will be removed in a future release of PyMC warnings.warn( 100%|██████████| 2000/2000 [00:01<00:00, 1959.13it/s]
-
<xarray.Dataset> Size: 84kB Dimensions: (chain: 4, draw: 500) Coordinates: * chain (chain) int64 32B 0 1 2 3 * draw (draw) int64 4kB 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499 Data variables: a (chain, draw) float64 16kB 1.406 1.409 1.413 ... 1.402 1.465 1.494 v1 (chain, draw) float64 16kB 0.4684 0.4786 0.4823 ... 0.4551 0.5755 t (chain, draw) float64 16kB 0.301 0.3007 0.3026 ... 0.2979 0.2961 v0 (chain, draw) float64 16kB 0.9083 0.9667 0.9723 ... 1.015 1.008 v2 (chain, draw) float64 16kB 0.209 0.1775 0.1519 ... 0.3455 0.2646 Attributes: created_at: 2025-07-13T13:38:55.187187+00:00 arviz_version: 0.21.0 inference_library: pymc inference_library_version: 5.21.1 sampling_time: 33.21286725997925 tuning_steps: 200 modeling_interface: bambi modeling_interface_version: 0.15.0
-
<xarray.Dataset> Size: 16MB Dimensions: (chain: 4, draw: 500, __obs__: 1000) Coordinates: * chain (chain) int64 32B 0 1 2 3 * draw (draw) int64 4kB 0 1 2 3 4 5 6 ... 493 494 495 496 497 498 499 * __obs__ (__obs__) int64 8kB 0 1 2 3 4 5 6 ... 994 995 996 997 998 999 Data variables: rt,response (chain, draw, __obs__) float64 16MB -0.3937 0.2195 ... -0.7151 Attributes: modeling_interface: bambi modeling_interface_version: 0.15.0
-
<xarray.Dataset> Size: 164kB Dimensions: (chain: 4, draw: 500, nstep_in_dim_0: 5, nstep_out_dim_0: 5) Coordinates: * chain (chain) int64 32B 0 1 2 3 * draw (draw) int64 4kB 0 1 2 3 4 5 6 ... 494 495 496 497 498 499 * nstep_in_dim_0 (nstep_in_dim_0) int64 40B 0 1 2 3 4 * nstep_out_dim_0 (nstep_out_dim_0) int64 40B 0 1 2 3 4 Data variables: nstep_in (chain, draw, nstep_in_dim_0) int64 80kB 5 8 5 0 ... 7 0 3 nstep_out (chain, draw, nstep_out_dim_0) int64 80kB 0 1 0 1 ... 0 1 0 Attributes: created_at: 2025-07-13T13:38:55.199850+00:00 arviz_version: 0.21.0 inference_library: pymc inference_library_version: 5.21.1 sampling_time: 33.21286725997925 tuning_steps: 200 modeling_interface: bambi modeling_interface_version: 0.15.0
-
<xarray.Dataset> Size: 24kB Dimensions: (__obs__: 1000, rt,response_extra_dim_0: 2) Coordinates: * __obs__ (__obs__) int64 8kB 0 1 2 3 4 ... 996 997 998 999 * rt,response_extra_dim_0 (rt,response_extra_dim_0) int64 16B 0 1 Data variables: rt,response (__obs__, rt,response_extra_dim_0) float64 16kB ... Attributes: created_at: 2025-07-13T13:38:55.201998+00:00 arviz_version: 0.21.0 inference_library: pymc inference_library_version: 5.21.1 modeling_interface: bambi modeling_interface_version: 0.15.0
-
<xarray.Dataset> Size: 6MB Dimensions: (chain: 4, draw: 200, __obs__: 1000) Coordinates: * chain (chain) int64 32B 0 1 2 3 * draw (draw) int64 2kB 0 1 2 3 4 5 6 7 ... 193 194 195 196 197 198 199 * __obs__ (__obs__) int64 8kB 0 1 2 3 4 5 6 7 ... 993 994 995 996 997 998 999 Data variables: v1 (chain, draw) float64 6kB 1.466e-10 2.5 2.5 ... 0.3742 0.621 0.5092 t (chain, draw) float64 6kB 0.1457 0.1819 0.1776 ... 0.3038 0.296 a (chain, draw) float64 6kB 3.0 2.976 2.977 ... 1.421 1.448 1.441 v2 (chain, draw) float64 6kB 1.329 1.675 1.623 ... 0.2861 0.2813 v0 (chain, draw) float64 6kB 2.5 2.496 2.496 ... 1.105 0.9855 1.004 v0_mean (chain, draw, __obs__) float64 6MB 2.5 2.5 2.5 ... 1.004 1.004 Attributes: created_at: 2025-07-13T13:38:55.197877+00:00 arviz_version: 0.21.0 inference_library: pymc inference_library_version: 5.21.1 sampling_time: 33.21286725997925 tuning_steps: 200 modeling_interface: bambi modeling_interface_version: 0.15.0
-
<xarray.Dataset> Size: 66kB Dimensions: (chain: 4, draw: 200, nstep_in_dim_0: 5, nstep_out_dim_0: 5) Coordinates: * chain (chain) int64 32B 0 1 2 3 * draw (draw) int64 2kB 0 1 2 3 4 5 6 ... 194 195 196 197 198 199 * nstep_in_dim_0 (nstep_in_dim_0) int64 40B 0 1 2 3 4 * nstep_out_dim_0 (nstep_out_dim_0) int64 40B 0 1 2 3 4 Data variables: nstep_in (chain, draw, nstep_in_dim_0) int64 32kB 0 0 0 0 ... 3 3 2 nstep_out (chain, draw, nstep_out_dim_0) int64 32kB 344 3 131 ... 0 0 Attributes: created_at: 2025-07-13T13:38:55.200992+00:00 arviz_version: 0.21.0 inference_library: pymc inference_library_version: 5.21.1 sampling_time: 33.21286725997925 tuning_steps: 200 modeling_interface: bambi modeling_interface_version: 0.15.0
import arviz as az
az.plot_trace(model.traces, var_names=["~v0_mean"])
plt.tight_layout()
/Users/afengler/Library/CloudStorage/OneDrive-Personal/proj_hssm/HSSM/.venv/lib/python3.12/site-packages/arviz/utils.py:146: UserWarning: Items starting with ~: ['v0_mean'] have not been found and will be ignored warnings.warn(
az.plot_pair(model.traces, var_names=["~v0_mean"])
plt.tight_layout()
/Users/afengler/Library/CloudStorage/OneDrive-Personal/proj_hssm/HSSM/.venv/lib/python3.12/site-packages/arviz/utils.py:146: UserWarning: Items starting with ~: ['v0_mean'] have not been found and will be ignored warnings.warn(