Attach trial wise parameters to idata¶
HSSM automatically cleans up the idata object we return from samplers, to avoid unnecessarily huge objects from being passed around (the decision was taken after having observed many inadvertent out of memory errors).
Instead, if so desired, we can attach the trial wise parameters to the idata object ourselves, using the add_likelihood_parameters_to_idata
function.
This quick tutorial shows you how to do this.
Load Modules¶
import pandas as pd
import hssm
Simulate Data¶
# Condition 1
condition_1 = hssm.simulate_data(
model="ddm", theta=dict(v=0.5, a=1.5, z=0.5, t=0.1), size=500
)
# Condition 2
condition_2 = hssm.simulate_data(
model="ddm", theta=dict(v=1.0, a=1.5, z=0.5, t=0.1), size=500
)
condition_1["condition"] = "C1"
condition_2["condition"] = "C2"
data = pd.concat([condition_1, condition_2]).reset_index(drop=True)
Define Model and Sample¶
model = hssm.HSSM(
model="ddm",
data=data,
include=[
{
"name": "v",
"formula": "v ~ 1 + condition",
}
],
)
idata = model.sample(sampler="mcmc", tune=500, draws=500)
Model initialized successfully. Using default initvals.
Initializing NUTS using adapt_diag... Multiprocess sampling (4 chains in 4 jobs) NUTS: [a, t, z, v_Intercept, v_condition]
Output()
Sampling 4 chains for 500 tune and 500 draw iterations (2_000 + 2_000 draws total) took 15 seconds. 100%|██████████| 2000/2000 [00:00<00:00, 2757.59it/s]
Checking the idata
object¶
idata
-
<xarray.Dataset> Size: 84kB Dimensions: (chain: 4, draw: 500, v_condition_dim: 1) Coordinates: * chain (chain) int64 32B 0 1 2 3 * draw (draw) int64 4kB 0 1 2 3 4 5 6 ... 494 495 496 497 498 499 * v_condition_dim (v_condition_dim) <U2 8B 'C2' Data variables: v_Intercept (chain, draw) float64 16kB 0.6073 0.5713 ... 0.5286 0.5562 z (chain, draw) float64 16kB 0.5004 0.5106 ... 0.5281 0.4792 v_condition (chain, draw, v_condition_dim) float64 16kB 0.3881 ... 0... t (chain, draw) float64 16kB 0.1444 0.1364 ... 0.1289 0.1399 a (chain, draw) float64 16kB 1.478 1.499 ... 1.493 1.441 Attributes: created_at: 2025-02-18T22:11:14.073518+00:00 arviz_version: 0.19.0 inference_library: pymc inference_library_version: 5.19.1 sampling_time: 15.156437873840332 tuning_steps: 500 modeling_interface: bambi modeling_interface_version: 0.15.0
-
<xarray.Dataset> Size: 16MB Dimensions: (chain: 4, draw: 500, __obs__: 1000) Coordinates: * chain (chain) int64 32B 0 1 2 3 * draw (draw) int64 4kB 0 1 2 3 4 5 6 ... 493 494 495 496 497 498 499 * __obs__ (__obs__) int64 8kB 0 1 2 3 4 5 6 ... 994 995 996 997 998 999 Data variables: rt,response (chain, draw, __obs__) float64 16MB -0.9346 -4.734 ... -0.6063 Attributes: modeling_interface: bambi modeling_interface_version: 0.15.0
-
<xarray.Dataset> Size: 248kB Dimensions: (chain: 4, draw: 500) Coordinates: * chain (chain) int64 32B 0 1 2 3 * draw (draw) int64 4kB 0 1 2 3 4 5 ... 495 496 497 498 499 Data variables: (12/17) acceptance_rate (chain, draw) float64 16kB 0.9046 0.9813 ... 0.9289 diverging (chain, draw) bool 2kB False False ... False False energy (chain, draw) float64 16kB 1.646e+03 ... 1.647e+03 energy_error (chain, draw) float64 16kB 0.08286 -0.3905 ... 0.187 index_in_trajectory (chain, draw) int64 16kB 1 -1 2 3 3 ... -3 4 -3 -3 6 largest_eigval (chain, draw) float64 16kB nan nan nan ... nan nan ... ... process_time_diff (chain, draw) float64 16kB 0.007986 ... 0.007488 reached_max_treedepth (chain, draw) bool 2kB False False ... False False smallest_eigval (chain, draw) float64 16kB nan nan nan ... nan nan step_size (chain, draw) float64 16kB 0.4256 0.4256 ... 0.3861 step_size_bar (chain, draw) float64 16kB 0.5357 0.5357 ... 0.4755 tree_depth (chain, draw) int64 16kB 3 3 4 4 4 3 ... 4 3 3 3 3 3 Attributes: created_at: 2025-02-18T22:11:14.084242+00:00 arviz_version: 0.19.0 inference_library: pymc inference_library_version: 5.19.1 sampling_time: 15.156437873840332 tuning_steps: 500 modeling_interface: bambi modeling_interface_version: 0.15.0
-
<xarray.Dataset> Size: 24kB Dimensions: (__obs__: 1000, rt,response_extra_dim_0: 2) Coordinates: * __obs__ (__obs__) int64 8kB 0 1 2 3 4 ... 996 997 998 999 * rt,response_extra_dim_0 (rt,response_extra_dim_0) int64 16B 0 1 Data variables: rt,response (__obs__, rt,response_extra_dim_0) float64 16kB ... Attributes: created_at: 2025-02-18T22:11:14.103497+00:00 arviz_version: 0.19.0 inference_library: pymc inference_library_version: 5.19.1 modeling_interface: bambi modeling_interface_version: 0.15.0
Note that none of the parameters attached in the posterior have the observation coordinate (__obs__
). However, looking at the model graph, we clearly see that for variables that are represented as regression targets, we in fact compute intermediate deterministics which do have dimensionality __obs__
.
model.graph()
Computing Trial Wise Parameters¶
We can use the compute_likelihood_parameters_to_idata
function to recompute and include in our idata
the trial wise deterministics, which are part of our model graph.
### Attach Trial Wise Parameters
idata_trialwise = model.add_likelihood_parameters_to_idata(idata)
### Checking the `idata` object
idata_trialwise
-
<xarray.Dataset> Size: 16MB Dimensions: (chain: 4, draw: 500, v_condition_dim: 1, __obs__: 1000) Coordinates: * chain (chain) int64 32B 0 1 2 3 * draw (draw) int64 4kB 0 1 2 3 4 5 6 ... 494 495 496 497 498 499 * v_condition_dim (v_condition_dim) <U2 8B 'C2' * __obs__ (__obs__) int64 8kB 0 1 2 3 4 5 ... 994 995 996 997 998 999 Data variables: v_Intercept (chain, draw) float64 16kB 0.6073 0.5713 ... 0.5286 0.5562 z (chain, draw) float64 16kB 0.5004 0.5106 ... 0.5281 0.4792 v_condition (chain, draw, v_condition_dim) float64 16kB 0.3881 ... 0... t (chain, draw) float64 16kB 0.1444 0.1364 ... 0.1289 0.1399 a (chain, draw) float64 16kB 1.478 1.499 ... 1.493 1.441 v (chain, draw, __obs__) float64 16MB 0.6073 0.6073 ... 1.055 Attributes: created_at: 2025-02-18T22:11:14.073518+00:00 arviz_version: 0.19.0 inference_library: pymc inference_library_version: 5.19.1 sampling_time: 15.156437873840332 tuning_steps: 500 modeling_interface: bambi modeling_interface_version: 0.15.0
-
<xarray.Dataset> Size: 16MB Dimensions: (chain: 4, draw: 500, __obs__: 1000) Coordinates: * chain (chain) int64 32B 0 1 2 3 * draw (draw) int64 4kB 0 1 2 3 4 5 6 ... 493 494 495 496 497 498 499 * __obs__ (__obs__) int64 8kB 0 1 2 3 4 5 6 ... 994 995 996 997 998 999 Data variables: rt,response (chain, draw, __obs__) float64 16MB -0.9346 -4.734 ... -0.6063 Attributes: modeling_interface: bambi modeling_interface_version: 0.15.0
-
<xarray.Dataset> Size: 248kB Dimensions: (chain: 4, draw: 500) Coordinates: * chain (chain) int64 32B 0 1 2 3 * draw (draw) int64 4kB 0 1 2 3 4 5 ... 495 496 497 498 499 Data variables: (12/17) acceptance_rate (chain, draw) float64 16kB 0.9046 0.9813 ... 0.9289 diverging (chain, draw) bool 2kB False False ... False False energy (chain, draw) float64 16kB 1.646e+03 ... 1.647e+03 energy_error (chain, draw) float64 16kB 0.08286 -0.3905 ... 0.187 index_in_trajectory (chain, draw) int64 16kB 1 -1 2 3 3 ... -3 4 -3 -3 6 largest_eigval (chain, draw) float64 16kB nan nan nan ... nan nan ... ... process_time_diff (chain, draw) float64 16kB 0.007986 ... 0.007488 reached_max_treedepth (chain, draw) bool 2kB False False ... False False smallest_eigval (chain, draw) float64 16kB nan nan nan ... nan nan step_size (chain, draw) float64 16kB 0.4256 0.4256 ... 0.3861 step_size_bar (chain, draw) float64 16kB 0.5357 0.5357 ... 0.4755 tree_depth (chain, draw) int64 16kB 3 3 4 4 4 3 ... 4 3 3 3 3 3 Attributes: created_at: 2025-02-18T22:11:14.084242+00:00 arviz_version: 0.19.0 inference_library: pymc inference_library_version: 5.19.1 sampling_time: 15.156437873840332 tuning_steps: 500 modeling_interface: bambi modeling_interface_version: 0.15.0
-
<xarray.Dataset> Size: 24kB Dimensions: (__obs__: 1000, rt,response_extra_dim_0: 2) Coordinates: * __obs__ (__obs__) int64 8kB 0 1 2 3 4 ... 996 997 998 999 * rt,response_extra_dim_0 (rt,response_extra_dim_0) int64 16B 0 1 Data variables: rt,response (__obs__, rt,response_extra_dim_0) float64 16kB ... Attributes: created_at: 2025-02-18T22:11:14.103497+00:00 arviz_version: 0.19.0 inference_library: pymc inference_library_version: 5.19.1 modeling_interface: bambi modeling_interface_version: 0.15.0
Note how the idata include the v
parameter which has the __obs__
coordinate. Computing these trial-wise parameters, can
serve us well for plotting and other post-hoc analyses, here we would otherwise struggle to be effective without incurring some pre-computation pain (which in addition would be error prone). See e.g. the tutorial on hierarchical Variational Inference for an example.