Source code for spectral_unmixing.registration

"""
Registration helpers for TZCYX microscopy stacks.

Author: Fabrizio Musacchio
Date: June 2026
"""
# %% IMPORTS
from __future__ import annotations

from collections.abc import Sequence

import numpy as np
from scipy.ndimage import median_filter, shift as ndi_shift
from skimage.registration import phase_cross_correlation

from .io import CANONICAL_AXIS_ORDER
# %% CONSTANTS
SUPPORTED_REGISTRATION_METHODS = {"phase_cross_correlation", "pystackreg"}
SUPPORTED_INTRA_STACK_REFERENCE_MODES = {"neighbor", "full_projection"}

# %% INTERNAL HELPERS
def _normalize_registration_method(method: str) -> str:
    """Normalize and validate the requested inter-frame registration backend."""

    normalized = str(method).strip().lower()
    if normalized not in SUPPORTED_REGISTRATION_METHODS:
        raise ValueError(
            f"Unsupported registration method {method!r}. "
            f"Supported methods: {sorted(SUPPORTED_REGISTRATION_METHODS)}."
        )
    return normalized


def _normalize_zrange(zrange: tuple[int, int] | Sequence[int] | None, z_count: int) -> tuple[int, int]:
    """Validate a strict half-open Z range for registration projections."""

    if zrange is None:
        return 0, z_count

    if len(zrange) != 2:
        raise ValueError("zrange must be None or a tuple/list with exactly two integers.")

    start = int(zrange[0])
    stop = int(zrange[1])
    if not 0 <= start < stop <= z_count:
        raise ValueError(
            f"zrange must satisfy 0 <= start < stop <= {z_count}. Got {(start, stop)!r}."
        )
    return start, stop


def _normalize_intra_stack_reference_mode(reference_mode: str) -> str:
    """Normalize and validate the intra-stack reference-image strategy."""

    normalized = str(reference_mode).strip().lower()
    if normalized not in SUPPORTED_INTRA_STACK_REFERENCE_MODES:
        raise ValueError(
            f"Unsupported intra-stack reference mode {reference_mode!r}. "
            f"Supported modes: {sorted(SUPPORTED_INTRA_STACK_REFERENCE_MODES)}."
        )
    return normalized


def _normalize_neighbor_window_size(neighbor_window_size: int) -> int:
    """Validate the odd-sized neighborhood used for local intra-stack references."""

    neighbor_window_size = int(neighbor_window_size)
    if neighbor_window_size < 1:
        raise ValueError(
            f"neighbor_window_size must be >= 1. Got {neighbor_window_size!r}."
        )
    if neighbor_window_size % 2 == 0:
        raise ValueError(
            "neighbor_window_size must be odd so that the current z-slice stays centered."
        )
    return neighbor_window_size


def _ensure_tzcyx_stack(stack) -> np.ndarray:
    """Validate that the input already follows canonical ``TZCYX`` order."""

    stack = np.asarray(stack)
    if stack.ndim != 5:
        raise ValueError(
            f"Expected a {CANONICAL_AXIS_ORDER} stack with 5 dimensions. Got shape {stack.shape!r}."
        )
    return stack


def _apply_median_to_zyx(volume_zyx: np.ndarray, kernel_size: int) -> np.ndarray:
    """Apply a 2D median filter independently to each Z plane of a ``ZYX`` volume."""

    filtered = np.empty_like(volume_zyx, dtype=np.float32)
    for z in range(volume_zyx.shape[0]):
        filtered[z, :, :] = median_filter(volume_zyx[z, :, :], size=(kernel_size, kernel_size))
    return filtered


def _build_intra_stack_reference_image(
    volume_zyx: np.ndarray,
    *,
    z_index: int,
    reference_mode: str,
    neighbor_window_size: int,
) -> np.ndarray:
    """Build the per-slice registration reference used for intra-stack drift correction."""

    if reference_mode == "full_projection":
        return np.max(volume_zyx, axis=0)

    half_window = neighbor_window_size // 2
    start = max(0, z_index - half_window)
    stop = min(volume_zyx.shape[0], z_index + half_window + 1)
    return np.max(volume_zyx[start:stop, :, :], axis=0)


def _build_registration_projections(
    stack: np.ndarray,
    *,
    registration_channel: int,
    zrange: tuple[int, int] | None,
    pre_median_filter: bool,
    post_median_filter: bool,
    median_kernel_size: int,
) -> np.ndarray:
    """Create per-time-point 2D registration projections from a ``TZCYX`` stack."""

    z_start, z_stop = _normalize_zrange(zrange, stack.shape[1])
    channel_stack = np.asarray(stack[:, z_start:z_stop, registration_channel, :, :], dtype=np.float32)
    working = channel_stack.copy()

    if pre_median_filter:
        for t in range(working.shape[0]):
            working[t, :, :, :] = _apply_median_to_zyx(working[t, :, :, :], median_kernel_size)

    projections = np.max(working, axis=1)

    if post_median_filter:
        for t in range(projections.shape[0]):
            projections[t, :, :] = median_filter(
                projections[t, :, :], size=(median_kernel_size, median_kernel_size)
            )

    return projections


