hssm.set_floatX
hssm.set_floatX ¶
Set float types for pytensor and Jax.
Often we wish to work with a specific type of float in both PyTensor and JAX. This function helps set float types in both packages.
Parameters:
-
dtype(Literal['float32', 'float64']) –Either
float32orfloat64. Float type for pytensor (and jax ifjax=True). -
update_jax(optional, default:True) –Whether this function also sets float type for JAX by changing the
jax_enable_x64setting in JAX config. Defaults to True.