"""Base class for all offax_psfs.fits files.
This has been deprecated. Please use OffJAX instead.
"""
import warnings
from multiprocessing import Pool
from pathlib import Path
import astropy.io.fits as pyfits
import astropy.units as u
import numpy as np
from astropy.units import Quantity
from hwoutils.fft import fft_shift
from hwoutils.transforms import downsample_psfs
from lod_unit import lod
from numpy.typing import NDArray
from yippy.util import convert_to_lod, create_shift_mask
from .logger import logger
[docs]
class OffAx:
"""Base class for off-axis PSF handling.
Handles YIP data loading, offset parsing, and symmetry detection.
``OffJAX`` inherits from this class and overrides ``create_psf`` /
``create_psfs`` with JAX-accelerated implementations.
The pure-Python ``create_psf`` and ``create_psfs`` methods on this class
are retained as a reference implementation but are not used in production.
All ``Coronagraph`` instances now use ``OffJAX`` exclusively.
This class loads and processes PSF data from the yield input package (YIP).
It currently supports oneD and quarter symmetric PSF YIPs. The primary use
is to interpolate the PSF data to a given x/y position. This is done by
calling the OffAx object with the x/y position as arguments, which itself
calls the psf object after converting units.
Attributes:
pixel_scale_arcsec (Quantity):
Pixel scale of the PSF data in lambda/D.
center_x (Quantity):
Central x position in the PSF data.
center_y (Quantity):
Central y position in the PSF data.
psf:
Instance of the appropriate PSF class (e.g., OneD, TwoD) based on input YIP.
Args:
yip_dir (Path):
Path to the directory containing PSF and offset data.
logger (Logger):
Logger for logging events and information.
offax_data_file (str):
Name of the file containing the PSF data.
offax_offsets_file (str):
Name of the file containing the offsets data.
pixel_scale_arcsec (Quantity):
Pixel scale of the PSF data in lambda/D.
"""
def __init__(
self,
yip_dir: Path,
offax_data_file: str,
offax_offsets_file: str,
pixel_scale_arcsec: Quantity,
x_symmetric: bool,
y_symmetric: bool,
downsample_shape: tuple[int, int] | None = None,
) -> None:
"""Initializes the OffAx class by loading PSF and offset data from YIP.
Args:
yip_dir:
Path to the directory containing PSF and offset data.
offax_data_file:
Name of the file containing the PSF data.
offax_offsets_file:
Name of the file containing the offsets data.
pixel_scale_arcsec:
Pixel scale of the PSF data in lambda/D.
x_symmetric:
Whether the PSFs are symmetric in x.
y_symmetric:
Whether the PSFs are symmetric in y.
downsample_shape:
Optional target shape (ny, nx) to downsample PSFs to.
If provided, all PSFs will be resampled to this shape
immediately after loading, conserving total flux.
The pixel_scale_arcsec will be updated accordingly.
"""
# Pixel scale in lambda/D
self.pixel_scale_arcsec = pixel_scale_arcsec
# Load symmetry
self.x_symmetric = x_symmetric
self.y_symmetric = y_symmetric
# Load off-axis PSF data (e.g. the planet) (unitless intensity maps)
psfs = pyfits.getdata(Path(yip_dir, offax_data_file), 0)
# Save the center of the pixel array, which is used for converting to
# lambda/D when the x/y positions are in pixels.
self.center_x = psfs.shape[1] / 2 * u.pix
self.center_y = psfs.shape[2] / 2 * u.pix
# Load the offset list, which is in units of lambda/D
offsets = pyfits.getdata(Path(yip_dir, offax_offsets_file), 0)
# Check whether offsets are given as 1D or 2D
one_d_offsets = len(offsets.shape) == 1
if one_d_offsets:
# Add a second dimension if the offsets are 1D
offsets = np.vstack((offsets, np.zeros_like(offsets)))
if len(offsets.shape) > 1:
if (offsets.shape[1] != 2) and (offsets.shape[0] == 2):
# This condition occurs when the offsets is transposed
# from the expected format
offsets = offsets.T
assert len(offsets) == psfs.shape[0], (
"Offsets and PSFs do not have the same number of elements"
)
########################################################################
# Determine the format of the input coronagraph files so we can handle #
# the coronagraph correctly (e.g. radially symmetric in x direction) #
########################################################################
# Get the unique values of the offset list so that we can format the
# data into
offsets_x = np.unique(offsets[:, 0])
offsets_y = np.unique(offsets[:, 1])
if len(offsets_x) == 1:
logger.info(f"{yip_dir.stem} is radially symmetric")
self.type = "1d"
# Instead of handling angles for 1dy, swap the x and y
offsets_x, offsets_y = (offsets_y, offsets_x)
offsets = np.vstack((offsets_y, offsets_x)).T
raise NotImplementedError(
"Verify that the PSFs are correct for this case!"
" I don't have a test file for this case yet but I think they"
" probably need to be rotated by 90 degrees."
)
elif len(offsets_y) == 1:
logger.info(f"{yip_dir.stem} is radially symmetric")
self.type = "1d"
elif np.min(offsets) >= 0:
logger.info(f"{yip_dir.stem} is quarterly symmetric")
self.type = "2dq"
# Check if 0 is included
if 0 not in offsets_x:
# Need to mirror the PSFs across the x axis for the interpolation
min_x = np.min(offsets_x)
# Add the mirrored offset to the original offsets
offsets_x = np.insert(offsets_x, 0, -min_x)
# Get all the PSFs that are at the minimum x value
min_x_psfs = psfs[offsets[:, 0] == min_x]
# Add the mirrored PSFs to the original PSFs
psfs = np.insert(psfs, 0, np.flip(min_x_psfs, axis=2), axis=0)
# Add the mirrored offset to the original offsets
new_offsets = np.array(np.meshgrid(-min_x, offsets_y)).T.reshape(-1, 2)
# Create an array of negative_min_x with the y offsets
offsets = np.insert(offsets, 0, new_offsets, axis=0)
if 0 not in offsets_y:
# Need to mirror the PSFs across the y axis for the interpolation
min_y = np.min(offsets_y)
offsets_y = np.insert(offsets_y, 0, -min_y)
# Get all the PSFs that are at the minimum y value
min_y_psfs = psfs[offsets[:, 1] == min_y]
# Add the mirrored PSFs to the original PSFs
psfs = np.insert(psfs, 0, np.flip(min_y_psfs, axis=1), axis=0)
# Add the mirrored offset to the original offsets
new_offsets = np.array(np.meshgrid(offsets_x, -min_y)).T.reshape(-1, 2)
# Create an array of negative_min_x with the y offsets
offsets = np.insert(offsets, 0, new_offsets, axis=0)
else:
logger.info(f"{yip_dir.stem} response is full 2D")
self.type = "2df"
# A lambda/D offset that represents the greatest separation where the PSF's
# center is within the image. This applies to all coronagraph types.
self.max_offset_in_image = psfs.shape[1] / 2 * u.pix * self.pixel_scale_arcsec
# Downsample PSFs if requested
if downsample_shape is not None:
original_shape = psfs.shape[1:]
logger.info(
f"Downsampling PSFs from {original_shape} to {downsample_shape}"
)
psfs, new_pixscale = downsample_psfs(
psfs, self.pixel_scale_arcsec.value, downsample_shape
)
# Update pixel scale with the new value (preserving units)
self.pixel_scale_arcsec = new_pixscale * self.pixel_scale_arcsec.unit
logger.info(f"New pixel scale: {self.pixel_scale_arcsec}")
# Update center positions for the new PSF shape
self.center_x = psfs.shape[2] / 2 * u.pix
self.center_y = psfs.shape[1] / 2 * u.pix
# Update max_offset_in_image after downsampling
self.max_offset_in_image = (
psfs.shape[1] / 2 * u.pix * self.pixel_scale_arcsec
)
self.flat_psfs = psfs
self.flat_offsets = offsets
self.n_psfs = len(offsets)
# Create index mapping from (x_idx, y_idx) -> flat_psfs index
# This avoids creating a sparse 4D reshaped_psfs array
self.x_inds = np.searchsorted(offsets_x, offsets[:, 0])
self.y_inds = np.searchsorted(offsets_y, offsets[:, 1])
# Create a 2D index map: offset_to_flat_idx[x_idx, y_idx] = flat_psf_index
# Use -1 to indicate no PSF at that (x_idx, y_idx) combination
self.offset_to_flat_idx = np.full(
(len(offsets_x), len(offsets_y)), -1, dtype=np.int32
)
for flat_idx in range(len(offsets)):
self.offset_to_flat_idx[self.x_inds[flat_idx], self.y_inds[flat_idx]] = (
flat_idx
)
self.x_offsets = offsets_x
self.y_offsets = offsets_y
# self.offsets = np.array(list(product(self.x_offsets, self.y_offsets)))
self.x_range = np.array([self.x_offsets[0], self.x_offsets[-1]])
self.y_range = np.array([self.y_offsets[0], self.y_offsets[-1]])
# Store PSF shape for convenience
self.psf_shape = self.flat_psfs.shape[1:]
[docs]
def get_psf_by_offset_idx(self, x_idx: int, y_idx: int):
"""Get PSF at the given offset indices.
Args:
x_idx (int):
Index into x_offsets array.
y_idx (int):
Index into y_offsets array.
Returns:
np.ndarray:
The PSF at the given offset indices, or None if no PSF exists
at that combination.
"""
flat_idx = self.offset_to_flat_idx[x_idx, y_idx]
if flat_idx >= 0:
return self.flat_psfs[flat_idx]
return None
[docs]
def create_psf(self, x: float, y: float):
"""Create and return the PSF at the specified off-axis position.
.. deprecated::
The pure-Python implementation is deprecated. Use
``Coronagraph`` (which uses ``OffJAX``) instead.
Interpolates and returns the Point Spread Function (PSF) at the specified
off-axis position (x, y). If the exact (x, y) position matches one of the
PSFs in the YIP, that PSF is returned directly. Otherwise, the PSFs
from the surrounding positions are combined using Gaussian weighting and
Fourier interpolation to produce an interpolated PSF.
Args:
x (float):
The x-coordinate of the off-axis position.
y (float):
The y-coordinate of the off-axis position.
Returns:
np.ndarray:
The interpolated PSF corresponding to the input (x, y) position.
Notes:
- If `self.type` is "1d", the (x, y) position is converted to a
radial separation and angle for interpolation.
- If `self.type` is "2dq", the (x, y) position is mirrored to the
first quadrant, and the PSF is flipped accordingly after
interpolation.
- Gaussian weighting is used to combine the nearest PSFs when the
exact (x, y) position does not match any precomputed PSF. The
weighting is based on the distance from the input position.
- The PSFs are shifted to align with the input position before
combining, and the final PSF is normalized by the cumulative weight
for each pixel.
"""
warnings.warn(
"OffAx.create_psf is deprecated; use Coronagraph (OffJAX) instead.",
DeprecationWarning,
stacklevel=2,
)
# Set default values
flip_lr, flip_ud = False, False
# Check for exact matches
if x in self.x_offsets and y in self.y_offsets:
x_ind = np.searchsorted(self.x_offsets, x)
y_ind = np.searchsorted(self.y_offsets, y)
flat_idx = self.offset_to_flat_idx[x_ind, y_ind]
if flat_idx >= 0:
return self.flat_psfs[flat_idx]
# Translate position based on type
if self.type == "1d":
flip_lr, flip_ud = x < 0, y < 0
sep = np.sqrt(x**2 + y**2)
_x, _y = sep, 0
elif self.type == "2dq":
flip_lr, flip_ud = x < 0, y < 0
_x, _y = abs(x), abs(y)
else:
_x, _y = x, y
# Get indices of nearest PSFs, in x and y directions
x_match = _x in self.x_offsets
_x_search = np.searchsorted(self.x_offsets, _x)
if x_match:
# If the x value is an exact match, we only need one index
x_inds = _x_search.reshape(-1)
else:
x_inds = np.array([_x_search - 1, _x_search])
y_match = _y in self.y_offsets
_y_search = np.searchsorted(self.y_offsets, _y)
if y_match:
# If the y value is an exact match, we only need one index
y_inds = _y_search.reshape(-1)
else:
y_inds = np.array([_y_search - 1, _y_search])
x_vals, y_vals = self.x_offsets[x_inds], self.y_offsets[y_inds]
# Get the indices of the nearest PSFs to the input (x, y)
near_inds = np.array(np.meshgrid(x_inds, y_inds)).T.reshape(-1, 2)
# Get the (x, y) offsets of the nearest PSFs to the input (x, y)
near_offsets = np.array(np.meshgrid(x_vals, y_vals)).T.reshape(-1, 2)
# Get the PSFs at the nearest offsets using the index mapping
flat_indices = self.offset_to_flat_idx[near_inds[:, 0], near_inds[:, 1]]
near_psfs = self.flat_psfs[flat_indices]
# Combine the PSFs
if len(near_psfs) > 1:
# Get the shift (in pixels) required to align with the input (x, y)
near_shifts = (
np.array([_x, _y]) - near_offsets
) / self.pixel_scale_arcsec.value
# Calculate the distance of each PSF from the input (x, y)
near_diffs = np.linalg.norm(near_shifts, axis=1)
# Gaussian weighting
sigma = 0.25
weights = np.exp(-(near_diffs**2) / (2 * sigma**2))
# Normalize the weights
weights /= weights.sum()
# Initialize the PSF array
psf = np.zeros_like(near_psfs[0])
# Initialize the weight array
# This weight system is used because shifting a PSF right by one pixel
# will leave a blank pixel on the left side of the image. The weight
# array keeps track of which PSFs have contributions for each pixel.
weight_array = np.zeros_like(psf)
for i, near_psf in enumerate(near_psfs):
shifted_psf = fft_shift(near_psf, *near_shifts[i])
weight_mask = create_shift_mask(near_psf, *near_shifts[i], weights[i])
# Add the weighted PSF to the total PSF
psf += weight_mask * shifted_psf
# Keep track of the weight for each pixel
weight_array += weight_mask
# Divide each pixel by its weight to get the final PSF
psf /= weight_array
else:
psf = near_psfs[0]
# Apply any necessary flips before shifting
if self.x_symmetric and flip_lr:
psf = np.fliplr(psf)
remaining_x_shift = x + _x
else:
remaining_x_shift = x - _x
if self.y_symmetric and flip_ud:
psf = np.flipud(psf)
remaining_y_shift = y + _y
else:
remaining_y_shift = y - _y
if remaining_x_shift != 0 or remaining_y_shift != 0:
psf = fft_shift(
psf,
remaining_x_shift / self.pixel_scale_arcsec.value,
remaining_y_shift / self.pixel_scale_arcsec.value,
)
return psf
[docs]
def create_psfs(self, x: NDArray, y: NDArray) -> NDArray:
"""Create and return the PSFs at the specified off-axis positions.
.. deprecated::
The pure-Python implementation is deprecated. Use
``Coronagraph`` (which uses ``OffJAX``) instead.
"""
warnings.warn(
"OffAx.create_psfs is deprecated; use Coronagraph (OffJAX) instead.",
DeprecationWarning,
stacklevel=2,
)
psfs = np.empty((len(x), *self.flat_psfs.shape[1:]))
for i in range(len(x)):
psfs[i] = self.create_psf(x[i], y[i])
return psfs
[docs]
def create_psfs_parallel(
self,
x: np.ndarray,
y: np.ndarray,
lam=None,
D=None,
dist=None,
workers: int = 4,
) -> np.ndarray:
"""Compute PSFs for batches of (x, y) arrays using multiprocessing.
.. deprecated::
The pure-Python implementation is deprecated. Use
``Coronagraph`` (which uses ``OffJAX``) instead.
Args:
x (np.ndarray):
Array of x positions.
y (np.ndarray):
Array of y positions.
lam (astropy.units.Quantity):
Wavelength of the observation
D (astropy.units.Quantity):
Diameter of the telescope
dist (astropy.units.Quantity):
Distance to the system
workers (int):
Number of parallel processes to use.
Returns:
np.ndarray:
A stacked numpy array of the computed PSFs with shape (N,
height, width), where N = len(x).
"""
if isinstance(x, Quantity):
# Convert the x and y positions to lambda/D if they are in pixels
if x.unit != lod:
x = convert_to_lod(
x, self.center_x, self.pixel_scale_arcsec, lam, D, dist
)
else:
x = x.value
if isinstance(y, Quantity):
if y.unit != lod:
y = convert_to_lod(
y, self.center_y, self.pixel_scale_arcsec, lam, D, dist
)
else:
y = y.value
# For each x[i], we create a column of identical x-values and the full y-array.
# This way, each process will handle (len(y)) points for that particular x[i].
args = []
for xi in x:
# Create an array filled with xi, same shape as y
x_col = np.full_like(y, xi)
args.append((x_col, y))
# Each process will handle a single (x_col, y) pair and run self.create_psfs
# which returns a 3D stack: shape = (len(y), height, width).
with Pool(processes=workers) as pool:
psf_rows = pool.starmap(self.create_psfs, args)
# psf_rows is now a list of arrays, each of shape (len(y), height, width)
# Stack them along a new axis to form (len(x), len(y), height, width)
psf_datacube = np.stack(psf_rows, axis=0)
return psf_datacube
[docs]
def __call__(
self, x: Quantity, y: Quantity, lam=None, D=None, dist=None
) -> NDArray:
"""Return the PSF at the given x/y position.
This function (via util.convert_to_lod) has the following assumptions
on the x/y values provided:
- If units are pixels, they follow the 00LL convention. As in the
(0,0) point is the lower left corner of the image.
- If the x/y values are in lambda/D, angular, or length units the
(0,0) point is the center of the image, where the star is
(hopefully) located.
Args:
x (astropy.units.Quantity):
x position. Can be either units of pixel, lod, an angular
unit (e.g. arcsec), or a length unit (e.g. AU)
y (astropy.units.Quantity):
y position. Can be either units of pixel, lod, an angular
unit (e.g. arcsec), or a length unit (e.g. AU)
lam (astropy.units.Quantity):
Wavelength of the observation
D (astropy.units.Quantity):
Diameter of the telescope
dist (astropy.units.Quantity):
Distance to the system
Returns:
NDArray:
The PSF at the given x/y position
"""
if isinstance(x, Quantity):
# Convert the x and y positions to lambda/D if they are in pixels
if x.unit != lod:
x = convert_to_lod(
x, self.center_x, self.pixel_scale_arcsec, lam, D, dist
)
x = x.value
if isinstance(y, Quantity):
if y.unit != lod:
y = convert_to_lod(
y, self.center_y, self.pixel_scale_arcsec, lam, D, dist
)
y = y.value
if np.isscalar(x) and np.isscalar(y):
return self.create_psf(x, y)
else:
return self.create_psfs(x, y)