def _phase_cross_correlation_shift(reference_projection: np.ndarray, moving_projection: np.ndarray) -> np.ndarray:
    """Estimate a 2D translation with :func:`skimage.registration.phase_cross_correlation`."""

    shift_2d, _, _ = phase_cross_correlation(reference_projection, moving_projection)
    return np.asarray(shift_2d, dtype=np.float32)


def _pystackreg_shift(reference_projection: np.ndarray, moving_projection: np.ndarray) -> np.ndarray:
    """Estimate a 2D translation with :mod:`pystackreg` in translation mode."""

    from pystackreg import StackReg  # pylint: disable=import-outside-toplevel

    sr = StackReg(StackReg.TRANSLATION)
    tmat = sr.register(reference_projection.astype(np.float32), moving_projection.astype(np.float32))
    shift_yx = np.asarray([-tmat[1, 2], -tmat[0, 2]], dtype=np.float32)
    return shift_yx


def _apply_translation_to_tzyx(stack_tzyx: np.ndarray, shift_yx: np.ndarray) -> np.ndarray:
    """Apply one XY translation to all channels and Z slices of a single time point."""

    shifted = np.empty_like(stack_tzyx, dtype=np.float32)
    for z in range(stack_tzyx.shape[0]):
        for c in range(stack_tzyx.shape[1]):
            shifted[z, c, :, :] = ndi_shift(
                np.asarray(stack_tzyx[z, c, :, :], dtype=np.float32),
                shift=tuple(float(v) for v in shift_yx),
                order=1,
                mode="constant",
                cval=0.0,
                prefilter=True,
            )
    return shifted


def _apply_translation_to_cyx(slice_cyx: np.ndarray, shift_yx: np.ndarray) -> np.ndarray:
    """Apply one XY translation to all channels of a single Z slice."""

    shifted = np.empty_like(slice_cyx, dtype=np.float32)
    for c in range(slice_cyx.shape[0]):
        shifted[c, :, :] = ndi_shift(
            np.asarray(slice_cyx[c, :, :], dtype=np.float32),
            shift=tuple(float(v) for v in shift_yx),
            order=1,
            mode="constant",
            cval=0.0,
            prefilter=True,
        )
    return shifted


def _print_verbose(verbose: bool, message: str) -> None:
    """Print a progress message only when verbose mode is enabled."""

    if verbose:
        print(message)


