import matplotlib.pyplot as plt
import numpy as np
import hssm
hssm.set_floatX("float32")
import arviz as az
import pymc as pm
Setting PyTensor floatX type to float32. Setting "jax_enable_x64" to False. If this is not intended, please set `jax` to False.
Load Data and Specify model¶
cav_data = hssm.load_data("cavanagh_theta")
cav_model = hssm.HSSM(data=cav_data, model="angle")
Model initialized successfully.
Inference¶
We will run MCMC and VI here to contrast results.
Run VI¶
vi_idata = cav_model.vi(niter=100000, method="fullrank_advi")
Using MCMC starting point defaults.
Output()
Finished [100%]: Average Loss = 6,035.9
Run MCMC¶
mcmc_idata = cav_model.sample()
Using default initvals.
/Users/afengler/miniconda3/envs/hssm516/lib/python3.11/site-packages/pymc/sampling/jax.py:451: UserWarning: There are not enough devices to run parallel chains: expected 4 but got 1. Chains will be drawn sequentially. If you are running MCMC in CPU, consider using `numpyro.set_host_device_count(4)` at the beginning of your program. You can double-check how many devices are available in your system using `jax.local_device_count()`. pmap_numpyro = MCMC( /Users/afengler/miniconda3/envs/hssm516/lib/python3.11/site-packages/jax/_src/numpy/array_methods.py:118: UserWarning: Explicitly requested dtype float64 requested in astype is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/jax-ml/jax#current-gotchas for more. return lax_numpy.astype(self, dtype, copy=copy, device=device) /Users/afengler/miniconda3/envs/hssm516/lib/python3.11/site-packages/jax/_src/numpy/array_methods.py:118: UserWarning: Explicitly requested dtype float64 requested in astype is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/jax-ml/jax#current-gotchas for more. return lax_numpy.astype(self, dtype, copy=copy, device=device) /Users/afengler/miniconda3/envs/hssm516/lib/python3.11/site-packages/jax/_src/numpy/array_methods.py:118: UserWarning: Explicitly requested dtype float64 requested in astype is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/jax-ml/jax#current-gotchas for more. return lax_numpy.astype(self, dtype, copy=copy, device=device) /Users/afengler/miniconda3/envs/hssm516/lib/python3.11/site-packages/jax/_src/numpy/array_methods.py:118: UserWarning: Explicitly requested dtype float64 requested in astype is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/jax-ml/jax#current-gotchas for more. return lax_numpy.astype(self, dtype, copy=copy, device=device) /Users/afengler/miniconda3/envs/hssm516/lib/python3.11/site-packages/jax/_src/numpy/array_methods.py:118: UserWarning: Explicitly requested dtype float64 requested in astype is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/jax-ml/jax#current-gotchas for more. return lax_numpy.astype(self, dtype, copy=copy, device=device) /Users/afengler/miniconda3/envs/hssm516/lib/python3.11/site-packages/jax/_src/numpy/array_methods.py:118: UserWarning: Explicitly requested dtype float64 requested in astype is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/jax-ml/jax#current-gotchas for more. return lax_numpy.astype(self, dtype, copy=copy, device=device) 0%| | 0/2000 [00:00<?, ?it/s]/Users/afengler/miniconda3/envs/hssm516/lib/python3.11/site-packages/jax/_src/numpy/array_methods.py:118: UserWarning: Explicitly requested dtype float64 requested in astype is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/jax-ml/jax#current-gotchas for more. return lax_numpy.astype(self, dtype, copy=copy, device=device) /Users/afengler/miniconda3/envs/hssm516/lib/python3.11/site-packages/jax/_src/numpy/array_methods.py:118: UserWarning: Explicitly requested dtype float64 requested in astype is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/jax-ml/jax#current-gotchas for more. return lax_numpy.astype(self, dtype, copy=copy, device=device) /Users/afengler/miniconda3/envs/hssm516/lib/python3.11/site-packages/jax/_src/numpy/array_methods.py:118: UserWarning: Explicitly requested dtype float64 requested in astype is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/jax-ml/jax#current-gotchas for more. return lax_numpy.astype(self, dtype, copy=copy, device=device) /Users/afengler/miniconda3/envs/hssm516/lib/python3.11/site-packages/jax/_src/numpy/array_methods.py:118: UserWarning: Explicitly requested dtype float64 requested in astype is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/jax-ml/jax#current-gotchas for more. return lax_numpy.astype(self, dtype, copy=copy, device=device) /Users/afengler/miniconda3/envs/hssm516/lib/python3.11/site-packages/jax/_src/numpy/array_methods.py:118: UserWarning: Explicitly requested dtype float64 requested in astype is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/jax-ml/jax#current-gotchas for more. return lax_numpy.astype(self, dtype, copy=copy, device=device) /Users/afengler/miniconda3/envs/hssm516/lib/python3.11/site-packages/jax/_src/numpy/array_methods.py:118: UserWarning: Explicitly requested dtype float64 requested in astype is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/jax-ml/jax#current-gotchas for more. return lax_numpy.astype(self, dtype, copy=copy, device=device) sample: 100%|██████████| 2000/2000 [00:44<00:00, 44.87it/s, 15 steps of size 2.60e-01. acc. prob=0.93] /Users/afengler/miniconda3/envs/hssm516/lib/python3.11/site-packages/jax/_src/numpy/array_methods.py:118: UserWarning: Explicitly requested dtype float64 requested in astype is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/jax-ml/jax#current-gotchas for more. return lax_numpy.astype(self, dtype, copy=copy, device=device) /Users/afengler/miniconda3/envs/hssm516/lib/python3.11/site-packages/jax/_src/numpy/array_methods.py:118: UserWarning: Explicitly requested dtype float64 requested in astype is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/jax-ml/jax#current-gotchas for more. return lax_numpy.astype(self, dtype, copy=copy, device=device) /Users/afengler/miniconda3/envs/hssm516/lib/python3.11/site-packages/jax/_src/numpy/array_methods.py:118: UserWarning: Explicitly requested dtype float64 requested in astype is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/jax-ml/jax#current-gotchas for more. return lax_numpy.astype(self, dtype, copy=copy, device=device) /Users/afengler/miniconda3/envs/hssm516/lib/python3.11/site-packages/jax/_src/numpy/array_methods.py:118: UserWarning: Explicitly requested dtype float64 requested in astype is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/jax-ml/jax#current-gotchas for more. return lax_numpy.astype(self, dtype, copy=copy, device=device) /Users/afengler/miniconda3/envs/hssm516/lib/python3.11/site-packages/jax/_src/numpy/array_methods.py:118: UserWarning: Explicitly requested dtype float64 requested in astype is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/jax-ml/jax#current-gotchas for more. return lax_numpy.astype(self, dtype, copy=copy, device=device) /Users/afengler/miniconda3/envs/hssm516/lib/python3.11/site-packages/jax/_src/numpy/array_methods.py:118: UserWarning: Explicitly requested dtype float64 requested in astype is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/jax-ml/jax#current-gotchas for more. return lax_numpy.astype(self, dtype, copy=copy, device=device) sample: 100%|██████████| 2000/2000 [00:41<00:00, 48.28it/s, 15 steps of size 2.69e-01. acc. prob=0.93] /Users/afengler/miniconda3/envs/hssm516/lib/python3.11/site-packages/jax/_src/numpy/array_methods.py:118: UserWarning: Explicitly requested dtype float64 requested in astype is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/jax-ml/jax#current-gotchas for more. return lax_numpy.astype(self, dtype, copy=copy, device=device) /Users/afengler/miniconda3/envs/hssm516/lib/python3.11/site-packages/jax/_src/numpy/array_methods.py:118: UserWarning: Explicitly requested dtype float64 requested in astype is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/jax-ml/jax#current-gotchas for more. return lax_numpy.astype(self, dtype, copy=copy, device=device) /Users/afengler/miniconda3/envs/hssm516/lib/python3.11/site-packages/jax/_src/numpy/array_methods.py:118: UserWarning: Explicitly requested dtype float64 requested in astype is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/jax-ml/jax#current-gotchas for more. return lax_numpy.astype(self, dtype, copy=copy, device=device) /Users/afengler/miniconda3/envs/hssm516/lib/python3.11/site-packages/jax/_src/numpy/array_methods.py:118: UserWarning: Explicitly requested dtype float64 requested in astype is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/jax-ml/jax#current-gotchas for more. return lax_numpy.astype(self, dtype, copy=copy, device=device) /Users/afengler/miniconda3/envs/hssm516/lib/python3.11/site-packages/jax/_src/numpy/array_methods.py:118: UserWarning: Explicitly requested dtype float64 requested in astype is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/jax-ml/jax#current-gotchas for more. return lax_numpy.astype(self, dtype, copy=copy, device=device) /Users/afengler/miniconda3/envs/hssm516/lib/python3.11/site-packages/jax/_src/numpy/array_methods.py:118: UserWarning: Explicitly requested dtype float64 requested in astype is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/jax-ml/jax#current-gotchas for more. return lax_numpy.astype(self, dtype, copy=copy, device=device) sample: 100%|██████████| 2000/2000 [00:40<00:00, 49.64it/s, 15 steps of size 2.68e-01. acc. prob=0.93] /Users/afengler/miniconda3/envs/hssm516/lib/python3.11/site-packages/jax/_src/numpy/array_methods.py:118: UserWarning: Explicitly requested dtype float64 requested in astype is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/jax-ml/jax#current-gotchas for more. return lax_numpy.astype(self, dtype, copy=copy, device=device) /Users/afengler/miniconda3/envs/hssm516/lib/python3.11/site-packages/jax/_src/numpy/array_methods.py:118: UserWarning: Explicitly requested dtype float64 requested in astype is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/jax-ml/jax#current-gotchas for more. return lax_numpy.astype(self, dtype, copy=copy, device=device) /Users/afengler/miniconda3/envs/hssm516/lib/python3.11/site-packages/jax/_src/numpy/array_methods.py:118: UserWarning: Explicitly requested dtype float64 requested in astype is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/jax-ml/jax#current-gotchas for more. return lax_numpy.astype(self, dtype, copy=copy, device=device) /Users/afengler/miniconda3/envs/hssm516/lib/python3.11/site-packages/jax/_src/numpy/array_methods.py:118: UserWarning: Explicitly requested dtype float64 requested in astype is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/jax-ml/jax#current-gotchas for more. return lax_numpy.astype(self, dtype, copy=copy, device=device) /Users/afengler/miniconda3/envs/hssm516/lib/python3.11/site-packages/jax/_src/numpy/array_methods.py:118: UserWarning: Explicitly requested dtype float64 requested in astype is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/jax-ml/jax#current-gotchas for more. return lax_numpy.astype(self, dtype, copy=copy, device=device) /Users/afengler/miniconda3/envs/hssm516/lib/python3.11/site-packages/jax/_src/numpy/array_methods.py:118: UserWarning: Explicitly requested dtype float64 requested in astype is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/jax-ml/jax#current-gotchas for more. return lax_numpy.astype(self, dtype, copy=copy, device=device) sample: 100%|██████████| 2000/2000 [00:35<00:00, 55.65it/s, 31 steps of size 3.14e-01. acc. prob=0.89] /Users/afengler/miniconda3/envs/hssm516/lib/python3.11/site-packages/jax/_src/numpy/array_methods.py:118: UserWarning: Explicitly requested dtype float64 requested in astype is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/jax-ml/jax#current-gotchas for more. return lax_numpy.astype(self, dtype, copy=copy, device=device) 100%|██████████| 4000/4000 [00:05<00:00, 753.24it/s]
Inspect Outputs¶
From our variational inference runs, we extract two objects.
- An
az.InferenceData
object stored undercav_model.vi_idata
. This stores a slightly cleaned up posterior sample, constructed by sampling from the variational posterior. - An
pm.Approximator
object stored undercav_model.vi_approx
that holds the variational posterior object itself. This is a rich structure and it is beyond the purpose of this tutorial to illustrate all it's details. Amongst other things you will be able to inspect the loss history and take samples such as those stored undercam_model.vi_data
.
.vi_idata
¶
The approximate variational posterior InferenceData
.
cav_model.vi_idata
-
<xarray.Dataset> Size: 28kB Dimensions: (chain: 1, draw: 1000) Coordinates: * chain (chain) int64 8B 0 * draw (draw) int64 8kB 0 1 2 3 4 5 6 7 ... 993 994 995 996 997 998 999 Data variables: t (chain, draw) float32 4kB 0.2749 0.2908 0.2857 ... 0.3 0.2892 0.28 v (chain, draw) float32 4kB 0.3608 0.3657 0.3635 ... 0.3571 0.3893 theta (chain, draw) float32 4kB 0.2255 0.2328 0.2202 ... 0.2273 0.2293 a (chain, draw) float32 4kB 1.308 1.313 1.305 ... 1.265 1.316 1.319 z (chain, draw) float32 4kB 0.5062 0.5074 0.5069 ... 0.5152 0.4947 Attributes: created_at: 2024-12-26T00:13:46.195079+00:00 arviz_version: 0.18.0 inference_library: pymc inference_library_version: 5.16.2
-
<xarray.Dataset> Size: 64kB Dimensions: (__obs__: 3988, rt,response_extra_dim_0: 2) Coordinates: * __obs__ (__obs__) int64 32kB 0 1 2 3 ... 3985 3986 3987 * rt,response_extra_dim_0 (rt,response_extra_dim_0) int64 16B 0 1 Data variables: rt,response (__obs__, rt,response_extra_dim_0) float32 32kB ... Attributes: created_at: 2024-12-26T00:13:46.205859+00:00 arviz_version: 0.18.0 inference_library: pymc inference_library_version: 5.16.2
.vi_approx
¶
The approximate variational posterior pm.Approximator
object.
We can take draws from the posterior with the .sample()
method.
cav_model.vi_approx.sample(draws=1000)
-
<xarray.Dataset> Size: 32MB Dimensions: (chain: 1, draw: 1000, __obs__: 3988) Coordinates: * chain (chain) int64 8B 0 * draw (draw) int64 8kB 0 1 2 3 4 5 6 7 ... 993 994 995 996 997 998 999 * __obs__ (__obs__) int64 32kB 0 1 2 3 4 5 ... 3982 3983 3984 3985 3986 3987 Data variables: a (chain, draw) float32 4kB 1.293 1.354 1.321 ... 1.287 1.32 1.386 t (chain, draw) float32 4kB 0.2799 0.2745 0.274 ... 0.2851 0.2552 theta (chain, draw) float32 4kB 0.2204 0.2521 0.2395 ... 0.2332 0.2821 v (chain, draw) float32 4kB 0.3256 0.3781 0.3329 ... 0.3904 0.3625 v_mean (chain, draw, __obs__) float64 32MB 0.3256 0.3256 ... 0.3625 0.3625 z (chain, draw) float32 4kB 0.5139 0.5094 0.5045 ... 0.5084 0.5067 Attributes: created_at: 2024-12-26T00:16:35.593902+00:00 arviz_version: 0.18.0 inference_library: pymc inference_library_version: 5.16.2
-
<xarray.Dataset> Size: 64kB Dimensions: (__obs__: 3988, rt,response_extra_dim_0: 2) Coordinates: * __obs__ (__obs__) int64 32kB 0 1 2 3 ... 3985 3986 3987 * rt,response_extra_dim_0 (rt,response_extra_dim_0) int64 16B 0 1 Data variables: rt,response (__obs__, rt,response_extra_dim_0) float32 32kB ... Attributes: created_at: 2024-12-26T00:16:35.600164+00:00 arviz_version: 0.18.0 inference_library: pymc inference_library_version: 5.16.2
The .hist
attribute stores the loss history. We can plot this to see how the loss function converged.
plt.plot(cav_model.vi_approx.hist)
plt.xlabel("Iteration")
plt.ylabel("Loss")
Text(0, 0.5, 'Loss')
Contrast outputs between MCMC and VI¶
__, axes = plt.subplots(4, 4, figsize=(10, 5))
az.plot_pair(cav_model.traces, ax=axes, scatter_kwargs=dict(alpha=0.01, color="blue"))
az.plot_pair(cav_model.vi_idata, ax=axes, scatter_kwargs=dict(alpha=0.04, color="red"))
array([[<Axes: ylabel='v'>, <Axes: >, <Axes: >, <Axes: >], [<Axes: ylabel='theta'>, <Axes: >, <Axes: >, <Axes: >], [<Axes: ylabel='a'>, <Axes: >, <Axes: >, <Axes: >], [<Axes: xlabel='t', ylabel='z'>, <Axes: xlabel='v'>, <Axes: xlabel='theta'>, <Axes: xlabel='a'>]], dtype=object)
cav_model.traces
-
<xarray.Dataset> Size: 88kB Dimensions: (chain: 4, draw: 1000) Coordinates: * chain (chain) int64 32B 0 1 2 3 * draw (draw) int64 8kB 0 1 2 3 4 5 6 7 ... 993 994 995 996 997 998 999 Data variables: t (chain, draw) float32 16kB 0.2832 0.2857 0.2826 ... 0.2675 0.2981 v (chain, draw) float32 16kB 0.3239 0.3621 0.3563 ... 0.3656 0.3823 theta (chain, draw) float32 16kB 0.23 0.2414 0.2277 ... 0.2419 0.2042 a (chain, draw) float32 16kB 1.309 1.321 1.305 ... 1.326 1.35 1.281 z (chain, draw) float32 16kB 0.5147 0.5066 0.5067 ... 0.5059 0.5055 Attributes: created_at: 2024-12-26T00:16:29.698105+00:00 arviz_version: 0.18.0 inference_library: numpyro inference_library_version: 0.15.2 sampling_time: 162.656462 tuning_steps: 1000 modeling_interface: bambi modeling_interface_version: 0.14.0
-
<xarray.Dataset> Size: 128MB Dimensions: (chain: 4, draw: 1000, __obs__: 3988) Coordinates: * chain (chain) int64 32B 0 1 2 3 * draw (draw) int64 8kB 0 1 2 3 4 5 6 ... 993 994 995 996 997 998 999 * __obs__ (__obs__) int64 32kB 0 1 2 3 4 5 ... 3983 3984 3985 3986 3987 Data variables: rt,response (chain, draw, __obs__) float64 128MB -0.9137 -1.297 ... -0.915 Attributes: modeling_interface: bambi modeling_interface_version: 0.14.0
-
<xarray.Dataset> Size: 124kB Dimensions: (chain: 4, draw: 1000) Coordinates: * chain (chain) int64 32B 0 1 2 3 * draw (draw) int64 8kB 0 1 2 3 4 5 6 ... 994 995 996 997 998 999 Data variables: acceptance_rate (chain, draw) float32 16kB 0.926 0.9808 ... 0.9477 0.9704 diverging (chain, draw) bool 4kB False False False ... False False energy (chain, draw) float32 16kB 6.027e+03 ... 6.027e+03 lp (chain, draw) float32 16kB 6.023e+03 ... 6.025e+03 n_steps (chain, draw) int32 16kB 15 15 15 7 7 3 ... 31 7 7 15 15 31 step_size (chain, draw) float32 16kB 0.2595 0.2595 ... 0.314 0.314 tree_depth (chain, draw) int64 32kB 4 4 4 3 3 2 2 4 ... 4 5 3 3 4 4 5 Attributes: created_at: 2024-12-26T00:16:29.702096+00:00 arviz_version: 0.18.0 modeling_interface: bambi modeling_interface_version: 0.14.0
-
<xarray.Dataset> Size: 64kB Dimensions: (__obs__: 3988, rt,response_extra_dim_0: 2) Coordinates: * __obs__ (__obs__) int64 32kB 0 1 2 3 ... 3985 3986 3987 * rt,response_extra_dim_0 (rt,response_extra_dim_0) int64 16B 0 1 Data variables: rt,response (__obs__, rt,response_extra_dim_0) float32 32kB ... Attributes: created_at: 2024-12-26T00:16:29.702990+00:00 arviz_version: 0.18.0 inference_library: numpyro inference_library_version: 0.15.2 sampling_time: 162.656462 tuning_steps: 1000 modeling_interface: bambi modeling_interface_version: 0.14.0
Further Reading¶
We suggest to check out the documentation on the VI api in Pymc for the full glory details of the capabilities we have access to.
Working directly through PyMC¶
Here we illustrate how to use our attached pymc_model
to make use of the object oriented API for variational inference. This allows us a few extra affordances.
Let's define a few helper functions first.
import warnings
import xarray as xr
from matplotlib import gridspec
from pymc.blocking import DictToArrayBijection, RaveledVars
def tracker_to_idata(tracker, model):
"""Turn a tracker object into an InferenceData object."""
tracker_groups = list(tracker.whatchdict.keys())
# n_steps = len(tracker[tracker_groups[0]])
stacked_results = {
tracker_group: {
key: np.stack([d[key] for d in tracker[tracker_group]])
for key in tracker[tracker_group][0]
}
for tracker_group in tracker_groups
}
# coords = {"vi_step": np.arange(n_steps)} | {
# k: np.array(v) for k, v in model.coords.items()
# }
var_to_dims = {
var.name: ("vi_step", *(model.named_vars_to_dims.get(var.name, ())))
for var in model.continuous_value_vars
}
datasets = {
key: xr.Dataset(
{
var: (var_to_dims[var], stacked_results[key][var])
for var in stacked_results[key].keys()
}
)
for key in tracker_groups
}
with warnings.catch_warnings(action="ignore"):
return az.InferenceData(**datasets)
def untransform_params(idata, model):
"""Bring transformed parmater back to their original scale."""
suffixes = ["_interval__", "_log__"]
def remove_suffixes(word, suffixes):
for suffix in suffixes:
if word.endswith(suffix):
return word[: -len(suffix)]
return word
free_rv_names = [rv_.name for rv_ in model.free_RVs]
transformed_vars = list(idata.mean.data_vars.keys())
collect_untransformed_vars = []
collect_untransformed_xarray_datasets = []
for var_ in transformed_vars:
var_untrans = remove_suffixes(var_, suffixes=suffixes)
if var_untrans in free_rv_names:
rv = model.free_RVs[free_rv_names.index(var_untrans)]
if model.rvs_to_transforms[rv] is not None:
untransformed_var = (
model.rvs_to_transforms[rv]
.backward(idata.mean[var_].values, *rv.owner.inputs)
.eval()
)
collect_untransformed_vars.append(var_)
collect_untransformed_xarray_datasets.append(
xr.Dataset(
data_vars={var_untrans: (("vi_step"), untransformed_var)}
)
)
return xr.merge([idata.mean] + collect_untransformed_xarray_datasets).drop_vars(
collect_untransformed_vars
)
def plot_vi_traces(idata):
"""Plot parameter history of the optimization alogrithm."""
if not isinstance(idata, az.InferenceData):
raise ValueError("idata must be an InferenceData object")
if "loss" not in idata.groups():
raise ValueError("InferenceData object must contain a 'loss' group")
if "mean_untransformed" not in idata.groups():
print(
"Using transformed variables because 'mean_untransformed' group not found"
)
data_vars = list(idata["mean"].data_vars.keys())
else:
data_vars = list(idata["mean_untransformed"].data_vars.keys())
fig = plt.figure(figsize=(8, 1.5 * len(data_vars)))
gs = gridspec.GridSpec(
len(data_vars) // 2 + 2
if (len(data_vars) % 2) == 0
else (len(data_vars) // 2) + 3,
2,
)
for i, var_ in enumerate(data_vars):
ax_tmp = fig.add_subplot(gs[i // 2, i % 2])
idata["mean_untransformed"][var_].plot(ax=ax_tmp)
ax_tmp.set_title(var_)
last_ax = fig.add_subplot(gs[-2:, :])
idata["loss"].loss.plot(ax=last_ax)
gs.tight_layout(fig)
return fig
# Define the ADVI runner
with cav_model.pymc_model:
advi = pm.ADVI()
# Set up starting point
start = cav_model.pymc_model.initial_point()
vars_dict = {var.name: var for var in cav_model.pymc_model.continuous_value_vars}
x0 = DictToArrayBijection.map(
{var_name: value for var_name, value in start.items() if var_name in vars_dict}
)
# Define quantities to track
tracker = pm.callbacks.Tracker(
mean=lambda: DictToArrayBijection.rmap(
RaveledVars(advi.approx.mean.eval(), x0.point_map_info), start
), # callable that returns mean
std=lambda: DictToArrayBijection.rmap(
RaveledVars(advi.approx.std.eval(), x0.point_map_info), start
), # callable that returns std
)
# Run VI fit
approx = advi.fit(n=30000, callbacks=[tracker])
vi_posterior_samples = approx.sample(1000)
vi_posterior_samples.posterior = vi_posterior_samples.posterior.drop_vars("v_mean")
Output()
Finished [100%]: Average Loss = 6,037.3
from copy import deepcopy
# Convert tracked quantities to idata
result = tracker_to_idata(tracker, cav_model.pymc_model)
# Add untransformed parameters
result.add_groups(
{"mean_untransformed": untransform_params(deepcopy(result), cav_model.pymc_model)}
)
# Add loss group
result.add_groups(
{"loss": xr.Dataset(data_vars={"loss": ("vi_step", np.array(approx.hist))})}
)
/Users/afengler/miniconda3/envs/hssm516/lib/python3.11/site-packages/arviz/data/inference_data.py:1538: UserWarning: The group mean_untransformed is not defined in the InferenceData scheme warnings.warn( /Users/afengler/miniconda3/envs/hssm516/lib/python3.11/site-packages/arviz/data/inference_data.py:1538: UserWarning: The group loss is not defined in the InferenceData scheme warnings.warn(
A quick look at our result
InferenceData
object, to understand what happened here.
We now have two additional groups:
mean_untransformed
which holds parameter values in the orignal space (instead of the parameters over which the optimization operates, which will always live in an unconstrained space)loss
which holds our training history
result
-
<xarray.Dataset> Size: 600kB Dimensions: (vi_step: 30000) Dimensions without coordinates: vi_step Data variables: t_interval__ (vi_step) float32 120kB -0.001 -0.001633 ... -1.798 -1.798 theta_interval__ (vi_step) float32 120kB 0.001 1.338e-05 ... -1.19 -1.189 a_interval__ (vi_step) float32 120kB -0.001 -0.00154 ... -0.5246 -0.525 z_interval__ (vi_step) float32 120kB 0.001 0.002 ... 0.02558 0.02568 v_interval__ (vi_step) float32 120kB -0.001 -7.28e-05 ... 0.2433 0.2435
-
<xarray.Dataset> Size: 600kB Dimensions: (vi_step: 30000) Dimensions without coordinates: vi_step Data variables: t_interval__ (vi_step) float32 120kB 0.6926 0.6931 ... 0.02431 0.02429 theta_interval__ (vi_step) float32 120kB 0.6926 0.6931 ... 0.02644 0.02643 a_interval__ (vi_step) float32 120kB 0.6936 0.6941 ... 0.01393 0.01393 z_interval__ (vi_step) float32 120kB 0.6926 0.6931 ... 0.02442 0.02443 v_interval__ (vi_step) float32 120kB 0.6926 0.6924 ... 0.01195 0.01195
-
<xarray.Dataset> Size: 600kB Dimensions: (vi_step: 30000) Dimensions without coordinates: vi_step Data variables: t (vi_step) float32 120kB 1.0 0.9997 0.9996 ... 0.2852 0.2851 0.285 theta (vi_step) float32 120kB 0.6003 0.6 0.5997 ... 0.2265 0.2266 0.2267 a (vi_step) float32 120kB 1.649 1.649 1.649 ... 1.304 1.304 1.304 z (vi_step) float32 120kB 0.5002 0.5004 0.5005 ... 0.5051 0.5051 v (vi_step) float32 120kB -0.0015 -0.0001094 ... 0.3631 0.3635
-
<xarray.Dataset> Size: 240kB Dimensions: (vi_step: 30000) Dimensions without coordinates: vi_step Data variables: loss (vi_step) float64 240kB 1.53e+04 9.374e+03 ... 6.036e+03 6.038e+03
Plot Results¶
We can plot the parameter trajectories (histories) over optimization steps, with our little helper function plot_vi_traces()
.
NOTE:
This is a random run, and we did not thoroughly check if the number of steps we allowed the optimizer were indeed enough to converge.
fig = plot_vi_traces(result)
__, axes = plt.subplots(4, 4, figsize=(10, 5))
# Plot MCMC [nuts]
az.plot_pair(cav_model.traces, ax=axes, scatter_kwargs=dict(alpha=0.01, color="blue"))
# Plot VI via .vi() [fullrank_advi]
az.plot_pair(cav_model.vi_idata, ax=axes, scatter_kwargs=dict(alpha=0.04, color="red"))
# Plot VI via pymc interface [advi]
# (We need to make sure the variables are in correct order)
az.plot_pair(
vi_posterior_samples.posterior[list(cav_model.traces.posterior.data_vars)],
ax=axes,
scatter_kwargs=dict(alpha=0.04, color="green"),
)
array([[<Axes: ylabel='v'>, <Axes: >, <Axes: >, <Axes: >], [<Axes: ylabel='theta'>, <Axes: >, <Axes: >, <Axes: >], [<Axes: ylabel='a'>, <Axes: >, <Axes: >, <Axes: >], [<Axes: xlabel='t', ylabel='z'>, <Axes: xlabel='v'>, <Axes: xlabel='theta'>, <Axes: xlabel='a'>]], dtype=object)
NOTE:
It is expected that the posterior of our last run looks a little worse. We chose to run the advi
algorithm, which implies only an isotropic Gaussian approximation to the posterior, so we expect to miss the posterior covariances which we pick up via fullrank_advi
as well as MCMC.
Caveats¶
Variational Inference is powerful, however it comes with it's own set of sharp edges.
- You will not always be in a position to compare VI with MCMC runs (after all if you can run full MCMC, there isn't much benefit to using VI at all) and it can be hard to a priori estimate how many steps you may need to run the algorithm for.
- The posteriors will be approximate, if the true posterior includes complex parameter trade-offs, VI might result in inaccurate posterior estimates.
- We recommend VI for
loglik_kind="approx_differentiable"
, since the gradients of theanalytical
log-likelihoods still prove somewhat brittle at this point in time.
Read¶
To learn a bit more about the VI-API in PyMC, we recommend you to read the excellent short tutorial in the main documentation.