yippy.offjax
============

.. py:module:: yippy.offjax

.. autoapi-nested-parse::

   Module for handling off-axis PSFs using JAX.



Classes
-------

.. autoapisummary::

   yippy.offjax.OffJAX


Module Contents
---------------

.. py:class:: OffJAX(yip_dir, offax_data_file, offax_offsets_file, pixel_scale_arcsec, x_symmetric, y_symmetric, downsample_shape = None)

   Bases: :py:obj:`yippy.offax.OffAx`


   Class for handling off-axis PSFs using JAX.

   This class inherits from OffAx and uses JAX for optimized computation.
   Memory-efficient: stores PSFs in flat array with index mapping.

   Attributes:
       pixel_scale_arcsec (Quantity):
           Pixel scale of the PSF data in lambda/D.
       center_x (Quantity):
           Central x position in the PSF data.
       center_y (Quantity):
           Central y position in the PSF data.
       flat_psfs (jnp.ndarray):
           Flat array of PSF data with shape (N_psfs, H, W), cast as JAX array.
       offset_to_flat_idx (jnp.ndarray):
           2D index mapping from (x_idx, y_idx) -> flat_psfs index.


   .. py:attribute:: flat_psfs


   .. py:attribute:: flat_x_offsets


   .. py:attribute:: flat_y_offsets


   .. py:attribute:: x_offsets


   .. py:attribute:: y_offsets


   .. py:attribute:: offset_to_flat_idx


   .. py:attribute:: kx


   .. py:attribute:: ky


   .. py:attribute:: create_psfs_kernel


   .. py:attribute:: create_psfs_j
      :value: None



   .. py:attribute:: create_psf_kernel_single
      :value: None



   .. py:attribute:: create_psfs

      Create and return the PSFs at the specified off-axis positions.

      .. deprecated::
          The pure-Python implementation is deprecated.  Use
          ``Coronagraph`` (which uses ``OffJAX``) instead.



   .. py:attribute:: create_psf

      Create and return the PSF at the specified off-axis position.

      .. deprecated::
          The pure-Python implementation is deprecated.  Use
          ``Coronagraph`` (which uses ``OffJAX``) instead.

      Interpolates and returns the Point Spread Function (PSF) at the specified
      off-axis position (x, y). If the exact (x, y) position matches one of the
      PSFs in the YIP, that PSF is returned directly. Otherwise, the PSFs
      from the surrounding positions are combined using Gaussian weighting and
      Fourier interpolation to produce an interpolated PSF.

      Args:
          x (float):
              The x-coordinate of the off-axis position.
          y (float):
              The y-coordinate of the off-axis position.

      Returns:
          np.ndarray:
              The interpolated PSF corresponding to the input (x, y) position.

      Notes:
          - If `self.type` is "1d", the (x, y) position is converted to a
          radial separation and angle for interpolation.
          - If `self.type` is "2dq", the (x, y) position is mirrored to the
          first quadrant, and the PSF is flipped accordingly after
          interpolation.
          - Gaussian weighting is used to combine the nearest PSFs when the
          exact (x, y) position does not match any precomputed PSF. The
          weighting is based on the distance from the input position.
          - The PSFs are shifted to align with the input position before
          combining, and the final PSF is normalized by the cumulative weight
          for each pixel.




   .. py:method:: create_psfs_parallel(x_vals, y_vals)

      Create off-axis PSFs at multiple positions in parallel using shard_map.

      Requires that ``hwoutils.set_host_device_count(N)`` was called at
      program startup to expose *N* CPU devices to JAX.