[docs] def correct_intra_stack_z_drift( stack, *, registration_channel: int = 0, method: str = "phase_cross_correlation", reference_mode: str = "neighbor", neighbor_window_size: int = 3, pre_median_filter: bool = False, post_median_filter: bool = False, median_kernel_size: int = 3, verbose: bool = True, ) -> np.ndarray: """ Correct XY drift between Z slices within each time point of a TZCYX stack. Parameters ---------- stack : array-like Input stack in canonical ``TZCYX`` order. registration_channel : int, optional Channel used to estimate the slice-wise XY shifts. The computed shifts are then applied to all channels of the corresponding Z slice. method : {"phase_cross_correlation", "pystackreg"}, optional Backend used for shift estimation. reference_mode : {"neighbor", "full_projection"}, optional Strategy used to build the per-slice reference image. ``"neighbor"`` uses a local max projection around each slice. ``"full_projection"`` uses the max projection across the entire Z stack of the current time point. neighbor_window_size : int, optional Odd number of slices used for ``reference_mode="neighbor"``. ``3`` means ``z-1, z, z+1``; ``5`` means ``z-2`` through ``z+2``. pre_median_filter : bool, optional If True, apply a slice-wise median filter to the registration channel before building the reference images. This affects only shift estimation. post_median_filter : bool, optional If True, apply a 2D median filter to both the moving slice and the reference image just before shift estimation. median_kernel_size : int, optional Median filter kernel size used by the optional pre/post filters. verbose : bool, optional If True, print the estimated shifts line-wise for each ``t`` and ``z``. Returns ------- np.ndarray Z-drift-corrected stack with the same ``TZCYX`` shape as the input. Notes ----- This function estimates XY shifts independently for each Z slice within each time point. The shifts are computed from a user-selected registration channel, but are applied to all channels of the affected slice. """ stack = _ensure_tzcyx_stack(stack).astype(np.float32, copy=True) method = _normalize_registration_method(method) reference_mode = _normalize_intra_stack_reference_mode(reference_mode) neighbor_window_size = _normalize_neighbor_window_size(neighbor_window_size) if not 0 <= int(registration_channel) < stack.shape[2]: raise ValueError( f"registration_channel must be between 0 and {stack.shape[2] - 1}. " f"Got {registration_channel!r}." ) if median_kernel_size < 1: raise ValueError( f"median_kernel_size must be >= 1. Got {median_kernel_size!r}." ) if stack.shape[1] <= 1: _print_verbose(verbose, "Skipping intra-stack Z drift correction because Z <= 1.") return stack.copy() corrected = stack.copy() for t in range(stack.shape[0]): volume_zyx = np.asarray(stack[t, :, int(registration_channel), :, :], dtype=np.float32) working_volume = volume_zyx.copy() if pre_median_filter: working_volume = _apply_median_to_zyx(working_volume, int(median_kernel_size)) for z in range(stack.shape[1]): moving_image = np.asarray(working_volume[z, :, :], dtype=np.float32) reference_image = _build_intra_stack_reference_image( working_volume, z_index=z, reference_mode=reference_mode, neighbor_window_size=neighbor_window_size, ).astype(np.float32, copy=False) if post_median_filter: moving_image = median_filter( moving_image, size=(int(median_kernel_size), int(median_kernel_size)), ) reference_image = median_filter( reference_image, size=(int(median_kernel_size), int(median_kernel_size)), ) if method == "phase_cross_correlation": shift_yx = _phase_cross_correlation_shift(reference_image, moving_image) else: shift_yx = _pystackreg_shift(reference_image, moving_image) corrected[t, z, :, :, :] = _apply_translation_to_cyx( stack[t, z, :, :, :], shift_yx, ) return corrected
[docs] def register_stack( stack, *, registration_channel: int, method: str = "phase_cross_correlation", zrange: tuple[int, int] | Sequence[int] | None = None, pre_median_filter: bool = False, post_median_filter: bool = False, median_kernel_size: int = 3, verbose: bool = True, ) -> np.ndarray: """ Register a TZCYX stack across time using shifts estimated from Z projections. Parameters ---------- stack : array-like Input stack in canonical ``TZCYX`` order. registration_channel : int Channel used to compute the time-wise registration shifts. method : {"phase_cross_correlation", "pystackreg"}, optional Backend used for shift estimation. zrange : tuple[int, int] or None, optional Optional half-open Z range ``(start, stop)`` used for the registration projection. pre_median_filter : bool, optional If True, apply a slice-wise median filter to the selected registration volume before max-Z projection. This affects only shift estimation, not the stack that is transformed. post_median_filter : bool, optional If True, apply a 2D median filter to each projection after max-Z projection. This affects only shift estimation, not the stack that is transformed. median_kernel_size : int, optional Median filter kernel size used by the optional pre/post filters. verbose : bool, optional If True, print the estimated shifts line-wise for each time point. Returns ------- np.ndarray Registered stack with the same ``TZCYX`` shape as the input. Notes ----- Time-wise shifts are estimated from 2D max-Z projections of the selected registration channel. The resulting translations are then applied to the original unprojected ``TZCYX`` data. """ stack = _ensure_tzcyx_stack(stack).astype(np.float32, copy=True) method = _normalize_registration_method(method) if stack.shape[0] <= 1: raise ValueError("Registration requires T > 1.") if not 0 <= int(registration_channel) < stack.shape[2]: raise ValueError( f"registration_channel must be between 0 and {stack.shape[2] - 1}. " f"Got {registration_channel!r}." ) if median_kernel_size < 1: raise ValueError( f"median_kernel_size must be >= 1. Got {median_kernel_size!r}." ) projections = _build_registration_projections( stack, registration_channel=int(registration_channel), zrange=zrange, pre_median_filter=pre_median_filter, post_median_filter=post_median_filter, median_kernel_size=int(median_kernel_size), ) reference_projection = projections[0, :, :] registered = stack.copy() _print_verbose( verbose, ( f"Registering stack with method='{method}', registration_channel=" f"{int(registration_channel)}, reference_t=0" ), ) _print_verbose(verbose, "t=0 shift_y=0.000 shift_x=0.000") for t in range(1, stack.shape[0]): moving_projection = projections[t, :, :] if method == "phase_cross_correlation": shift_yx = _phase_cross_correlation_shift(reference_projection, moving_projection) else: shift_yx = _pystackreg_shift(reference_projection, moving_projection) _print_verbose( verbose, f"t={t} shift_y={float(shift_yx[0]):.3f} shift_x={float(shift_yx[1]):.3f}", ) registered[t, :, :, :, :] = _apply_translation_to_tzyx( stack[t, :, :, :, :], shift_yx ) return registered
# %% PUBLIC API __all__ = [ "SUPPORTED_INTRA_STACK_REFERENCE_MODES", "SUPPORTED_REGISTRATION_METHODS", "correct_intra_stack_z_drift", "register_stack", ] # %% END