yippy.eqx_coronagraph
=====================

.. py:module:: yippy.eqx_coronagraph

.. autoapi-nested-parse::

   Pure JAX/Equinox coronagraph module.

   This module provides ``EqxCoronagraph``, a first-class ``eqx.Module`` that
   wraps the data loaded by :class:`yippy.Coronagraph` into a form that is fully
   compatible with ``jax.jit``, ``jax.vmap``, and other JAX transformations.

   Usage::

       from yippy import EqxCoronagraph

       # One-liner: pass a YIP path directly
       coro = EqxCoronagraph("/path/to/yip")

       # Or from an existing yippy Coronagraph
       from yippy import Coronagraph
       yippy_coro = Coronagraph("/path/to/yip")
       coro = EqxCoronagraph(yippy_coro=yippy_coro)

   All methods on ``EqxCoronagraph`` are JIT-traceable.  Downstream code should
   use ``eqx.filter_jit`` (not ``jax.jit``) when JIT-compiling functions that
   accept an ``EqxCoronagraph`` as input::

       import equinox as eqx

       @eqx.filter_jit
       def simulate(coro, x, y):
           psf = coro.create_psf(x, y)
           stellar = coro.stellar_intens(0.01)
           return psf + stellar



Classes
-------

.. autoapisummary::

   yippy.eqx_coronagraph.EqxCoronagraph


Functions
---------

.. autoapisummary::

   yippy.eqx_coronagraph._scipy_to_interpax


Module Contents
---------------

.. py:class:: EqxCoronagraph(yip_path = None, *, yippy_coro = None, ensure_psf_datacube = False, downsample_shape = None, aperture_radius_lod = 0.7, contrast_floor = None, use_inscribed_diameter = False, x_symmetric = True, y_symmetric = True, **kwargs)

   Bases: :py:obj:`equinox.Module`


   Pure JAX/Equinox coronagraph -- no astropy, no scipy, no I/O at runtime.

   This module stores all coronagraph data as JAX arrays and interpax
   interpolators.  It is a valid pytree and can be passed through any JAX
   transformation.

   Fields fall into two categories when processed by ``eqx.filter_jit``:

   **Dynamic** (JAX arrays / eqx.Module leaves -- values can change without
   recompiling, *provided shapes stay the same*):

   - ``sky_trans``, ``psf_datacube``
   - All ``interpax.CubicSpline`` interpolators (they are ``eqx.Module``
     instances whose leaves are JAX arrays)

   **Static** (non-array Python objects -- changing triggers recompilation,
   but ``filter_jit`` handles this automatically):

   - ``create_psf``, ``create_psfs`` (callables / closures)
   - Scalar metadata (``pixel_scale_lod``, ``IWA``, ``OWA``, etc.)
   - ``psf_shape`` (tuple)

   Switching between different ``EqxCoronagraph`` instances inside a
   ``filter_jit``-compiled function **will** cause recompilation (different
   callable closures and likely different interpolator shapes).  This is
   expected and unavoidable.


   .. py:attribute:: pixel_scale_lod
      :type:  float


   .. py:attribute:: psf_shape
      :type:  tuple[int, int]


   .. py:attribute:: center_x
      :type:  float


   .. py:attribute:: center_y
      :type:  float


   .. py:attribute:: IWA
      :type:  float


   .. py:attribute:: OWA
      :type:  float


   .. py:attribute:: frac_obscured
      :type:  float


   .. py:attribute:: contrast_floor
      :type:  float | None


   .. py:attribute:: create_psf
      :type:  callable


   .. py:attribute:: create_psfs
      :type:  callable


   .. py:attribute:: _stellar_ln_interp
      :type:  interpax.CubicSpline


   .. py:attribute:: _throughput_interp
      :type:  interpax.CubicSpline


   .. py:attribute:: _log_contrast_interp
      :type:  interpax.CubicSpline


   .. py:attribute:: _occ_trans_interp
      :type:  interpax.CubicSpline


   .. py:attribute:: _core_area_interp
      :type:  interpax.CubicSpline


   .. py:attribute:: _core_mean_intensity_interp
      :type:  interpax.CubicSpline


   .. py:attribute:: _core_mean_intensity_interp_2d
      :type:  interpax.Interpolator2D | None


   .. py:attribute:: _has_2d_core_intensity
      :type:  bool


   .. py:attribute:: sky_trans
      :type:  jaxtyping.Array


   .. py:attribute:: psf_datacube
      :type:  jaxtyping.Array | None


   .. py:method:: stellar_intens(stellar_diam_lod)

      Interpolate the stellar intensity map for a given stellar diameter.

      Args:
          stellar_diam_lod: Stellar diameter in lam/D (unitless float).

      Returns:
          2-D JAX array containing the stellar intensity map.



   .. py:method:: throughput(separation_lod)

      Evaluate coronagraph throughput at the given separation.

      Args:
          separation_lod: Separation from the star in lam/D.

      Returns:
          Scalar throughput value.



   .. py:method:: raw_contrast(separation_lod)

      Evaluate raw contrast at the given separation (log-space interpolation).

      Args:
          separation_lod: Separation from the star in lam/D.

      Returns:
          Scalar raw contrast value.



   .. py:method:: noise_floor_exosims(separation_lod, contrast_floor = 1e-10, ppf = 30.0)

      Noise floor in EXOSIMS contrast convention.

      Computed as ``max(|raw_contrast|, contrast_floor) / ppf``.

      Args:
          separation_lod: Separation from the star in lambda/D.
          contrast_floor: Minimum contrast value.
          ppf: Post-processing noise suppression factor.

      Returns:
          Scalar noise floor value (EXOSIMS convention).



   .. py:method:: noise_floor_ayo(separation_lod, ppf = 30.0)

      Noise floor in AYO/pyEDITH per-pixel convention.

      Computed as ``core_mean_intensity(sep) / ppf``.

      Args:
          separation_lod: Separation from the star in lambda/D.
          ppf: Post-processing noise suppression factor.

      Returns:
          Scalar noise floor value (AYO/pyEDITH convention).



   .. py:method:: occulter_transmission(separation_lod)

      Evaluate occulter transmission at the given separation.

      Args:
          separation_lod: Separation from the star in lam/D.

      Returns:
          Scalar occulter transmission value.



   .. py:method:: core_area(separation_lod)

      Evaluate core area at the given separation.

      Args:
          separation_lod: Separation from the star in lam/D.

      Returns:
          Scalar core area value in (lam/D)**2.



   .. py:method:: core_mean_intensity(separation_lod, stellar_diam_lod = 0.0)

      Evaluate core mean intensity at the given separation.

      Uses the 1D spline for the default diameter (point source) and
      the 2D interpolant for non-default stellar diameters when
      available.

      Args:
          separation_lod: Separation from the star in lambda/D.
          stellar_diam_lod: Stellar angular diameter in lambda/D.
              Default is 0.0 (point source).

      Returns:
          Scalar core mean intensity value.



.. py:function:: _scipy_to_interpax(scipy_spline)

   Convert a ``scipy.interpolate.BSpline`` / ``make_interp_spline`` to interpax.

   The scipy spline stores knots (``t``) and coefficients (``c``).  We
   re-evaluate it on its interior knots (the original data x-values) and
   build a fresh interpax interpolator from those (x, y) pairs.

   For linear splines (k=1) we use ``interpax.Interpolator1D(method='linear')``.
   For cubic splines (k=3) we use ``interpax.CubicSpline``.

   Args:
       scipy_spline: A scipy BSpline or result of ``make_interp_spline``.

   Returns:
       An interpax interpolator that approximates the same function.


