yippy.jax_funcs#

JAX functions for image processing operations.

Functions#

create_shift_mask(psf, shift_x, shift_y, x_grid, y_grid)

Create a mask to identify valid pixels to average.

get_near_inds_offsets_1D(x_offsets, y_offsets, _x, _y)

Computes the nearest indices and offsets for the given _x position in 1D.

create_avg_psf_1D(x, y, pixel_scale_arcsec, x_offsets, ...)

Creates and returns the PSF at the specified off-axis position using JAX.

create_avg_psf_2DQ(x, y, pixel_scale_arcsec, ...)

Creates and returns the PSF at the specified off-axis position using JAX.

get_near_inds_offsets_2D(x_offsets, y_offsets, _x, _y)

Computes the nearest indices and offsets for the given (_x, _y) position in 2D.

shift_and_mask(near_psf, shift_x, shift_y, weight, ...)

Shifts the PSF in x and y directions and applies a weight mask.

convert_xy_1D(x, y)

Converts x and y to 1D coordinates.

convert_xy_2DQ(x, y)

Converts x and y to 2D quarter symmetric coordinates.

convert_xy_2D(x, y)

Converts x and y to 2D coordinates.

basic_shift_val(input_val, converted_val, ...)

Calculates the shift in pixels for a basic shift.

sym_shift_val(input_val, converted_val, pixel_scale_arcsec)

Calculates the shift in pixels for a symmetric shift.

x_basic_shift(input_val, converted_val, PSF, ...)

Shifts the PSF to the specified x position.

y_basic_shift(input_val, converted_val, PSF, ...)

Shifts the PSF to the specified y position.

x_symmetric_shift(input_val, converted_val, PSF, ...)

Shifts the PSF to the specified position assuming symmetry about x=0.

y_symmetric_shift(input_val, converted_val, PSF, ...)

Shifts the PSF to the specified position assuming symmetry about y=0.

fft_rotate_jax(image, rot_deg)

Rotate an image by a specified angle using Fourier-based shear operations.

rotate_with_shear(image, rot_deg)

Rotate an image by a specified angle using Fourier-based shear operations.

fft_shear_setup(image)

Perform a shear operation in the Fourier domain.

fft_shear_x(image, shear_factor, x_freqs, x_dists)

Perform a shear operation in the Fourier domain along the x-axis.

fft_shear_y(image, shear_factor, y_freqs, y_dists)

Perform a shear operation in the Fourier domain along the y-axis.

decompose_angle_jax(angle)

Decompose an angle from [0, 360) to (-45, 45] and 90 degree rotations.

rot90_traceable(m[, k, axes])

Rotate an array by 90 degrees in the plane specified by axes.

synthesize_psf_separable(x, y, pixel_scale_arcsec, ...)

Synthesizes PSF using separable 1D FFTs.

synthesize_psf_idw(x, y, pixel_scale_arcsec, ...[, ...])

Synthesizes PSF using Power-2 Polar IDW for irregular grids.

Module Contents#

yippy.jax_funcs.create_shift_mask(psf, shift_x, shift_y, x_grid, y_grid, fill_val=1)[source]#

Create a mask to identify valid pixels to average.

This function is useful because when the PSF is shifted there are empty pixels, since they were outside the initial image, and should not be included in the final average. This version avoids lax.cond for better performance under JIT.

Args:
psf (jax.numpy.ndarray):

The PSF image to shift.

shift_x (float):

The shift in the x direction.

shift_y (float):

The shift in the y direction.

x_grid (jax.numpy.ndarray):

The x-coordinate grid.

y_grid (jax.numpy.ndarray):

The y-coordinate grid.

fill_val (float, optional):

The value to fill the mask with.

Returns:
jax.numpy.ndarray:

The mask to identify valid pixels to average.

yippy.jax_funcs.get_near_inds_offsets_1D(x_offsets, y_offsets, _x, _y)[source]#

Computes the nearest indices and offsets for the given _x position in 1D.

Args:

x_offsets (jnp.ndarray): 1D array of x offsets, sorted in ascending order. y_offsets (jnp.ndarray): 1D array of y offsets. _x (float): The x-coordinate for which to find nearby offsets.

Returns:
Tuple[jnp.ndarray, jnp.ndarray]:
  • near_inds: Array of shape (2,) containing the indices of the two surrounding points.

  • near_offsets: Array of shape (2,) containing the corresponding x offsets.

Parameters:
  • x_offsets (jax.numpy.ndarray)

  • y_offsets (jax.numpy.ndarray)

  • _x (float)

  • _y (float)

yippy.jax_funcs.create_avg_psf_1D(x, y, pixel_scale_arcsec, x_offsets, y_offsets, x_grid, y_grid, x_phasor, y_phasor, reshaped_psfs)[source]#

Creates and returns the PSF at the specified off-axis position using JAX.

