Source code for yippy._precision
"""Precision policy for yippy.
All float storage follows the global ``jax_enable_x64`` flag: float32 when the
flag is off (the default, memory-friendly), float64 when it is on. This avoids a
segmented float32/float64 pipeline, which JAX handles poorly.
"""
import jax
import numpy as np
[docs]
def float_dtype():
"""Active default float dtype: float64 if jax_enable_x64 else float32.
Returns a numpy dtype usable for both ``np.*`` and ``jnp.*`` allocations.
"""
return jax.dtypes.canonicalize_dtype(np.float64)
[docs]
def dtype_tag():
"""Short cache key for the active float dtype: ``"f32"`` or ``"f64"``."""
return "f64" if np.dtype(float_dtype()).itemsize == 8 else "f32"