Source code for spectral_unmixing.filters

"""
Filtering and projection helpers for spectral unmixing workflows.

Author: Fabrizio Musacchio
Date: June 2026
"""
# %% IMPORTS
from __future__ import annotations

from collections.abc import Sequence

import numpy as np
from scipy.ndimage import gaussian_filter, median_filter
from skimage.exposure import match_histograms

from .io import CANONICAL_AXIS_ORDER

# %% DEFAULTS AND CONSTANTS
SUPPORTED_FILTERS = {"median", "gaussian"}
SECOND_CHANNEL_INDEX = 1

# %% INTERNAL HELPERS
def _normalize_filter_sequence(filters: str | Sequence[str]) -> list[str]:
    """Normalize one or more filter names into a validated execution sequence."""

    if isinstance(filters, str):
        filter_sequence = [filters]
    else:
        filter_sequence = list(filters)

    if not filter_sequence:
        raise ValueError("filters must contain at least one filter name.")

    normalized = []
    for filter_name in filter_sequence:
        normalized_name = str(filter_name).strip().lower()
        if normalized_name not in SUPPORTED_FILTERS:
            raise ValueError(
                f"Unsupported filter {filter_name!r}. Supported filters: {sorted(SUPPORTED_FILTERS)}."
            )
        normalized.append(normalized_name)
    return normalized


def _normalize_optional_filter_sequence(filters: str | Sequence[str] | None) -> list[str] | None:
    """Normalize an optional filter sequence while preserving ``None``."""

    if filters is None:
        return None
    return _normalize_filter_sequence(filters)


def _ensure_tzcyx_stack(stack) -> np.ndarray:
    """Promote ``YX`` or ``ZYX`` input to canonical ``TZCYX`` stack shape."""

    stack = np.asarray(stack)
    if stack.ndim == 2:
        return stack[np.newaxis, np.newaxis, np.newaxis, :, :]
    if stack.ndim == 3:
        return stack[np.newaxis, :, np.newaxis, :, :]
    if stack.ndim == 5:
        return stack
    raise ValueError(
        "Expected a stack with shape YX, ZYX, or TZCYX. "
        f"Got shape {stack.shape!r}."
    )


def _restore_original_shape(filtered_stack: np.ndarray, original_ndim: int) -> np.ndarray:
    """Undo temporary dimension promotion performed by :func:`_ensure_tzcyx_stack`."""

    if original_ndim == 2:
        return filtered_stack[0, 0, 0, :, :]
    if original_ndim == 3:
        return filtered_stack[0, :, 0, :, :]
    return filtered_stack


def _normalize_zrange(
    zrange: tuple[int, int] | Sequence[int] | None,
    z_count: int,
) -> tuple[int, int]:
    """Clamp and sanitize an optional half-open Z range against stack bounds."""

    if zrange is None:
        return 0, z_count

    if len(zrange) != 2:
        raise ValueError("zrange must be None or a tuple/list with exactly two integers.")

    start = int(zrange[0])
    stop = int(zrange[1])

    start = max(0, min(start, z_count))
    stop = max(0, min(stop, z_count))

    if stop < start:
        start, stop = stop, start

    if start == stop:
        if start >= z_count:
            start = max(0, z_count - 1)
            stop = z_count
        else:
            stop = min(z_count, start + 1)

    return start, stop


def _normalize_time_dependent_parameter(
    value,
    *,
    time_count: int,
    name: str,
    cast,
):
    """Expand a scalar or sequence into one value per time point."""

    if isinstance(value, Sequence) and not isinstance(value, (str, bytes)):
        values = list(value)
        if not values:
            raise ValueError(f"{name} must not be an empty list.")
        if len(values) == time_count:
            return [cast(item) for item in values]
        fallback = cast(values[0])
        return [fallback for _ in range(time_count)]
    scalar_value = cast(value)
    return [scalar_value for _ in range(time_count)]


def _resolve_channel2_sequence(
    primary_sequence: list,
    channel2_value,
    *,
    time_count: int,
    name: str,
    cast,
) -> list:
    """Resolve per-time filter parameters for the optional second-channel override."""

    if channel2_value is None:
        return list(primary_sequence)
    return _normalize_time_dependent_parameter(
        channel2_value,
        time_count=time_count,
        name=name,
        cast=cast,
    )