Parameters:
  • x (float)

  • y (float)

  • pixel_scale_arcsec (float)

  • x_offsets (jax.numpy.ndarray)

  • y_offsets (jax.numpy.ndarray)

  • x_grid (jax.numpy.ndarray)

  • y_grid (jax.numpy.ndarray)

  • x_phasor (jax.numpy.ndarray)

  • y_phasor (jax.numpy.ndarray)

  • reshaped_psfs (jax.numpy.ndarray)

yippy.jax_funcs.create_avg_psf_2DQ(x, y, pixel_scale_arcsec, x_offsets, y_offsets, x_grid, y_grid, x_phasor, y_phasor, reshaped_psfs)[source]#

Creates and returns the PSF at the specified off-axis position using JAX.

Parameters:
  • x (float)

  • y (float)

  • pixel_scale_arcsec (float)

  • x_offsets (jax.numpy.ndarray)

  • y_offsets (jax.numpy.ndarray)

  • x_grid (jax.numpy.ndarray)

  • y_grid (jax.numpy.ndarray)

  • x_phasor (jax.numpy.ndarray)

  • y_phasor (jax.numpy.ndarray)

  • reshaped_psfs (jax.numpy.ndarray)

yippy.jax_funcs.get_near_inds_offsets_2D(x_offsets, y_offsets, _x, _y)[source]#

Computes the nearest indices and offsets for the given (_x, _y) position in 2D.

Args:

x_offsets (jnp.ndarray): 1D array of x offsets, sorted in ascending order. y_offsets (jnp.ndarray): 1D array of y offsets, sorted in ascending order. _x (float): The x-coordinate for which to find nearby offsets. _y (float): The y-coordinate for which to find nearby offsets.

Returns:
Tuple[jnp.ndarray, jnp.ndarray]:
  • near_inds: Array of shape (4, 2) containing the indices of the four surrounding points.

  • near_offsets: Array of shape (4, 2) containing the corresponding (x, y) offsets.

Parameters:
  • x_offsets (jax.numpy.ndarray)

  • y_offsets (jax.numpy.ndarray)

  • _x (float)

  • _y (float)

yippy.jax_funcs.shift_and_mask(near_psf, shift_x, shift_y, weight, x_grid, y_grid, x_phasor, y_phasor)[source]#

Shifts the PSF in x and y directions and applies a weight mask.

Args:

near_psf (jax.numpy.ndarray): The PSF image to shift. shift_x (float): Shift in the x-direction. shift_y (float): Shift in the y-direction. weight (float): Weight for the PSF. x_grid (jax.numpy.ndarray): The x-coordinate grid. y_grid (jax.numpy.ndarray): The y-coordinate grid. x_phasor (jax.numpy.ndarray): Precomputed components for the Fourier shift. y_phasor (jax.numpy.ndarray): Precomputed components for the Fourier shift.

Returns:
Tuple[jax.numpy.ndarray, jax.numpy.ndarray]:
  • Weighted shifted PSF.

  • Weight mask.

yippy.jax_funcs.convert_xy_1D(x, y)[source]#

Converts x and y to 1D coordinates.

yippy.jax_funcs.convert_xy_2DQ(x, y)[source]#

Converts x and y to 2D quarter symmetric coordinates.

yippy.jax_funcs.convert_xy_2D(x, y)[source]#

Converts x and y to 2D coordinates.

yippy.jax_funcs.basic_shift_val(input_val, converted_val, pixel_scale_arcsec)[source]#

Calculates the shift in pixels for a basic shift.

yippy.jax_funcs.sym_shift_val(input_val, converted_val, pixel_scale_arcsec)[source]#

Calculates the shift in pixels for a symmetric shift.

yippy.jax_funcs.x_basic_shift(input_val, converted_val, PSF, pixel_scale_arcsec, x_phasor)[source]#

Shifts the PSF to the specified x position.

yippy.jax_funcs.y_basic_shift(input_val, converted_val, PSF, pixel_scale_arcsec, y_phasor)[source]#

Shifts the PSF to the specified y position.

yippy.jax_funcs.x_symmetric_shift(input_val, converted_val, PSF, pixel_scale_arcsec, x_phasor)[source]#

Shifts the PSF to the specified position assuming symmetry about x=0.

yippy.jax_funcs.y_symmetric_shift(input_val, converted_val, PSF, pixel_scale_arcsec, y_phasor)[source]#

Shifts the PSF to the specified position assuming symmetry about y=0.

yippy.jax_funcs.fft_rotate_jax(image, rot_deg)[source]#

Rotate an image by a specified angle using Fourier-based shear operations.

This function performs an image rotation by decomposing the rotation into three sequential shear operations in the Fourier domain. For more details see Larkin et al. (1997).

Args:
image (jax.numpy.ndarray):

The input image to be rotated.

rot_deg (float):

The rotation angle in degrees. Positive values rotate the image counterclockwise, and negative values rotate it clockwise.

Returns:
jax.numpy.ndarray:

The rotated image.

yippy.jax_funcs.rotate_with_shear(image, rot_deg)[source]#

Rotate an image by a specified angle using Fourier-based shear operations.

