Source code for yippy.eqx_coronagraph

"""Pure JAX/Equinox coronagraph module.

This module provides ``EqxCoronagraph``, a first-class ``eqx.Module`` that
wraps the data loaded by :class:`yippy.Coronagraph` into a form that is fully
compatible with ``jax.jit``, ``jax.vmap``, and other JAX transformations.

Usage::

    from yippy import EqxCoronagraph

    # One-liner: pass a YIP path directly
    coro = EqxCoronagraph("/path/to/yip")

    # Or from an existing yippy Coronagraph
    from yippy import Coronagraph
    yippy_coro = Coronagraph("/path/to/yip")
    coro = EqxCoronagraph(yippy_coro=yippy_coro)

All methods on ``EqxCoronagraph`` are JIT-traceable.  Downstream code should
use ``eqx.filter_jit`` (not ``jax.jit``) when JIT-compiling functions that
accept an ``EqxCoronagraph`` as input::

    import equinox as eqx

    @eqx.filter_jit
    def simulate(coro, x, y):
        psf = coro.create_psf(x, y)
        stellar = coro.stellar_intens(0.01)
        return psf + stellar
"""

from __future__ import annotations

from pathlib import Path

import equinox as eqx
import interpax
import jax.numpy as jnp
import numpy as np
from jaxtyping import Array

from ._precision import float_dtype
from .coronagraph import Coronagraph


