"""Coronagraph performance metric computation.
This module provides standalone functions for computing coronagraph performance
curves: throughput, raw contrast, core area, occulter transmission, and core
mean intensity. Each metric has its own compute function for clarity, plus a
``compute_all_performance_curves`` orchestrator.
All functions operate on a Coronagraph instance passed as the first argument,
keeping the Coronagraph class itself slim while preserving backward-compatible
delegation methods on it.
"""
from __future__ import annotations
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING
import jax.numpy as jnp
import numpy as np
from hwoutils.constants import GAUSSIAN_FWHM_FACTOR
from hwoutils.radial import radial_profile
from hwoutils.transforms import resample_flux
from lod_unit import lod
from scipy.interpolate import RegularGridInterpolator, make_interp_spline
from scipy.optimize import root_scalar
from .logger import logger
from .util import (
convert_to_pix,
extract_and_oversample_subarray,
load_coro_performance_from_fits,
measure_flux_in_oversampled_aperture,
save_coro_performance_to_fits,
)
if TYPE_CHECKING:
from .coronagraph import Coronagraph
# ---------------------------------------------------------------------------
# Data classes
# ---------------------------------------------------------------------------
[docs]
@dataclass
class OffAxisPosition:
"""Data for a single off-axis PSF position along the performance axis."""
separation: float # |x| in lam/D
psf: np.ndarray # the off-axis PSF image
px: int # pixel x position in the image
py: int # pixel y position in the image
# ---------------------------------------------------------------------------
# Internal helpers
# ---------------------------------------------------------------------------
[docs]
def _iter_xaxis_positions(coro: Coronagraph):
"""Yield :class:`OffAxisPosition` for each valid x-offset (y ~= 0).
Iterates over each valid x-offset position along the performance-curve
axis. All native YIP offsets are included, even those where the nominal
point-source location lies outside the image, because the truncation-ratio
throughput remains meaningful as long as the translated PSF retains
significant in-image power.
"""
y_offsets = np.array(coro.offax.y_offsets)
y_idx = int(np.argmin(np.abs(y_offsets)))
y_lod_val = y_offsets[y_idx]
if y_lod_val != 0:
logger.warning(f"No PSF at y=0, using closest y offset: {y_lod_val:.3f} lam/D")
for i, x_lod_val in enumerate(np.array(coro.offax.x_offsets)):
if x_lod_val < 0:
continue
r = abs(x_lod_val)
planet_psf = coro.offax.get_psf_by_offset_idx(i, y_idx)
if planet_psf is None:
continue
px = convert_to_pix(
x_lod_val, coro.offax.center_x, coro.pixel_scale_arcsec
).value.astype(int)
py = convert_to_pix(
y_lod_val, coro.offax.center_y, coro.pixel_scale_arcsec
).value.astype(int)
yield OffAxisPosition(separation=r, psf=planet_psf, px=px, py=py)
[docs]
def _collect_and_sort(separations: list, values: list) -> tuple[np.ndarray, np.ndarray]:
"""Convert lists to sorted arrays by separation."""
sep = np.array(separations)
val = np.array(values)
order = np.argsort(sep)
return sep[order], val[order]
[docs]
def _point_source_in_image(pos: OffAxisPosition) -> bool:
"""Return True if the nominal point-source pixel lies inside the PSF image."""
ny, nx = pos.psf.shape
return 0 <= pos.px < nx and 0 <= pos.py < ny
[docs]
def _oversample_psf(
psf: np.ndarray, pixel_scale_arcsec: float, oversample: int
) -> np.ndarray:
"""Oversample a PSF using flux-conserving resampling.
Uses ``resample_flux`` from hwoutils which converts to surface brightness
before interpolation and back to integrated flux after, guaranteeing
per-pixel flux accuracy.
Args:
psf: 2D PSF image.
pixel_scale_arcsec: Pixel scale of the input PSF (lam/D per pixel).
oversample: Oversampling factor.
Returns:
Oversampled PSF with flux conserved and negative values clamped.
"""
os_pix = pixel_scale_arcsec / oversample
ny_os = psf.shape[0] * oversample
nx_os = psf.shape[1] * oversample
psf_os = np.asarray(
resample_flux(
jnp.asarray(np.asarray(psf, dtype=np.float64)),
pixel_scale_arcsec,
os_pix,
(ny_os, nx_os),
)
)
return np.maximum(psf_os, 0.0)
[docs]
def _threshold_mask(psf_os: np.ndarray, trunc_ratio: float) -> np.ndarray:
"""Create boolean mask of pixels exceeding ``trunc_ratio * peak``.
Falls back to peak-only mask if no pixels exceed the threshold.
"""
peak = psf_os.max()
mask = psf_os > trunc_ratio * peak
if not mask.any():
mask = psf_os == peak
return mask
[docs]
def _compute_iwa_owa(
coro: Coronagraph, sep: np.ndarray, throughput: np.ndarray
) -> None:
"""Compute IWA and OWA from throughput curve, storing on *coro*."""
valid_mask = throughput > 0
half_max_throughput = max(throughput[valid_mask]) / 2
def iwa_func(x):
return coro.throughput_interp(x) - half_max_throughput
iwa_bracket = None
for i in range(1, len(sep)):
if iwa_func(sep[i - 1]) * iwa_func(sep[i]) < 0:
iwa_bracket = [sep[i - 1], sep[i]]
break
if iwa_bracket is not None:
coro.IWA = root_scalar(iwa_func, bracket=iwa_bracket).root * lod
else:
first_valid_idx = np.argmax(valid_mask)
coro.IWA = sep[first_valid_idx] * lod
logger.warning(f"Could not find IWA bracket, using first valid sep: {coro.IWA}")
if hasattr(coro.offax, "max_offset_in_image"):
coro.OWA = coro.offax.max_offset_in_image
logger.info(
f"OWA set to max_offset_in_image: {coro.OWA.to(lod).value:.2f} lam/D"
)
else:
coro.OWA = np.max(sep) * lod
logger.warning(
"max_offset_in_image not available, using maximum separation as OWA"
)
# ---------------------------------------------------------------------------
# Public utility
# ---------------------------------------------------------------------------
[docs]
def compute_radial_average(
image: np.ndarray,
pixel_scale_value: float,
center: tuple[float, float] | None = None,
nbins: int | None = None,
) -> tuple[np.ndarray, np.ndarray]:
"""Compute radial average of a 2D image.
Thin wrapper around :func:`hwoutils.radial.radial_profile` for
backward compatibility.
Args:
image: 2D numpy array.
pixel_scale_value: Pixel scale in lambda/D per pixel (float).
center: ``(cy, cx)`` pixel coordinates of centre. Defaults to
image centre.
nbins: Number of radial bins. Defaults to ``floor(max_dim / 2)``.
Returns:
``(separations_lod, radial_profile)`` - 1-D arrays.
"""
kwargs = {}
if center is not None:
# Incoming convention is [x, y]; radial_profile expects (cy, cx)
if isinstance(center, list):
kwargs["center"] = (center[1], center[0])
else:
kwargs["center"] = center
if nbins is not None:
kwargs["nbins"] = nbins
seps, prof = radial_profile(
jnp.asarray(np.asarray(image, dtype=np.float64)),
pixel_scale_arcsec=pixel_scale_value,
**kwargs,
)
return np.asarray(seps), np.asarray(prof)
# ---------------------------------------------------------------------------
# Individual metric functions
# ---------------------------------------------------------------------------
[docs]
def compute_throughput_curve(
coro: Coronagraph,
aperture_radius_lod: float = 0.7,
oversample: int = 2,
) -> tuple[np.ndarray, np.ndarray]:
"""Compute coronagraph throughput vs separation.
Throughput is the fraction of the total flux (planet PSF normalised to 1)
that lands inside a photometric aperture at each off-axis position.
Args:
coro: Coronagraph instance.
aperture_radius_lod: Aperture radius in lam/D.
oversample: Oversampling factor.
Returns:
``(separations, throughputs)`` - sorted 1-D arrays.
"""
separations, throughputs = [], []
radius_pix = aperture_radius_lod / coro.pixel_scale_arcsec.value
for pos in _iter_xaxis_positions(coro):
if not _point_source_in_image(pos):
separations.append(pos.separation)
throughputs.append(0.0)
continue
sub_os, px_os, py_os, r_os, sub_orig = extract_and_oversample_subarray(
pos.psf, pos.px, pos.py, radius_pix, oversample
)
# Center aperture on PSF peak to match AYO's behavior
peak_ij = np.unravel_index(sub_os.argmax(), sub_os.shape)
py_os, px_os = float(peak_ij[0]), float(peak_ij[1])
flux = measure_flux_in_oversampled_aperture(
sub_os, px_os, py_os, r_os, sub_orig
)
separations.append(pos.separation)
throughputs.append(flux)
return _collect_and_sort(separations, throughputs)
[docs]
def compute_raw_contrast_curve(
coro: Coronagraph,
stellar_diam=0.0 * lod,
aperture_radius_lod: float = 0.7,
oversample: int = 2,
) -> tuple[np.ndarray, np.ndarray]:
"""Compute raw contrast curve vs separation.
Contrast is the ratio of stellar flux to planet flux within the same
photometric aperture at each off-axis position.
Args:
coro: Coronagraph instance.
stellar_diam: Stellar angular diameter for the on-axis PSF.
aperture_radius_lod: Aperture radius in lam/D.
oversample: Oversampling factor.
Returns:
``(separations, contrasts)`` - sorted 1-D arrays.
"""
star_psf = coro.stellar_intens(stellar_diam)
separations, contrasts = [], []
radius_pix = aperture_radius_lod / coro.pixel_scale_arcsec.value
for pos in _iter_xaxis_positions(coro):
if not _point_source_in_image(pos):
# Raw contrast is undefined outside the image. Use a large finite
# sentinel so downstream spline construction stays well-defined;
# the point lies beyond OWA (max_offset_in_image) so it will not
# be queried in physically meaningful regimes.
separations.append(pos.separation)
contrasts.append(1e20)
continue
# Planet flux
sub_os_p, px_os_p, py_os_p, r_os_p, sub_orig_p = (
extract_and_oversample_subarray(
pos.psf, pos.px, pos.py, radius_pix, oversample
)
)
planet_flux = measure_flux_in_oversampled_aperture(
sub_os_p, px_os_p, py_os_p, r_os_p, sub_orig_p
)
# Star flux
sub_os_s, sx_os, sy_os, r_os_s, sub_orig_s = extract_and_oversample_subarray(
star_psf, pos.px, pos.py, radius_pix, oversample
)
star_flux = measure_flux_in_oversampled_aperture(
sub_os_s, sx_os, sy_os, r_os_s, sub_orig_s
)
contrast_val = star_flux / planet_flux if planet_flux > 0 else np.inf
separations.append(pos.separation)
contrasts.append(contrast_val)
return _collect_and_sort(separations, contrasts)
[docs]
def compute_core_area_curve(
coro: Coronagraph,
aperture_radius_lod: float = 0.7,
fit_gaussian: bool = False,
use_phot_aperture_as_min: bool = False,
oversample: int = 2,
) -> tuple[np.ndarray, np.ndarray]:
"""Compute core area vs separation.
The core area represents the effective area of the PSF core. If
*fit_gaussian* is True a 2-D Gaussian is fitted to each PSF; otherwise a
fixed circular aperture is used.
Args:
coro: Coronagraph instance.
aperture_radius_lod: Aperture radius in lam/D.
fit_gaussian: Whether to fit a 2-D Gaussian.
use_phot_aperture_as_min: Use aperture area as a floor when fitting.
oversample: Oversampling factor.
Returns:
``(separations, core_areas)`` - sorted 1-D arrays, area in (lam/D)**2.
"""
if fit_gaussian:
from scipy.optimize import curve_fit
def gaussian_2d(coords, amplitude, x0, y0, sigma_x, sigma_y):
x, y = coords
return (
amplitude
* np.exp(-(((x - x0) / sigma_x) ** 2 + ((y - y0) / sigma_y) ** 2) / 2)
).ravel()
separations, core_areas = [], []
for pos in _iter_xaxis_positions(coro):
if fit_gaussian and not _point_source_in_image(pos):
separations.append(pos.separation)
core_areas.append(np.pi * aperture_radius_lod**2)
continue
if fit_gaussian:
sub_os, px_os, py_os, _, _ = extract_and_oversample_subarray(
pos.psf,
pos.px,
pos.py,
aperture_radius_lod / coro.pixel_scale_arcsec.value * 3,
oversample,
)
y_grid, x_grid = np.indices(sub_os.shape)
amplitude = sub_os.max()
sigma_init = 1.0 * oversample
popt, _ = curve_fit(
gaussian_2d,
(x_grid, y_grid),
sub_os.ravel(),
p0=[amplitude, px_os, py_os, sigma_init, sigma_init],
)
_, _, _, sigma_x, sigma_y = popt
sigma_x_lod = sigma_x / oversample * coro.pixel_scale_arcsec.value
sigma_y_lod = sigma_y / oversample * coro.pixel_scale_arcsec.value
fwhm_x = GAUSSIAN_FWHM_FACTOR * sigma_x_lod
fwhm_y = GAUSSIAN_FWHM_FACTOR * sigma_y_lod
area = np.pi * fwhm_x * fwhm_y / 4
if use_phot_aperture_as_min:
area = max(area, np.pi * aperture_radius_lod**2)
else:
area = np.pi * aperture_radius_lod**2
separations.append(pos.separation)
core_areas.append(area)
return _collect_and_sort(separations, core_areas)
[docs]
def compute_truncation_throughput_curve(
coro: Coronagraph,
psf_trunc_ratio: float = 0.5,
oversample: int | None = None,
) -> tuple[np.ndarray, np.ndarray]:
"""Compute throughput using a PSF-truncation-ratio aperture.
Instead of a fixed circular aperture, selects all pixels where the
oversampled PSF exceeds ``psf_trunc_ratio * peak``. Throughput is the
sum of those pixels (after flux-conserving resampling). This
matches the ``photap_frac`` calculation in AYO's ``load_coronagraph.pro``.
Args:
coro: Coronagraph instance.
psf_trunc_ratio: Fraction of the PSF peak used as threshold
(e.g. 0.5 keeps all pixels > 50% of the peak).
oversample: Oversampling factor. ``None`` uses AYO's rule:
``ceil(pixscale / 0.05)``.
Returns:
``(separations, throughputs)`` - sorted 1-D arrays.
"""
pix_lod = coro.pixel_scale_arcsec.value
if oversample is None:
oversample = int(np.ceil(pix_lod / 0.05))
separations, throughputs = [], []
for pos in _iter_xaxis_positions(coro):
psf_os = _oversample_psf(pos.psf, pix_lod, oversample)
mask = _threshold_mask(psf_os, psf_trunc_ratio)
throughput = psf_os[mask].sum()
separations.append(pos.separation)
throughputs.append(throughput)
return _collect_and_sort(separations, throughputs)
[docs]
def compute_truncation_core_area_curve(
coro: Coronagraph,
psf_trunc_ratio: float = 0.5,
oversample: int | None = None,
) -> tuple[np.ndarray, np.ndarray]:
"""Compute core area using a PSF-truncation-ratio aperture.
The core area is the solid angle (in (lam/D)**2) of all oversampled pixels
that exceed ``psf_trunc_ratio * peak``. This matches AYO's
``omega_lod`` calculation.
Args:
coro: Coronagraph instance.
psf_trunc_ratio: Fraction of PSF peak used as threshold.
oversample: Oversampling factor. ``None`` uses AYO's rule.
Returns:
``(separations, core_areas)`` - sorted 1-D arrays, area in (lam/D)^2.
"""
pix_lod = coro.pixel_scale_arcsec.value
if oversample is None:
oversample = int(np.ceil(pix_lod / 0.05))
# Solid angle of one oversampled pixel in (lam/D)**2
os_pix_lod = pix_lod / oversample
pix_solid_angle = os_pix_lod**2
separations, core_areas = [], []
for pos in _iter_xaxis_positions(coro):
psf_os = _oversample_psf(pos.psf, pix_lod, oversample)
mask = _threshold_mask(psf_os, psf_trunc_ratio)
area = mask.sum() * pix_solid_angle
separations.append(pos.separation)
core_areas.append(area)
return _collect_and_sort(separations, core_areas)
[docs]
def compute_occ_trans_curve(
coro: Coronagraph,
) -> tuple[np.ndarray, np.ndarray]:
"""Compute occulter (sky) transmission curve.
This is the radial profile of the sky transmission mask.
Args:
coro: Coronagraph instance.
Returns:
``(separations_lod, occ_trans)`` - 1-D arrays.
"""
sky_trans_data = coro.sky_trans()
return compute_radial_average(sky_trans_data, coro.pixel_scale_arcsec.value)
[docs]
def compute_core_mean_intensity_curve(
coro: Coronagraph,
stellar_diam_values=None,
) -> tuple[np.ndarray, dict]:
"""Compute core mean intensity curves for different stellar diameters.
Args:
coro: Coronagraph instance.
stellar_diam_values:
List of stellar diameters. ``None`` -> use all available.
Returns:
``(separations_lod, intensities_dict)`` where *intensities_dict*
maps each stellar diameter to its radial intensity profile.
"""
available_diams = coro.stellar_intens.diams
if stellar_diam_values is None:
stellar_diam_values = available_diams
else:
for diam in stellar_diam_values:
if diam not in available_diams:
raise ValueError(
f"Requested stellar diameter {diam}"
f" not found in available diameters"
)
center = (coro.stellar_intens.center_y, coro.stellar_intens.center_x)
pix_scale = coro.pixel_scale_arcsec.value
# Use the first diameter to determine bin count
stellar_psf = coro.stellar_intens(stellar_diam_values[0])
nbins = int(np.floor(np.max(stellar_psf.shape) / 2))
intensities: dict = {}
separations = None
for diam in stellar_diam_values:
psf = coro.stellar_intens(diam)
seps, profile = radial_profile(
jnp.asarray(np.asarray(psf, dtype=np.float64)),
pixel_scale_arcsec=pix_scale,
center=center,
nbins=nbins,
)
if separations is None:
separations = np.asarray(seps)
intensities[diam] = np.asarray(profile)
return separations, intensities
# ---------------------------------------------------------------------------
# Interpolator builder (shared by compute_all_performance_curves and
# Coronagraph.set_psf_trunc_ratio)
# ---------------------------------------------------------------------------
[docs]
def _build_perf_interps(
coro, sep, throughput, raw_contrast, sep_ca, core_area, interp_order
):
"""Assign throughput, contrast, and core area spline interpolators to *coro*."""
coro.interp_order = interp_order
coro.throughput_interp = make_interp_spline(sep, throughput, k=interp_order)
log_contrast = np.log10(np.abs(raw_contrast) + 1e-20)
coro._log_contrast_interp = make_interp_spline(sep, log_contrast, k=interp_order)
coro.raw_contrast_interp = make_interp_spline(sep, raw_contrast, k=interp_order)
coro.core_area_interp = make_interp_spline(sep_ca, core_area, k=interp_order)
# ---------------------------------------------------------------------------
# Orchestrator
# ---------------------------------------------------------------------------