This is a helper function that simplifies the fft_rotate_jax function by simplifying the lambda function used in the lax.cond call.

Args:
image (jax.numpy.ndarray):

The input image to be rotated.

rot_deg (float):

The rotation angle in degrees.

Returns:
jax.numpy.ndarray:

The rotated image.

yippy.jax_funcs.fft_shear_setup(image)[source]#

Perform a shear operation in the Fourier domain.

Args:
image (jax.numpy.ndarray):

The input image to be sheared.

Returns:
tuple:
  • jax.numpy.ndarray: x frequencies used for the Fourier transform.

  • jax.numpy.ndarray: x distances from the center of the image.

  • jax.numpy.ndarray: y frequencies used for the Fourier transform.

  • jax.numpy.ndarray: y distances from the center of the image

yippy.jax_funcs.fft_shear_x(image, shear_factor, x_freqs, x_dists)[source]#

Perform a shear operation in the Fourier domain along the x-axis.

Uses JAX functions to perform the shear operation in the Fourier domain along the x-axis.

Args:
image (jax.numpy.ndarray):

The input image to be sheared.

shear_factor (float):

The shear factor.

x_freqs (jax.numpy.ndarray):

x frequencies used for the Fourier transform.

x_dists (jax.numpy.ndarray):

x distances from the center of the image.

Returns:
jax.numpy.ndarray:

The sheared image with the zero padding removed.

yippy.jax_funcs.fft_shear_y(image, shear_factor, y_freqs, y_dists)[source]#

Perform a shear operation in the Fourier domain along the y-axis.

Uses JAX operations.

Args:
image (jax.numpy.ndarray):

The input image to be sheared.

shear_factor (float):

The shear factor.

y_freqs (jax.numpy.ndarray):

y frequencies used for the Fourier transform.

y_dists (jax.numpy.ndarray):

y distances from the center of the image.

Returns:
jax.numpy.ndarray:

The sheared image with the zero padding removed.

yippy.jax_funcs.decompose_angle_jax(angle)[source]#

Decompose an angle from [0, 360) to (-45, 45] and 90 degree rotations.

Args:
angle (float):

The input angle in degrees.

Returns:
tuple:
  • float: The rotation angle in the range (-45, 45] degrees.

  • int: The number of 90-degree rotations to apply.

yippy.jax_funcs.rot90_traceable(m, k=1, axes=(0, 1))[source]#

Rotate an array by 90 degrees in the plane specified by axes.

This function is a traceable version of numpy.rot90 taken from the jax GitHub issues.

Args:
m (jax.numpy.ndarray):

The input array to be rotated.

k (int, optional):

The number of 90-degree rotations to apply.

axes (tuple, optional):

The axes to rotate the array in.

Returns:
jax.numpy.ndarray:

The rotated array.

yippy.jax_funcs.synthesize_psf_separable(x, y, pixel_scale_arcsec, flat_psfs, x_offsets, y_offsets, offset_to_flat_idx, kx, ky, n_pad, x_symmetric, y_symmetric, input_type, max_offset)[source]#

Synthesizes PSF using separable 1D FFTs.

Optimized for speed and stability with even-sized padding. Uses flat_psfs with offset_to_flat_idx mapping for memory efficiency.

Args:
x (float):

X coordinate in lambda/D.

y (float):

Y coordinate in lambda/D.

pixel_scale_arcsec (float):

Pixel scale in lambda/D per pixel.

flat_psfs (jnp.ndarray):

Flat array of PSFs with shape (N_psfs, H, W).

x_offsets (jnp.ndarray):

Sorted array of unique x offsets.

y_offsets (jnp.ndarray):

Sorted array of unique y offsets.

offset_to_flat_idx (jnp.ndarray):

2D array mapping (x_idx, y_idx) -> flat_psfs index.

kx (jnp.ndarray):

FFT frequencies for x-axis.

ky (jnp.ndarray):

FFT frequencies for y-axis.

n_pad (int):

Number of padding pixels.

x_symmetric (bool):

Whether PSFs are symmetric about x=0.

y_symmetric (bool):

Whether PSFs are symmetric about y=0.

input_type (str):

Type of input data (“1d”, “2dq”, or “2df”).

max_offset (float):

Maximum valid offset distance (-1 to disable bounds check).

Returns:
jnp.ndarray:

Synthesized PSF at the requested (x, y) position.

yippy.jax_funcs.synthesize_psf_idw(x, y, pixel_scale_arcsec, flat_psfs, flat_x_offsets, flat_y_offsets, kx, ky, n_pad, x_symmetric, y_symmetric, input_type, k_neighbors=4)[source]#

Synthesizes PSF using Power-2 Polar IDW for irregular grids.

Distance is the tangent-plane arc length at the query, d = sqrt(dr**2 + (r_q * dtheta)**2), and weights use 1 / d**2. The polar metric matches the radial/azimuthal anisotropy of coronagraph PSF structure, and the squared exponent acts as a soft nearest-neighbor on smooth grids. See yippy-paper Section 4 for the kernel-selection rationale.