Attach trial wise parameters to idata post-hoc¶
HSSM automatically cleans up the idata object we return from samplers, to avoid unnecessarily huge objects from being passed around (the decision was taken after having observed many inadvertent out of memory errors).
Instead, if so desired, we can attach the trial wise parameters to the idata object ourselves, using the add_likelihood_parameters_to_idata
function.
This quick tutorial shows you how to do this.
Load Modules¶
import pandas as pd
import hssm
Simulate Data¶
# Condition 1
condition_1 = hssm.simulate_data(
model="ddm", theta=dict(v=0.5, a=1.5, z=0.5, t=0.1), size=500
)
# Condition 2
condition_2 = hssm.simulate_data(
model="ddm", theta=dict(v=1.0, a=1.5, z=0.5, t=0.1), size=500
)
condition_1["condition"] = "C1"
condition_2["condition"] = "C2"
data = pd.concat([condition_1, condition_2]).reset_index(drop=True)
Define Model and Sample¶
model = hssm.HSSM(
model="ddm",
data=data,
include=[
{
"name": "v",
"formula": "v ~ 1 + condition",
}
],
)
idata = model.sample(sampler="mcmc", tune=500, draws=500)
Model initialized successfully. Using default initvals.
Initializing NUTS using adapt_diag... Multiprocess sampling (4 chains in 4 jobs) NUTS: [t, z, a, v_Intercept, v_condition]
Output()
Sampling 4 chains for 500 tune and 500 draw iterations (2_000 + 2_000 draws total) took 17 seconds. The rhat statistic is larger than 1.01 for some parameters. This indicates problems during sampling. See https://arxiv.org/abs/1903.08008 for details 100%|██████████| 2000/2000 [00:01<00:00, 1468.13it/s]
Checking the idata
object¶
idata
-
<xarray.Dataset> Size: 84kB Dimensions: (chain: 4, draw: 500, v_condition_dim: 1) Coordinates: * chain (chain) int64 32B 0 1 2 3 * draw (draw) int64 4kB 0 1 2 3 4 5 6 ... 494 495 496 497 498 499 * v_condition_dim (v_condition_dim) <U2 8B 'C2' Data variables: a (chain, draw) float64 16kB 1.451 1.525 1.459 ... 1.496 1.5 v_condition (chain, draw, v_condition_dim) float64 16kB 0.5384 ... 0... t (chain, draw) float64 16kB 0.13 0.07329 ... 0.0961 0.08384 v_Intercept (chain, draw) float64 16kB 0.5206 0.6879 ... 0.6451 0.6132 z (chain, draw) float64 16kB 0.4694 0.4642 ... 0.4799 0.4654 Attributes: created_at: 2025-09-27T18:39:31.023458+00:00 arviz_version: 0.22.0 inference_library: pymc inference_library_version: 5.25.1 sampling_time: 16.67565417289734 tuning_steps: 500 modeling_interface: bambi modeling_interface_version: 0.15.0
-
<xarray.Dataset> Size: 16MB Dimensions: (chain: 4, draw: 500, __obs__: 1000) Coordinates: * chain (chain) int64 32B 0 1 2 3 * draw (draw) int64 4kB 0 1 2 3 4 5 6 ... 493 494 495 496 497 498 499 * __obs__ (__obs__) int64 8kB 0 1 2 3 4 5 6 ... 994 995 996 997 998 999 Data variables: rt,response (chain, draw, __obs__) float64 16MB -3.858 -1.374 ... -0.9489 Attributes: modeling_interface: bambi modeling_interface_version: 0.15.0
-
<xarray.Dataset> Size: 264kB Dimensions: (chain: 4, draw: 500) Coordinates: * chain (chain) int64 32B 0 1 2 3 * draw (draw) int64 4kB 0 1 2 3 4 5 ... 495 496 497 498 499 Data variables: (12/18) largest_eigval (chain, draw) float64 16kB nan nan nan ... nan nan step_size (chain, draw) float64 16kB 0.5168 0.5168 ... 0.7723 process_time_diff (chain, draw) float64 16kB 0.01719 ... 0.009474 n_steps (chain, draw) float64 16kB 15.0 7.0 7.0 ... 7.0 7.0 divergences (chain, draw) int64 16kB 0 0 0 0 0 0 ... 0 0 0 0 0 0 reached_max_treedepth (chain, draw) bool 2kB False False ... False False ... ... energy_error (chain, draw) float64 16kB 0.3607 -0.593 ... -0.7798 perf_counter_diff (chain, draw) float64 16kB 0.02747 0.01597 ... 0.0147 step_size_bar (chain, draw) float64 16kB 0.4898 0.4898 ... 0.4493 perf_counter_start (chain, draw) float64 16kB 1.499e+06 ... 1.499e+06 tree_depth (chain, draw) int64 16kB 4 3 3 4 3 3 ... 2 3 3 3 3 3 smallest_eigval (chain, draw) float64 16kB nan nan nan ... nan nan Attributes: created_at: 2025-09-27T18:39:31.034727+00:00 arviz_version: 0.22.0 inference_library: pymc inference_library_version: 5.25.1 sampling_time: 16.67565417289734 tuning_steps: 500 modeling_interface: bambi modeling_interface_version: 0.15.0
-
<xarray.Dataset> Size: 24kB Dimensions: (__obs__: 1000, rt,response_extra_dim_0: 2) Coordinates: * __obs__ (__obs__) int64 8kB 0 1 2 3 4 ... 996 997 998 999 * rt,response_extra_dim_0 (rt,response_extra_dim_0) int64 16B 0 1 Data variables: rt,response (__obs__, rt,response_extra_dim_0) float64 16kB ... Attributes: created_at: 2025-09-27T18:39:31.038482+00:00 arviz_version: 0.22.0 inference_library: pymc inference_library_version: 5.25.1 modeling_interface: bambi modeling_interface_version: 0.15.0
Note that none of the parameters attached in the posterior have the observation coordinate (__obs__
). However, looking at the model graph, we clearly see that for variables that are represented as regression targets, we in fact compute intermediate deterministics which do have dimensionality __obs__
.
model.graph()
Computing Trial Wise Parameters¶
We can use the compute_likelihood_parameters_to_idata
function to recompute and include in our idata
the trial wise deterministics, which are part of our model graph.
### Attach Trial Wise Parameters
idata_trialwise = model.add_likelihood_parameters_to_idata(idata)
### Checking the `idata` object
idata_trialwise
-
<xarray.Dataset> Size: 16MB Dimensions: (chain: 4, draw: 500, v_condition_dim: 1, __obs__: 1000) Coordinates: * chain (chain) int64 32B 0 1 2 3 * draw (draw) int64 4kB 0 1 2 3 4 5 6 ... 494 495 496 497 498 499 * v_condition_dim (v_condition_dim) <U2 8B 'C2' * __obs__ (__obs__) int64 8kB 0 1 2 3 4 5 ... 994 995 996 997 998 999 Data variables: a (chain, draw) float64 16kB 1.451 1.525 1.459 ... 1.496 1.5 v_condition (chain, draw, v_condition_dim) float64 16kB 0.5384 ... 0... t (chain, draw) float64 16kB 0.13 0.07329 ... 0.0961 0.08384 v_Intercept (chain, draw) float64 16kB 0.5206 0.6879 ... 0.6451 0.6132 z (chain, draw) float64 16kB 0.4694 0.4642 ... 0.4799 0.4654 v (chain, draw, __obs__) float64 16MB 0.5206 0.5206 ... 1.151 Attributes: created_at: 2025-09-27T18:39:31.023458+00:00 arviz_version: 0.22.0 inference_library: pymc inference_library_version: 5.25.1 sampling_time: 16.67565417289734 tuning_steps: 500 modeling_interface: bambi modeling_interface_version: 0.15.0
-
<xarray.Dataset> Size: 16MB Dimensions: (chain: 4, draw: 500, __obs__: 1000) Coordinates: * chain (chain) int64 32B 0 1 2 3 * draw (draw) int64 4kB 0 1 2 3 4 5 6 ... 493 494 495 496 497 498 499 * __obs__ (__obs__) int64 8kB 0 1 2 3 4 5 6 ... 994 995 996 997 998 999 Data variables: rt,response (chain, draw, __obs__) float64 16MB -3.858 -1.374 ... -0.9489 Attributes: modeling_interface: bambi modeling_interface_version: 0.15.0
-
<xarray.Dataset> Size: 264kB Dimensions: (chain: 4, draw: 500) Coordinates: * chain (chain) int64 32B 0 1 2 3 * draw (draw) int64 4kB 0 1 2 3 4 5 ... 495 496 497 498 499 Data variables: (12/18) largest_eigval (chain, draw) float64 16kB nan nan nan ... nan nan step_size (chain, draw) float64 16kB 0.5168 0.5168 ... 0.7723 process_time_diff (chain, draw) float64 16kB 0.01719 ... 0.009474 n_steps (chain, draw) float64 16kB 15.0 7.0 7.0 ... 7.0 7.0 divergences (chain, draw) int64 16kB 0 0 0 0 0 0 ... 0 0 0 0 0 0 reached_max_treedepth (chain, draw) bool 2kB False False ... False False ... ... energy_error (chain, draw) float64 16kB 0.3607 -0.593 ... -0.7798 perf_counter_diff (chain, draw) float64 16kB 0.02747 0.01597 ... 0.0147 step_size_bar (chain, draw) float64 16kB 0.4898 0.4898 ... 0.4493 perf_counter_start (chain, draw) float64 16kB 1.499e+06 ... 1.499e+06 tree_depth (chain, draw) int64 16kB 4 3 3 4 3 3 ... 2 3 3 3 3 3 smallest_eigval (chain, draw) float64 16kB nan nan nan ... nan nan Attributes: created_at: 2025-09-27T18:39:31.034727+00:00 arviz_version: 0.22.0 inference_library: pymc inference_library_version: 5.25.1 sampling_time: 16.67565417289734 tuning_steps: 500 modeling_interface: bambi modeling_interface_version: 0.15.0
-
<xarray.Dataset> Size: 24kB Dimensions: (__obs__: 1000, rt,response_extra_dim_0: 2) Coordinates: * __obs__ (__obs__) int64 8kB 0 1 2 3 4 ... 996 997 998 999 * rt,response_extra_dim_0 (rt,response_extra_dim_0) int64 16B 0 1 Data variables: rt,response (__obs__, rt,response_extra_dim_0) float64 16kB ... Attributes: created_at: 2025-09-27T18:39:31.038482+00:00 arviz_version: 0.22.0 inference_library: pymc inference_library_version: 5.25.1 modeling_interface: bambi modeling_interface_version: 0.15.0
Note how the idata include the v
parameter which has the __obs__
coordinate. Computing these trial-wise parameters, can
serve us well for plotting and other post-hoc analyses, here we would otherwise struggle to be effective without incurring some pre-computation pain (which in addition would be error prone). See e.g. the tutorial on hierarchical Variational Inference for an example.