Saving and loading models
Saving and loading models¶
In this short how-to, tutorial, we show how to save a HSSM model instance and its inference results to disk and then re-instantiate the model from the saved files.
Load data and instantiate HSSM model¶
In [1]:
Copied!
import hssm
cav_data = hssm.load_data("cavanagh_theta")
basic_hssm_model = hssm.HSSM(
data=cav_data,
process_initvals=True,
link_settings="log_logit",
model="angle",
include=[
{
"name": "v",
"formula": "v ~ 1 + C(stim)",
}
],
)
import hssm
cav_data = hssm.load_data("cavanagh_theta")
basic_hssm_model = hssm.HSSM(
data=cav_data,
process_initvals=True,
link_settings="log_logit",
model="angle",
include=[
{
"name": "v",
"formula": "v ~ 1 + C(stim)",
}
],
)
Model initialized successfully.
In [2]:
Copied!
basic_hssm_model.sample(sampler="nuts_numpyro", tune=100, draws=100, chains=2)
basic_hssm_model.sample(sampler="nuts_numpyro", tune=100, draws=100, chains=2)
Using default initvals.
0%| | 0/200 [00:00<?, ?it/s]
0%| | 0/200 [00:00<?, ?it/s]
We recommend running at least 4 chains for robust computation of convergence diagnostics The rhat statistic is larger than 1.01 for some parameters. This indicates problems during sampling. See https://arxiv.org/abs/1903.08008 for details The effective sample size per chain is smaller than 100 for some parameters. A higher number is needed for reliable rhat and ess computation. See https://arxiv.org/abs/1903.08008 for details 100%|██████████| 200/200 [00:00<00:00, 231.02it/s]
Out[2]:
arviz.InferenceData
-
- chain: 2
- draw: 100
- v_C(stim)_dim: 2
- chain(chain)int640 1
array([0, 1])
- draw(draw)int640 1 2 3 4 5 6 ... 94 95 96 97 98 99
array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99])
- v_C(stim)_dim(v_C(stim)_dim)<U2'WL' 'WW'
array(['WL', 'WW'], dtype='<U2')
- t(chain, draw)float640.2711 0.2703 ... 0.2638 0.2892
array([[0.27107433, 0.27029438, 0.2832146 , 0.2901967 , 0.29418918, 0.28743502, 0.27996954, 0.28626217, 0.28420574, 0.30303111, 0.26125505, 0.26017182, 0.28960593, 0.26838368, 0.27773932, 0.29665549, 0.2885844 , 0.28934031, 0.30054777, 0.28753019, 0.28376193, 0.27811781, 0.26809868, 0.27351286, 0.28327434, 0.2755852 , 0.27914245, 0.2664366 , 0.29023629, 0.27812718, 0.27872301, 0.28018269, 0.279741 , 0.2812447 , 0.2668713 , 0.28133441, 0.29327071, 0.2922234 , 0.27536024, 0.27812307, 0.28392304, 0.27945348, 0.27505756, 0.28117276, 0.29305754, 0.27569163, 0.28009211, 0.28663539, 0.27545264, 0.28025096, 0.27549457, 0.26515561, 0.27995804, 0.27094257, 0.27707954, 0.2792755 , 0.27866942, 0.28226554, 0.2896466 , 0.2454507 , 0.28831758, 0.25358915, 0.30412865, 0.29021417, 0.28321212, 0.27679912, 0.27005335, 0.28576492, 0.27835835, 0.28973552, 0.30071773, 0.28164819, 0.26972696, 0.26942888, 0.27865543, 0.28191744, 0.28200271, 0.27821402, 0.28718703, 0.29133902, 0.28137662, 0.27129236, 0.28176172, 0.28001904, 0.29531748, 0.26946591, 0.27228032, 0.28423554, 0.27822863, 0.27578613, 0.26845696, 0.2828685 , 0.27273301, 0.28024895, 0.2743769 , 0.28102404, 0.28602787, 0.27135638, 0.2916716 , 0.28336642], [0.27385423, 0.26145553, 0.27030703, 0.29484523, 0.27979606, 0.26385052, 0.28278135, 0.28036129, 0.28483753, 0.27928952, 0.27870878, 0.278842 , 0.29497034, 0.26642618, 0.28000718, 0.2873669 , 0.28179292, 0.27373098, 0.27069809, 0.28286612, 0.27151674, 0.283891 , 0.28823164, 0.28770927, 0.26778838, 0.29664465, 0.2963411 , 0.27123195, 0.27065687, 0.28563783, 0.26154178, 0.28687948, 0.27522787, 0.27930688, 0.29583159, 0.25996636, 0.27459415, 0.28668246, 0.27080566, 0.27795999, 0.27884568, 0.28213364, 0.28711017, 0.28080601, 0.28612666, 0.27489832, 0.27636647, 0.29316558, 0.28631953, 0.27890935, 0.27162946, 0.28793933, 0.26968605, 0.27569139, 0.28106481, 0.28899423, 0.29393693, 0.28483958, 0.27751075, 0.28435202, 0.27822199, 0.28435569, 0.27938488, 0.28922568, 0.28600568, 0.27350628, 0.28289071, 0.29567248, 0.27543813, 0.28268814, 0.28659235, 0.27137242, 0.27575287, 0.29430378, 0.26058572, 0.28243789, 0.27372694, 0.29772807, 0.26357498, 0.28969494, 0.28302545, 0.29062946, 0.299872 , 0.28396631, 0.26711829, 0.26073585, 0.29229349, 0.27397549, 0.28019158, 0.28144419, 0.27577008, 0.29895981, 0.2671145 , 0.26860885, 0.28525725, 0.2759356 , 0.2775591 , 0.28914294, 0.26375489, 0.2892129 ]])
- theta(chain, draw)float640.2433 0.2511 ... 0.2527 0.2255
array([[0.24326231, 0.25108288, 0.24520421, 0.22226852, 0.22409109, 0.24102553, 0.23827643, 0.2139873 , 0.21471743, 0.21274798, 0.26078772, 0.26222369, 0.20862929, 0.25866512, 0.24747942, 0.21352012, 0.23757933, 0.23271411, 0.21129133, 0.21324598, 0.22210079, 0.22581467, 0.2401381 , 0.23631462, 0.22712897, 0.22930927, 0.24828476, 0.25696909, 0.21040534, 0.23130551, 0.23920285, 0.23511095, 0.21682949, 0.23354963, 0.23104936, 0.24097189, 0.20921452, 0.22137949, 0.24111417, 0.24067178, 0.22698348, 0.24005611, 0.23898533, 0.21921629, 0.22619821, 0.23050565, 0.23843234, 0.22887811, 0.23112954, 0.23570556, 0.23144632, 0.23262298, 0.2113889 , 0.26356316, 0.25225683, 0.21455276, 0.2486544 , 0.23665616, 0.18989878, 0.27443814, 0.18400961, 0.27442673, 0.19097128, 0.22692891, 0.22156131, 0.23103024, 0.24444693, 0.23024027, 0.22305685, 0.2102469 , 0.21451411, 0.20829359, 0.26251649, 0.26299788, 0.23145355, 0.23078978, 0.23891088, 0.24055281, 0.23063526, 0.2130672 , 0.2178522 , 0.24769354, 0.23418385, 0.22328483, 0.22120176, 0.24040198, 0.24108069, 0.22441637, 0.24207556, 0.23113153, 0.23886472, 0.23299047, 0.23887358, 0.23411533, 0.23964761, 0.24387568, 0.22690838, 0.22792578, 0.21587436, 0.22217075], [0.25083832, 0.24051618, 0.24282962, 0.22800486, 0.21340639, 0.25361613, 0.2186827 , 0.24276368, 0.23036224, 0.2320159 , 0.23305896, 0.25100281, 0.22968466, 0.2364107 , 0.21104515, 0.23604528, 0.24611337, 0.22386109, 0.2381144 , 0.21850567, 0.24523405, 0.23487622, 0.20689155, 0.20395982, 0.23663132, 0.20503275, 0.19745004, 0.2326277 , 0.25170947, 0.2068546 , 0.23830682, 0.23401045, 0.22222293, 0.24606036, 0.207173 , 0.2514214 , 0.23695367, 0.23020819, 0.23888104, 0.22347542, 0.23591359, 0.22951583, 0.22934348, 0.22446312, 0.22410901, 0.24963469, 0.23942655, 0.21075319, 0.23976424, 0.24207611, 0.23289941, 0.23441103, 0.22758847, 0.25593879, 0.21461821, 0.22213855, 0.22124086, 0.21492664, 0.23979015, 0.22472325, 0.2405338 , 0.23378941, 0.22639936, 0.22871079, 0.22321269, 0.23779912, 0.22275692, 0.23151993, 0.21840806, 0.24998311, 0.23645753, 0.23456594, 0.24455924, 0.21879733, 0.24354819, 0.24448543, 0.22325138, 0.21654661, 0.24614132, 0.21374432, 0.23753798, 0.20230619, 0.20428586, 0.22046114, 0.23676915, 0.23168919, 0.23436279, 0.22197593, 0.21215309, 0.22568696, 0.23314946, 0.22457078, 0.24186484, 0.24022983, 0.23755283, 0.23249121, 0.22738086, 0.21879369, 0.252673 , 0.22548315]])
- v_C(stim)(chain, draw, v_C(stim)_dim)float640.2256 -0.05172 ... 0.2822 0.001279
array([[[ 2.25621624e-01, -5.17180488e-02], [ 2.77226353e-01, -5.66300511e-03], [ 2.80631077e-01, -8.76033221e-03], [ 2.67181089e-01, -7.28834229e-03], [ 2.90328961e-01, 2.14947897e-02], [ 2.68153626e-01, -2.79809255e-02], [ 2.67289268e-01, -3.02241037e-02], [ 2.44819700e-01, -3.91618488e-02], [ 2.43083017e-01, -3.47153787e-02], [ 2.42472096e-01, -3.42822766e-02], [ 2.37460788e-01, -2.80529602e-02], [ 3.29892368e-01, 2.40139294e-02], [ 3.25541890e-01, 3.99577254e-02], [ 3.10731145e-01, 6.13059688e-02], [ 2.66210865e-01, -4.90242746e-02], [ 2.69714562e-01, -4.18405337e-02], [ 2.76959417e-01, -1.49384802e-02], [ 3.02032885e-01, -2.39235013e-02], [ 2.87567541e-01, -1.19668983e-02], [ 2.80171363e-01, -4.10256644e-03], ... [ 3.09097881e-01, 1.35430337e-02], [ 3.11690470e-01, -9.18425762e-03], [ 3.24113534e-01, -3.90581158e-03], [ 3.11766829e-01, -4.77989390e-03], [ 2.23778430e-01, -2.38967181e-02], [ 2.21122437e-01, -2.13574908e-02], [ 2.24842992e-01, -2.33652957e-02], [ 2.38847242e-01, -2.70190978e-02], [ 2.80191290e-01, -1.85032185e-02], [ 2.78506733e-01, 1.38331991e-02], [ 2.83188824e-01, -7.44509541e-03], [ 2.80086615e-01, -3.97430880e-03], [ 2.75295176e-01, -7.15222595e-03], [ 2.97367672e-01, -2.14219806e-02], [ 3.02201094e-01, -1.47887158e-02], [ 2.52715997e-01, 6.96394272e-04], [ 2.48452039e-01, -7.76535408e-02], [ 2.73372337e-01, -6.27180534e-02], [ 2.70914545e-01, -1.39730092e-02], [ 2.82234712e-01, 1.27924595e-03]]])
- v_Intercept(chain, draw)float640.1445 0.115 ... 0.1336 0.1132
array([[0.14445889, 0.11495637, 0.10390673, 0.10645612, 0.10319398, 0.12976907, 0.12887687, 0.13838062, 0.13736157, 0.1378738 , 0.1287564 , 0.0813565 , 0.06852001, 0.09064687, 0.136256 , 0.12328137, 0.11143886, 0.10820369, 0.11458064, 0.10668375, 0.11290865, 0.11545381, 0.10893473, 0.10941432, 0.12873778, 0.12007283, 0.07763399, 0.082075 , 0.09633114, 0.08152645, 0.13122036, 0.13635973, 0.13133235, 0.13397593, 0.13176219, 0.12910837, 0.12222339, 0.12553843, 0.12181786, 0.12251668, 0.11937316, 0.15172698, 0.13880723, 0.12965287, 0.13321058, 0.10918676, 0.10644825, 0.09998687, 0.13053876, 0.13679364, 0.09338833, 0.0788464 , 0.07537973, 0.07791955, 0.07541671, 0.12194531, 0.09957608, 0.1100147 , 0.10715917, 0.14345168, 0.1457646 , 0.14614234, 0.13755922, 0.14169752, 0.10244016, 0.15717165, 0.16069633, 0.15157075, 0.14176241, 0.14228769, 0.09762434, 0.09640278, 0.12453993, 0.13166668, 0.08526617, 0.08977704, 0.08839306, 0.0850433 , 0.0950155 , 0.10228936, 0.09089444, 0.08213184, 0.10906814, 0.07478675, 0.12502478, 0.10960928, 0.09187282, 0.09667762, 0.13845287, 0.1083912 , 0.09512186, 0.10881827, 0.12718301, 0.12477 , 0.12111805, 0.10578294, 0.13578205, 0.15495656, 0.12421762, 0.08442427], [0.14291239, 0.14199344, 0.16571233, 0.12632287, 0.11194768, 0.14044154, 0.12246514, 0.11084241, 0.15832311, 0.11761963, 0.14084894, 0.13976803, 0.09506751, 0.12204623, 0.11717043, 0.11770169, 0.1283094 , 0.12199761, 0.13445646, 0.08646085, 0.13614914, 0.10949302, 0.13730921, 0.13997038, 0.15269386, 0.14947184, 0.12727879, 0.14302009, 0.13658027, 0.15809282, 0.09333731, 0.08397798, 0.10534866, 0.13528635, 0.14889791, 0.14647377, 0.1452529 , 0.09218862, 0.11259906, 0.12892639, 0.10298518, 0.13104238, 0.13284734, 0.11092865, 0.11664429, 0.09715926, 0.08318035, 0.11955805, 0.12223321, 0.07710802, 0.09407859, 0.08444269, 0.11851171, 0.10920673, 0.12743697, 0.12642685, 0.16179613, 0.16242633, 0.14191151, 0.14521505, 0.10722575, 0.14009479, 0.13887257, 0.13928871, 0.13342271, 0.12281351, 0.09605767, 0.08999236, 0.10469688, 0.12027525, 0.12093373, 0.0929129 , 0.09036851, 0.10332017, 0.09690416, 0.12694735, 0.1429124 , 0.14029999, 0.10407244, 0.12107459, 0.08369347, 0.09902108, 0.09148154, 0.10233275, 0.1489298 , 0.1530843 , 0.15014374, 0.13339342, 0.12130766, 0.13637212, 0.09791948, 0.08940045, 0.12617092, 0.11308577, 0.10523354, 0.11775793, 0.18327342, 0.13677999, 0.13356756, 0.11317643]])
- z(chain, draw)float640.5107 0.4994 ... 0.5009 0.4965
array([[0.51065735, 0.4993735 , 0.50102366, 0.51150646, 0.50021651, 0.50190475, 0.50230881, 0.50559029, 0.5112306 , 0.51121477, 0.50573208, 0.50336444, 0.49650004, 0.50701308, 0.51145828, 0.49659364, 0.51109127, 0.50501488, 0.50914701, 0.50159071, 0.50269476, 0.50251191, 0.49946842, 0.49988202, 0.50576136, 0.49133072, 0.50655277, 0.50989417, 0.50308462, 0.5083536 , 0.5015402 , 0.50434583, 0.50401537, 0.50080944, 0.4950588 , 0.50829142, 0.5065353 , 0.5070418 , 0.49631281, 0.4952776 , 0.49619004, 0.5026105 , 0.50069855, 0.49894737, 0.50179394, 0.50110943, 0.50548398, 0.50350737, 0.50770273, 0.49916883, 0.50896302, 0.50202419, 0.50228245, 0.50308678, 0.50061701, 0.50729071, 0.50459195, 0.4955013 , 0.49447706, 0.50393839, 0.49987632, 0.50195598, 0.50056829, 0.50152323, 0.50514492, 0.49940621, 0.50074894, 0.49868778, 0.50861785, 0.50382654, 0.50870257, 0.50391551, 0.50091736, 0.49747485, 0.50784746, 0.51624101, 0.51272022, 0.51389624, 0.50446087, 0.50362234, 0.5052527 , 0.50116323, 0.49849024, 0.49832123, 0.49748354, 0.4918572 , 0.50578943, 0.49387511, 0.50031346, 0.4974091 , 0.50966194, 0.49628973, 0.50384248, 0.49740569, 0.50589278, 0.50496312, 0.49955055, 0.50382603, 0.48953624, 0.51068176], [0.50472822, 0.49743556, 0.50435059, 0.50593525, 0.5077958 , 0.50383477, 0.50550959, 0.50501108, 0.49367094, 0.50663071, 0.49530401, 0.49454011, 0.51154719, 0.50596864, 0.50347164, 0.50126582, 0.50161416, 0.49474621, 0.50591167, 0.51062761, 0.50619918, 0.5019218 , 0.49277697, 0.49245733, 0.48998766, 0.49752168, 0.49850105, 0.5002423 , 0.49807924, 0.48856586, 0.5011659 , 0.50010295, 0.505047 , 0.4887176 , 0.48942586, 0.4879506 , 0.4896561 , 0.50383715, 0.50281374, 0.49595321, 0.50363288, 0.50642851, 0.50646652, 0.50641454, 0.4983593 , 0.51003936, 0.50774964, 0.507557 , 0.50708948, 0.50332925, 0.50719276, 0.50924078, 0.4965204 , 0.50020322, 0.50365238, 0.5043114 , 0.51000765, 0.50585036, 0.49988897, 0.50311012, 0.50410023, 0.50350011, 0.50065713, 0.50356668, 0.50348161, 0.49794366, 0.50884017, 0.50599871, 0.50815505, 0.49832054, 0.50557891, 0.49950395, 0.50085926, 0.50661819, 0.49908784, 0.50110731, 0.49472319, 0.49701242, 0.5027631 , 0.49271289, 0.5059427 , 0.50283134, 0.50678268, 0.50439838, 0.50033102, 0.5013627 , 0.50454771, 0.50256232, 0.49979031, 0.50232137, 0.51512999, 0.51709763, 0.50067297, 0.50474126, 0.50329987, 0.50186357, 0.4913956 , 0.4955777 , 0.50086147, 0.49652004]])
- a(chain, draw)float641.34 1.358 1.332 ... 1.361 1.302
array([[1.34020558, 1.3578405 , 1.33174794, 1.30921126, 1.29016221, 1.31554376, 1.33749882, 1.29898855, 1.29205485, 1.29853132, 1.36241376, 1.37204679, 1.29346221, 1.35763441, 1.35855461, 1.28917712, 1.31345467, 1.31079681, 1.27818375, 1.29826756, 1.31181578, 1.31293082, 1.32858935, 1.33211557, 1.31954561, 1.32629464, 1.34082685, 1.36857799, 1.29265281, 1.32511263, 1.32809752, 1.32188858, 1.31118073, 1.3310306 , 1.33209448, 1.33157276, 1.28793252, 1.30389595, 1.33943003, 1.33368758, 1.30773812, 1.32894184, 1.33802936, 1.31284697, 1.29952435, 1.32529523, 1.34186526, 1.31215315, 1.32375454, 1.3136819 , 1.33239339, 1.34217624, 1.29749341, 1.36157549, 1.34030452, 1.31669137, 1.33409432, 1.3228084 , 1.27856956, 1.41532111, 1.27027738, 1.39885315, 1.25953528, 1.30819959, 1.30764156, 1.33030521, 1.35278741, 1.30594649, 1.32336035, 1.29512239, 1.2848085 , 1.29075963, 1.3666829 , 1.3631942 , 1.32615375, 1.33055398, 1.32967378, 1.33848426, 1.31784465, 1.28957518, 1.30642639, 1.35431796, 1.32076859, 1.31971671, 1.295198 , 1.33702673, 1.33556214, 1.31634636, 1.33289613, 1.30773235, 1.33226868, 1.33019979, 1.33604558, 1.32309552, 1.33353876, 1.33928614, 1.31465945, 1.31855593, 1.29889572, 1.31296836], [1.34305363, 1.34169396, 1.35024912, 1.29764871, 1.30180641, 1.36745314, 1.30225779, 1.329338 , 1.31919497, 1.31894255, 1.33640241, 1.33709815, 1.30838295, 1.33653608, 1.29035083, 1.29921625, 1.34740091, 1.32357704, 1.35270266, 1.30378155, 1.35033058, 1.33653647, 1.28288725, 1.28150751, 1.33626383, 1.28090419, 1.2731372 , 1.33513503, 1.35618979, 1.29361696, 1.34146123, 1.31665137, 1.31825678, 1.33281703, 1.27719811, 1.35397198, 1.3263708 , 1.31514612, 1.33744946, 1.30795035, 1.34118151, 1.31377068, 1.33073071, 1.31866493, 1.29910158, 1.34680737, 1.34033487, 1.29174211, 1.33445132, 1.32909289, 1.33928188, 1.32641507, 1.33716359, 1.3540698 , 1.30265748, 1.30763608, 1.28880116, 1.30218966, 1.32586582, 1.30653833, 1.33304809, 1.33142593, 1.33238903, 1.30066092, 1.30085109, 1.32628459, 1.31083683, 1.29346632, 1.30748869, 1.3396784 , 1.31911918, 1.32489296, 1.3358631 , 1.29607741, 1.35451063, 1.32626218, 1.32024982, 1.30474344, 1.36875277, 1.2966414 , 1.31947489, 1.29109576, 1.27123677, 1.3042758 , 1.33778923, 1.33658104, 1.30819315, 1.30544666, 1.30732805, 1.31357099, 1.33524215, 1.30089052, 1.3415367 , 1.35300139, 1.33055653, 1.33274716, 1.31769501, 1.29857765, 1.36053203, 1.30210376]])
- chainPandasIndex
PandasIndex(Index([0, 1], dtype='int64', name='chain'))
- drawPandasIndex
PandasIndex(Index([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99], dtype='int64', name='draw'))
- v_C(stim)_dimPandasIndex
PandasIndex(Index(['WL', 'WW'], dtype='object', name='v_C(stim)_dim'))
- created_at :
- 2025-07-13T13:18:09.812288+00:00
- arviz_version :
- 0.19.0
- inference_library :
- numpyro
- inference_library_version :
- 0.16.1
- sampling_time :
- 83.794966
- tuning_steps :
- 100
- modeling_interface :
- bambi
- modeling_interface_version :
- 0.15.0
<xarray.Dataset> Size: 12kB Dimensions: (chain: 2, draw: 100, v_C(stim)_dim: 2) Coordinates: * chain (chain) int64 16B 0 1 * draw (draw) int64 800B 0 1 2 3 4 5 6 7 ... 92 93 94 95 96 97 98 99 * v_C(stim)_dim (v_C(stim)_dim) <U2 16B 'WL' 'WW' Data variables: t (chain, draw) float64 2kB 0.2711 0.2703 ... 0.2638 0.2892 theta (chain, draw) float64 2kB 0.2433 0.2511 ... 0.2527 0.2255 v_C(stim) (chain, draw, v_C(stim)_dim) float64 3kB 0.2256 ... 0.001279 v_Intercept (chain, draw) float64 2kB 0.1445 0.115 ... 0.1336 0.1132 z (chain, draw) float64 2kB 0.5107 0.4994 ... 0.5009 0.4965 a (chain, draw) float64 2kB 1.34 1.358 1.332 ... 1.361 1.302 Attributes: created_at: 2025-07-13T13:18:09.812288+00:00 arviz_version: 0.19.0 inference_library: numpyro inference_library_version: 0.16.1 sampling_time: 83.794966 tuning_steps: 100 modeling_interface: bambi modeling_interface_version: 0.15.0
xarray.Dataset -
- chain: 2
- draw: 100
- __obs__: 3988
- chain(chain)int640 1
array([0, 1])
- draw(draw)int640 1 2 3 4 5 6 ... 94 95 96 97 98 99
array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99])
- __obs__(__obs__)int640 1 2 3 4 ... 3984 3985 3986 3987
array([ 0, 1, 2, ..., 3985, 3986, 3987])
- rt,response(chain, draw, __obs__)float64-1.004 -1.192 ... -2.48 -1.082
array([[[-1.00383554, -1.19237947, -0.97109316, ..., -0.49733245, -2.51344968, -1.10439779], [-1.04921556, -1.15770605, -0.96816302, ..., -0.51056633, -2.42796943, -1.08421762], [-1.05462605, -1.17061534, -0.97124517, ..., -0.49728737, -2.43903022, -1.09556964], ..., [-1.00532895, -1.1965098 , -0.90624487, ..., -0.48583125, -2.55267388, -1.04807329], [-1.05124196, -1.1692354 , -0.95839909, ..., -0.49818311, -2.4874525 , -1.08425979], [-1.09353168, -1.20491518, -0.97937904, ..., -0.50161596, -2.41720084, -1.11931524]], [[-1.00092735, -1.18192965, -0.99418865, ..., -0.51464644, -2.50740027, -1.11852317], [-1.02580215, -1.18885879, -1.02455323, ..., -0.54277579, -2.50579358, -1.14457865], [-0.98213414, -1.16868668, -0.99081498, ..., -0.49320113, -2.53095569, -1.11263361], ..., [-1.02751257, -1.17754695, -1.01911607, ..., -0.46967638, -2.52054876, -1.14740975], [-1.02496471, -1.16046655, -0.95203682, ..., -0.48938838, -2.46735067, -1.07205292], [-1.05403939, -1.18073235, -0.94929562, ..., -0.48380274, -2.4800749 , -1.08206233]]])
- chainPandasIndex
PandasIndex(Index([0, 1], dtype='int64', name='chain'))
- drawPandasIndex
PandasIndex(Index([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99], dtype='int64', name='draw'))
- __obs__PandasIndex
PandasIndex(Index([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, ... 3978, 3979, 3980, 3981, 3982, 3983, 3984, 3985, 3986, 3987], dtype='int64', name='__obs__', length=3988))
- modeling_interface :
- bambi
- modeling_interface_version :
- 0.15.0
<xarray.Dataset> Size: 6MB Dimensions: (chain: 2, draw: 100, __obs__: 3988) Coordinates: * chain (chain) int64 16B 0 1 * draw (draw) int64 800B 0 1 2 3 4 5 6 7 8 ... 92 93 94 95 96 97 98 99 * __obs__ (__obs__) int64 32kB 0 1 2 3 4 5 ... 3983 3984 3985 3986 3987 Data variables: rt,response (chain, draw, __obs__) float64 6MB -1.004 -1.192 ... -1.082 Attributes: modeling_interface: bambi modeling_interface_version: 0.15.0
xarray.Dataset -
- chain: 2
- draw: 100
- chain(chain)int640 1
array([0, 1])
- draw(draw)int640 1 2 3 4 5 6 ... 94 95 96 97 98 99
array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99])
- acceptance_rate(chain, draw)float640.9996 0.998 ... 0.9305 0.9841
array([[0.99959147, 0.99795927, 0.98128387, 0.99990002, 0.89772796, 0.99820462, 0.69096141, 0.98252748, 0.82411439, 0.84435923, 1. , 0.99496124, 0.99713283, 0.99883306, 0.95808028, 0.86545516, 0.99273571, 0.99588769, 0.99248056, 0.99954616, 0.88297682, 0.87376695, 0.87125679, 0.98185738, 0.99151073, 0.97886829, 0.99972904, 0.99490728, 0.98300868, 0.99268945, 0.97974793, 0.94736616, 0.9820964 , 0.97874003, 0.98696342, 0.6718646 , 0.82681039, 0.96042944, 1. , 0.99916626, 0.96916378, 0.95031962, 0.9767269 , 0.99706192, 0.95787124, 0.98757417, 0.80047925, 0.99950497, 0.95937149, 0.92286035, 0.89578117, 0.95451686, 0.99023633, 0.99912501, 0.9958187 , 0.99041939, 0.99745063, 0.99066778, 0.95486248, 0.99985097, 0.97104634, 0.99646582, 0.96656612, 0.99951417, 1. , 0.99840882, 0.9858132 , 0.99407668, 0.99830878, 0.87461291, 0.90403639, 0.55900948, 0.99932487, 0.98088514, 0.99155187, 0.98262407, 0.98775487, 0.9710508 , 0.86503979, 0.99807946, 0.85914877, 0.95289078, 0.97450287, 0.99798464, 0.9748548 , 0.90479035, 0.87323156, 0.99847522, 0.85307651, 0.665544 , 0.9908473 , 0.9900854 , 0.99277815, 0.98314964, 0.79287441, 0.91812067, 0.90782326, 0.83585227, 0.99807589, 0.99192691], [0.94613476, 0.84364401, 0.98449648, 0.88101649, 0.92174276, 0.97481159, 0.93914003, 0.99755211, 0.92536462, 0.87938716, 0.53937867, 0.93100132, 0.99755173, 0.99487279, 0.62409033, 0.84006324, 0.99791154, 0.9971409 , 0.83685624, 0.99976752, 0.99934781, 0.72477762, 0.76509426, 0.7866193 , 0.97696516, 0.9599949 , 0.96704023, 0.3276338 , 0.98513049, 0.98718041, 0.81764468, 0.97982925, 0.97547751, 0.98366708, 0.99867072, 0.92089713, 0.87504963, 0.85477488, 0.6405157 , 0.9665819 , 0.76501764, 0.99991442, 0.53255452, 0.99715789, 0.86139972, 0.91035728, 0.95838916, 0.98185482, 0.84757498, 0.99694334, 0.96593018, 0.79849835, 0.83519781, 0.99920369, 0.99416523, 0.96792893, 0.76235043, 0.92844523, 0.86907129, 0.98368454, 0.98959092, 0.52130253, 0.29058248, 0.97918514, 0.75978766, 0.93505179, 0.84131432, 0.68285193, 1. , 0.97858671, 1. , 0.96114213, 0.9457233 , 0.98730004, 0.92809701, 0.9879835 , 0.9342146 , 0.63490146, 1. , 0.99550177, 0.98318352, 0.84240336, 0.99407547, 0.89670003, 0.87319545, 0.74596724, 0.87654001, 0.86590994, 0.98441648, 0.97419152, 0.67424353, 0.97801081, 0.71494229, 0.96177558, 0.99950984, 0.99733621, 0.69160179, 0.999468 , 0.93050476, 0.98408588]])
- diverging(chain, draw)boolFalse False False ... False False
array([[False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False], [False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False]])
- energy(chain, draw)float645.935e+03 5.936e+03 ... 5.935e+03
array([[5935.25856435, 5935.92292166, 5935.69835759, 5935.9691483 , 5937.59212981, 5936.04822511, 5937.89087963, 5936.3201261 , 5937.16824666, 5938.52058111, 5938.78767679, 5938.28886335, 5940.57457151, 5942.35346951, 5942.49231745, 5940.18725553, 5939.55718039, 5936.51390978, 5937.83654695, 5939.78066127, 5934.25755985, 5933.25360226, 5936.6656297 , 5934.2172741 , 5933.35054323, 5934.83524693, 5939.40090939, 5938.07391918, 5935.88925559, 5938.33750835, 5935.29851102, 5934.76948955, 5940.08222261, 5935.1381491 , 5933.92982423, 5941.69823863, 5937.40761084, 5935.51500364, 5933.86448202, 5935.97107339, 5934.74884566, 5937.70361736, 5934.05574828, 5936.53777065, 5935.23510307, 5936.28630986, 5939.27457279, 5933.49529997, 5934.60556921, 5937.1229986 , 5936.18773966, 5941.43288047, 5939.15250266, 5938.73800861, 5938.33214627, 5939.81413655, 5938.91112096, 5935.79746029, 5941.78946937, 5943.52276352, 5943.86125451, 5942.36283444, 5941.81287125, 5938.85747673, 5933.59248448, 5936.59608654, 5935.1428543 , 5935.08500416, 5935.38143353, 5936.90175278, 5936.02421442, 5945.159906 , 5936.75088162, 5938.41598605, 5937.30254143, 5936.33452302, 5935.12561544, 5935.75037464, 5935.1110861 , 5934.99494974, ... 5936.86485103, 5939.95103174, 5944.29082844, 5937.01764774, 5937.86714691, 5938.4381522 , 5937.33863128, 5941.92121956, 5934.6381705 , 5937.72634813, 5941.5567955 , 5937.10062559, 5936.99677644, 5938.57974465, 5938.82414438, 5939.96748402, 5937.41734836, 5936.96286775, 5937.90349674, 5935.30922818, 5936.83292163, 5934.44285533, 5936.46762981, 5935.88983108, 5934.48602826, 5934.96882439, 5935.28822373, 5934.90140244, 5936.18055728, 5936.67048695, 5939.11585784, 5937.42932503, 5936.99673371, 5935.37437307, 5935.87642071, 5934.53097601, 5938.56016261, 5938.66719555, 5936.33169249, 5934.99381638, 5933.67253135, 5937.00171853, 5942.13779525, 5935.74020637, 5935.36517549, 5935.30469823, 5936.58449994, 5941.39906058, 5940.01262821, 5937.27100134, 5936.39671027, 5936.9934274 , 5936.07623062, 5936.22709849, 5937.44773906, 5935.96771456, 5935.45142079, 5938.38691703, 5938.73058747, 5939.5354658 , 5941.54716438, 5937.59784285, 5939.54970971, 5939.11411889, 5937.71481976, 5938.04114627, 5941.33132692, 5937.34584464, 5936.07411087, 5937.83907416, 5940.45865323, 5936.61408557, 5941.02833152, 5938.87233773, 5934.11984983, 5935.21491063, 5940.22071295, 5938.9656867 , 5936.25998233, 5935.35757334]])
- lp(chain, draw)float645.934e+03 5.933e+03 ... 5.933e+03
array([[5934.21720784, 5932.56866252, 5933.19084989, 5932.73010385, 5934.36734052, 5933.35219194, 5931.7007748 , 5932.56503329, 5935.40733521, 5937.82588707, 5934.97758222, 5936.54928766, 5938.86255999, 5938.60342444, 5936.09173025, 5935.01095516, 5933.50897914, 5933.25199782, 5934.85380119, 5932.54828456, 5931.50175048, 5931.4987701 , 5933.61106728, 5931.53899029, 5932.02722549, 5933.58336084, 5934.08876634, 5935.24387384, 5933.83133976, 5933.67950273, 5931.77860378, 5933.16716452, 5933.82935887, 5932.22308641, 5933.08588937, 5932.68883376, 5932.86272656, 5932.37262959, 5932.83987668, 5933.73974507, 5933.14164798, 5933.17426443, 5932.08365296, 5933.44380276, 5933.56103747, 5932.63124785, 5932.75829711, 5931.9918964 , 5932.57630268, 5932.95635275, 5932.92379723, 5935.88097608, 5936.45642635, 5936.37482983, 5936.44037966, 5937.62882319, 5933.27654077, 5933.52694306, 5939.21768168, 5939.98664159, 5940.05557132, 5937.16011852, 5936.43461823, 5932.86015069, 5931.71317629, 5932.90943849, 5934.16062273, 5932.83567364, 5933.72623112, 5933.19630714, 5934.33142618, 5935.22460927, 5934.21266953, 5935.13883569, 5932.59732408, 5934.63060388, 5932.85293 , 5933.32158322, 5932.68027631, 5933.05286418, ... 5934.17862913, 5937.81236487, 5935.44912876, 5936.07510774, 5933.80305903, 5934.9140576 , 5934.62144065, 5932.45285181, 5932.94815782, 5935.92308498, 5934.65460408, 5934.11349762, 5934.09593622, 5935.52897317, 5936.70205241, 5936.31130764, 5933.73100381, 5933.03076526, 5932.6702857 , 5933.0492608 , 5933.37827411, 5932.09611455, 5934.86648151, 5931.69400197, 5932.48688487, 5932.71699594, 5933.08264484, 5932.93063218, 5932.87569295, 5933.87180567, 5934.55619992, 5934.09728661, 5933.89411944, 5934.05905446, 5932.5246454 , 5932.00427438, 5936.8567976 , 5934.62286435, 5933.53668737, 5932.93044981, 5932.21927579, 5933.33302547, 5934.62271319, 5932.96797382, 5931.76219991, 5933.02710753, 5934.08124053, 5938.75189757, 5936.03655369, 5935.63510183, 5933.11047608, 5935.0654157 , 5934.21530096, 5933.90136117, 5934.74321868, 5932.76904494, 5932.90407935, 5937.0093989 , 5935.0343393 , 5935.88392571, 5933.12540712, 5935.80179671, 5936.85164572, 5933.25740237, 5934.18923798, 5937.28848118, 5935.39256079, 5934.24260432, 5933.254081 , 5935.07394892, 5934.04653936, 5935.42368502, 5932.28989553, 5933.54556031, 5932.85648178, 5932.82490775, 5936.60518201, 5934.19285865, 5933.03086192, 5932.81506027]])
- n_steps(chain, draw)int6415 191 127 127 255 ... 31 31 31 15
array([[ 15, 191, 127, 127, 255, 127, 63, 127, 63, 3, 63, 255, 63, 95, 191, 63, 191, 127, 95, 63, 127, 59, 127, 31, 127, 127, 127, 63, 63, 127, 255, 31, 127, 95, 127, 127, 95, 63, 127, 127, 63, 191, 127, 95, 63, 255, 63, 127, 127, 95, 127, 255, 31, 63, 31, 95, 127, 127, 47, 63, 63, 63, 39, 63, 127, 127, 63, 31, 127, 63, 159, 255, 63, 95, 127, 63, 31, 63, 63, 255, 191, 63, 191, 127, 127, 63, 127, 127, 191, 191, 127, 127, 95, 127, 127, 255, 191, 63, 63, 127], [ 31, 31, 31, 31, 7, 15, 15, 15, 31, 23, 27, 15, 31, 31, 31, 7, 15, 23, 31, 15, 31, 15, 31, 3, 15, 31, 31, 15, 15, 15, 31, 15, 15, 31, 31, 31, 15, 31, 31, 31, 31, 63, 7, 31, 31, 31, 15, 15, 15, 31, 15, 15, 31, 31, 31, 27, 31, 31, 15, 15, 31, 31, 7, 31, 15, 15, 31, 7, 15, 31, 31, 15, 15, 15, 15, 31, 31, 31, 31, 15, 31, 31, 15, 7, 31, 3, 7, 7, 31, 15, 31, 15, 31, 15, 31, 31, 31, 31, 31, 15]])
- step_size(chain, draw)float640.02087 0.02087 ... 0.1425 0.1425
array
- tree_depth(chain, draw)int644 8 7 7 8 7 6 7 ... 5 4 5 5 5 5 5 4
array([[4, 8, 7, 7, 8, 7, 6, 7, 6, 2, 6, 8, 6, 7, 8, 6, 8, 7, 7, 6, 7, 6, 7, 5, 7, 7, 7, 6, 6, 7, 8, 5, 7, 7, 7, 7, 7, 6, 7, 7, 6, 8, 7, 7, 6, 8, 6, 7, 7, 7, 7, 8, 5, 6, 5, 7, 7, 7, 6, 6, 6, 6, 6, 6, 7, 7, 6, 5, 7, 6, 8, 8, 6, 7, 7, 6, 5, 6, 6, 8, 8, 6, 8, 7, 7, 6, 7, 7, 8, 8, 7, 7, 7, 7, 7, 8, 8, 6, 6, 7], [5, 5, 5, 5, 3, 4, 4, 4, 5, 5, 5, 4, 5, 5, 5, 3, 4, 5, 5, 4, 5, 4, 5, 2, 4, 5, 5, 4, 4, 4, 5, 4, 4, 5, 5, 5, 4, 5, 5, 5, 5, 6, 3, 5, 5, 5, 4, 4, 4, 5, 4, 4, 5, 5, 5, 5, 5, 5, 4, 4, 5, 5, 3, 5, 4, 4, 5, 3, 4, 5, 5, 4, 4, 4, 4, 5, 5, 5, 5, 4, 5, 5, 4, 3, 5, 2, 3, 3, 5, 4, 5, 4, 5, 4, 5, 5, 5, 5, 5, 4]])
- chainPandasIndex
PandasIndex(Index([0, 1], dtype='int64', name='chain'))
- drawPandasIndex
PandasIndex(Index([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99], dtype='int64', name='draw'))
- created_at :
- 2025-07-13T13:18:09.820843+00:00
- arviz_version :
- 0.19.0
- modeling_interface :
- bambi
- modeling_interface_version :
- 0.15.0
<xarray.Dataset> Size: 11kB Dimensions: (chain: 2, draw: 100) Coordinates: * chain (chain) int64 16B 0 1 * draw (draw) int64 800B 0 1 2 3 4 5 6 7 ... 93 94 95 96 97 98 99 Data variables: acceptance_rate (chain, draw) float64 2kB 0.9996 0.998 ... 0.9305 0.9841 diverging (chain, draw) bool 200B False False False ... False False energy (chain, draw) float64 2kB 5.935e+03 5.936e+03 ... 5.935e+03 lp (chain, draw) float64 2kB 5.934e+03 5.933e+03 ... 5.933e+03 n_steps (chain, draw) int64 2kB 15 191 127 127 255 ... 31 31 31 15 step_size (chain, draw) float64 2kB 0.02087 0.02087 ... 0.1425 0.1425 tree_depth (chain, draw) int64 2kB 4 8 7 7 8 7 6 7 ... 5 4 5 5 5 5 5 4 Attributes: created_at: 2025-07-13T13:18:09.820843+00:00 arviz_version: 0.19.0 modeling_interface: bambi modeling_interface_version: 0.15.0
xarray.Dataset -
- __obs__: 3988
- rt,response_extra_dim_0: 2
- __obs__(__obs__)int640 1 2 3 4 ... 3984 3985 3986 3987
array([ 0, 1, 2, ..., 3985, 3986, 3987])
- rt,response_extra_dim_0(rt,response_extra_dim_0)int640 1
array([0, 1])
- rt,response(__obs__, rt,response_extra_dim_0)float641.21 1.0 1.63 1.0 ... -1.0 1.25 1.0
array([[ 1.21 , 1. ], [ 1.63 , 1. ], [ 1.03 , 1. ], ..., [ 0.784, 1. ], [ 2.35 , -1. ], [ 1.25 , 1. ]])
- __obs__PandasIndex
PandasIndex(Index([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, ... 3978, 3979, 3980, 3981, 3982, 3983, 3984, 3985, 3986, 3987], dtype='int64', name='__obs__', length=3988))
- rt,response_extra_dim_0PandasIndex
PandasIndex(Index([0, 1], dtype='int64', name='rt,response_extra_dim_0'))
- created_at :
- 2025-07-13T13:18:09.821885+00:00
- arviz_version :
- 0.19.0
- inference_library :
- numpyro
- inference_library_version :
- 0.16.1
- sampling_time :
- 83.794966
- tuning_steps :
- 100
- modeling_interface :
- bambi
- modeling_interface_version :
- 0.15.0
<xarray.Dataset> Size: 96kB Dimensions: (__obs__: 3988, rt,response_extra_dim_0: 2) Coordinates: * __obs__ (__obs__) int64 32kB 0 1 2 3 ... 3985 3986 3987 * rt,response_extra_dim_0 (rt,response_extra_dim_0) int64 16B 0 1 Data variables: rt,response (__obs__, rt,response_extra_dim_0) float64 64kB ... Attributes: created_at: 2025-07-13T13:18:09.821885+00:00 arviz_version: 0.19.0 inference_library: numpyro inference_library_version: 0.16.1 sampling_time: 83.794966 tuning_steps: 100 modeling_interface: bambi modeling_interface_version: 0.15.0
xarray.Dataset
VI¶
In [3]:
Copied!
basic_hssm_model.vi(method="advi", niter=5000)
basic_hssm_model.vi(method="advi", niter=5000)
Using MCMC starting point defaults.
Output()
Finished [100%]: Average Loss = 7,657.4
Out[3]:
arviz.InferenceData
-
- chain: 1
- draw: 1000
- v_C(stim)_dim: 2
- chain(chain)int640
array([0])
- draw(draw)int640 1 2 3 4 5 ... 995 996 997 998 999
array([ 0, 1, 2, ..., 997, 998, 999])
- v_C(stim)_dim(v_C(stim)_dim)<U2'WL' 'WW'
array(['WL', 'WW'], dtype='<U2')
- t(chain, draw)float640.03159 0.01737 ... 0.04878 0.06127
array([[0.03159457, 0.01736855, 0.04560688, 0.05954031, 0.0689499 , 0.0228133 , 0.05251808, 0.03896728, 0.01519128, 0.06229585, 0.11347293, 0.02124811, 0.02799431, 0.06894736, 0.050179 , 0.10205082, 0.03254602, 0.04113372, 0.0486455 , 0.08249909, 0.02816791, 0.04811642, 0.02572304, 0.04222521, 0.04928289, 0.06659457, 0.19621655, 0.02210643, 0.03160165, 0.10889739, 0.05468867, 0.06842242, 0.02058516, 0.0269555 , 0.04946644, 0.0235454 , 0.02378135, 0.04468541, 0.05768936, 0.03173909, 0.02196157, 0.03641064, 0.02353924, 0.08402386, 0.09174565, 0.04267104, 0.08890823, 0.03465182, 0.03225694, 0.02992218, 0.04188207, 0.15431403, 0.01657497, 0.07063462, 0.01738542, 0.03082115, 0.03074145, 0.14348427, 0.07070812, 0.01494958, 0.07527073, 0.05284161, 0.03420729, 0.06314845, 0.07889665, 0.01606157, 0.01600727, 0.05330849, 0.00965311, 0.0740026 , 0.08928536, 0.12459241, 0.05386065, 0.04339598, 0.04592368, 0.01184427, 0.01071922, 0.02803612, 0.05370468, 0.02377575, 0.0899783 , 0.0607347 , 0.14578409, 0.23226459, 0.0355627 , 0.06190687, 0.11893501, 0.01953912, 0.03829394, 0.05402782, 0.02390822, 0.01741979, 0.0520828 , 0.03851535, 0.06108397, 0.06496068, 0.04863576, 0.09594083, 0.03030212, 0.04627686, ... 0.04584004, 0.07045403, 0.16536434, 0.02563812, 0.0300032 , 0.03809777, 0.04275981, 0.20999274, 0.07382361, 0.044817 , 0.05151838, 0.04427956, 0.02315061, 0.06562763, 0.06506318, 0.08332097, 0.02651115, 0.01584398, 0.04146191, 0.0932619 , 0.03346893, 0.0354855 , 0.02048455, 0.03197591, 0.04602293, 0.04097261, 0.01642449, 0.01253088, 0.02492551, 0.0354764 , 0.03365967, 0.0566835 , 0.07478387, 0.07182077, 0.0699653 , 0.01408348, 0.03479146, 0.04040105, 0.09356272, 0.01966151, 0.03462169, 0.03624388, 0.01721261, 0.02916739, 0.01941683, 0.04541605, 0.01344339, 0.01649738, 0.04801734, 0.05362758, 0.04540702, 0.03328456, 0.13660245, 0.04229365, 0.03773229, 0.02018139, 0.03601188, 0.12524871, 0.03560262, 0.05878382, 0.09252339, 0.08396877, 0.01394965, 0.02342707, 0.01432594, 0.02409568, 0.14799217, 0.03561085, 0.04517255, 0.0509253 , 0.02982149, 0.09388265, 0.00758194, 0.05172245, 0.02378412, 0.07161388, 0.01487561, 0.05215514, 0.02488094, 0.05062225, 0.0309079 , 0.05041487, 0.02217952, 0.03979388, 0.01765873, 0.05467043, 0.05859612, 0.04177546, 0.01742696, 0.02169883, 0.07772268, 0.11190227, 0.0659654 , 0.01815897, 0.02887966, 0.04603546, 0.03021183, 0.0564655 , 0.04877926, 0.06126824]])
- theta(chain, draw)float640.4613 0.321 ... 0.5744 0.4231
array([[0.4612708 , 0.3210007 , 0.51363678, 0.19271219, 0.63839598, 0.53552034, 0.27560585, 0.22123229, 0.59027177, 0.17847105, 0.26911709, 0.71305451, 0.47387806, 0.31599813, 0.21585964, 0.49788744, 0.48856745, 0.23983747, 0.36359657, 0.33128402, 0.40696458, 0.56604395, 0.27992748, 0.45270877, 0.52624868, 0.47866911, 0.62113593, 0.23452318, 0.51231041, 0.50747781, 0.32750475, 0.31345456, 0.36257061, 0.39657988, 0.56647155, 0.3625394 , 0.19658199, 0.388231 , 0.24708274, 0.54687548, 0.35069249, 0.37253182, 0.30398149, 0.58055613, 0.25441569, 0.54378006, 0.51594129, 0.34762476, 0.23506332, 0.62376439, 0.43692827, 0.39935759, 0.49444065, 0.34073543, 0.30623492, 0.21477054, 0.30742203, 0.36763881, 0.39952354, 0.47636691, 0.34361696, 0.37865786, 0.3619811 , 0.64235952, 0.43019789, 0.43888493, 0.33844357, 0.4371779 , 0.46024981, 0.43410523, 0.29380188, 0.37883665, 0.24397359, 0.28526827, 0.47180543, 0.37843591, 0.49487431, 0.5197718 , 0.70536145, 0.39067293, 0.52702935, 0.3148981 , 0.43591584, 0.28512759, 0.46556005, 0.29193419, 0.25539792, 0.35503535, 0.28080402, 0.333235 , 0.63338073, 0.43027923, 0.40307423, 0.50361904, 0.18243337, 0.46227315, 0.49548878, 0.35864543, 0.33330429, 0.59813238, ... 0.36903867, 0.48366179, 0.47559625, 0.52920782, 0.36787317, 0.21208665, 0.55129951, 0.38550089, 0.39174603, 0.41348401, 0.44404295, 0.61550022, 0.67483185, 0.29400134, 0.43214495, 0.41988719, 0.44161888, 0.2845982 , 0.34412171, 0.34561976, 0.32964023, 0.51633425, 0.26756587, 0.32469493, 0.33107586, 0.31743943, 0.47117352, 0.39622138, 0.38730581, 0.5628015 , 0.53769803, 0.41417246, 0.69214992, 0.30404261, 0.50645787, 0.49997326, 0.58718967, 0.50857481, 0.62394925, 0.49038155, 0.41430362, 0.28827863, 0.45045957, 0.29392119, 0.31315022, 0.46854849, 0.32029061, 0.33235563, 0.30001024, 0.58659796, 0.60124514, 0.29227227, 0.14594035, 0.41915479, 0.28478982, 0.49886961, 0.31800156, 0.60062493, 0.66373319, 0.39261382, 0.4426875 , 0.504022 , 0.48474251, 0.38973274, 0.5022772 , 0.28521242, 0.4498533 , 0.4091644 , 0.25956504, 0.32831271, 0.26436867, 0.35886819, 0.31116116, 0.18061403, 0.46966406, 0.32825542, 0.71851102, 0.70725769, 0.37371847, 0.50203404, 0.31935303, 0.39721188, 0.25428007, 0.38404353, 0.66539123, 0.40995916, 0.39601806, 0.639423 , 0.5335943 , 0.29804864, 0.37261205, 0.27414449, 0.50426802, 0.30595738, 0.37545406, 0.49103304, 0.29298825, 0.23027194, 0.57436747, 0.42309187]])
- v_C(stim)(chain, draw, v_C(stim)_dim)float640.4996 0.05379 ... 0.7245 -0.6441
array([[[ 0.49957479, 0.05379231], [ 0.39404363, -0.13959234], [ 0.1346573 , 0.07635624], ..., [ 0.48498565, 0.01815528], [ 0.80995221, -0.49645356], [ 0.72447872, -0.64405229]]])
- v_Intercept(chain, draw)float64-0.1138 0.4594 ... 0.5505 0.57
array([[-1.13810964e-01, 4.59380193e-01, 4.20651436e-01, 6.14674851e-01, -1.39571826e-02, 4.52879464e-01, 1.68651620e-01, 1.98906460e-01, 6.21725281e-01, -3.15786319e-01, 2.55061038e-01, 6.28559474e-01, 2.67665823e-01, 2.49820615e-01, 2.61404636e-01, 4.13491198e-01, 2.56150075e-01, 2.37896690e-01, 7.08924191e-01, -1.89299609e-01, 3.37874146e-01, 5.20707581e-01, -3.17767874e-01, 3.49401557e-01, 3.91378470e-01, -1.42096123e-02, 5.19966916e-01, -2.02347799e-01, -1.42197883e-01, 2.59209205e-03, 2.18544684e-01, 5.53157366e-01, -5.68421802e-02, 2.17832713e-01, 6.71551839e-01, 6.47269037e-01, 1.33084249e-01, 2.00592919e-01, 3.55401153e-01, -2.41130728e-01, 5.99649561e-01, 6.81421065e-01, 6.16663371e-01, 5.32255867e-01, 6.38353112e-01, 2.50751835e-01, 3.88712065e-01, 4.22740214e-01, 5.45420667e-01, 6.74003518e-01, 7.03197859e-01, 2.12983253e-01, -2.71587842e-01, -3.19475536e-01, 8.30567067e-02, 5.35639439e-01, 3.84951205e-01, 3.65456658e-01, 1.15273034e-01, 3.57506206e-01, ... 3.49026345e-01, 4.58826051e-01, 8.71044746e-02, -1.07254018e-01, 3.46522447e-01, 1.06696020e+00, -2.98727468e-02, 3.08114939e-01, 5.28385226e-01, 2.27297901e-01, 7.31413545e-01, 2.17142390e-02, 2.93235913e-01, 9.51917428e-01, 4.33395628e-01, 3.33249577e-01, 2.30431494e-02, 2.18585865e-01, 3.37656061e-01, 3.83165612e-01, 7.10848677e-02, 9.81556644e-01, -6.83677749e-02, -3.05930907e-01, 4.67198689e-02, 6.39368167e-01, 2.49364725e-01, 2.66395128e-01, 7.63843281e-03, -1.46521784e-01, 3.34983457e-01, -8.57426288e-03, -1.56866733e-01, 3.35921624e-01, 7.58651632e-01, 2.95313324e-01, -3.25947813e-01, 2.97635970e-01, 7.40028424e-01, 3.63835508e-01, 6.83549703e-01, 2.58080740e-01, 7.85811507e-01, 7.70299925e-01, 2.92789377e-01, 2.85339425e-02, 3.91321316e-01, -1.56377201e-01, 6.86333484e-01, 5.78538386e-01, 1.22258413e-01, 3.61062613e-01, 1.35585368e-02, -1.91049830e-01, -1.31319411e-01, 4.87067059e-02, 5.50507350e-01, 5.69990184e-01]])
- z(chain, draw)float640.5089 0.5361 ... 0.6234 0.4803
array([[0.50886747, 0.53610624, 0.37688563, 0.59170918, 0.553923 , 0.49910031, 0.53114068, 0.37670194, 0.51917421, 0.36608343, 0.51806007, 0.59859601, 0.50205779, 0.53734563, 0.56156586, 0.67816843, 0.39874932, 0.60635032, 0.44516388, 0.5727414 , 0.49597409, 0.47560274, 0.49762223, 0.34657981, 0.56167752, 0.47414206, 0.52652467, 0.62413646, 0.40139193, 0.55160523, 0.56497917, 0.4454659 , 0.37353733, 0.41080505, 0.46367507, 0.51665074, 0.415118 , 0.62550002, 0.42338899, 0.56039769, 0.49466325, 0.58671466, 0.50231402, 0.47243618, 0.58057066, 0.45372211, 0.48524065, 0.55114172, 0.52135406, 0.50838753, 0.36944185, 0.56775067, 0.63003998, 0.55899287, 0.46862425, 0.50611198, 0.53815794, 0.59125121, 0.60089707, 0.51964084, 0.41375122, 0.5249008 , 0.5024947 , 0.5110178 , 0.52327107, 0.59960034, 0.59073841, 0.53391962, 0.4615026 , 0.46445273, 0.49565726, 0.60607976, 0.49570727, 0.58986873, 0.59624026, 0.50956989, 0.40939918, 0.52302866, 0.5275419 , 0.48284121, 0.48524688, 0.60977515, 0.4970483 , 0.56256273, 0.53323433, 0.5864215 , 0.36720171, 0.49273049, 0.63675593, 0.44944067, 0.51987587, 0.51737007, 0.56595227, 0.36247523, 0.44303285, 0.64863723, 0.53640648, 0.37702018, 0.63249883, 0.5211138 , ... 0.45082262, 0.48724313, 0.50857388, 0.53597854, 0.42769785, 0.46359854, 0.47223602, 0.5025299 , 0.48180986, 0.53161415, 0.58366799, 0.51093717, 0.49575155, 0.46385091, 0.53985139, 0.44500355, 0.5113039 , 0.53754946, 0.47830192, 0.52166192, 0.42849112, 0.47688832, 0.6016429 , 0.47653853, 0.55615135, 0.49743479, 0.39620609, 0.44298276, 0.45500636, 0.63236547, 0.63683112, 0.49492445, 0.54743223, 0.50484316, 0.56400139, 0.46224351, 0.52792934, 0.48948153, 0.4430009 , 0.51492015, 0.36137438, 0.43186072, 0.45804076, 0.38992515, 0.54768022, 0.60258233, 0.42397353, 0.58430076, 0.61060835, 0.4484861 , 0.59850611, 0.47774215, 0.56136233, 0.60078504, 0.40444661, 0.53416352, 0.48753814, 0.55106631, 0.55026548, 0.51625391, 0.60636211, 0.57305594, 0.47772621, 0.55244823, 0.53113515, 0.58338276, 0.43045653, 0.45966913, 0.40583966, 0.43973832, 0.38859626, 0.52735455, 0.57007025, 0.46255734, 0.52442851, 0.51868618, 0.54750116, 0.4535264 , 0.47015241, 0.52731691, 0.57200012, 0.54779686, 0.47429572, 0.59514091, 0.42261326, 0.59418185, 0.54644732, 0.38083168, 0.40901651, 0.58381738, 0.44945231, 0.56477412, 0.57340356, 0.54605956, 0.58275339, 0.35213698, 0.56335019, 0.63824483, 0.62336975, 0.48030307]])
- a(chain, draw)float641.958 1.357 2.204 ... 1.752 1.873
array([[1.95792461, 1.35728235, 2.2035927 , 1.58620336, 1.9409273 , 1.51701284, 1.69674265, 1.79796509, 1.48485917, 1.96643451, 1.63741379, 1.68262394, 1.93421517, 1.84524154, 1.95568637, 1.86187788, 1.53857798, 2.0116445 , 1.48048448, 1.70661148, 2.05853773, 1.54095854, 1.91130118, 1.90386614, 1.89173615, 1.73278725, 1.88575966, 1.88964639, 1.85451575, 1.8280116 , 1.86086975, 1.72009226, 1.81401922, 1.96245543, 1.73314536, 1.83643908, 1.54823145, 1.603397 , 1.89158206, 1.99013529, 1.69510502, 1.9277735 , 2.1794556 , 2.26456417, 1.85296538, 2.37074188, 1.64055316, 1.97437305, 2.0838758 , 2.15160421, 1.51684388, 1.717839 , 2.17936934, 2.09386364, 1.55251539, 1.86057893, 1.54493886, 1.95207162, 2.34083596, 1.49724893, 2.05294686, 1.60723121, 2.03777225, 2.05731853, 2.08495465, 1.8793375 , 1.56226148, 1.706151 , 2.03498765, 1.66301543, 1.44541985, 1.83197727, 1.90795943, 1.60242298, 1.8752932 , 1.66520744, 1.49020295, 2.07244516, 1.9067422 , 1.59512611, 1.45453744, 2.11054825, 1.6468379 , 2.01919227, 2.06410142, 1.99295421, 1.74665533, 1.48152905, 1.80710754, 2.13670363, 1.71722383, 2.13275268, 1.61336465, 1.82497286, 1.7620231 , 2.08801581, 2.31643199, 1.80796577, 2.03448116, 1.69534633, ... 2.09287583, 1.883358 , 2.08687592, 1.89950773, 1.89069429, 1.84267371, 1.45025561, 1.78585189, 2.0944354 , 1.77155626, 2.20088156, 2.11132116, 1.3948772 , 1.89357296, 1.73823088, 2.00012698, 1.92573755, 1.6490901 , 1.95835944, 1.75841338, 1.8966526 , 1.94584219, 1.6510486 , 1.71367665, 1.82917517, 1.61892269, 1.37951854, 1.48930011, 1.64173764, 1.99177355, 1.83190437, 1.98432369, 1.48203159, 1.29317717, 1.93161376, 1.81271643, 1.52362306, 1.83989342, 1.84722997, 1.93660326, 1.73374373, 1.82119237, 1.74643546, 1.78743116, 2.09913027, 1.70254282, 1.52643128, 1.56925697, 1.87545239, 1.91385787, 1.84366338, 1.76920745, 1.67688983, 2.01858151, 1.62089911, 2.14723538, 1.71502053, 1.7682157 , 2.01205095, 1.55055748, 1.64417407, 1.79476562, 1.90548809, 1.65550277, 1.98836667, 1.68430432, 1.93340044, 1.87005656, 1.58580302, 2.30242049, 1.95778672, 1.60549294, 2.14849843, 2.18091921, 1.53053867, 1.57505493, 1.5917061 , 1.90074802, 2.03747076, 1.87656057, 1.5420994 , 1.78990945, 1.96009532, 1.94931652, 1.6963524 , 1.66484778, 1.48267689, 1.65171573, 2.15338145, 2.1714255 , 1.74395334, 1.64474958, 1.92101692, 1.88316223, 2.19295787, 1.95194921, 1.7389084 , 2.14877006, 1.75197989, 1.87270676]])
- chainPandasIndex
PandasIndex(Index([0], dtype='int64', name='chain'))
- drawPandasIndex
PandasIndex(Index([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, ... 990, 991, 992, 993, 994, 995, 996, 997, 998, 999], dtype='int64', name='draw', length=1000))
- v_C(stim)_dimPandasIndex
PandasIndex(Index(['WL', 'WW'], dtype='object', name='v_C(stim)_dim'))
- created_at :
- 2025-07-13T13:19:12.711403+00:00
- arviz_version :
- 0.19.0
- inference_library :
- pymc
- inference_library_version :
- 5.19.1
<xarray.Dataset> Size: 64kB Dimensions: (chain: 1, draw: 1000, v_C(stim)_dim: 2) Coordinates: * chain (chain) int64 8B 0 * draw (draw) int64 8kB 0 1 2 3 4 5 6 ... 994 995 996 997 998 999 * v_C(stim)_dim (v_C(stim)_dim) <U2 16B 'WL' 'WW' Data variables: t (chain, draw) float64 8kB 0.03159 0.01737 ... 0.04878 0.06127 theta (chain, draw) float64 8kB 0.4613 0.321 ... 0.5744 0.4231 v_C(stim) (chain, draw, v_C(stim)_dim) float64 16kB 0.4996 ... -0.6441 v_Intercept (chain, draw) float64 8kB -0.1138 0.4594 ... 0.5505 0.57 z (chain, draw) float64 8kB 0.5089 0.5361 ... 0.6234 0.4803 a (chain, draw) float64 8kB 1.958 1.357 2.204 ... 1.752 1.873 Attributes: created_at: 2025-07-13T13:19:12.711403+00:00 arviz_version: 0.19.0 inference_library: pymc inference_library_version: 5.19.1
xarray.Dataset -
- __obs__: 3988
- rt,response_extra_dim_0: 2
- __obs__(__obs__)int640 1 2 3 4 ... 3984 3985 3986 3987
array([ 0, 1, 2, ..., 3985, 3986, 3987])
- rt,response_extra_dim_0(rt,response_extra_dim_0)int640 1
array([0, 1])
- rt,response(__obs__, rt,response_extra_dim_0)float641.21 1.0 1.63 1.0 ... -1.0 1.25 1.0
array([[ 1.21 , 1. ], [ 1.63 , 1. ], [ 1.03 , 1. ], ..., [ 0.784, 1. ], [ 2.35 , -1. ], [ 1.25 , 1. ]])
- __obs__PandasIndex
PandasIndex(Index([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, ... 3978, 3979, 3980, 3981, 3982, 3983, 3984, 3985, 3986, 3987], dtype='int64', name='__obs__', length=3988))
- rt,response_extra_dim_0PandasIndex
PandasIndex(Index([0, 1], dtype='int64', name='rt,response_extra_dim_0'))
- created_at :
- 2025-07-13T13:19:12.727276+00:00
- arviz_version :
- 0.19.0
- inference_library :
- pymc
- inference_library_version :
- 5.19.1
<xarray.Dataset> Size: 96kB Dimensions: (__obs__: 3988, rt,response_extra_dim_0: 2) Coordinates: * __obs__ (__obs__) int64 32kB 0 1 2 3 ... 3985 3986 3987 * rt,response_extra_dim_0 (rt,response_extra_dim_0) int64 16B 0 1 Data variables: rt,response (__obs__, rt,response_extra_dim_0) float64 64kB ... Attributes: created_at: 2025-07-13T13:19:12.727276+00:00 arviz_version: 0.19.0 inference_library: pymc inference_library_version: 5.19.1
xarray.Dataset
Saving and Loading the model¶
In [4]:
Copied!
basic_hssm_model.save_model(model_name="test_model")
basic_hssm_model.save_model(model_name="test_model")
We are using the defaults here, which save the model and its inference results to the hssm_models/test_model/
directory inside your curerent working directory.
Up to three files are saved in the model directory:
model.pkl
: The model instance.traces.nc
: The MCMC traces.vi_traces.nc
: The VI traces.
We can now load the model from the directory we just created, using the HSSM
classmethod load_model
.
In [5]:
Copied!
loaded_model = hssm.HSSM.load_model(path="hssm_models/test_model")
loaded_model = hssm.HSSM.load_model(path="hssm_models/test_model")
Model initialized successfully.
With this simple workflow your models are portable across sessions and machines.