"""JAX functions for image processing operations."""
from functools import partial
import jax.numpy as jnp
from hwoutils.fft import fft_shift_x, fft_shift_y
from jax import lax
[docs]
def create_shift_mask(psf, shift_x, shift_y, x_grid, y_grid, fill_val=1):
"""Create a mask to identify valid pixels to average.
This function is useful because when the PSF is shifted there are empty
pixels, since they were outside the initial image, and should not be
included in the final average. This version avoids lax.cond for better
performance under JIT.
Args:
psf (jax.numpy.ndarray):
The PSF image to shift.
shift_x (float):
The shift in the x direction.
shift_y (float):
The shift in the y direction.
x_grid (jax.numpy.ndarray):
The x-coordinate grid.
y_grid (jax.numpy.ndarray):
The y-coordinate grid.
fill_val (float, optional):
The value to fill the mask with.
Returns:
jax.numpy.ndarray:
The mask to identify valid pixels to average.
"""
height, width = psf.shape
# Create masks for x and y shifts using branchless arithmetic.
# This is much more efficient than using lax.cond.
# For positive x shift, valid pixels are where x_grid >= ceil(shift_x).
# For negative x shift, valid pixels are where x_grid < width - ceil(-shift_x).
shift_x_pos = jnp.maximum(0, shift_x)
shift_x_neg = jnp.maximum(0, -shift_x)
n_x_pos = jnp.ceil(shift_x_pos).astype(jnp.int32)
n_x_neg = jnp.ceil(shift_x_neg).astype(jnp.int32)
mask_x = (x_grid >= n_x_pos) & (x_grid < (width - n_x_neg))
# For positive y shift, valid pixels are where y_grid >= ceil(shift_y).
# For negative y shift, valid pixels are where y_grid < height - ceil(-shift_y).
shift_y_pos = jnp.maximum(0, shift_y)
shift_y_neg = jnp.maximum(0, -shift_y)
n_y_pos = jnp.ceil(shift_y_pos).astype(jnp.int32)
n_y_neg = jnp.ceil(shift_y_neg).astype(jnp.int32)
mask_y = (y_grid >= n_y_pos) & (y_grid < (height - n_y_neg))
# Combine the masks and multiply by the fill value
mask = mask_x & mask_y
return mask.astype(psf.dtype) * fill_val
[docs]
def get_near_inds_offsets_1D(
x_offsets: jnp.ndarray, y_offsets: jnp.ndarray, _x: float, _y: float
):
"""Computes the nearest indices and offsets for the given _x position in 1D.
Args:
x_offsets (jnp.ndarray): 1D array of x offsets, sorted in ascending order.
y_offsets (jnp.ndarray): 1D array of y offsets.
_x (float): The x-coordinate for which to find nearby offsets.
Returns:
Tuple[jnp.ndarray, jnp.ndarray]:
- near_inds: Array of shape (2,) containing the indices of the two
surrounding points.
- near_offsets: Array of shape (2,) containing the corresponding x offsets.
"""
# Find insertion index for _x
_x_ind = jnp.searchsorted(x_offsets, _x, side="left")
# Handle boundary conditions
x_ind_low = jnp.clip(_x_ind - 1, 0, x_offsets.size - 1)
x_ind_high = jnp.clip(_x_ind, 0, x_offsets.size - 1)
y_ind = 0
# Collect the two nearest indices
near_inds = jnp.array([[x_ind_low, y_ind], [x_ind_high, y_ind]])
# Extract the corresponding x offset values
x_vals_low = x_offsets[x_ind_low]
x_vals_high = x_offsets[x_ind_high]
y_val = y_offsets[y_ind]
near_offsets = jnp.array([[x_vals_low, y_val], [x_vals_high, y_val]])
return near_inds, near_offsets
[docs]
def create_avg_psf_1D(
x: float,
y: float,
pixel_scale_arcsec: float,
x_offsets: jnp.ndarray,
y_offsets: jnp.ndarray,
x_grid: jnp.ndarray,
y_grid: jnp.ndarray,
x_phasor: jnp.ndarray,
y_phasor: jnp.ndarray,
reshaped_psfs: jnp.ndarray,
):
"""Creates and returns the PSF at the specified off-axis position using JAX."""
# The core logic is similar to OffAx, but uses JAX operations
_x, _y = convert_xy_1D(x, y)
near_inds, near_offsets = get_near_inds_offsets_1D(x_offsets, y_offsets, _x, _y)
near_psfs = reshaped_psfs[near_inds[:, 0], near_inds[:, 1]]
near_shifts = (jnp.array([_x, _y]) - near_offsets) / pixel_scale_arcsec
near_diffs = jnp.linalg.norm(near_shifts, axis=1)
sigma = 0.25
# Adding a small value to avoid division by zero
weights = jnp.exp(-(near_diffs**2) / (2 * sigma**2)) + 1e-16
weights /= weights.sum()
# Manually shift each PSF
shifted_psf1, mask1 = shift_and_mask(
near_psfs[0],
near_shifts[0, 0],
near_shifts[0, 1],
weights[0],
x_grid,
y_grid,
x_phasor,
y_phasor,
)
shifted_psf2, mask2 = shift_and_mask(
near_psfs[1],
near_shifts[1, 0],
near_shifts[1, 1],
weights[1],
x_grid,
y_grid,
x_phasor,
y_phasor,
)
# Accumulate the weighted PSFs and the weight masks
psf = shifted_psf1 + shifted_psf2
weight_array = mask1 + mask2
# Normalize the PSF
safe_reciprocal = jnp.where(weight_array != 0, 1.0 / weight_array, 0.0)
psf = psf * safe_reciprocal
return psf
[docs]
def create_avg_psf_2DQ(
x: float,
y: float,
pixel_scale_arcsec: float,
x_offsets: jnp.ndarray,
y_offsets: jnp.ndarray,
x_grid: jnp.ndarray,
y_grid: jnp.ndarray,
x_phasor: jnp.ndarray,
y_phasor: jnp.ndarray,
reshaped_psfs: jnp.ndarray,
):
"""Creates and returns the PSF at the specified off-axis position using JAX."""
# The core logic is similar to OffAx, but uses JAX operations
_x, _y = convert_xy_2DQ(x, y)
near_inds, near_offsets = get_near_inds_offsets_2D(x_offsets, y_offsets, _x, _y)
near_psfs = reshaped_psfs[near_inds[:, 0], near_inds[:, 1]]
near_shifts = (jnp.array([_x, _y]) - near_offsets) / pixel_scale_arcsec
near_diffs = jnp.linalg.norm(near_shifts, axis=1)
sigma = 0.25
# Adding a small value to avoid division by zero
weights = jnp.exp(-(near_diffs**2) / (2 * sigma**2)) + 1e-16
weights /= weights.sum()
# Manually shift each PSF
shifted_psf1, mask1 = shift_and_mask(
near_psfs[0],
near_shifts[0, 0],
near_shifts[0, 1],
weights[0],
x_grid,
y_grid,
x_phasor,
y_phasor,
)
shifted_psf2, mask2 = shift_and_mask(
near_psfs[1],
near_shifts[1, 0],
near_shifts[1, 1],
weights[1],
x_grid,
y_grid,
x_phasor,
y_phasor,
)
shifted_psf3, mask3 = shift_and_mask(
near_psfs[2],
near_shifts[2, 0],
near_shifts[2, 1],
weights[2],
x_grid,
y_grid,
x_phasor,
y_phasor,
)
shifted_psf4, mask4 = shift_and_mask(
near_psfs[3],
near_shifts[3, 0],
near_shifts[3, 1],
weights[3],
x_grid,
y_grid,
x_phasor,
y_phasor,
)
# Accumulate the weighted PSFs and the weight masks
psf = shifted_psf1 + shifted_psf2 + shifted_psf3 + shifted_psf4
weight_array = mask1 + mask2 + mask3 + mask4
safe_reciprocal = jnp.where(weight_array != 0, 1.0 / weight_array, 0.0)
psf = psf * safe_reciprocal
return psf
[docs]
def get_near_inds_offsets_2D(
x_offsets: jnp.ndarray, y_offsets: jnp.ndarray, _x: float, _y: float
):
"""Computes the nearest indices and offsets for the given (_x, _y) position in 2D.
Args:
x_offsets (jnp.ndarray): 1D array of x offsets, sorted in ascending order.
y_offsets (jnp.ndarray): 1D array of y offsets, sorted in ascending order.
_x (float): The x-coordinate for which to find nearby offsets.
_y (float): The y-coordinate for which to find nearby offsets.
Returns:
Tuple[jnp.ndarray, jnp.ndarray]:
- near_inds: Array of shape (4, 2) containing the indices of the
four surrounding points.
- near_offsets: Array of shape (4, 2) containing the corresponding
(x, y) offsets.
"""
# Find insertion indices for _x and _y
_x_ind = jnp.searchsorted(x_offsets, _x, side="left")
_y_ind = jnp.searchsorted(y_offsets, _y, side="left")
# Handle boundary conditions for x indices
x_ind_low = jnp.clip(_x_ind - 1, 0, x_offsets.size - 1)
x_ind_high = jnp.clip(_x_ind, 0, x_offsets.size - 1)
# Handle boundary conditions for y indices
y_ind_low = jnp.clip(_y_ind - 1, 0, y_offsets.size - 1)
y_ind_high = jnp.clip(_y_ind, 0, y_offsets.size - 1)
# Collect the two nearest indices for x and y
x_inds = jnp.array([x_ind_low, x_ind_high])
y_inds = jnp.array([y_ind_low, y_ind_high])
# Manually create the four combinations without using meshgrid
near_inds = jnp.array(
[
[x_inds[0], y_inds[0]],
[x_inds[0], y_inds[1]],
[x_inds[1], y_inds[0]],
[x_inds[1], y_inds[1]],
]
)
# Extract the corresponding x and y offset values
x_vals_low = x_offsets[x_inds[0]]
x_vals_high = x_offsets[x_inds[1]]
y_vals_low = y_offsets[y_inds[0]]
y_vals_high = y_offsets[y_inds[1]]
# Manually create the corresponding (x, y) offset combinations
near_offsets = jnp.array(
[
[x_vals_low, y_vals_low],
[x_vals_low, y_vals_high],
[x_vals_high, y_vals_low],
[x_vals_high, y_vals_high],
]
)
return near_inds, near_offsets
[docs]
def shift_and_mask(
near_psf, shift_x, shift_y, weight, x_grid, y_grid, x_phasor, y_phasor
):
"""Shifts the PSF in x and y directions and applies a weight mask.
Args:
near_psf (jax.numpy.ndarray): The PSF image to shift.
shift_x (float): Shift in the x-direction.
shift_y (float): Shift in the y-direction.
weight (float): Weight for the PSF.
x_grid (jax.numpy.ndarray): The x-coordinate grid.
y_grid (jax.numpy.ndarray): The y-coordinate grid.
x_phasor (jax.numpy.ndarray): Precomputed components for the Fourier shift.
y_phasor (jax.numpy.ndarray): Precomputed components for the Fourier shift.
Returns:
Tuple[jax.numpy.ndarray, jax.numpy.ndarray]:
- Weighted shifted PSF.
- Weight mask.
"""
# Shift the PSF in x and y directions
shifted_psf = fft_shift_x(near_psf, shift_x, x_phasor)
shifted_psf = fft_shift_y(shifted_psf, shift_y, y_phasor)
# Create the weight mask
weight_mask = create_shift_mask(
near_psf, shift_x, shift_y, x_grid, y_grid, fill_val=weight
)
# Apply the weight mask
weighted_psf = weight_mask * shifted_psf
return weighted_psf, weight_mask
[docs]
def convert_xy_1D(x, y):
"""Converts x and y to 1D coordinates."""
return jnp.sqrt(x**2 + y**2), 0
[docs]
def convert_xy_2DQ(x, y):
"""Converts x and y to 2D quarter symmetric coordinates."""
return jnp.abs(x), jnp.abs(y)
[docs]
def convert_xy_2D(x, y):
"""Converts x and y to 2D coordinates."""
return x, y
[docs]
def basic_shift_val(input_val, converted_val, pixel_scale_arcsec):
"""Calculates the shift in pixels for a basic shift."""
return (input_val - converted_val) / pixel_scale_arcsec
[docs]
def sym_shift_val(input_val, converted_val, pixel_scale_arcsec):
"""Calculates the shift in pixels for a symmetric shift."""
sign = jnp.where(input_val >= 0, 1.0, -1.0)
return (input_val - sign * converted_val) / pixel_scale_arcsec
[docs]
def x_basic_shift(input_val, converted_val, PSF, pixel_scale_arcsec, x_phasor):
"""Shifts the PSF to the specified x position."""
shift = basic_shift_val(input_val, converted_val, pixel_scale_arcsec)
return fft_shift_x(PSF, shift, x_phasor)
[docs]
def y_basic_shift(input_val, converted_val, PSF, pixel_scale_arcsec, y_phasor):
"""Shifts the PSF to the specified y position."""
shift = basic_shift_val(input_val, converted_val, pixel_scale_arcsec)
return fft_shift_y(PSF, shift, y_phasor)
[docs]
def x_symmetric_shift(input_val, converted_val, PSF, pixel_scale_arcsec, x_phasor):
"""Shifts the PSF to the specified position assuming symmetry about x=0."""
# Apply a horizontal flip if the input value is negative
_PSF = jnp.where(input_val < 0, jnp.fliplr(PSF), PSF)
# Calculate the distance to shift the PSF
shift = sym_shift_val(input_val, converted_val, pixel_scale_arcsec)
# Apply the shift
return fft_shift_x(_PSF, shift, x_phasor)
[docs]
def y_symmetric_shift(input_val, converted_val, PSF, pixel_scale_arcsec, y_phasor):
"""Shifts the PSF to the specified position assuming symmetry about y=0."""
# Apply a vertical flip if the input value is negative
_PSF = jnp.where(input_val < 0, jnp.flipud(PSF), PSF)
# Get the distance to shift the PSF
shift = sym_shift_val(input_val, converted_val, pixel_scale_arcsec)
# Apply the shift
return fft_shift_y(_PSF, shift, y_phasor)
[docs]
def fft_rotate_jax(image, rot_deg):
"""Rotate an image by a specified angle using Fourier-based shear operations.
This function performs an image rotation by decomposing the rotation into
three sequential shear operations in the Fourier domain. For more details
see Larkin et al. (1997).
Args:
image (jax.numpy.ndarray):
The input image to be rotated.
rot_deg (float):
The rotation angle in degrees. Positive values rotate the image
counterclockwise, and negative values rotate it clockwise.
Returns:
jax.numpy.ndarray:
The rotated image.
"""
# To rotate counterclockwise, with the origin in the lower left, we use the
# negative of the angle
rot_deg = -rot_deg
# Cut the angle to (-45, 45] and a number of 90-degree rotations
rot_deg, n_rot = decompose_angle_jax(rot_deg)
image = rot90_traceable(image, k=n_rot)
image = lax.cond(
rot_deg != 0.0,
lambda x: rotate_with_shear(image, x),
lambda x: image,
rot_deg,
)
return image
[docs]
def rotate_with_shear(image, rot_deg):
"""Rotate an image by a specified angle using Fourier-based shear operations.
This is a helper function that simplifies the fft_rotate_jax function by
simplifying the lambda function used in the lax.cond call.
Args:
image (jax.numpy.ndarray):
The input image to be rotated.
rot_deg (float):
The rotation angle in degrees.
Returns:
jax.numpy.ndarray:
The rotated image.
"""
theta = jnp.deg2rad(rot_deg)
a = jnp.tan(theta / 2)
b = -jnp.sin(theta)
x_freqs, x_dists, y_freqs, y_dists = fft_shear_setup(image)
# Rotate using three shears
# s_x
image = fft_shear_x(image, a, x_freqs, x_dists)
# s_yx
image = fft_shear_y(image, b, y_freqs, y_dists)
# s_xyx
image = fft_shear_x(image, a, x_freqs, x_dists)
return image
[docs]
def fft_shear_setup(image):
"""Perform a shear operation in the Fourier domain.
Args:
image (jax.numpy.ndarray):
The input image to be sheared.
Returns:
tuple:
- jax.numpy.ndarray: x frequencies used for the Fourier transform.
- jax.numpy.ndarray: x distances from the center of the image.
- jax.numpy.ndarray: y frequencies used for the Fourier transform.
- jax.numpy.ndarray: y distances from the center of the image
"""
# Calculate padding size based on the image dimensions
n_pixels = image.shape[0]
n_pad = int(1.5 * n_pixels)
# Pad the image with zeros
padded = jnp.pad(image, n_pad, mode="constant")
# Calculate the coordinate array for the padded image
padded_height, padded_width = padded.shape
center_y, center_x = (jnp.array(padded.shape) - 1) / 2
grid_y, grid_x = jnp.mgrid[0:padded_height, 0:padded_width]
# Array of distances from the center of the image along the shear axis
# if axis == 1:
# Shearing along the horizontal axis
x_dists = grid_x - center_x
x_perpendicular_axis = 1
# Compute the Fourier frequencies for the dimension perpendicular to the shear axis
x_freqs = jnp.fft.fftfreq(x_dists.shape[x_perpendicular_axis])
x_freqs = jnp.fft.fftshift(x_freqs)
# Tile the shifted frequencies to match the dimensions of the padded image
x_freqs = jnp.tile(x_freqs, (x_dists.shape[1], 1)).T
# Shearing along the vertical axis
y_dists = grid_y - center_y
# Determine the perpendicular axis to the shear direction
y_perpendicular_axis = 0
y_freqs = jnp.fft.fftfreq(y_dists.shape[y_perpendicular_axis])
y_freqs = jnp.fft.fftshift(y_freqs)
y_freqs = jnp.tile(y_freqs, (y_dists.shape[0], 1))
return x_freqs, x_dists, y_freqs, y_dists
[docs]
def fft_shear_x(image, shear_factor, x_freqs, x_dists):
"""Perform a shear operation in the Fourier domain along the x-axis.
Uses JAX functions to perform the shear operation in the Fourier domain
along the x-axis.
Args:
image (jax.numpy.ndarray):
The input image to be sheared.
shear_factor (float):
The shear factor.
x_freqs (jax.numpy.ndarray):
x frequencies used for the Fourier transform.
x_dists (jax.numpy.ndarray):
x distances from the center of the image.
Returns:
jax.numpy.ndarray:
The sheared image with the zero padding removed.
"""
# Calculate padding size based on the image dimensions
n_pixels = image.shape[0]
n_pad = int(1.5 * n_pixels)
img_edge = n_pad + n_pixels
# Pad the image with zeros
padded = jnp.pad(image, n_pad, mode="constant")
padded = jnp.fft.fftshift(padded)
padded = jnp.fft.fftshift(jnp.fft.fft(padded, axis=1))
# Apply the phase shift (shear) in the Fourier domain
padded = jnp.exp(-2j * jnp.pi * shear_factor * x_freqs * x_dists) * padded
# Shift back and apply the inverse Fourier transform along the specified axis
padded = jnp.fft.fftshift(padded)
padded = jnp.fft.ifft(padded, axis=1)
padded = jnp.fft.fftshift(padded)
# Unpad the image to return to the original size
image = jnp.real(padded[n_pad:img_edge, n_pad:img_edge])
return image
[docs]
def fft_shear_y(image, shear_factor, y_freqs, y_dists):
"""Perform a shear operation in the Fourier domain along the y-axis.
Uses JAX operations.
Args:
image (jax.numpy.ndarray):
The input image to be sheared.
shear_factor (float):
The shear factor.
y_freqs (jax.numpy.ndarray):
y frequencies used for the Fourier transform.
y_dists (jax.numpy.ndarray):
y distances from the center of the image.
Returns:
jax.numpy.ndarray:
The sheared image with the zero padding removed.
"""
# Calculate padding size based on the image dimensions
n_pixels = image.shape[0]
n_pad = int(1.5 * n_pixels)
img_edge = n_pad + n_pixels
# Pad the image with zeros
padded = jnp.pad(image, n_pad, mode="constant")
padded = jnp.fft.fftshift(padded)
padded = jnp.fft.fftshift(jnp.fft.fft(padded, axis=0))
# Apply the phase shift (shear) in the Fourier domain
padded = jnp.exp(-2j * jnp.pi * shear_factor * y_freqs * y_dists) * padded
# Shift back and apply the inverse Fourier transform along the specified axis
padded = jnp.fft.fftshift(padded)
padded = jnp.fft.ifft(padded, axis=0)
padded = jnp.fft.fftshift(padded)
# Unpad the image to return to the original size
image = jnp.real(padded[n_pad:img_edge, n_pad:img_edge])
return image
[docs]
def decompose_angle_jax(angle):
"""Decompose an angle from [0, 360) to (-45, 45] and 90 degree rotations.
Args:
angle (float):
The input angle in degrees.
Returns:
tuple:
- float: The rotation angle in the range (-45, 45] degrees.
- int: The number of 90-degree rotations to apply.
"""
# Normalize the angle to [0, 360)
angle = angle % 360
# Determine the number of 90-degree rotations
n_rot = (angle // 90).astype(int)
# Cut the angle to [0, 90)
adjusted_angle = angle % 90
# Adjust the angle to the range (-45, 45]
adjusted_angle, n_rot = lax.cond(
adjusted_angle > 45,
lambda x: (x - 90, n_rot + 1),
lambda x: (x, n_rot),
adjusted_angle,
)
# if n_rot is 4, set it to 0 to avoid unnecessary rotations
# this occurs when the angle is in the range (315, 360)
n_rot = lax.cond(
n_rot == 4,
lambda x: 0,
lambda x: x,
n_rot,
)
return adjusted_angle, n_rot
[docs]
def rot90_traceable(m, k=1, axes=(0, 1)):
"""Rotate an array by 90 degrees in the plane specified by axes.
This function is a traceable version of `numpy.rot90` taken from the
jax GitHub issues.
Args:
m (jax.numpy.ndarray):
The input array to be rotated.
k (int, optional):
The number of 90-degree rotations to apply.
axes (tuple, optional):
The axes to rotate the array in.
Returns:
jax.numpy.ndarray:
The rotated array.
"""
k %= 4
return lax.switch(k, [partial(jnp.rot90, m, k=i, axes=axes) for i in range(4)])
[docs]
def synthesize_psf_separable(
x,
y,
pixel_scale_arcsec,
flat_psfs,
x_offsets,
y_offsets,
offset_to_flat_idx,
kx,
ky,
n_pad,
x_symmetric,
y_symmetric,
input_type,
max_offset,
):
"""Synthesizes PSF using separable 1D FFTs.
Optimized for speed and stability with even-sized padding.
Uses flat_psfs with offset_to_flat_idx mapping for memory efficiency.
Args:
x (float):
X coordinate in lambda/D.
y (float):
Y coordinate in lambda/D.
pixel_scale_arcsec (float):
Pixel scale in lambda/D per pixel.
flat_psfs (jnp.ndarray):
Flat array of PSFs with shape (N_psfs, H, W).
x_offsets (jnp.ndarray):
Sorted array of unique x offsets.
y_offsets (jnp.ndarray):
Sorted array of unique y offsets.
offset_to_flat_idx (jnp.ndarray):
2D array mapping (x_idx, y_idx) -> flat_psfs index.
kx (jnp.ndarray):
FFT frequencies for x-axis.
ky (jnp.ndarray):
FFT frequencies for y-axis.
n_pad (int):
Number of padding pixels.
x_symmetric (bool):
Whether PSFs are symmetric about x=0.
y_symmetric (bool):
Whether PSFs are symmetric about y=0.
input_type (str):
Type of input data ("1d", "2dq", or "2df").
max_offset (float):
Maximum valid offset distance (-1 to disable bounds check).
Returns:
jnp.ndarray:
Synthesized PSF at the requested (x, y) position.
"""
# Neighbor lookup
if input_type == "1d":
_x, _y = convert_xy_1D(x, y)
inds, offsets = get_near_inds_offsets_1D(x_offsets, y_offsets, _x, _y)
else:
_x, _y = convert_xy_2DQ(x, y)
inds, offsets = get_near_inds_offsets_2D(x_offsets, y_offsets, _x, _y)
# Use the index mapping to get flat_psfs indices from (x_idx, y_idx) pairs
flat_indices = offset_to_flat_idx[inds[:, 0], inds[:, 1]]
neighbors = flat_psfs[flat_indices]
# We calculate weights based on the relative position of _x/_y
# within the bounding box of the neighbors found.
if input_type == "1d":
# Neighbors are: [x_low, 0], [x_high, 0]
x0 = offsets[0, 0]
x1 = offsets[1, 0]
# Span of the current grid cell
span_x = x1 - x0
# Avoid division by zero if query lands exactly on a node or grid is 1 point
span_x = jnp.where(span_x == 0, 1.0, span_x)
# Normalize distance (0.0 to 1.0)
u = (_x - x0) / span_x
# Linear weights: w0 = (1-u), w1 = u
# shape (2,)
weights = jnp.array([1.0 - u, u])
# Safety: If span was 0, force weight to [1, 0]
weights = jnp.where(x1 == x0, jnp.array([1.0, 0.0]), weights)
else: # 2D Case
# Neighbors order from get_near_inds_offsets_2D is:
# 0: (x0, y0), 1: (x0, y1), 2: (x1, y0), 3: (x1, y1)
x0 = offsets[0, 0]
x1 = offsets[3, 0] # dim 0 is x
y0 = offsets[0, 1]
y1 = offsets[3, 1] # dim 1 is y
span_x = x1 - x0
span_y = y1 - y0
# Avoid div by zero
span_x = jnp.where(span_x == 0, 1.0, span_x)
span_y = jnp.where(span_y == 0, 1.0, span_y)
# Normalize coordinates (0.0 to 1.0)
u = (_x - x0) / span_x
v = (_y - y0) / span_y
# Bilinear Weights
# w00 = (1-u)(1-v)
# w01 = (1-u)v
# w10 = u(1-v)
# w11 = uv
w0 = (1.0 - u) * (1.0 - v)
w1 = (1.0 - u) * v
w2 = u * (1.0 - v)
w3 = u * v
weights = jnp.array([w0, w1, w2, w3])
# Safety for coincident nodes
weights = jnp.where(
(x1 == x0) & (y1 == y0), jnp.array([1.0, 0.0, 0.0, 0.0]), weights
)
# Symmetry logic
eff_offsets_x = offsets[:, 0]
eff_offsets_y = offsets[:, 1]
if x_symmetric:
eff_offsets_x = jnp.where(x < 0, -offsets[:, 0], offsets[:, 0])
neighbors = lax.cond(
x < 0, lambda p: jnp.flip(p, axis=2), lambda p: p, neighbors
)
if y_symmetric:
eff_offsets_y = jnp.where(y < 0, -offsets[:, 1], offsets[:, 1])
neighbors = lax.cond(
y < 0, lambda p: jnp.flip(p, axis=1), lambda p: p, neighbors
)
# Calculate Shift (in pixels)
shift_x = (x - eff_offsets_x) / pixel_scale_arcsec
shift_y = (y - eff_offsets_y) / pixel_scale_arcsec
# Pad (using consistent n_pad from init)
# (K, H, W) -> (K, H_pad, W_pad)
pad_width = ((0, 0), (n_pad, n_pad), (n_pad, n_pad))
neighbors = jnp.pad(neighbors, pad_width)
# Capture width for robust reconstruction
_, _h_pad, w_pad = neighbors.shape
# FFT X (Real-to-Complex, Axis 2)
# Output: (K, H_pad, W_pad//2 + 1)
spec = jnp.fft.rfft(neighbors, axis=2)
# Apply X-Shift
# kx: (W_freq,). shift_x: (K,).
# Broadcast: (K, 1, W_freq)
phasor_x = jnp.exp(-2j * jnp.pi * shift_x[:, None, None] * kx[None, None, :])
spec = spec * phasor_x
# FFT Y (Complex-to-Complex, Axis 1)
spec = jnp.fft.fft(spec, axis=1)
# Apply Y-Shift
# ky: (H_pad,). shift_y: (K,).
# Broadcast: (K, H_pad, 1)
phasor_y = jnp.exp(-2j * jnp.pi * shift_y[:, None, None] * ky[None, :, None])
spec = spec * phasor_y
# Weighted Sum (Collapse Neighbors)
# (K, H_freq, W_freq) -> (H_freq, W_freq)
summed_spec = jnp.sum(spec * weights[:, None, None], axis=0)
# IFFT Y
res = jnp.fft.ifft(summed_spec, axis=0)
# IFFT X (Complex-to-Real)
# Pass n=w_pad to ensure we reconstruct to the exact even size
res = jnp.fft.irfft(res, n=w_pad, axis=1)
# Unpad
h_orig = flat_psfs.shape[-2]
psf = res[n_pad : n_pad + h_orig, n_pad : n_pad + h_orig]
psf = jnp.maximum(psf, 0.0)
# Mask if out of bounds
# Only currently implemented for 1D (radial symmetry)
# max_offset must be provided (> 0) to enable masking
psf = lax.cond(
(input_type == "1d") & (max_offset > 0),
lambda p: jnp.where(jnp.sqrt(x**2 + y**2) > max_offset, 0.0, p),
lambda p: p,
psf,
)
return psf
[docs]
def synthesize_psf_idw(
x,
y,
pixel_scale_arcsec,
flat_psfs,
flat_x_offsets,
flat_y_offsets,
kx,
ky,
n_pad,
x_symmetric,
y_symmetric,
input_type,
k_neighbors=4,
):
"""Synthesizes PSF using Power-2 Polar IDW for irregular grids.
Distance is the tangent-plane arc length at the query,
``d = sqrt(dr**2 + (r_q * dtheta)**2)``, and weights use ``1 / d**2``.
The polar metric matches the radial/azimuthal anisotropy of coronagraph
PSF structure, and the squared exponent acts as a soft nearest-neighbor
on smooth grids. See yippy-paper Section 4 for the kernel-selection
rationale.
"""
# Handle Symmetry (Map to 1st Quadrant if 2DQ)
if input_type == "2dq":
_x, _y = convert_xy_2DQ(x, y)
elif input_type == "1d":
_x, _y = convert_xy_1D(x, y)
else:
_x, _y = x, y
# Polar arc-length distance, linearized at the query.
r_q = jnp.hypot(_x, _y)
theta_q = jnp.arctan2(_y, _x)
r_n = jnp.hypot(flat_x_offsets, flat_y_offsets)
theta_n = jnp.arctan2(flat_y_offsets, flat_x_offsets)
dr = r_n - r_q
dtheta = (theta_n - theta_q + jnp.pi) % (2.0 * jnp.pi) - jnp.pi
arc = r_q * dtheta
dists = jnp.sqrt(dr**2 + arc**2)
# Find K Nearest Neighbors
# We use top_k on negative distance to find the smallest distances
neg_dists, inds = lax.top_k(-dists, k_neighbors)
near_dists = -neg_dists
# Get Neighbors
neighbors = flat_psfs[inds]
# Power-2 inverse-distance weights with a small epsilon to avoid /0.
weights = 1.0 / (near_dists + 1e-10) ** 2
weights = weights / jnp.sum(weights)
# Apply Symmetry Flips to neighbors
eff_offsets_x = flat_x_offsets[inds]
eff_offsets_y = flat_y_offsets[inds]
if x_symmetric:
# Check against original x
neighbors = lax.cond(
x < 0, lambda p: jnp.flip(p, axis=2), lambda p: p, neighbors
)
eff_offsets_x = jnp.where(x < 0, -eff_offsets_x, eff_offsets_x)
if y_symmetric:
# Check against original y
neighbors = lax.cond(
y < 0, lambda p: jnp.flip(p, axis=1), lambda p: p, neighbors
)
eff_offsets_y = jnp.where(y < 0, -eff_offsets_y, eff_offsets_y)
# Calculate Shift (in pixels)
shift_x = (x - eff_offsets_x) / pixel_scale_arcsec
shift_y = (y - eff_offsets_y) / pixel_scale_arcsec
# Pad and FFT (Standard Shift Logic)
pad_width = ((0, 0), (n_pad, n_pad), (n_pad, n_pad))
neighbors = jnp.pad(neighbors, pad_width)
# FFT X
spec = jnp.fft.rfft(neighbors, axis=2)
phasor_x = jnp.exp(-2j * jnp.pi * shift_x[:, None, None] * kx[None, None, :])
spec = spec * phasor_x
# FFT Y
spec = jnp.fft.fft(spec, axis=1)
phasor_y = jnp.exp(-2j * jnp.pi * shift_y[:, None, None] * ky[None, :, None])
spec = spec * phasor_y
# Weighted Sum
summed_spec = jnp.sum(spec * weights[:, None, None], axis=0)
# IFFT
res = jnp.fft.ifft(summed_spec, axis=0)
res = jnp.fft.irfft(res, n=neighbors.shape[-1], axis=1)
# Unpad
h_orig = flat_psfs.shape[-2]
psf = res[n_pad : n_pad + h_orig, n_pad : n_pad + h_orig]
return jnp.maximum(psf, 0.0)