def _apply_filter_sequence_to_volume(
    volume_zyx: np.ndarray,
    *,
    filter_sequence: Sequence[str],
    median_size: int,
    gaussian_sigma: float,
    apply_3d: bool,
) -> np.ndarray:
    """Apply a validated filter sequence to one ``ZYX`` volume."""

    working_volume = np.asarray(volume_zyx, dtype=np.float32).copy()

    for filter_name in filter_sequence:
        if apply_3d:
            if filter_name == "median":
                working_volume = median_filter(
                    working_volume,
                    size=(median_size, median_size, median_size),
                ).astype(np.float32, copy=False)
            else:
                working_volume = gaussian_filter(
                    working_volume,
                    sigma=(gaussian_sigma, gaussian_sigma, gaussian_sigma),
                ).astype(np.float32, copy=False)
        else:
            filtered = np.empty_like(working_volume, dtype=np.float32)
            for z in range(working_volume.shape[0]):
                plane = working_volume[z, :, :]
                if filter_name == "median":
                    filtered[z, :, :] = median_filter(
                        plane,
                        size=(median_size, median_size),
                    )
                else:
                    filtered[z, :, :] = gaussian_filter(
                        plane,
                        sigma=(gaussian_sigma, gaussian_sigma),
                    )
            working_volume = filtered

    return working_volume


def _apply_filter_sequences_tzcyx(
    stack: np.ndarray,
    *,
    filter_sequence: Sequence[str],
    second_channel_filter_sequence: Sequence[str],
    median_sizes: Sequence[int],
    gaussian_sigmas: Sequence[float],
    second_channel_median_sizes: Sequence[int],
    second_channel_gaussian_sigmas: Sequence[float],
    apply_3d: bool,
) -> np.ndarray:
    """Apply channel-aware filter sequences across a canonical ``TZCYX`` stack."""

    filtered = np.empty_like(stack, dtype=np.float32)
    time_count, z_count, channel_count = stack.shape[:3]

    for t in range(time_count):
        for c in range(channel_count):
            volume = np.asarray(stack[t, :, c, :, :], dtype=np.float32)
            if c == SECOND_CHANNEL_INDEX and channel_count > SECOND_CHANNEL_INDEX:
                current_filter_sequence = second_channel_filter_sequence
                current_median_size = int(second_channel_median_sizes[t])
                current_gaussian_sigma = float(second_channel_gaussian_sigmas[t])
            else:
                current_filter_sequence = filter_sequence
                current_median_size = int(median_sizes[t])
                current_gaussian_sigma = float(gaussian_sigmas[t])

            if current_median_size < 1:
                raise ValueError(
                    f"median_size must be >= 1. Got {current_median_size!r} at t={t}, c={c}."
                )
            if current_gaussian_sigma <= 0:
                raise ValueError(
                    "gaussian_sigma must be > 0. "
                    f"Got {current_gaussian_sigma!r} at t={t}, c={c}."
                )

            filtered[t, :, c, :, :] = _apply_filter_sequence_to_volume(
                volume,
                filter_sequence=current_filter_sequence,
                median_size=current_median_size,
                gaussian_sigma=current_gaussian_sigma,
                apply_3d=apply_3d,
            )
    return filtered


