Skip to content

hssm.set_floatX

hssm.set_floatX

set_floatX(dtype: Literal['float32', 'float64'], update_jax: bool = True)

Set float types for pytensor and Jax.

Often we wish to work with a specific type of float in both PyTensor and JAX. This function helps set float types in both packages.

Parameters:

  • dtype (Literal['float32', 'float64']) –

    Either float32 or float64. Float type for pytensor (and jax if jax=True).

  • update_jax (optional, default: True ) –

    Whether this function also sets float type for JAX by changing the jax_enable_x64 setting in JAX config. Defaults to True.