hssm.set_floatX
hssm.set_floatX ¶
Set float types for pytensor and Jax.
Often we wish to work with a specific type of float in both PyTensor and JAX. This function helps set float types in both packages.
Parameters:
-
dtype
(Literal['float32', 'float64']
) –Either
float32
orfloat64
. Float type for pytensor (and jax ifjax=True
). -
update_jax
(optional
, default:True
) –Whether this function also sets float type for JAX by changing the
jax_enable_x64
setting in JAX config. Defaults to True.