[docs] def apply_filters( stack, filters: str | Sequence[str], *, filters_channel2: str | Sequence[str] | None = None, median_size: int | Sequence[int] = 3, gaussian_sigma: float | Sequence[float] = 1.0, median_size_channel2: int | Sequence[int] | None = None, gaussian_sigma_channel2: float | Sequence[float] | None = None, apply_3d: bool = False, ) -> np.ndarray: """ Apply one or more filters to a microscopy stack. Parameters ---------- stack : array-like Input stack in canonical ``TZCYX`` order, or a simpler ``ZYX`` / ``YX`` array. filters : str or sequence of str Either a single filter name or a sequence such as ``["median", "gaussian"]``. Filters are applied in the order provided. filters_channel2 : str or sequence of str or None, optional Optional filter sequence applied only to the second channel (index ``1``). If ``None``, the same ``filters`` sequence is used for all channels. median_size : int or sequence of int, optional Median kernel size. If a sequence with length ``T`` is provided, the value is applied per time point. If the sequence length does not match ``T``, only the first entry is used for all time points. gaussian_sigma : float or sequence of float, optional Gaussian sigma. If a sequence with length ``T`` is provided, the value is applied per time point. If the sequence length does not match ``T``, only the first entry is used for all time points. median_size_channel2 : int or sequence of int or None, optional Optional median kernel size override for the second channel (index ``1``). If ``None``, the values from ``median_size`` are reused. gaussian_sigma_channel2 : float or sequence of float or None, optional Optional Gaussian sigma override for the second channel (index ``1``). If ``None``, the values from ``gaussian_sigma`` are reused. apply_3d : bool, optional If True, apply filters in 3D over ``ZYX`` for each ``T`` and ``C`` volume. If False, apply them plane-wise in ``YX`` for each available ``T`` and ``Z``. Returns ------- np.ndarray Filtered stack with the same shape as the input. Notes ----- Channel-specific overrides currently target the second channel, i.e. channel index ``1``. This matches the common two-channel use case in the unmixing workflow where the second channel may need different smoothing strength. """ filter_sequence = _normalize_filter_sequence(filters) second_channel_filter_sequence = _normalize_optional_filter_sequence(filters_channel2) original_stack = np.asarray(stack) original_ndim = original_stack.ndim working_stack = _ensure_tzcyx_stack(original_stack).astype(np.float32, copy=True) time_count = int(working_stack.shape[0]) median_sizes = _normalize_time_dependent_parameter( median_size, time_count=time_count, name="median_size", cast=int, ) gaussian_sigmas = _normalize_time_dependent_parameter( gaussian_sigma, time_count=time_count, name="gaussian_sigma", cast=float, ) second_channel_filter_sequence = ( filter_sequence if second_channel_filter_sequence is None else second_channel_filter_sequence ) second_channel_median_sizes = _resolve_channel2_sequence( median_sizes, median_size_channel2, time_count=time_count, name="median_size_channel2", cast=int, ) second_channel_gaussian_sigmas = _resolve_channel2_sequence( gaussian_sigmas, gaussian_sigma_channel2, time_count=time_count, name="gaussian_sigma_channel2", cast=float, ) working_stack = _apply_filter_sequences_tzcyx( working_stack, filter_sequence=filter_sequence, second_channel_filter_sequence=second_channel_filter_sequence, median_sizes=median_sizes, gaussian_sigmas=gaussian_sigmas, second_channel_median_sizes=second_channel_median_sizes, second_channel_gaussian_sigmas=second_channel_gaussian_sigmas, apply_3d=apply_3d, ) return _restore_original_shape(working_stack, original_ndim)
[docs] def max_z_project( stack, *, zrange: tuple[int, int] | Sequence[int] | None = None, ) -> np.ndarray: """ Compute a maximum-intensity projection over the Z axis while preserving ``T`` and ``C``. Parameters ---------- stack : array-like Input stack in canonical ``TZCYX`` order, or a simpler ``ZYX`` / ``YX`` array. zrange : tuple[int, int] or None, optional Optional half-open Z range ``(start, stop)`` used for the projection. If the provided bounds fall outside the stack, they are clamped to the valid Z extent. If ``None``, the full Z range is used. Returns ------- np.ndarray The returned stack stays in canonical ``TZCYX`` order with a singleton Z dimension. """ stack_tzcyx = _ensure_tzcyx_stack(stack) z_start, z_stop = _normalize_zrange(zrange, stack_tzcyx.shape[1]) projected = np.max(stack_tzcyx[:, z_start:z_stop, :, :, :], axis=1, keepdims=True) return projected
[docs] def match_histograms_across_time( stack, *, reference_t: int = 0, ) -> np.ndarray: """ Match each time point to a reference time point using per-channel histogram matching. Parameters ---------- stack : array-like Input stack in canonical ``TZCYX`` order. reference_t : int, optional Reference time point used for histogram matching. Default is ``0``. Returns ------- np.ndarray Histogram-matched stack with the same ``TZCYX`` shape as the input. Notes ----- Matching is performed independently for each channel. If ``Z > 1``, the full ``ZYX`` volume of each time point is matched to the corresponding reference volume of the same channel. """ stack_tzcyx = _ensure_tzcyx_stack(stack) if stack_tzcyx.ndim != 5: raise ValueError( f"Expected a {CANONICAL_AXIS_ORDER} stack. Got shape {stack_tzcyx.shape!r}." ) if stack_tzcyx.shape[0] <= 1: raise ValueError("Histogram matching across time requires T > 1.") if not 0 <= int(reference_t) < stack_tzcyx.shape[0]: raise ValueError( f"reference_t must be between 0 and {stack_tzcyx.shape[0] - 1}. Got {reference_t!r}." ) matched = stack_tzcyx.astype(np.float32, copy=True) reference_t = int(reference_t) for c in range(stack_tzcyx.shape[2]): reference_volume = np.asarray(stack_tzcyx[reference_t, :, c, :, :], dtype=np.float32) for t in range(stack_tzcyx.shape[0]): if t == reference_t: continue moving_volume = np.asarray(stack_tzcyx[t, :, c, :, :], dtype=np.float32) matched[t, :, c, :, :] = match_histograms( moving_volume, reference_volume, channel_axis=None, ).astype(np.float32, copy=False) return matched
# %% PUBLIC API __all__ = [ "CANONICAL_AXIS_ORDER", "SUPPORTED_FILTERS", "apply_filters", "match_histograms_across_time", "max_z_project", ] # %% END