Custom models from onnx files
Build HSSM models starting from ONNX files¶
In this tutorial we build a HSSM
model directly from an onnx
file. For our purposes, the onnx
file-format provides nice translation layer from deep learning frameworks into a common layer from which we can then reconstruct computation graph to use through PyMC.
import os
import matplotlib.pyplot as plt
import numpy as np
import hssm
Loading the network¶
# Networks
network_path = os.path.join("data", "race_3_no_bias_lan_no_batch.onnx")
The network we load here does not have dynamic input dimensions, which prevents us from batching computations.
Instead of fixing things behind the scenes and loading a fixed network, we provide a useful snippet below that shows how to rectify this situation.
import onnx
import onnxruntime as ort
# Load model from path
onnx_model = onnx.load(network_path)
# Change input and output dimensions to be dynamic to allow for batching
# (in case this is not already done)
for input_tensor in onnx_model.graph.input:
dim_proto = input_tensor.type.tensor_type.shape.dim[0]
if not dim_proto.dim_param == "None":
dim_proto.dim_param = "None"
for output_tensor in onnx_model.graph.output:
dim_proto = output_tensor.type.tensor_type.shape.dim[0]
if not dim_proto.dim_param == "None":
dim_proto.dim_param = "None"
input_name = onnx_model.graph.input[0].name
# Please uncomment the below line to save the adjusted model
# onnx.save(onnx_model, "test_files/race_3_no_bias_lan_batch.onnx")
Armed with the corrected network, let's test inference speed on a data-batch of $1000$ trials.
# Load model batch ready model
ort_session = ort.InferenceSession("data/race_3_no_bias_lan_batch.onnx")
# Test inference speed
import time
start = time.time()
for i in range(100):
ort_session.run(
None, {input_name: np.random.uniform(size=(1000, 8)).astype(np.float32)}
)
end = time.time()
print(f"Time taken: {(end - start) / 100} seconds")
Time taken: 0.0005616998672485352 seconds
Defining the Likelihood¶
The network we loaded corresponds to a LAN
, for a Race model with three choice alternatives.
This model has three drift parameters v0, v1, v2
, a boundary parameter a
, a starting point bias z
and a non-decision-time t
.
Data from this model has the usual rt, choice
format.
We use this to construct a simple blackbox likelihood function below. This likelihood function takes the respective data and model parameters as arguments.
The function body shapes these input arguments into a matrix and performs a batched forward pass through the loaded network via the onnx.runtime
.
def my_blackbox_race_model(data, v0, v1, v2, a, t, z):
"""Calculate log-likelihood for a 3-choice race model.
Parameters
----------
data : np.ndarray
Array of shape (n_trials, 2) containing response times in first column
and choices (0, 1, or 2) in second column
v0 : float
Drift rate for accumulator 0
v1 : float
Drift rate for accumulator 1
v2 : float
Drift rate for accumulator 2
a : float
Decision threshold/boundary
t : float
Non-decision time
z : float
Starting point bias
Returns
-------
np.ndarray
Array of log-likelihood values for each trial
"""
data_nrows = data.shape[0]
data = np.vstack(
[np.full(data_nrows, param_) for param_ in [v0, v1, v2, a, t, z]]
+ [data[:, 0], data[:, 1]]
).T.astype(np.float32)
return ort_session.run(None, {input_name: data})[0].squeeze()
Simulate example data¶
# Set parameters
v0 = 1.0
v1 = 0.5
v2 = 0.25
a = 1.5
t = 0.3
z = 0.5
# simulate some data from the model
obs_race3 = hssm.simulate_data(
theta=dict(v0=v0, v1=v1, v2=v2, a=a, t=t, z=z), model="race_no_bias_3", size=1000
)
Test Likelihood Outputs¶
# Test that outputs are reasonable
for choice in [0, 1, 2]:
rts = np.linspace(0, 20, 1000)
choices = np.repeat(choice, 1000)
data = np.vstack([rts, choices]).T
out = my_blackbox_race_model(data, v0, v1, v2, a, t, z)
plt.plot(rts, np.exp(out), label=f"choice: {choice}")
plt.legend()
plt.show()
Build HSSM Model¶
We can now build a simple HSSM
model that takes in our new blackbox likelihood.
model = hssm.HSSM(
data=obs_race3,
model="race_no_bias_3", # some name for the model
model_config={
"list_params": ["v0", "v1", "v2", "a", "z", "t"],
"bounds": {
"v0": (0.0, 2.5),
"v1": (0.0, 2.5),
"v2": (0.0, 2.5),
"a": (1.0, 3.0),
"z": (0.0, 0.9),
"t": (0.001, 2),
},
}, # minimal specification of model parameters and parameter bounds
loglik_kind="blackbox", # use the blackbox loglik
loglik=my_blackbox_race_model,
choices=[0, 1, 2], # list the legal choice options
z=0.5,
p_outlier=0,
)
You have specified the `lapse` argument to include a lapse distribution, but `p_outlier` is set to either 0 or None. Your lapse distribution will be ignored. Model initialized successfully.
model.graph()
model.sample(draws=500, tune=200, discard_tuned_samples=False)
Using default initvals.
Multiprocess sampling (4 chains in 4 jobs) CompoundStep >Slice: [v2] >Slice: [t] >Slice: [a] >Slice: [v1] >Slice: [v0]
Output()
Sampling 4 chains for 200 tune and 500 draw iterations (800 + 2_000 draws total) took 28 seconds. The rhat statistic is larger than 1.01 for some parameters. This indicates problems during sampling. See https://arxiv.org/abs/1903.08008 for details The effective sample size per chain is smaller than 100 for some parameters. A higher number is needed for reliable rhat and ess computation. See https://arxiv.org/abs/1903.08008 for details 100%|██████████| 2000/2000 [00:00<00:00, 2250.62it/s]
-
<xarray.Dataset> Size: 84kB Dimensions: (chain: 4, draw: 500) Coordinates: * chain (chain) int64 32B 0 1 2 3 * draw (draw) int64 4kB 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499 Data variables: t (chain, draw) float64 16kB 0.315 0.3216 0.3237 ... 0.3224 0.319 v2 (chain, draw) float64 16kB 0.001899 0.03002 ... 0.06186 0.01676 v0 (chain, draw) float64 16kB 0.7342 0.7491 0.7877 ... 0.7088 0.7291 a (chain, draw) float64 16kB 1.23 1.222 1.235 ... 1.225 1.247 1.233 v1 (chain, draw) float64 16kB 0.3053 0.4118 0.2612 ... 0.3729 0.3705 Attributes: created_at: 2024-12-25T23:24:24.451015+00:00 arviz_version: 0.18.0 inference_library: pymc inference_library_version: 5.16.2 sampling_time: 28.457155227661133 tuning_steps: 200 modeling_interface: bambi modeling_interface_version: 0.14.0
-
<xarray.Dataset> Size: 16MB Dimensions: (chain: 4, draw: 500, __obs__: 1000) Coordinates: * chain (chain) int64 32B 0 1 2 3 * draw (draw) int64 4kB 0 1 2 3 4 5 6 ... 493 494 495 496 497 498 499 * __obs__ (__obs__) int64 8kB 0 1 2 3 4 5 6 ... 994 995 996 997 998 999 Data variables: rt,response (chain, draw, __obs__) float64 16MB 0.2108 -2.806 ... -0.8196 Attributes: modeling_interface: bambi modeling_interface_version: 0.14.0
-
<xarray.Dataset> Size: 164kB Dimensions: (chain: 4, draw: 500, nstep_in_dim_0: 5, nstep_out_dim_0: 5) Coordinates: * chain (chain) int64 32B 0 1 2 3 * draw (draw) int64 4kB 0 1 2 3 4 5 6 ... 494 495 496 497 498 499 * nstep_in_dim_0 (nstep_in_dim_0) int64 40B 0 1 2 3 4 * nstep_out_dim_0 (nstep_out_dim_0) int64 40B 0 1 2 3 4 Data variables: nstep_in (chain, draw, nstep_in_dim_0) int64 80kB 1 2 5 1 ... 6 2 6 nstep_out (chain, draw, nstep_out_dim_0) int64 80kB 1 0 0 1 ... 0 0 1 Attributes: created_at: 2024-12-25T23:24:24.456908+00:00 arviz_version: 0.18.0 inference_library: pymc inference_library_version: 5.16.2 sampling_time: 28.457155227661133 tuning_steps: 200 modeling_interface: bambi modeling_interface_version: 0.14.0
-
<xarray.Dataset> Size: 24kB Dimensions: (__obs__: 1000, rt,response_extra_dim_0: 2) Coordinates: * __obs__ (__obs__) int64 8kB 0 1 2 3 4 ... 996 997 998 999 * rt,response_extra_dim_0 (rt,response_extra_dim_0) int64 16B 0 1 Data variables: rt,response (__obs__, rt,response_extra_dim_0) float64 16kB ... Attributes: created_at: 2024-12-25T23:24:24.458647+00:00 arviz_version: 0.18.0 inference_library: pymc inference_library_version: 5.16.2 modeling_interface: bambi modeling_interface_version: 0.14.0
-
<xarray.Dataset> Size: 6MB Dimensions: (chain: 4, draw: 200, __obs__: 1000) Coordinates: * chain (chain) int64 32B 0 1 2 3 * draw (draw) int64 2kB 0 1 2 3 4 5 6 7 ... 193 194 195 196 197 198 199 * __obs__ (__obs__) int64 8kB 0 1 2 3 4 5 6 7 ... 993 994 995 996 997 998 999 Data variables: a (chain, draw) float64 6kB 3.0 3.0 3.0 3.0 ... 1.283 1.318 1.379 t (chain, draw) float64 6kB 0.04004 0.1003 0.1193 ... 0.3096 0.307 v0 (chain, draw) float64 6kB 2.038 1.805 1.803 ... 0.76 0.9304 0.8776 v0_mean (chain, draw, __obs__) float64 6MB 2.038 2.038 ... 0.8776 0.8776 v1 (chain, draw) float64 6kB 1.305 1.632 1.749 ... 0.3983 0.5868 v2 (chain, draw) float64 6kB 3.007e-60 0.5101 1.829 ... 0.05378 0.269 Attributes: created_at: 2024-12-25T23:24:24.455268+00:00 arviz_version: 0.18.0 inference_library: pymc inference_library_version: 5.16.2 sampling_time: 28.457155227661133 tuning_steps: 200 modeling_interface: bambi modeling_interface_version: 0.14.0
-
<xarray.Dataset> Size: 66kB Dimensions: (chain: 4, draw: 200, nstep_in_dim_0: 5, nstep_out_dim_0: 5) Coordinates: * chain (chain) int64 32B 0 1 2 3 * draw (draw) int64 2kB 0 1 2 3 4 5 6 ... 194 195 196 197 198 199 * nstep_in_dim_0 (nstep_in_dim_0) int64 40B 0 1 2 3 4 * nstep_out_dim_0 (nstep_out_dim_0) int64 40B 0 1 2 3 4 Data variables: nstep_in (chain, draw, nstep_in_dim_0) int64 32kB 0 0 0 2 ... 8 0 3 nstep_out (chain, draw, nstep_out_dim_0) int64 32kB 395 3 786 ... 1 1 Attributes: created_at: 2024-12-25T23:24:24.457926+00:00 arviz_version: 0.18.0 inference_library: pymc inference_library_version: 5.16.2 sampling_time: 28.457155227661133 tuning_steps: 200 modeling_interface: bambi modeling_interface_version: 0.14.0
import arviz as az
az.plot_trace(model.traces, var_names=["~v0_mean"])
plt.tight_layout()
/Users/afengler/miniconda3/envs/hssm516/lib/python3.11/site-packages/arviz/utils.py:142: UserWarning: Items starting with ~: ['v0_mean'] have not been found and will be ignored warnings.warn(
az.plot_pair(model.traces, var_names=["~v0_mean"])
plt.tight_layout()
/Users/afengler/miniconda3/envs/hssm516/lib/python3.11/site-packages/arviz/utils.py:142: UserWarning: Items starting with ~: ['v0_mean'] have not been found and will be ignored warnings.warn(