Saving and loading models
Saving and loading models¶
In this short how-to, tutorial, we show how to save a HSSM model instance and its inference results to disk and then re-instantiate the model from the saved files.
Load data and instantiate HSSM model¶
In [1]:
Copied!
import hssm
cav_data = hssm.load_data("cavanagh_theta")
basic_hssm_model = hssm.HSSM(
data=cav_data,
process_initvals=True,
link_settings="log_logit",
model="angle",
include=[
{
"name": "v",
"formula": "v ~ 1 + C(stim)",
}
],
)
import hssm
cav_data = hssm.load_data("cavanagh_theta")
basic_hssm_model = hssm.HSSM(
data=cav_data,
process_initvals=True,
link_settings="log_logit",
model="angle",
include=[
{
"name": "v",
"formula": "v ~ 1 + C(stim)",
}
],
)
Model initialized successfully.
In [2]:
Copied!
basic_hssm_model.sample(sampler="nuts_numpyro",
tune=100,
draws=100,
chains=2)
basic_hssm_model.sample(sampler="nuts_numpyro",
tune=100,
draws=100,
chains=2)
Using default initvals.
/Users/afengler/Library/CloudStorage/OneDrive-Personal/proj_hssm/HSSM/.venv/lib/python3.11/site-packages/pymc/sampling/jax.py:475: UserWarning: There are not enough devices to run parallel chains: expected 2 but got 1. Chains will be drawn sequentially. If you are running MCMC in CPU, consider using `numpyro.set_host_device_count(2)` at the beginning of your program. You can double-check how many devices are available in your system using `jax.local_device_count()`. pmap_numpyro = MCMC( sample: 100%|██████████| 200/200 [01:06<00:00, 3.01it/s, 127 steps of size 3.16e-02. acc. prob=0.96] sample: 100%|██████████| 200/200 [00:57<00:00, 3.47it/s, 127 steps of size 2.55e-02. acc. prob=0.92] We recommend running at least 4 chains for robust computation of convergence diagnostics The rhat statistic is larger than 1.01 for some parameters. This indicates problems during sampling. See https://arxiv.org/abs/1903.08008 for details The effective sample size per chain is smaller than 100 for some parameters. A higher number is needed for reliable rhat and ess computation. See https://arxiv.org/abs/1903.08008 for details 100%|██████████| 200/200 [00:00<00:00, 306.78it/s]
Out[2]:
arviz.InferenceData
-
<xarray.Dataset> Size: 12kB Dimensions: (chain: 2, draw: 100, v_C(stim)_dim: 2) Coordinates: * chain (chain) int64 16B 0 1 * draw (draw) int64 800B 0 1 2 3 4 5 6 7 ... 92 93 94 95 96 97 98 99 * v_C(stim)_dim (v_C(stim)_dim) <U2 16B 'WL' 'WW' Data variables: v_C(stim) (chain, draw, v_C(stim)_dim) float64 3kB 0.2478 ... -0.01796 z (chain, draw) float64 2kB 0.4999 0.5086 ... 0.5136 0.5006 t (chain, draw) float64 2kB 0.2896 0.2853 ... 0.2931 0.2688 a (chain, draw) float64 2kB 1.289 1.317 1.356 ... 1.309 1.351 v_Intercept (chain, draw) float64 2kB 0.1255 0.1369 ... 0.1131 0.1124 theta (chain, draw) float64 2kB 0.2111 0.2201 ... 0.2339 0.2392 Attributes: created_at: 2025-09-27T00:19:51.139489+00:00 arviz_version: 0.22.0 inference_library: numpyro inference_library_version: 0.19.0 sampling_time: 128.046042 tuning_steps: 100 modeling_interface: bambi modeling_interface_version: 0.15.0
-
<xarray.Dataset> Size: 6MB Dimensions: (chain: 2, draw: 100, __obs__: 3988) Coordinates: * chain (chain) int64 16B 0 1 * draw (draw) int64 800B 0 1 2 3 4 5 6 7 8 ... 92 93 94 95 96 97 98 99 * __obs__ (__obs__) int64 32kB 0 1 2 3 4 5 ... 3983 3984 3985 3986 3987 Data variables: rt,response (chain, draw, __obs__) float64 6MB -1.045 -1.203 ... -1.117 Attributes: modeling_interface: bambi modeling_interface_version: 0.15.0
-
<xarray.Dataset> Size: 11kB Dimensions: (chain: 2, draw: 100) Coordinates: * chain (chain) int64 16B 0 1 * draw (draw) int64 800B 0 1 2 3 4 5 6 7 ... 93 94 95 96 97 98 99 Data variables: acceptance_rate (chain, draw) float64 2kB 0.9995 0.943 ... 0.9951 0.9995 step_size (chain, draw) float64 2kB 0.03157 0.03157 ... 0.02553 diverging (chain, draw) bool 200B False False False ... False False energy (chain, draw) float64 2kB 5.939e+03 5.939e+03 ... 5.937e+03 n_steps (chain, draw) int64 2kB 127 63 63 127 ... 191 127 127 127 tree_depth (chain, draw) int64 2kB 7 6 6 7 6 7 6 7 ... 7 7 7 7 8 7 7 7 lp (chain, draw) float64 2kB 5.934e+03 5.934e+03 ... 5.933e+03 Attributes: created_at: 2025-09-27T00:19:51.160945+00:00 arviz_version: 0.22.0 modeling_interface: bambi modeling_interface_version: 0.15.0
-
<xarray.Dataset> Size: 96kB Dimensions: (__obs__: 3988, rt,response_extra_dim_0: 2) Coordinates: * __obs__ (__obs__) int64 32kB 0 1 2 3 ... 3985 3986 3987 * rt,response_extra_dim_0 (rt,response_extra_dim_0) int64 16B 0 1 Data variables: rt,response (__obs__, rt,response_extra_dim_0) float64 64kB ... Attributes: created_at: 2025-09-27T00:19:51.162170+00:00 arviz_version: 0.22.0 inference_library: numpyro inference_library_version: 0.19.0 sampling_time: 128.046042 tuning_steps: 100 modeling_interface: bambi modeling_interface_version: 0.15.0
VI¶
In [3]:
Copied!
basic_hssm_model.vi(method="advi",
niter=5000)
basic_hssm_model.vi(method="advi",
niter=5000)
Using MCMC starting point defaults.
Output()
Finished [100%]: Average Loss = 7,614.1
Out[3]:
arviz.InferenceData
-
<xarray.Dataset> Size: 64kB Dimensions: (chain: 1, draw: 1000, v_C(stim)_dim: 2) Coordinates: * chain (chain) int64 8B 0 * draw (draw) int64 8kB 0 1 2 3 4 5 6 ... 994 995 996 997 998 999 * v_C(stim)_dim (v_C(stim)_dim) <U2 16B 'WL' 'WW' Data variables: v_C(stim) (chain, draw, v_C(stim)_dim) float64 16kB 0.03516 ... -0.2433 z (chain, draw) float64 8kB 0.5211 0.4881 ... 0.5238 0.5404 t (chain, draw) float64 8kB 0.02918 0.07872 ... 0.01165 a (chain, draw) float64 8kB 1.53 1.772 1.796 ... 1.631 1.963 v_Intercept (chain, draw) float64 8kB 0.7442 0.1945 ... 0.09171 0.211 theta (chain, draw) float64 8kB 0.4323 0.2049 ... 0.1619 0.6074 Attributes: created_at: 2025-09-27T00:20:42.669194+00:00 arviz_version: 0.22.0 inference_library: pymc inference_library_version: 5.25.1
-
<xarray.Dataset> Size: 96kB Dimensions: (__obs__: 3988, rt,response_extra_dim_0: 2) Coordinates: * __obs__ (__obs__) int64 32kB 0 1 2 3 ... 3985 3986 3987 * rt,response_extra_dim_0 (rt,response_extra_dim_0) int64 16B 0 1 Data variables: rt,response (__obs__, rt,response_extra_dim_0) float64 64kB ... Attributes: created_at: 2025-09-27T00:20:42.675659+00:00 arviz_version: 0.22.0 inference_library: pymc inference_library_version: 5.25.1
Saving and Loading the model¶
In [4]:
Copied!
basic_hssm_model.save_model(model_name="test_model")
basic_hssm_model.save_model(model_name="test_model")
We are using the defaults here, which save the model and its inference results to the hssm_models/test_model/
directory inside your curerent working directory.
Up to three files are saved in the model directory:
model.pkl
: The model instance.traces.nc
: The MCMC traces.vi_traces.nc
: The VI traces.
We can now load the model from the directory we just created, using the HSSM
classmethod load_model
.
In [5]:
Copied!
loaded_model = hssm.HSSM.load_model(path="hssm_models/test_model")
loaded_model = hssm.HSSM.load_model(path="hssm_models/test_model")
Model initialized successfully.
With this simple workflow your models are portable across sessions and machines.