Saving and loading models
Saving and loading models¶
In this short how-to, tutorial, we show how to save a HSSM model instance and its inference results to disk and then re-instantiate the model from the saved files.
Load data and instantiate HSSM model¶
In [21]:
Copied!
import hssm
cav_data = hssm.load_data("cavanagh_theta")
basic_hssm_model = hssm.HSSM(
data=cav_data,
process_initvals=True,
link_settings="log_logit",
model="angle",
include=[
{
"name": "v",
"formula": "v ~ 1 + C(stim)",
}
],
)
import hssm
cav_data = hssm.load_data("cavanagh_theta")
basic_hssm_model = hssm.HSSM(
data=cav_data,
process_initvals=True,
link_settings="log_logit",
model="angle",
include=[
{
"name": "v",
"formula": "v ~ 1 + C(stim)",
}
],
)
Model initialized successfully.
In [2]:
Copied!
basic_hssm_model.sample(sampler="nuts_numpyro", tune=100, draws=100, chains=2)
basic_hssm_model.sample(sampler="nuts_numpyro", tune=100, draws=100, chains=2)
Using default initvals.
0%| | 0/200 [00:00<?, ?it/s]
0%| | 0/200 [00:00<?, ?it/s]
We recommend running at least 4 chains for robust computation of convergence diagnostics The rhat statistic is larger than 1.01 for some parameters. This indicates problems during sampling. See https://arxiv.org/abs/1903.08008 for details The effective sample size per chain is smaller than 100 for some parameters. A higher number is needed for reliable rhat and ess computation. See https://arxiv.org/abs/1903.08008 for details 100%|██████████| 200/200 [00:00<00:00, 245.47it/s]
Out[2]:
arviz.InferenceData
-
- chain: 2
- draw: 100
- v_C(stim)_dim: 2
- chain(chain)int640 1
array([0, 1])
- draw(draw)int640 1 2 3 4 5 6 ... 94 95 96 97 98 99
array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99])
- v_C(stim)_dim(v_C(stim)_dim)<U2'WL' 'WW'
array(['WL', 'WW'], dtype='<U2')
- t(chain, draw)float640.2779 0.2868 ... 0.3006 0.2852
array([[0.27794726, 0.28683523, 0.25616826, 0.26723101, 0.28752123, 0.26457173, 0.28308768, 0.2793441 , 0.27107046, 0.27675541, 0.26928591, 0.27629762, 0.2627406 , 0.25885952, 0.29575716, 0.25444596, 0.25865325, 0.28471537, 0.27151541, 0.27133228, 0.28094957, 0.27681021, 0.27577089, 0.29255859, 0.26431373, 0.2803167 , 0.26910647, 0.26912909, 0.28238392, 0.29596476, 0.27735141, 0.28124704, 0.27505575, 0.26973833, 0.28416777, 0.26439491, 0.26146506, 0.29535416, 0.29110825, 0.27105588, 0.29682286, 0.26958756, 0.2634999 , 0.29047146, 0.29135172, 0.28289868, 0.2709616 , 0.26966883, 0.28910033, 0.29466557, 0.27526561, 0.27576686, 0.27708674, 0.28460881, 0.28419513, 0.29652454, 0.29593201, 0.29844543, 0.27178917, 0.2743662 , 0.28297141, 0.28822835, 0.27672602, 0.27352927, 0.28708584, 0.27307593, 0.28859248, 0.27988818, 0.26920821, 0.26993258, 0.27512302, 0.28962456, 0.28472403, 0.27450643, 0.26156176, 0.28881319, 0.27630539, 0.27533414, 0.26096796, 0.29166572, 0.2712885 , 0.28204951, 0.2720148 , 0.28468807, 0.28019522, 0.27119086, 0.26808307, 0.28236868, 0.27909878, 0.28941435, 0.29133885, 0.27869249, 0.26697165, 0.2781981 , 0.27523632, 0.28905276, 0.2803168 , 0.26733187, 0.2810824 , 0.28080433], [0.27831647, 0.27947132, 0.28100207, 0.27185771, 0.25643368, 0.29381018, 0.27235307, 0.27054797, 0.28294369, 0.27244957, 0.28408825, 0.27785335, 0.26140675, 0.28309613, 0.28434203, 0.27683756, 0.28636873, 0.28395158, 0.26819602, 0.28760293, 0.27061892, 0.28453443, 0.27147776, 0.26959915, 0.2878931 , 0.2601075 , 0.27395741, 0.27920255, 0.27735171, 0.2753878 , 0.27691636, 0.2851309 , 0.27900454, 0.28641762, 0.2714652 , 0.26533341, 0.29570311, 0.26989572, 0.28900186, 0.27649808, 0.27469844, 0.28398399, 0.27263044, 0.26949626, 0.27028282, 0.27525224, 0.2698769 , 0.26511564, 0.29475938, 0.25609631, 0.2594342 , 0.26538768, 0.28379351, 0.2862198 , 0.28257633, 0.26875347, 0.27509668, 0.28094009, 0.27120807, 0.26899237, 0.27007739, 0.29139102, 0.27208123, 0.26429555, 0.26310048, 0.28062837, 0.268191 , 0.29163348, 0.27126162, 0.2766789 , 0.28838749, 0.26687202, 0.28344015, 0.28478477, 0.28406388, 0.28265573, 0.29059044, 0.28602175, 0.26978179, 0.28384231, 0.28401338, 0.27327449, 0.26423351, 0.24738395, 0.29578718, 0.29324555, 0.292814 , 0.2815433 , 0.27972924, 0.28996886, 0.28253274, 0.28151731, 0.28321058, 0.2854816 , 0.28758432, 0.29108217, 0.30085286, 0.25886009, 0.30059999, 0.28516021]])
- theta(chain, draw)float640.2375 0.2246 ... 0.2091 0.2026
array([[0.23747815, 0.22458424, 0.25147083, 0.24298253, 0.21903024, 0.24422449, 0.23115378, 0.24112704, 0.24170242, 0.22850864, 0.23933004, 0.21963389, 0.23741258, 0.23729713, 0.21548696, 0.25742852, 0.24512951, 0.21368745, 0.22853091, 0.23457026, 0.2285053 , 0.23785923, 0.24080338, 0.2245862 , 0.24775856, 0.20958863, 0.25724709, 0.25075745, 0.24223672, 0.21596319, 0.21090387, 0.21336853, 0.25594148, 0.2516482 , 0.22455947, 0.23043865, 0.23015074, 0.23323168, 0.24180347, 0.25312078, 0.2057309 , 0.26390159, 0.26797225, 0.23191926, 0.23188824, 0.2273516 , 0.2369826 , 0.24273739, 0.21771681, 0.2136289 , 0.23373692, 0.23582593, 0.23393072, 0.22028862, 0.22163795, 0.21306924, 0.22113926, 0.21971981, 0.24146359, 0.2256954 , 0.24141782, 0.22476118, 0.22717582, 0.22772357, 0.23732442, 0.24292052, 0.22043455, 0.22166257, 0.22471902, 0.23759252, 0.23388848, 0.24194701, 0.23309665, 0.22186194, 0.24159751, 0.23435816, 0.24070485, 0.22267622, 0.2405643 , 0.22156944, 0.24976018, 0.22050752, 0.21668344, 0.23611408, 0.24592658, 0.23305748, 0.22471069, 0.23905549, 0.23141134, 0.21807302, 0.2236745 , 0.24222848, 0.24181521, 0.24173408, 0.24725979, 0.21862142, 0.21424607, 0.22480082, 0.2441637 , 0.21827103], [0.22591565, 0.24091771, 0.23033165, 0.25447162, 0.25564407, 0.2236256 , 0.22844321, 0.23347877, 0.22304974, 0.23057625, 0.2367739 , 0.23448444, 0.23525542, 0.2297417 , 0.23503118, 0.22510608, 0.23074721, 0.21491854, 0.24087142, 0.22193874, 0.24375803, 0.22268293, 0.25025971, 0.23516079, 0.22477487, 0.2471812 , 0.23040685, 0.22690329, 0.24003596, 0.23025299, 0.23642203, 0.22629773, 0.24796138, 0.21341038, 0.24497949, 0.24984192, 0.22474355, 0.23447918, 0.23507221, 0.23192808, 0.23614661, 0.23366562, 0.22938743, 0.22780795, 0.24804703, 0.25252924, 0.24238594, 0.25511995, 0.20556602, 0.25867518, 0.26638187, 0.24256867, 0.23467043, 0.20840525, 0.23593462, 0.24450853, 0.22032979, 0.21622024, 0.24238332, 0.23625128, 0.23704185, 0.22609773, 0.23053452, 0.23276615, 0.22975954, 0.22822105, 0.23684396, 0.21584948, 0.25182882, 0.24800074, 0.22164926, 0.24552289, 0.23752072, 0.24597162, 0.21986731, 0.23664978, 0.22616556, 0.23240246, 0.24484838, 0.21840479, 0.22562733, 0.24795108, 0.25324136, 0.25418122, 0.22124138, 0.21628466, 0.20372964, 0.22630729, 0.24353526, 0.22995743, 0.23224565, 0.23809266, 0.23406009, 0.23664537, 0.23432574, 0.22489409, 0.21709306, 0.26062422, 0.20910589, 0.20264787]])
- v_Intercept(chain, draw)float640.115 0.1134 ... 0.09005 0.1168
array([[0.11502038, 0.11343259, 0.09221548, 0.09607123, 0.11150241, 0.10856393, 0.1305178 , 0.12979811, 0.11319105, 0.12751436, 0.11365914, 0.13993971, 0.15243877, 0.15638923, 0.15611402, 0.15654489, 0.16528758, 0.12500035, 0.15102614, 0.11225461, 0.13217682, 0.11734442, 0.14535492, 0.14653031, 0.15996568, 0.15277657, 0.12472837, 0.12430361, 0.08442859, 0.07384489, 0.13404061, 0.11959452, 0.15153973, 0.15009982, 0.13705176, 0.16794712, 0.16815289, 0.1678627 , 0.13689828, 0.10558868, 0.08735405, 0.09762278, 0.10077773, 0.11283236, 0.11299398, 0.10581475, 0.10696834, 0.12753862, 0.14505061, 0.14214459, 0.12535666, 0.10350909, 0.12105754, 0.10952397, 0.1181948 , 0.10768935, 0.09909075, 0.13117713, 0.14948143, 0.15144779, 0.14513573, 0.12894649, 0.10627449, 0.13555261, 0.12619309, 0.13257106, 0.11231503, 0.13985784, 0.14205906, 0.14886731, 0.12391934, 0.10833658, 0.10107319, 0.14065614, 0.10981355, 0.14555123, 0.14631135, 0.15526849, 0.15786436, 0.12417578, 0.15888613, 0.12384961, 0.12355289, 0.12429155, 0.12661345, 0.13012897, 0.12827066, 0.09419788, 0.15686063, 0.11084653, 0.14932996, 0.07742446, 0.08417316, 0.09929535, 0.10178594, 0.12085405, 0.11268754, 0.11674138, 0.13152111, 0.14658617], [0.09693983, 0.15449941, 0.16616491, 0.15679933, 0.15467767, 0.08033037, 0.12630916, 0.102302 , 0.11373434, 0.08890811, 0.08943667, 0.09220479, 0.12712764, 0.13069219, 0.12434023, 0.15574775, 0.09246615, 0.14143084, 0.13034899, 0.09912838, 0.10333585, 0.10996493, 0.11620817, 0.17155919, 0.14607243, 0.14853236, 0.14562918, 0.12739037, 0.08214728, 0.12579825, 0.13675899, 0.11067202, 0.12455817, 0.13175593, 0.14081888, 0.14628855, 0.11258272, 0.11627517, 0.09369054, 0.12407194, 0.08829056, 0.13872651, 0.14103373, 0.11222812, 0.11751713, 0.11823022, 0.11770946, 0.12082651, 0.14467383, 0.15405022, 0.15371083, 0.164283 , 0.13933227, 0.13880335, 0.08428517, 0.13147419, 0.14168344, 0.1013137 , 0.12288587, 0.116463 , 0.09369037, 0.12533561, 0.12124155, 0.12571629, 0.1048083 , 0.11994769, 0.12794078, 0.13342528, 0.0986474 , 0.11260975, 0.10875441, 0.09269563, 0.08585115, 0.12876326, 0.15731363, 0.10021031, 0.11205584, 0.10324476, 0.13271992, 0.13202053, 0.10534594, 0.10246787, 0.09364201, 0.08950058, 0.08096229, 0.08300519, 0.08077764, 0.08022309, 0.10014483, 0.14036153, 0.10088456, 0.0961807 , 0.10060917, 0.11248453, 0.10688459, 0.12453317, 0.06075571, 0.08298465, 0.090053 , 0.11675408]])
- a(chain, draw)float641.318 1.327 1.358 ... 1.266 1.288
array([[1.31804998, 1.32710392, 1.35845149, 1.34546132, 1.29079347, 1.36115416, 1.32091294, 1.3378816 , 1.34551898, 1.31554127, 1.34178254, 1.31464278, 1.35195398, 1.35010978, 1.28595128, 1.37591114, 1.36806204, 1.2954896 , 1.33834369, 1.32691921, 1.32132916, 1.33500502, 1.34082754, 1.30141263, 1.35427996, 1.29564123, 1.35187378, 1.36121529, 1.33328219, 1.29248934, 1.30432576, 1.30618532, 1.36004933, 1.3678642 , 1.30168627, 1.34915766, 1.34460927, 1.29803855, 1.32706323, 1.34959602, 1.27765091, 1.37224434, 1.3795806 , 1.31494728, 1.31227054, 1.32151176, 1.32707682, 1.33619872, 1.3045236 , 1.29887874, 1.32133028, 1.32098917, 1.3258997 , 1.31143255, 1.31015507, 1.27713728, 1.28723017, 1.29844479, 1.35589397, 1.30739285, 1.33887302, 1.30531926, 1.31114882, 1.32105466, 1.32306781, 1.35188083, 1.31591811, 1.32500734, 1.31033484, 1.33057642, 1.31931276, 1.32424875, 1.31292028, 1.33299098, 1.3394638 , 1.329294 , 1.3520346 , 1.32805173, 1.35153018, 1.28760773, 1.33647528, 1.3167412 , 1.31090135, 1.30327309, 1.32608032, 1.35135168, 1.32449759, 1.32506765, 1.32145499, 1.2967955 , 1.29954761, 1.34318686, 1.3450965 , 1.35506592, 1.35655679, 1.30590868, 1.30662147, 1.33546752, 1.33376909, 1.31115041], [1.31129463, 1.33350204, 1.31883112, 1.34889412, 1.35548153, 1.31458488, 1.31357498, 1.3132651 , 1.32229855, 1.33488482, 1.31656217, 1.32656497, 1.33853458, 1.30139364, 1.3340325 , 1.32551445, 1.31162205, 1.31104557, 1.34926893, 1.29610628, 1.34845861, 1.31424929, 1.35078595, 1.32619067, 1.30931537, 1.35719176, 1.33617719, 1.32438337, 1.33764024, 1.334458 , 1.32750464, 1.31926661, 1.33319026, 1.29344903, 1.34639584, 1.35052484, 1.30729665, 1.32540666, 1.32502005, 1.32639394, 1.3319471 , 1.32051532, 1.33431501, 1.33282669, 1.35051674, 1.34332678, 1.33633469, 1.36053436, 1.27835952, 1.36832605, 1.39201395, 1.3461482 , 1.3264752 , 1.28610719, 1.32974754, 1.36100738, 1.31957643, 1.31086058, 1.32829425, 1.33561948, 1.34189466, 1.31072029, 1.31908066, 1.33698627, 1.33699327, 1.31643525, 1.33349666, 1.29510472, 1.35296802, 1.35091492, 1.29542524, 1.34214414, 1.31539327, 1.33570372, 1.31458686, 1.3249801 , 1.31181519, 1.32837238, 1.35896377, 1.29635091, 1.31618144, 1.35456834, 1.36210832, 1.37970174, 1.29826331, 1.29873766, 1.28572679, 1.32553529, 1.32645695, 1.30551415, 1.31405484, 1.32384261, 1.31511201, 1.31378292, 1.30838902, 1.29885395, 1.28238703, 1.38784381, 1.26612039, 1.28775176]])
- v_C(stim)(chain, draw, v_C(stim)_dim)float640.2641 -0.0243 ... 0.2513 0.01448
array([[[ 2.64056969e-01, -2.42976917e-02], [ 2.58592742e-01, -2.19227927e-02], [ 2.91934642e-01, 1.28326496e-02], [ 2.91599264e-01, -1.48121105e-02], [ 2.82326739e-01, 6.88302940e-03], [ 3.08339622e-01, -2.01647405e-02], [ 3.05888053e-01, 8.14796567e-03], [ 2.94207835e-01, -3.69767123e-03], [ 2.63866442e-01, -1.01054163e-02], [ 2.47601086e-01, -4.99247745e-02], [ 2.64172999e-01, -3.17795975e-02], [ 2.56725886e-01, -3.87933356e-02], [ 2.40929405e-01, -3.22494750e-02], [ 2.41606187e-01, -4.20732079e-02], [ 2.51602380e-01, -3.28378260e-02], [ 2.36163996e-01, -6.10110368e-02], [ 2.30765582e-01, -4.42214202e-02], [ 2.87979033e-01, -1.78627235e-02], [ 2.48767240e-01, -3.93416001e-02], [ 2.96334227e-01, -4.37348292e-06], ... [ 2.96222101e-01, 7.74294110e-03], [ 2.94832905e-01, 1.37535382e-02], [ 2.58641960e-01, 1.94996272e-02], [ 2.59750385e-01, 1.40507121e-02], [ 3.29518119e-01, 5.47202191e-03], [ 3.31038765e-01, 3.31903416e-03], [ 3.37463631e-01, -2.70259262e-03], [ 3.34176098e-01, -5.93874744e-03], [ 2.79363737e-01, -1.84034732e-02], [ 2.57298324e-01, 2.00917220e-03], [ 2.84637036e-01, -2.69279395e-03], [ 2.83768553e-01, 8.65163383e-03], [ 2.72769057e-01, -6.73227050e-03], [ 2.79909507e-01, 3.10702517e-03], [ 2.86782412e-01, 4.03461888e-03], [ 2.81429320e-01, -4.21570945e-02], [ 3.07044155e-01, -2.17251153e-02], [ 2.84985460e-01, -1.45847419e-02], [ 3.09632234e-01, -8.74480166e-04], [ 2.51295496e-01, 1.44836487e-02]]])
- z(chain, draw)float640.5108 0.5077 ... 0.508 0.5043
array([[0.51080121, 0.50771916, 0.50757617, 0.50719418, 0.50102112, 0.4997359 , 0.49991763, 0.49981877, 0.50475506, 0.50462185, 0.49853626, 0.50488295, 0.49554204, 0.49540437, 0.48832519, 0.49657472, 0.49970611, 0.49251648, 0.50306969, 0.49448165, 0.4983552 , 0.49950197, 0.49656398, 0.49639566, 0.49261742, 0.49467092, 0.50390655, 0.50348884, 0.50733608, 0.50771943, 0.49409634, 0.50084448, 0.50427875, 0.50628102, 0.49411232, 0.49720347, 0.49698506, 0.49499909, 0.50577971, 0.50923114, 0.51335895, 0.51076789, 0.50577469, 0.50711892, 0.50690809, 0.50982796, 0.49459122, 0.49510487, 0.50124301, 0.49964218, 0.5028778 , 0.49881595, 0.50227329, 0.5091899 , 0.50618915, 0.50657375, 0.50284608, 0.50495312, 0.49788419, 0.49487244, 0.50619576, 0.50621972, 0.50301974, 0.49652524, 0.50954208, 0.50635014, 0.50293521, 0.5026232 , 0.50302042, 0.49823332, 0.49698347, 0.49822171, 0.50693377, 0.49972322, 0.49323057, 0.50036793, 0.50022219, 0.49442792, 0.4961308 , 0.5000034 , 0.49856539, 0.49985626, 0.49944139, 0.50086306, 0.49878422, 0.49778531, 0.49981104, 0.50857945, 0.50042989, 0.50495508, 0.50545967, 0.50640634, 0.50526683, 0.50530945, 0.50269495, 0.49953456, 0.50192309, 0.5070404 , 0.4979052 , 0.50380442], [0.50480424, 0.49909223, 0.50296217, 0.5005614 , 0.50102759, 0.50595605, 0.50431096, 0.50340346, 0.50790681, 0.4997461 , 0.49959405, 0.49988969, 0.49754911, 0.50398097, 0.50556264, 0.49572089, 0.51637238, 0.50429743, 0.50574805, 0.49830778, 0.49902728, 0.50347178, 0.51093665, 0.49009445, 0.49793168, 0.49889534, 0.49435754, 0.50235631, 0.50644283, 0.49778948, 0.50359175, 0.50896341, 0.49717334, 0.50085381, 0.49671328, 0.50324167, 0.50876988, 0.50569142, 0.5082235 , 0.49455839, 0.50657805, 0.50576474, 0.50249328, 0.50113655, 0.50377824, 0.50359858, 0.50103637, 0.50874687, 0.50934167, 0.50203537, 0.50323684, 0.50649782, 0.50523443, 0.5071836 , 0.50914593, 0.50077149, 0.49177358, 0.50279961, 0.49732106, 0.49374823, 0.50478751, 0.51341549, 0.51023127, 0.50673628, 0.50206398, 0.50923835, 0.48945467, 0.5011068 , 0.50541118, 0.51219583, 0.5101435 , 0.51234745, 0.51061935, 0.49989632, 0.49748926, 0.50631896, 0.50442028, 0.50673452, 0.49760768, 0.49148546, 0.51151014, 0.51058072, 0.50316357, 0.50224201, 0.51603913, 0.51647413, 0.51321944, 0.51627378, 0.50754801, 0.49769583, 0.50343691, 0.50798482, 0.50862513, 0.50127702, 0.49921927, 0.4997421 , 0.52128397, 0.51443948, 0.50798789, 0.50434176]])
- chainPandasIndex
PandasIndex(Index([0, 1], dtype='int64', name='chain'))
- drawPandasIndex
PandasIndex(Index([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99], dtype='int64', name='draw'))
- v_C(stim)_dimPandasIndex
PandasIndex(Index(['WL', 'WW'], dtype='object', name='v_C(stim)_dim'))
- created_at :
- 2024-12-29T17:36:28.431012+00:00
- arviz_version :
- 0.19.0
- inference_library :
- numpyro
- inference_library_version :
- 0.16.1
- sampling_time :
- 74.365569
- tuning_steps :
- 100
- modeling_interface :
- bambi
- modeling_interface_version :
- 0.15.0
<xarray.Dataset> Size: 12kB Dimensions: (chain: 2, draw: 100, v_C(stim)_dim: 2) Coordinates: * chain (chain) int64 16B 0 1 * draw (draw) int64 800B 0 1 2 3 4 5 6 7 ... 92 93 94 95 96 97 98 99 * v_C(stim)_dim (v_C(stim)_dim) <U2 16B 'WL' 'WW' Data variables: t (chain, draw) float64 2kB 0.2779 0.2868 ... 0.3006 0.2852 theta (chain, draw) float64 2kB 0.2375 0.2246 ... 0.2091 0.2026 v_Intercept (chain, draw) float64 2kB 0.115 0.1134 ... 0.09005 0.1168 a (chain, draw) float64 2kB 1.318 1.327 1.358 ... 1.266 1.288 v_C(stim) (chain, draw, v_C(stim)_dim) float64 3kB 0.2641 ... 0.01448 z (chain, draw) float64 2kB 0.5108 0.5077 ... 0.508 0.5043 Attributes: created_at: 2024-12-29T17:36:28.431012+00:00 arviz_version: 0.19.0 inference_library: numpyro inference_library_version: 0.16.1 sampling_time: 74.365569 tuning_steps: 100 modeling_interface: bambi modeling_interface_version: 0.15.0
xarray.Dataset -
- chain: 2
- draw: 100
- __obs__: 3988
- chain(chain)int640 1
array([0, 1])
- draw(draw)int640 1 2 3 4 5 6 ... 94 95 96 97 98 99
array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99])
- __obs__(__obs__)int640 1 2 3 4 ... 3984 3985 3986 3987
array([ 0, 1, 2, ..., 3985, 3986, 3987])
- rt,response(chain, draw, __obs__)float64-1.04 -1.202 ... -2.496 -1.078
array([[[-1.04046549, -1.20156872, -0.96371021, ..., -0.46963895, -2.48871423, -1.10515753], [-1.05244367, -1.18192568, -0.98332685, ..., -0.51041431, -2.43752575, -1.11283985], [-1.08414305, -1.18484 , -0.96604773, ..., -0.49532261, -2.41001391, -1.09437584], ..., [-1.06373359, -1.18551357, -0.98176128, ..., -0.48229474, -2.45087754, -1.11346285], [-1.02117031, -1.15610018, -0.97795691, ..., -0.46003701, -2.48303742, -1.09866628], [-1.0171746 , -1.18052936, -1.02230107, ..., -0.4399228 , -2.52835525, -1.15468858]], [[-1.07992426, -1.20059877, -0.99270899, ..., -0.4942563 , -2.45071245, -1.12910082], [-0.99308167, -1.17321128, -1.00770157, ..., -0.51675049, -2.52467375, -1.12784662], [-0.98170283, -1.18918847, -0.92359832, ..., -0.50396596, -2.5642393 , -1.0596331 ], ..., [-1.08709007, -1.17139511, -1.01787659, ..., -0.51849257, -2.34820433, -1.13535006], [-1.0785668 , -1.21306911, -0.95833337, ..., -0.42703308, -2.48642186, -1.11316623], [-1.06548382, -1.21546898, -0.92942533, ..., -0.50349897, -2.49605486, -1.07762304]]])
- chainPandasIndex
PandasIndex(Index([0, 1], dtype='int64', name='chain'))
- drawPandasIndex
PandasIndex(Index([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99], dtype='int64', name='draw'))
- __obs__PandasIndex
PandasIndex(Index([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, ... 3978, 3979, 3980, 3981, 3982, 3983, 3984, 3985, 3986, 3987], dtype='int64', name='__obs__', length=3988))
- modeling_interface :
- bambi
- modeling_interface_version :
- 0.15.0
<xarray.Dataset> Size: 6MB Dimensions: (chain: 2, draw: 100, __obs__: 3988) Coordinates: * chain (chain) int64 16B 0 1 * draw (draw) int64 800B 0 1 2 3 4 5 6 7 8 ... 92 93 94 95 96 97 98 99 * __obs__ (__obs__) int64 32kB 0 1 2 3 4 5 ... 3983 3984 3985 3986 3987 Data variables: rt,response (chain, draw, __obs__) float64 6MB -1.04 -1.202 ... -1.078 Attributes: modeling_interface: bambi modeling_interface_version: 0.15.0
xarray.Dataset -
- chain: 2
- draw: 100
- chain(chain)int640 1
array([0, 1])
- draw(draw)int640 1 2 3 4 5 6 ... 94 95 96 97 98 99
array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99])
- acceptance_rate(chain, draw)float640.9693 0.9136 ... 0.9664 0.9924
array([[0.96930827, 0.91355286, 0.99684856, 0.99242584, 0.93312522, 0.99743907, 0.9672061 , 0.95425219, 0.93719988, 0.96193701, 0.9972729 , 0.97766396, 0.99226791, 0.80300486, 0.99497608, 0.91236092, 0.97357358, 0.99776325, 0.94039181, 0.99906453, 0.99216791, 0.9974501 , 0.9936246 , 0.94477821, 0.99135525, 0.95138316, 0.96941832, 0.99953219, 0.99971465, 0.94092181, 0.98701514, 0.98522984, 0.98304164, 0.95853477, 0.99767714, 0.92217373, 1. , 0.98437117, 0.71290671, 0.99861054, 0.99372022, 0.98867465, 0.99920126, 0.80426299, 0.99951302, 0.99124369, 0.99509264, 0.97163004, 0.95915573, 0.92604585, 0.99662235, 0.99146533, 0.99998357, 0.96936118, 0.99846505, 0.7254611 , 0.99157881, 0.8648015 , 0.9809459 , 0.99232308, 0.99872997, 0.96813915, 0.90622961, 0.92887769, 0.99526989, 0.94590557, 0.92254565, 0.99221324, 0.72854364, 0.99439796, 0.99028306, 0.94903661, 0.99847635, 0.81212378, 0.99682304, 0.95319428, 0.99944682, 0.98655711, 0.99237545, 0.98776779, 0.94557221, 0.99175218, 0.96238747, 0.96245993, 0.98235927, 0.86035221, 0.99817276, 0.98199333, 0.99976913, 0.88722503, 0.99308788, 0.94135736, 0.99520152, 0.80431042, 0.9893826 , 0.97278472, 0.9937367 , 0.93344819, 0.98973693, 0.96566344], [0.92034417, 0.98907066, 0.69528645, 0.54946972, 0.36778202, 0.94666458, 0.96734995, 0.82232722, 0.96558463, 0.99642261, 0.98060309, 0.98945221, 0.71850868, 0.97679381, 0.82549958, 0.99313692, 0.96082572, 0.93292119, 0.93313773, 0.88228244, 0.99309888, 0.9974448 , 0.83243559, 0.99445389, 0.91094572, 0.83186826, 0.93577069, 0.95736712, 0.99557939, 0.8897974 , 0.73406365, 0.99300288, 0.81966376, 0.98219689, 0.90602551, 0.84116091, 0.96624891, 0.85400147, 0.97988233, 0.99842699, 0.94746175, 0.98624558, 0.95736179, 0.97637251, 0.85285882, 0.9482059 , 0.778366 , 0.99339181, 0.9930957 , 0.82129892, 0.97230993, 0.8819175 , 0.98370124, 0.83601665, 0.97954764, 0.85341166, 0.67996599, 0.99990001, 0.78053428, 0.86365232, 0.91658286, 0.98779635, 0.6637827 , 0.85037935, 0.85658693, 0.96561452, 0.99724145, 0.94493854, 0.98873271, 0.94717567, 0.66005006, 0.72503966, 0.94964433, 0.56575182, 0.871132 , 0.97977963, 0.80600082, 0.97565619, 0.98147831, 0.75526916, 0.86971664, 0.95921779, 0.71774692, 0.84807007, 0.90368732, 0.80134585, 0.96558709, 0.99035787, 0.9841441 , 0.99244423, 0.97565473, 0.99128238, 0.94301119, 0.80538348, 0.99333034, 0.91050712, 0.99761961, 0.95549407, 0.9663979 , 0.9924364 ]])
- diverging(chain, draw)boolFalse False False ... False False
array([[False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False], [False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False]])
- energy(chain, draw)float645.937e+03 5.94e+03 ... 5.939e+03
array([[5936.7003157 , 5939.81446386, 5937.67924547, 5937.53248753, 5935.17191678, 5935.0785271 , 5940.21800268, 5936.72292003, 5935.12139351, 5935.1607942 , 5935.41777142, 5935.04060925, 5935.96318701, 5941.59651186, 5938.3914569 , 5941.00249287, 5940.2430522 , 5938.35681909, 5937.85051381, 5936.22174824, 5934.25935463, 5935.75461147, 5934.07856867, 5935.59399556, 5937.31638305, 5936.50044075, 5937.6831537 , 5934.44747103, 5935.71509602, 5936.44966627, 5939.82521739, 5939.92119231, 5938.91012608, 5937.30048272, 5936.45707687, 5939.6752502 , 5938.03475103, 5939.06493449, 5943.7401486 , 5938.40023057, 5937.16479559, 5937.58299776, 5938.42866007, 5939.69547114, 5933.8886096 , 5935.81194939, 5940.70546468, 5941.30638536, 5940.30355426, 5938.96421726, 5937.25851252, 5934.41531512, 5934.29199492, 5938.44348579, 5937.99245788, 5938.59551354, 5938.82358351, 5941.02643118, 5935.77113712, 5937.01588453, 5935.98734732, 5934.76341302, 5935.02349258, 5937.23734626, 5934.4381794 , 5935.17917994, 5934.9405243 , 5937.61836278, 5938.6551293 , 5938.04769629, 5935.89654035, 5938.10068571, 5937.6474893 , 5941.58521911, 5938.81408334, 5939.84510885, 5937.10381476, 5940.90619593, 5937.85904577, 5937.05261881, ... 5936.22257475, 5937.32172773, 5938.37255171, 5940.49897413, 5940.63067194, 5936.82649363, 5937.55907035, 5939.35719379, 5935.93609259, 5936.39173782, 5936.42718212, 5937.44992365, 5936.90211033, 5934.82119587, 5935.38923281, 5935.06879104, 5936.96483571, 5937.43068116, 5936.49729679, 5935.60874488, 5936.50751886, 5936.83773844, 5936.68278424, 5935.48259945, 5936.73399833, 5933.27623901, 5936.29182519, 5936.14338134, 5936.32640422, 5941.17357257, 5940.13854863, 5942.26656767, 5938.27525141, 5937.37187403, 5944.55521558, 5939.47986582, 5938.81066749, 5936.59722861, 5939.43619817, 5938.57070447, 5937.3882025 , 5937.14690534, 5938.05484656, 5938.41915518, 5940.30437245, 5942.10115676, 5937.4961414 , 5941.29920633, 5939.64344372, 5936.61344796, 5939.10761608, 5939.58824534, 5939.02911126, 5938.53319763, 5941.63302135, 5934.52390783, 5934.55547538, 5934.30616932, 5935.74910908, 5937.8659899 , 5938.49520373, 5936.02429269, 5939.34972783, 5942.78197481, 5947.14044081, 5940.71130632, 5942.36841424, 5942.42783061, 5941.83122582, 5936.38001563, 5935.0830335 , 5932.4777508 , 5934.69034936, 5934.71025685, 5933.62243531, 5935.81426307, 5943.22707464, 5943.76127085, 5941.82811078, 5938.82497306]])
- lp(chain, draw)float645.933e+03 5.935e+03 ... 5.936e+03
array([[5933.47644151, 5934.69823243, 5935.94732205, 5932.89772809, 5932.80626013, 5934.64265472, 5935.5800513 , 5933.51372461, 5931.81014701, 5932.4100991 , 5933.52261106, 5933.25775926, 5934.63613421, 5935.17810831, 5936.62541608, 5935.71874402, 5936.66282122, 5933.59160007, 5934.30722919, 5932.89429767, 5932.84475104, 5931.53682797, 5932.74187473, 5934.08121585, 5934.57081141, 5934.40763703, 5933.45791069, 5932.52055502, 5933.09115301, 5935.02880478, 5934.84812566, 5936.35753569, 5934.56744167, 5935.14323969, 5932.87269798, 5937.17020157, 5936.9707177 , 5937.77314472, 5935.18315433, 5933.27366669, 5935.58985673, 5935.25351886, 5935.00730956, 5932.3508756 , 5932.41392576, 5934.1638463 , 5937.24069802, 5936.25240676, 5936.91313174, 5936.15657812, 5933.09793192, 5932.5739402 , 5931.8906846 , 5936.46433877, 5932.24145039, 5934.02984736, 5934.86456005, 5933.76719272, 5934.60779572, 5935.03501151, 5933.28089488, 5932.16798661, 5932.49873642, 5932.01106982, 5932.31475091, 5932.5457963 , 5933.45229073, 5933.31724726, 5936.07917074, 5933.24567277, 5933.02408785, 5936.24312285, 5933.00399834, 5936.84472751, 5935.77469116, 5934.98922784, 5934.74112007, 5935.6458386 , 5934.49420519, 5934.08269592, ... 5932.48673893, 5934.16043163, 5936.20125649, 5937.65147419, 5933.19835379, 5936.50086963, 5934.25070644, 5931.49834238, 5933.20367351, 5932.62882915, 5932.46363622, 5932.64906305, 5933.51656017, 5932.07668893, 5933.11286669, 5933.69569912, 5934.12339545, 5932.91707335, 5934.54698442, 5933.39613695, 5933.34810625, 5933.51572952, 5933.29207543, 5933.95136761, 5931.92712953, 5932.73338331, 5932.34357382, 5933.53186859, 5935.25353684, 5937.44195978, 5939.1855246 , 5936.44077131, 5934.60341619, 5936.07288553, 5932.81669317, 5934.80276094, 5934.98140293, 5933.88365815, 5935.4213794 , 5933.22960819, 5934.64442067, 5933.68540125, 5934.15474543, 5934.23417285, 5937.20246325, 5935.42158422, 5934.11220521, 5934.4979873 , 5934.70914498, 5932.98184975, 5935.03759055, 5937.17400442, 5934.39623088, 5933.51910984, 5933.6524139 , 5931.67582912, 5932.10378169, 5932.80579445, 5933.79032621, 5933.80672947, 5934.10696282, 5934.2232599 , 5936.43317721, 5941.07865794, 5938.13893361, 5938.87769366, 5939.87358396, 5938.94090924, 5932.62894035, 5933.99625582, 5931.7232659 , 5932.08029021, 5932.06501441, 5932.425959 , 5933.10436577, 5932.9721683 , 5940.55199147, 5937.5604282 , 5936.84987818, 5936.36374083]])
- n_steps(chain, draw)int6435 27 63 43 127 ... 63 31 127 127
array([[ 35, 27, 63, 43, 127, 39, 63, 63, 127, 127, 95, 47, 63, 47, 95, 127, 63, 127, 127, 63, 63, 63, 63, 63, 63, 63, 63, 63, 63, 31, 127, 63, 63, 127, 127, 127, 3, 23, 63, 63, 31, 127, 31, 63, 23, 31, 63, 63, 63, 31, 31, 63, 95, 63, 31, 31, 63, 63, 63, 31, 63, 63, 63, 127, 127, 127, 95, 127, 11, 63, 63, 63, 63, 63, 63, 63, 63, 63, 63, 63, 63, 63, 15, 31, 23, 43, 39, 79, 63, 127, 127, 127, 63, 63, 31, 63, 63, 47, 79, 63], [ 63, 127, 127, 95, 7, 255, 191, 127, 63, 63, 7, 63, 87, 127, 63, 127, 127, 95, 63, 127, 63, 79, 127, 63, 63, 31, 31, 127, 127, 127, 79, 63, 95, 31, 31, 63, 127, 63, 127, 127, 127, 127, 31, 79, 23, 7, 31, 63, 127, 79, 31, 31, 63, 31, 111, 127, 63, 127, 63, 63, 63, 127, 63, 47, 63, 63, 63, 63, 127, 63, 63, 127, 63, 95, 103, 127, 47, 127, 127, 31, 79, 95, 63, 31, 63, 31, 63, 31, 127, 63, 83, 63, 127, 127, 31, 119, 63, 31, 127, 127]])
- step_size(chain, draw)float640.04777 0.04777 ... 0.02995 0.02995
array
- tree_depth(chain, draw)int646 5 6 6 7 6 6 6 ... 7 7 5 7 6 5 7 7
array([[6, 5, 6, 6, 7, 6, 6, 6, 7, 7, 7, 6, 6, 6, 7, 7, 6, 7, 7, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 5, 7, 6, 6, 7, 7, 7, 2, 5, 6, 6, 5, 7, 5, 6, 5, 5, 6, 6, 6, 5, 5, 6, 7, 6, 5, 5, 6, 6, 6, 5, 6, 6, 6, 7, 7, 7, 7, 7, 4, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 4, 5, 5, 6, 6, 7, 6, 7, 7, 7, 6, 6, 5, 6, 6, 6, 7, 6], [6, 7, 7, 7, 3, 8, 8, 7, 6, 6, 3, 6, 7, 7, 6, 7, 7, 7, 6, 7, 6, 7, 7, 6, 6, 5, 5, 7, 7, 7, 7, 6, 7, 5, 5, 6, 7, 6, 7, 7, 7, 7, 5, 7, 5, 3, 5, 6, 7, 7, 5, 5, 6, 5, 7, 7, 6, 7, 6, 6, 6, 7, 6, 6, 6, 6, 6, 6, 7, 6, 6, 7, 6, 7, 7, 7, 6, 7, 7, 5, 7, 7, 6, 5, 6, 5, 6, 5, 7, 6, 7, 6, 7, 7, 5, 7, 6, 5, 7, 7]])
- chainPandasIndex
PandasIndex(Index([0, 1], dtype='int64', name='chain'))
- drawPandasIndex
PandasIndex(Index([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99], dtype='int64', name='draw'))
- created_at :
- 2024-12-29T17:36:28.439006+00:00
- arviz_version :
- 0.19.0
- modeling_interface :
- bambi
- modeling_interface_version :
- 0.15.0
<xarray.Dataset> Size: 11kB Dimensions: (chain: 2, draw: 100) Coordinates: * chain (chain) int64 16B 0 1 * draw (draw) int64 800B 0 1 2 3 4 5 6 7 ... 93 94 95 96 97 98 99 Data variables: acceptance_rate (chain, draw) float64 2kB 0.9693 0.9136 ... 0.9664 0.9924 diverging (chain, draw) bool 200B False False False ... False False energy (chain, draw) float64 2kB 5.937e+03 5.94e+03 ... 5.939e+03 lp (chain, draw) float64 2kB 5.933e+03 5.935e+03 ... 5.936e+03 n_steps (chain, draw) int64 2kB 35 27 63 43 127 ... 63 31 127 127 step_size (chain, draw) float64 2kB 0.04777 0.04777 ... 0.02995 tree_depth (chain, draw) int64 2kB 6 5 6 6 7 6 6 6 ... 7 7 5 7 6 5 7 7 Attributes: created_at: 2024-12-29T17:36:28.439006+00:00 arviz_version: 0.19.0 modeling_interface: bambi modeling_interface_version: 0.15.0
xarray.Dataset -
- __obs__: 3988
- rt,response_extra_dim_0: 2
- __obs__(__obs__)int640 1 2 3 4 ... 3984 3985 3986 3987
array([ 0, 1, 2, ..., 3985, 3986, 3987])
- rt,response_extra_dim_0(rt,response_extra_dim_0)int640 1
array([0, 1])
- rt,response(__obs__, rt,response_extra_dim_0)float641.21 1.0 1.63 1.0 ... -1.0 1.25 1.0
array([[ 1.21 , 1. ], [ 1.63 , 1. ], [ 1.03 , 1. ], ..., [ 0.784, 1. ], [ 2.35 , -1. ], [ 1.25 , 1. ]])
- __obs__PandasIndex
PandasIndex(Index([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, ... 3978, 3979, 3980, 3981, 3982, 3983, 3984, 3985, 3986, 3987], dtype='int64', name='__obs__', length=3988))
- rt,response_extra_dim_0PandasIndex
PandasIndex(Index([0, 1], dtype='int64', name='rt,response_extra_dim_0'))
- created_at :
- 2024-12-29T17:36:28.439825+00:00
- arviz_version :
- 0.19.0
- inference_library :
- numpyro
- inference_library_version :
- 0.16.1
- sampling_time :
- 74.365569
- tuning_steps :
- 100
- modeling_interface :
- bambi
- modeling_interface_version :
- 0.15.0
<xarray.Dataset> Size: 96kB Dimensions: (__obs__: 3988, rt,response_extra_dim_0: 2) Coordinates: * __obs__ (__obs__) int64 32kB 0 1 2 3 ... 3985 3986 3987 * rt,response_extra_dim_0 (rt,response_extra_dim_0) int64 16B 0 1 Data variables: rt,response (__obs__, rt,response_extra_dim_0) float64 64kB ... Attributes: created_at: 2024-12-29T17:36:28.439825+00:00 arviz_version: 0.19.0 inference_library: numpyro inference_library_version: 0.16.1 sampling_time: 74.365569 tuning_steps: 100 modeling_interface: bambi modeling_interface_version: 0.15.0
xarray.Dataset
VI¶
In [3]:
Copied!
basic_hssm_model.vi(method="advi", niter=5000)
basic_hssm_model.vi(method="advi", niter=5000)
Using MCMC starting point defaults.
Output()
Finished [100%]: Average Loss = 7,602.3
Out[3]:
arviz.InferenceData
-
- chain: 1
- draw: 1000
- v_C(stim)_dim: 2
- chain(chain)int640
array([0])
- draw(draw)int640 1 2 3 4 5 ... 995 996 997 998 999
array([ 0, 1, 2, ..., 997, 998, 999])
- v_C(stim)_dim(v_C(stim)_dim)<U2'WL' 'WW'
array(['WL', 'WW'], dtype='<U2')
- t(chain, draw)float640.06786 0.03004 ... 0.06839 0.1304
array([[0.06786296, 0.03003982, 0.05920732, 0.05832841, 0.03203248, 0.0174499 , 0.03395316, 0.06548453, 0.10061986, 0.14321018, 0.01777372, 0.08423009, 0.11860123, 0.04029211, 0.03215075, 0.02206151, 0.04723042, 0.04736733, 0.02481742, 0.01600813, 0.09920822, 0.06229017, 0.01652173, 0.02386984, 0.01273488, 0.07825072, 0.04873854, 0.05915113, 0.04949851, 0.00930142, 0.01694682, 0.09019389, 0.04103734, 0.02482633, 0.0783034 , 0.01468513, 0.09740179, 0.0169084 , 0.02399878, 0.06051549, 0.03942673, 0.02998372, 0.04955722, 0.03491691, 0.06290803, 0.07481355, 0.06747422, 0.11416454, 0.03531236, 0.07056434, 0.09476232, 0.01458073, 0.02245394, 0.08916791, 0.01936262, 0.07298664, 0.04489865, 0.05490875, 0.02987097, 0.02152044, 0.04678878, 0.02709434, 0.09688435, 0.0326495 , 0.22052688, 0.06479499, 0.01555489, 0.04743043, 0.07119477, 0.01189206, 0.04926375, 0.00743564, 0.04389368, 0.07910106, 0.21350145, 0.07722829, 0.01876588, 0.10016276, 0.07242177, 0.04726068, 0.00850815, 0.02680858, 0.03815189, 0.0568656 , 0.02497921, 0.03367738, 0.07312476, 0.07417083, 0.03821289, 0.00847856, 0.02855786, 0.02710555, 0.03776768, 0.07654604, 0.0078025 , 0.10528856, 0.02541813, 0.02350995, 0.02846694, 0.04759372, ... 0.09771637, 0.02584973, 0.02717141, 0.01060055, 0.04041101, 0.03549 , 0.02664468, 0.03690573, 0.06157881, 0.04475593, 0.17926468, 0.09782107, 0.11305665, 0.01011689, 0.06225206, 0.08169478, 0.03266727, 0.03146117, 0.01850526, 0.06956821, 0.29178297, 0.0526446 , 0.03186715, 0.02611162, 0.0332495 , 0.06975714, 0.01121412, 0.09118948, 0.0407829 , 0.06873959, 0.0284108 , 0.02124709, 0.01509531, 0.06663696, 0.09786318, 0.05862414, 0.01452625, 0.03721975, 0.01505186, 0.05239516, 0.06136795, 0.06845386, 0.04553735, 0.08219952, 0.19432821, 0.04417989, 0.02610302, 0.0239279 , 0.11954523, 0.01270713, 0.02838471, 0.02158185, 0.0270457 , 0.01344702, 0.23577728, 0.03869071, 0.03010569, 0.14535131, 0.07423796, 0.08407014, 0.04837644, 0.09865208, 0.09775826, 0.21336083, 0.11583727, 0.02776113, 0.07903354, 0.22879948, 0.18817705, 0.06863023, 0.05462118, 0.02559531, 0.17984798, 0.03564975, 0.05500262, 0.05309128, 0.03101089, 0.02262196, 0.0789675 , 0.0546434 , 0.05042091, 0.10404316, 0.02031674, 0.10158777, 0.08130486, 0.08657429, 0.04409418, 0.05228283, 0.1300084 , 0.03725881, 0.02382618, 0.07857706, 0.04851816, 0.07625496, 0.15447037, 0.08807763, 0.09453274, 0.03655109, 0.06839257, 0.13044387]])
- theta(chain, draw)float640.2548 0.3476 ... 0.4254 0.3646
array([[0.25481762, 0.34755504, 0.37395145, 0.39230801, 0.32951786, 0.33565 , 0.20094239, 0.38639729, 0.36037105, 0.57878535, 0.55296212, 0.5538138 , 0.48359075, 0.79458259, 0.37656958, 0.43815685, 0.41899134, 0.39682418, 0.35777003, 0.7650756 , 0.35092918, 0.48961658, 0.51028323, 0.41903435, 0.4181456 , 0.72376237, 0.43850371, 0.45929083, 0.5103042 , 0.23269616, 0.50928611, 0.53544299, 0.51602939, 0.24541229, 0.53363032, 0.3954037 , 0.30802267, 0.37513517, 0.45639053, 0.44241137, 0.48795923, 0.29747088, 0.63317442, 0.44635309, 0.44924638, 0.4295682 , 0.55595022, 0.31321327, 0.33466672, 0.26227341, 0.27513743, 0.42170278, 0.46013415, 0.40907365, 0.37831388, 0.4068315 , 0.50574347, 0.38555311, 0.44503661, 0.37889257, 0.44604734, 0.27871716, 0.50251293, 0.46936647, 0.23099504, 0.40601629, 0.33460521, 0.49727703, 0.57037813, 0.27808461, 0.54707617, 0.34377508, 0.26490871, 0.37988135, 0.39777015, 0.26928995, 0.25324641, 0.43014873, 0.32945961, 0.40747231, 0.28066334, 0.47488193, 0.26490864, 0.44323174, 0.29274569, 0.48823756, 0.42258631, 0.47756246, 0.3883756 , 0.40895099, 0.58188324, 0.30101091, 0.77304602, 0.41350122, 0.57489075, 0.41234697, 0.45074051, 0.48317563, 0.38287696, 0.33557614, ... 0.37818959, 0.44648464, 0.30721562, 0.45138747, 0.45820526, 0.25402476, 0.52773799, 0.33728775, 0.49861188, 0.43019816, 0.431302 , 0.35562949, 0.4131056 , 0.50472907, 0.20003075, 0.21486421, 0.60471093, 0.1761736 , 0.5994095 , 0.37942009, 0.49745256, 0.32888222, 0.41494156, 0.61102865, 0.42337536, 0.36767338, 0.46370621, 0.52122072, 0.42835422, 0.42234053, 0.45843204, 0.23570218, 0.48958609, 0.65209432, 0.36688753, 0.45754938, 0.36074481, 0.42135833, 0.34846472, 0.69798224, 0.37560315, 0.57284074, 0.43642847, 0.37826888, 0.64345387, 0.40791039, 0.53043519, 0.25224294, 0.51755446, 0.39066041, 0.26521131, 0.45013882, 0.44919639, 0.40911629, 0.31546985, 0.38967853, 0.33933607, 0.43103783, 0.23255109, 0.47847102, 0.51779706, 0.58131675, 0.31346639, 0.33433577, 0.34483749, 0.42470608, 0.17006239, 0.34000653, 0.21560483, 0.58623625, 0.4693506 , 0.27109853, 0.23932953, 0.36186722, 0.35262345, 0.40170762, 0.3947975 , 0.31173672, 0.55321391, 0.34353953, 0.36731086, 0.28428798, 0.30154697, 0.71198288, 0.43367717, 0.43954844, 0.36687781, 0.68536007, 0.53959921, 0.38260134, 0.30240505, 0.56501723, 0.3544573 , 0.27328239, 0.23006414, 0.54565363, 0.51961063, 0.30962065, 0.42535008, 0.3646272 ]])
- v_Intercept(chain, draw)float64-0.1862 0.7777 ... -0.1479 0.1578
array([[-1.86207477e-01, 7.77697030e-01, 1.39117687e-01, -2.57160493e-02, 2.48946735e-02, 1.54467551e-01, 1.79677613e-01, 6.76247511e-01, 2.89336152e-01, 5.25925363e-01, 7.39250713e-01, 9.82825903e-01, 1.93507594e-01, 1.88584502e-01, 4.88436248e-01, 1.27945198e-01, -2.17095963e-01, -1.24266654e-02, 3.30939519e-01, 2.82442511e-01, 2.35341676e-01, 1.46581106e-01, 4.69787900e-01, 1.24375655e-01, 3.52691742e-01, 1.65448253e-01, 3.03720062e-01, 9.83999977e-02, 4.85992718e-01, -3.10965103e-01, 4.55007680e-01, 2.99176400e-01, -1.81214422e-01, 3.42920360e-01, 4.17175876e-01, 3.06379608e-01, 3.57712834e-01, -8.24953372e-02, 6.42763731e-01, 8.85099527e-02, 3.45051894e-01, 6.28441945e-01, 3.21470941e-01, 1.53544344e-01, 5.24989844e-01, -2.95341073e-02, 7.35273769e-01, -1.44832813e-01, 1.55336388e-01, 6.01660048e-01, -2.03356088e-02, 3.47988284e-02, -7.31294938e-02, 3.17918611e-01, -4.67632159e-02, 3.70223766e-01, 5.16337289e-01, 9.04301096e-01, 5.01478504e-01, 3.47159636e-01, ... 2.53162223e-01, 3.27500601e-01, 3.09475636e-01, 6.15623966e-01, 4.67669009e-01, -1.08994058e-01, 5.30138343e-01, -1.02067194e-01, 1.47059871e-01, 4.65437613e-01, -3.83708228e-02, 2.81353378e-01, 1.38391503e-01, 2.27489287e-01, 8.57966125e-01, 5.76420886e-01, 2.42595836e-01, -1.39481092e-01, 4.07041244e-01, 7.83902313e-02, 7.67236629e-01, 2.73586356e-01, 4.41052749e-01, 2.92753377e-01, 7.92304083e-02, 1.00530196e-01, 2.66105361e-01, 2.20835165e-01, 7.56202815e-02, 2.12670688e-01, 9.76010495e-01, 2.05503618e-01, 1.81738765e-01, 9.76515511e-01, 4.61383480e-01, -1.00289292e-01, 2.29643937e-01, 2.77502278e-01, -9.95421864e-03, 6.38624001e-01, -5.84541807e-02, 3.50067897e-02, 4.34437120e-01, 1.60989407e-01, 3.29332711e-02, 3.73847154e-01, -2.16801340e-01, 5.12532013e-01, 1.89792478e-01, 1.75193528e-01, 3.83935751e-01, 1.34423596e-01, 1.64014231e-01, 1.31835361e+00, 8.03332261e-02, -1.22729467e-01, -1.47918936e-01, 1.57822484e-01]])
- a(chain, draw)float641.663 1.959 2.22 ... 2.091 1.56
array([[1.66265614, 1.95921624, 2.21996068, 1.52475466, 1.64650825, 1.83886833, 2.18678257, 1.66398047, 1.69865465, 1.49699794, 2.02662379, 1.8591126 , 1.88813757, 1.78134977, 2.12201893, 1.96677003, 2.41798913, 1.26599025, 1.92750775, 1.37281808, 2.20203029, 1.82484369, 2.00549285, 1.82656007, 1.98471901, 2.05154475, 1.93813144, 1.57698845, 1.17040667, 1.76144675, 2.29355806, 1.74987073, 1.88023524, 2.08621509, 1.59742165, 1.98093589, 1.8212716 , 1.72668543, 1.22447697, 1.7467697 , 1.70022516, 2.15742447, 2.07766957, 1.53454711, 1.90572126, 1.73598448, 2.03163985, 1.53895423, 1.70934435, 1.84137098, 1.33724162, 1.9237027 , 1.99753856, 1.64427262, 2.13183537, 1.8625727 , 1.52079077, 1.77663525, 2.25440358, 2.11255232, 1.88441566, 1.73834879, 1.7484802 , 1.75225968, 1.93747941, 1.40036591, 1.42830266, 2.19094434, 1.61055281, 1.43216193, 2.00118296, 1.790816 , 1.71353684, 1.52623882, 1.93217024, 1.80898576, 2.08958847, 1.83657169, 2.37814386, 1.95156227, 1.39290215, 1.77937626, 1.93294168, 1.57226024, 2.12736992, 2.29783407, 1.91194618, 1.79510409, 1.83343849, 1.78378321, 2.10087047, 1.59692834, 1.79522224, 2.18498522, 2.05790165, 1.8411629 , 1.54563382, 1.95407898, 1.90381638, 2.04199813, ... 1.64064186, 1.87169923, 2.16051481, 1.84329256, 2.17600857, 1.89467815, 1.77018954, 1.9971933 , 1.83847362, 1.61290915, 1.83858786, 1.51775521, 1.59831461, 1.45312114, 1.91492521, 1.89477261, 1.69290465, 2.07795034, 2.30439347, 1.97335368, 1.50272465, 1.93501534, 1.89514865, 1.9277341 , 1.7453088 , 1.93864371, 1.5480007 , 2.52627596, 2.13730545, 2.24224873, 1.67602614, 1.92502974, 1.52409268, 2.24112044, 1.96230819, 1.70127168, 2.08233918, 1.35890547, 2.06525819, 2.14351858, 1.82729532, 1.17222679, 1.88823781, 1.68975661, 1.83870481, 1.75860311, 1.4681761 , 1.79410332, 1.98852915, 1.44042567, 1.7894042 , 1.88381079, 1.45004584, 1.99337134, 1.53012408, 1.54649232, 1.70895348, 2.10070314, 1.72915144, 1.62426078, 1.45729029, 1.3298778 , 1.7663289 , 2.01535001, 1.79626113, 1.99204511, 1.87629868, 1.95466264, 1.92986305, 1.7535537 , 2.12455606, 1.90457378, 1.58297318, 1.58699811, 1.55806666, 1.93546917, 1.43731377, 1.50399075, 1.97779386, 1.61824815, 1.49296286, 1.81436931, 1.98239149, 1.58716118, 2.37246133, 1.80416043, 2.01641927, 1.69676209, 1.82083995, 1.90894098, 1.50585149, 1.53945647, 1.76939482, 1.64443953, 1.72019508, 2.06690805, 1.33630472, 1.79462882, 2.09085805, 1.55995718]])
- v_C(stim)(chain, draw, v_C(stim)_dim)float640.05082 -0.07389 ... 0.2854 0.7853
array([[[ 0.05082204, -0.07388764], [ 0.51277797, -0.48555457], [ 0.71595984, -0.01111757], ..., [ 0.1836927 , 0.01096876], [-0.02144832, -0.47956816], [ 0.28543295, 0.78533541]]])
- z(chain, draw)float640.4489 0.5882 ... 0.6573 0.3789
array([[0.44892806, 0.58818192, 0.59008135, 0.45599864, 0.47138689, 0.47234354, 0.58342246, 0.52870285, 0.52022644, 0.52730869, 0.50923877, 0.42035067, 0.52300446, 0.47853038, 0.60762544, 0.38662899, 0.52199387, 0.5904088 , 0.36505672, 0.47745878, 0.55670138, 0.48915656, 0.49983174, 0.47534199, 0.56982805, 0.60103303, 0.4551434 , 0.38659584, 0.41597691, 0.47844226, 0.4745128 , 0.49340679, 0.39412338, 0.37827778, 0.47116129, 0.36532827, 0.59465955, 0.57543242, 0.46693809, 0.3632573 , 0.47940699, 0.53783388, 0.48640967, 0.4389055 , 0.46422489, 0.45870444, 0.59567102, 0.52837785, 0.46419536, 0.56582463, 0.48619368, 0.51098656, 0.48664556, 0.51113915, 0.31671645, 0.48030674, 0.40900138, 0.53114091, 0.50828429, 0.3476918 , 0.57763866, 0.43149492, 0.62184898, 0.55474727, 0.47096674, 0.3217504 , 0.60950403, 0.38463868, 0.54735493, 0.41472756, 0.64189865, 0.48424362, 0.60444519, 0.5147773 , 0.49524539, 0.55981379, 0.47221623, 0.4273899 , 0.47898447, 0.50573375, 0.46803161, 0.49442597, 0.54488159, 0.5379821 , 0.64459265, 0.50318676, 0.5086945 , 0.54772269, 0.45946403, 0.45718846, 0.48360373, 0.55456572, 0.46442188, 0.52326715, 0.54799288, 0.43347348, 0.45384685, 0.57671663, 0.43554551, 0.47785536, ... 0.51156666, 0.50079111, 0.50104696, 0.54383793, 0.57949521, 0.47776019, 0.49947177, 0.52191264, 0.57407637, 0.47599955, 0.52827248, 0.64695134, 0.55409177, 0.50903357, 0.46332864, 0.45233337, 0.53473939, 0.48347194, 0.47785752, 0.5186818 , 0.51111329, 0.40310675, 0.47354308, 0.45468782, 0.49965114, 0.51171534, 0.49404845, 0.58714547, 0.68673874, 0.45611328, 0.55082832, 0.44311178, 0.43756553, 0.60495923, 0.4742834 , 0.50124579, 0.68138024, 0.45452623, 0.50297339, 0.48434335, 0.43567904, 0.58089911, 0.44105307, 0.52739559, 0.43031164, 0.52764968, 0.47850476, 0.44787365, 0.39468781, 0.6283012 , 0.49228565, 0.50302675, 0.56638496, 0.4894699 , 0.45805 , 0.62754712, 0.59338428, 0.44750413, 0.55374695, 0.49756929, 0.48869547, 0.58622679, 0.50827025, 0.48731583, 0.57494044, 0.53381507, 0.43756947, 0.62493895, 0.60934685, 0.3523004 , 0.55900543, 0.49631201, 0.48437617, 0.60075839, 0.51452958, 0.44523538, 0.4571805 , 0.62924701, 0.47955239, 0.43969815, 0.56937654, 0.43512303, 0.48958128, 0.55577256, 0.48437443, 0.43002647, 0.47097632, 0.58724227, 0.50858866, 0.36141545, 0.55954352, 0.51639419, 0.49898185, 0.33047523, 0.45059226, 0.52234387, 0.37313609, 0.48365044, 0.6573426 , 0.37887247]])
- chainPandasIndex
PandasIndex(Index([0], dtype='int64', name='chain'))
- drawPandasIndex
PandasIndex(Index([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, ... 990, 991, 992, 993, 994, 995, 996, 997, 998, 999], dtype='int64', name='draw', length=1000))
- v_C(stim)_dimPandasIndex
PandasIndex(Index(['WL', 'WW'], dtype='object', name='v_C(stim)_dim'))
- created_at :
- 2024-12-29T17:37:29.511390+00:00
- arviz_version :
- 0.19.0
- inference_library :
- pymc
- inference_library_version :
- 5.19.1
<xarray.Dataset> Size: 64kB Dimensions: (chain: 1, draw: 1000, v_C(stim)_dim: 2) Coordinates: * chain (chain) int64 8B 0 * draw (draw) int64 8kB 0 1 2 3 4 5 6 ... 994 995 996 997 998 999 * v_C(stim)_dim (v_C(stim)_dim) <U2 16B 'WL' 'WW' Data variables: t (chain, draw) float64 8kB 0.06786 0.03004 ... 0.06839 0.1304 theta (chain, draw) float64 8kB 0.2548 0.3476 ... 0.4254 0.3646 v_Intercept (chain, draw) float64 8kB -0.1862 0.7777 ... -0.1479 0.1578 a (chain, draw) float64 8kB 1.663 1.959 2.22 ... 2.091 1.56 v_C(stim) (chain, draw, v_C(stim)_dim) float64 16kB 0.05082 ... 0.7853 z (chain, draw) float64 8kB 0.4489 0.5882 ... 0.6573 0.3789 Attributes: created_at: 2024-12-29T17:37:29.511390+00:00 arviz_version: 0.19.0 inference_library: pymc inference_library_version: 5.19.1
xarray.Dataset -
- __obs__: 3988
- rt,response_extra_dim_0: 2
- __obs__(__obs__)int640 1 2 3 4 ... 3984 3985 3986 3987
array([ 0, 1, 2, ..., 3985, 3986, 3987])
- rt,response_extra_dim_0(rt,response_extra_dim_0)int640 1
array([0, 1])
- rt,response(__obs__, rt,response_extra_dim_0)float641.21 1.0 1.63 1.0 ... -1.0 1.25 1.0
array([[ 1.21 , 1. ], [ 1.63 , 1. ], [ 1.03 , 1. ], ..., [ 0.784, 1. ], [ 2.35 , -1. ], [ 1.25 , 1. ]])
- __obs__PandasIndex
PandasIndex(Index([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, ... 3978, 3979, 3980, 3981, 3982, 3983, 3984, 3985, 3986, 3987], dtype='int64', name='__obs__', length=3988))
- rt,response_extra_dim_0PandasIndex
PandasIndex(Index([0, 1], dtype='int64', name='rt,response_extra_dim_0'))
- created_at :
- 2024-12-29T17:37:29.516235+00:00
- arviz_version :
- 0.19.0
- inference_library :
- pymc
- inference_library_version :
- 5.19.1
<xarray.Dataset> Size: 96kB Dimensions: (__obs__: 3988, rt,response_extra_dim_0: 2) Coordinates: * __obs__ (__obs__) int64 32kB 0 1 2 3 ... 3985 3986 3987 * rt,response_extra_dim_0 (rt,response_extra_dim_0) int64 16B 0 1 Data variables: rt,response (__obs__, rt,response_extra_dim_0) float64 64kB ... Attributes: created_at: 2024-12-29T17:37:29.516235+00:00 arviz_version: 0.19.0 inference_library: pymc inference_library_version: 5.19.1
xarray.Dataset
Saving and Loading the model¶
In [4]:
Copied!
basic_hssm_model.save_model(model_name="test_model")
basic_hssm_model.save_model(model_name="test_model")
We are using the defaults here, which save the model and its inference results to the hssm_models/test_model/
directory inside your curerent working directory.
Up to three files are saved in the model directory:
model.pkl
: The model instance.traces.nc
: The MCMC traces.vi_traces.nc
: The VI traces.
We can now load the model from the directory we just created, using the HSSM
classmethod load_model
.
In [5]:
Copied!
loaded_model = hssm.HSSM.load_model(path="hssm_models/test_model")
loaded_model = hssm.HSSM.load_model(path="hssm_models/test_model")
Model initialized successfully.
With this simple workflow your models are portable across sessions and machines.