[docs] class EqxCoronagraph(eqx.Module): """Pure JAX/Equinox coronagraph -- no astropy, no scipy, no I/O at runtime. This module stores all coronagraph data as JAX arrays and interpax interpolators. It is a valid pytree and can be passed through any JAX transformation. Fields fall into two categories when processed by ``eqx.filter_jit``: **Dynamic** (JAX arrays / eqx.Module leaves -- values can change without recompiling, *provided shapes stay the same*): - ``sky_trans``, ``psf_datacube`` - All ``interpax.CubicSpline`` interpolators (they are ``eqx.Module`` instances whose leaves are JAX arrays) **Static** (non-array Python objects -- changing triggers recompilation, but ``filter_jit`` handles this automatically): - ``create_psf``, ``create_psfs`` (callables / closures) - Scalar metadata (``pixel_scale_lod``, ``IWA``, ``OWA``, etc.) - ``psf_shape`` (tuple) Switching between different ``EqxCoronagraph`` instances inside a ``filter_jit``-compiled function **will** cause recompilation (different callable closures and likely different interpolator shapes). This is expected and unavoidable. """ # -- Scalar metadata (auto-static in filter_jit) ----------------------- pixel_scale_lod: float psf_shape: tuple[int, int] center_x: float center_y: float IWA: float # inner working angle in lam/D OWA: float # outer working angle in lam/D frac_obscured: float contrast_floor: float | None # -- Off-axis PSF synthesis (auto-static: callables) ------------------ create_psf: callable create_psfs: callable # -- Stellar intensity interpolation (dynamic: eqx.Module) ----------- _stellar_ln_interp: interpax.CubicSpline # -- Performance curves (dynamic: eqx.Module) ------------------------ _throughput_interp: interpax.CubicSpline _log_contrast_interp: interpax.CubicSpline _occ_trans_interp: interpax.CubicSpline _core_area_interp: interpax.CubicSpline _core_mean_intensity_interp: interpax.CubicSpline _core_mean_intensity_interp_2d: interpax.Interpolator2D | None _has_2d_core_intensity: bool # -- Static arrays (dynamic) ----------------------------------------- sky_trans: Array psf_datacube: Array | None # -- Construction ----------------------------------------------------- def __init__( self, yip_path: str | Path | None = None, *, yippy_coro: Coronagraph | None = None, ensure_psf_datacube: bool = False, # Forwarded to yippy.Coronagraph when building from yip_path downsample_shape: tuple[int, int] | None = None, aperture_radius_lod: float = 0.7, contrast_floor: float | None = None, use_inscribed_diameter: bool = False, # Extra Coronagraph kwargs x_symmetric: bool = True, y_symmetric: bool = True, **kwargs, ): """Create a pure-JAX coronagraph from a YIP directory or existing Coronagraph. Args: yip_path: Path to a Yield Input Package directory. If provided (and ``yippy_coro`` is not), a temporary ``yippy.Coronagraph`` is built internally. yippy_coro: An already-initialised ``yippy.Coronagraph`` instance. Takes precedence over ``yip_path`` if both are given. ensure_psf_datacube: If ``True``, generate/load the 4-D PSF datacube and store it. The datacube can be very large; default is ``False``. downsample_shape: Optional ``(ny, nx)`` to downsample PSFs (forwarded). aperture_radius_lod: Aperture radius in lam/D for performance curves (forwarded). contrast_floor: Minimum contrast value for engineering stability floor (forwarded). use_inscribed_diameter: Whether to use inscribed diameter for lam/D calcs (forwarded). x_symmetric: Whether off-axis PSFs are symmetric about the x-axis (forwarded). y_symmetric: Whether off-axis PSFs are symmetric about the y-axis (forwarded). **kwargs: Additional keyword arguments to pass to ``yippy.Coronagraph``. Raises: ValueError: If neither ``yip_path`` nor ``yippy_coro`` is provided. """ if yippy_coro is None and yip_path is None: raise ValueError("Provide either yip_path or yippy_coro") if yippy_coro is None: yippy_coro = Coronagraph( yip_path, downsample_shape=downsample_shape, aperture_radius_lod=aperture_radius_lod, contrast_floor=contrast_floor, x_symmetric=x_symmetric, y_symmetric=y_symmetric, use_inscribed_diameter=use_inscribed_diameter, **kwargs, ) # -- Scalar metadata --------------------------------------------- self.pixel_scale_lod = float(yippy_coro.pixel_scale_arcsec.value) self.psf_shape = tuple(map(int, yippy_coro.psf_shape)) self.center_x = float(yippy_coro.offax.center_x.value) self.center_y = float(yippy_coro.offax.center_y.value) self.IWA = float(yippy_coro.IWA.value) self.OWA = float(yippy_coro.OWA.value) self.frac_obscured = float(yippy_coro.frac_obscured) self.contrast_floor = ( float(contrast_floor) if contrast_floor is not None else None ) # -- PSF creation callables -------------------------------------- self.create_psf = yippy_coro.offax.create_psf self.create_psfs = yippy_coro.offax.create_psfs # -- Stellar intensity interpolation ----------------------------- # FITS-sourced arrays are big-endian (>f8); jnp cannot ingest those, so # an explicit dtype both byte-swaps to native and follows the x64 flag. stellar = yippy_coro.stellar_intens stellar_diams = jnp.asarray(stellar.diams.value, dtype=float_dtype()) # Convert stellar PSFs to JAX arrays and build log-space interpolator stellar_psfs = jnp.asarray(stellar.psfs, dtype=float_dtype()) self._stellar_ln_interp = interpax.CubicSpline( stellar_diams, jnp.log(stellar_psfs) ) # -- Performance curve interpolators ----------------------------- self._throughput_interp = _scipy_to_interpax(yippy_coro.throughput_interp) self._log_contrast_interp = _scipy_to_interpax(yippy_coro._log_contrast_interp) self._occ_trans_interp = _scipy_to_interpax(yippy_coro.occ_trans_interp) self._core_area_interp = _scipy_to_interpax(yippy_coro.core_area_interp) self._core_mean_intensity_interp = _scipy_to_interpax( yippy_coro.core_intensity_interp ) # 2D core mean intensity (separation x stellar_diam) when available if yippy_coro.core_intensity_interp_2d is not None: rgi = yippy_coro.core_intensity_interp_2d # RegularGridInterpolator stores grid points in .grid sep_knots = jnp.asarray(rgi.grid[0], dtype=float_dtype()) diam_knots = jnp.asarray(rgi.grid[1], dtype=float_dtype()) values_2d = jnp.asarray(rgi.values, dtype=float_dtype()) self._core_mean_intensity_interp_2d = interpax.Interpolator2D( sep_knots, diam_knots, values_2d, method="linear", extrap=False, # returns NaN out-of-bounds ) self._has_2d_core_intensity = True else: self._core_mean_intensity_interp_2d = None self._has_2d_core_intensity = False # -- Sky transmission -------------------------------------------- self.sky_trans = jnp.asarray(yippy_coro.sky_trans(), dtype=float_dtype()) # -- Optional PSF datacube --------------------------------------- if ensure_psf_datacube: if not yippy_coro.has_psf_datacube: yippy_coro.create_psf_datacube() datacube = yippy_coro.psf_datacube # jnp.asarray canonicalizes to the active float dtype and is a no-op # when the cube already matches (no copy). self.psf_datacube = jnp.asarray(datacube) # Release reference in yippy to avoid duplicate storage yippy_coro.psf_datacube = None else: self.psf_datacube = None # -- Public methods (all JIT-traceable) -------------------------------
[docs] def stellar_intens(self, stellar_diam_lod: float) -> Array: """Interpolate the stellar intensity map for a given stellar diameter. Args: stellar_diam_lod: Stellar diameter in lam/D (unitless float). Returns: 2-D JAX array containing the stellar intensity map. """ return jnp.exp(self._stellar_ln_interp(stellar_diam_lod))
[docs] def throughput(self, separation_lod: float) -> Array: """Evaluate coronagraph throughput at the given separation. Args: separation_lod: Separation from the star in lam/D. Returns: Scalar throughput value. """ return self._throughput_interp(separation_lod)
[docs] def raw_contrast(self, separation_lod: float) -> Array: """Evaluate raw contrast at the given separation (log-space interpolation). Args: separation_lod: Separation from the star in lam/D. Returns: Scalar raw contrast value. """ result = jnp.power(10.0, self._log_contrast_interp(separation_lod)) if self.contrast_floor is not None: result = jnp.maximum(result, self.contrast_floor) return result
[docs] def noise_floor_exosims( self, separation_lod: float, contrast_floor: float = 1e-10, ppf: float = 30.0, ) -> Array: """Noise floor in EXOSIMS contrast convention. Computed as ``max(|raw_contrast|, contrast_floor) / ppf``. Args: separation_lod: Separation from the star in lambda/D. contrast_floor: Minimum contrast value. ppf: Post-processing noise suppression factor. Returns: Scalar noise floor value (EXOSIMS convention). """ rc = jnp.abs(self.raw_contrast(separation_lod)) return jnp.maximum(rc, contrast_floor) / ppf
[docs] def noise_floor_ayo( self, separation_lod: float, ppf: float = 30.0, ) -> Array: """Noise floor in AYO/pyEDITH per-pixel convention. Computed as ``core_mean_intensity(sep) / ppf``. Args: separation_lod: Separation from the star in lambda/D. ppf: Post-processing noise suppression factor. Returns: Scalar noise floor value (AYO/pyEDITH convention). """ return self.core_mean_intensity(separation_lod) / ppf
[docs] def occulter_transmission(self, separation_lod: float) -> Array: """Evaluate occulter transmission at the given separation. Args: separation_lod: Separation from the star in lam/D. Returns: Scalar occulter transmission value. """ return self._occ_trans_interp(separation_lod)
[docs] def core_area(self, separation_lod: float) -> Array: """Evaluate core area at the given separation. Args: separation_lod: Separation from the star in lam/D. Returns: Scalar core area value in (lam/D)**2. """ return self._core_area_interp(separation_lod)
[docs] def core_mean_intensity( self, separation_lod: float, stellar_diam_lod: float = 0.0 ) -> Array: """Evaluate core mean intensity at the given separation. Uses the 1D spline for the default diameter (point source) and the 2D interpolant for non-default stellar diameters when available. Args: separation_lod: Separation from the star in lambda/D. stellar_diam_lod: Stellar angular diameter in lambda/D. Default is 0.0 (point source). Returns: Scalar core mean intensity value. """ if stellar_diam_lod != 0.0 and self._has_2d_core_intensity: return self._core_mean_intensity_interp_2d(separation_lod, stellar_diam_lod) return self._core_mean_intensity_interp(separation_lod)
# -- Helpers ------------------------------------------------------------------
[docs] def _scipy_to_interpax(scipy_spline): """Convert a ``scipy.interpolate.BSpline`` / ``make_interp_spline`` to interpax. The scipy spline stores knots (``t``) and coefficients (``c``). We re-evaluate it on its interior knots (the original data x-values) and build a fresh interpax interpolator from those (x, y) pairs. For linear splines (k=1) we use ``interpax.Interpolator1D(method='linear')``. For cubic splines (k=3) we use ``interpax.CubicSpline``. Args: scipy_spline: A scipy BSpline or result of ``make_interp_spline``. Returns: An interpax interpolator that approximates the same function. """ # Extract the unique interior knots (stripping the k+1 padded boundary # knots from each end). For make_interp_spline the interior knots # exactly equal the original x data. k = scipy_spline.k t = scipy_spline.t x_np = np.unique(t[k:-k]) # Evaluate the scipy spline on those x values y_np = scipy_spline(x_np) # Convert to JAX arrays x_jax = jnp.asarray(x_np, dtype=float_dtype()) y_jax = jnp.asarray(y_np, dtype=float_dtype()) if k <= 1: return interpax.Interpolator1D(x_jax, y_jax, method="linear", extrap=True) return interpax.CubicSpline(x_jax, y_jax)