yippy.eqx_coronagraph#

Pure JAX/Equinox coronagraph module.

This module provides EqxCoronagraph, a first-class eqx.Module that wraps the data loaded by yippy.Coronagraph into a form that is fully compatible with jax.jit, jax.vmap, and other JAX transformations.

Usage:

from yippy import EqxCoronagraph

# One-liner: pass a YIP path directly
coro = EqxCoronagraph("/path/to/yip")

# Or from an existing yippy Coronagraph
from yippy import Coronagraph
yippy_coro = Coronagraph("/path/to/yip")
coro = EqxCoronagraph(yippy_coro=yippy_coro)

All methods on EqxCoronagraph are JIT-traceable. Downstream code should use eqx.filter_jit (not jax.jit) when JIT-compiling functions that accept an EqxCoronagraph as input:

import equinox as eqx

@eqx.filter_jit
def simulate(coro, x, y):
    psf = coro.create_psf(x, y)
    stellar = coro.stellar_intens(0.01)
    return psf + stellar

Classes#

EqxCoronagraph

Pure JAX/Equinox coronagraph -- no astropy, no scipy, no I/O at runtime.

Functions#

_scipy_to_interpax(scipy_spline)

Convert a scipy.interpolate.BSpline / make_interp_spline to interpax.

Module Contents#

class yippy.eqx_coronagraph.EqxCoronagraph(yip_path=None, *, yippy_coro=None, ensure_psf_datacube=False, downsample_shape=None, aperture_radius_lod=0.7, contrast_floor=None, use_inscribed_diameter=False, x_symmetric=True, y_symmetric=True, **kwargs)[source]#

Bases: equinox.Module

Pure JAX/Equinox coronagraph – no astropy, no scipy, no I/O at runtime.

This module stores all coronagraph data as JAX arrays and interpax interpolators. It is a valid pytree and can be passed through any JAX transformation.

Fields fall into two categories when processed by eqx.filter_jit:

Dynamic (JAX arrays / eqx.Module leaves – values can change without recompiling, provided shapes stay the same):

  • sky_trans, psf_datacube

  • All interpax.CubicSpline interpolators (they are eqx.Module instances whose leaves are JAX arrays)

Static (non-array Python objects – changing triggers recompilation, but filter_jit handles this automatically):

  • create_psf, create_psfs (callables / closures)

  • Scalar metadata (pixel_scale_lod, IWA, OWA, etc.)

  • psf_shape (tuple)

Switching between different EqxCoronagraph instances inside a filter_jit-compiled function will cause recompilation (different callable closures and likely different interpolator shapes). This is expected and unavoidable.

Parameters:
pixel_scale_lod: float#
psf_shape: tuple[int, int]#
center_x: float#
center_y: float#
IWA: float#
OWA: float#
frac_obscured: float#
contrast_floor: float | None#
create_psf: callable#
create_psfs: callable#
_stellar_ln_interp: interpax.CubicSpline#
_throughput_interp: interpax.CubicSpline#
_log_contrast_interp: interpax.CubicSpline#
_occ_trans_interp: interpax.CubicSpline#
_core_area_interp: interpax.CubicSpline#
_core_mean_intensity_interp: interpax.CubicSpline#
_core_mean_intensity_interp_2d: interpax.Interpolator2D | None#
_has_2d_core_intensity: bool#
sky_trans: jaxtyping.Array#
psf_datacube: jaxtyping.Array | None#
stellar_intens(stellar_diam_lod)[source]#

Interpolate the stellar intensity map for a given stellar diameter.

Args:

stellar_diam_lod: Stellar diameter in lam/D (unitless float).

Returns:

2-D JAX array containing the stellar intensity map.

Parameters:

stellar_diam_lod (float)

Return type:

jaxtyping.Array

throughput(separation_lod)[source]#

Evaluate coronagraph throughput at the given separation.

Args:

separation_lod: Separation from the star in lam/D.

Returns:

Scalar throughput value.

Parameters:

separation_lod (float)

Return type:

jaxtyping.Array

raw_contrast(separation_lod)[source]#

Evaluate raw contrast at the given separation (log-space interpolation).

Args:

separation_lod: Separation from the star in lam/D.

Returns:

Scalar raw contrast value.

Parameters:

separation_lod (float)

Return type:

jaxtyping.Array

noise_floor_exosims(separation_lod, contrast_floor=1e-10, ppf=30.0)[source]#

Noise floor in EXOSIMS contrast convention.

Computed as max(|raw_contrast|, contrast_floor) / ppf.

Args:

separation_lod: Separation from the star in lambda/D. contrast_floor: Minimum contrast value. ppf: Post-processing noise suppression factor.

Returns:

Scalar noise floor value (EXOSIMS convention).

Parameters:
Return type:

jaxtyping.Array

noise_floor_ayo(separation_lod, ppf=30.0)[source]#

Noise floor in AYO/pyEDITH per-pixel convention.

Computed as core_mean_intensity(sep) / ppf.

Args:

separation_lod: Separation from the star in lambda/D. ppf: Post-processing noise suppression factor.

Returns:

Scalar noise floor value (AYO/pyEDITH convention).

Parameters:
Return type:

jaxtyping.Array

occulter_transmission(separation_lod)[source]#

Evaluate occulter transmission at the given separation.

Args:

separation_lod: Separation from the star in lam/D.

Returns:

Scalar occulter transmission value.

Parameters:

separation_lod (float)

Return type:

jaxtyping.Array

core_area(separation_lod)[source]#

Evaluate core area at the given separation.

Args:

separation_lod: Separation from the star in lam/D.

Returns:

Scalar core area value in (lam/D)**2.

Parameters:

separation_lod (float)

Return type:

jaxtyping.Array

core_mean_intensity(separation_lod, stellar_diam_lod=0.0)[source]#

Evaluate core mean intensity at the given separation.

Uses the 1D spline for the default diameter (point source) and the 2D interpolant for non-default stellar diameters when available.

Args:

separation_lod: Separation from the star in lambda/D. stellar_diam_lod: Stellar angular diameter in lambda/D.

Default is 0.0 (point source).

Returns:

Scalar core mean intensity value.

Parameters:
Return type:

jaxtyping.Array

yippy.eqx_coronagraph._scipy_to_interpax(scipy_spline)[source]#

Convert a scipy.interpolate.BSpline / make_interp_spline to interpax.

The scipy spline stores knots (t) and coefficients (c). We re-evaluate it on its interior knots (the original data x-values) and build a fresh interpax interpolator from those (x, y) pairs.

For linear splines (k=1) we use interpax.Interpolator1D(method='linear'). For cubic splines (k=3) we use interpax.CubicSpline.

Args:

scipy_spline: A scipy BSpline or result of make_interp_spline.

Returns:

An interpax interpolator that approximates the same function.