Source code for spacecore.linop._diagonal

from __future__ import annotations

from functools import cached_property
from math import prod
from typing import Any, cast

from ._base import LinOp
from ._metric import _metric_is_hermitian_by_basis, _requires_euclidean_or_riesz
from .._checks import checked_method
from ..backend import Context, jax_pytree_class
from ..space import (
    CoordinateSpace,
    DenseCoordinateSpace,
    DenseVectorSpace,
    ElementwiseJordanSpace,
    WeightedInnerProduct,
)
from ..types import DenseArray
from .._contextual import resolve_context_priority
from ..kernels import core_kernels
from ..kernels.core.diagonal import _DiagonalMode


[docs] @core_kernels("diagonal") @jax_pytree_class class DiagonalLinOp(LinOp[CoordinateSpace, CoordinateSpace]): r""" Represent a coordinatewise diagonal linear operator. ``DiagonalLinOp(diagonal, space)`` maps ``x`` to ``diagonal * x`` in coordinates. The adjoint is metric-aware: Euclidean spaces use the complex conjugate of the diagonal, while non-Euclidean spaces use their Riesz maps as ``R_X^{-1} D^dagger R_X``. Parameters ---------- diagonal : DenseArray Dense backend array with shape ``space.shape``. space : Space or None, optional Domain and codomain space. If omitted, a vector space is inferred from ``diagonal.shape``. ctx : Context, str, or None, optional Backend context specification. Default is resolved from ``space``. Attributes ---------- diagonal : DenseArray Stored diagonal values. Examples -------- >>> import numpy as np >>> import spacecore as sc >>> ctx = sc.Context(sc.NumpyOps(), dtype=np.float64) >>> X = sc.DenseCoordinateSpace((2,), ctx) >>> D = sc.DiagonalLinOp(ctx.asarray([2.0, 3.0]), X, ctx) >>> D.apply(ctx.asarray([4.0, 5.0])) array([ 8., 15.]) """ def __init__( self, diagonal: DenseArray, space: CoordinateSpace | None = None, ctx: Context | str | None = None, ) -> None: ctx = resolve_context_priority(ctx, space) ctx.assert_dense(diagonal) if space is None: space = DenseCoordinateSpace(tuple(diagonal.shape), ctx) _requires_euclidean_or_riesz(space, space, "DiagonalLinOp") super().__init__(space, space, ctx) expected = tuple(self.domain.shape) if tuple(diagonal.shape) != expected: raise TypeError( f"Expected diagonal.shape == space.shape == {expected}, got {diagonal.shape}" ) self.diagonal = diagonal self._diag_flat = diagonal.reshape((prod(self.domain.shape),)) dtype = self.ops.get_dtype(diagonal) self._diag_adjoint = ( self.ops.conj(diagonal) if self.ops.is_complex_dtype(dtype) else diagonal ) self._diag_adjoint_flat = self._diag_adjoint.reshape((prod(self.domain.shape),)) self._mode = self._select_mode() def _select_mode(self) -> _DiagonalMode: """Select the diagonal computation mode once for this operator.""" if ( type(self.domain) in (DenseCoordinateSpace, DenseVectorSpace, ElementwiseJordanSpace) ) and cast(Any, self.domain).is_euclidean: return _DiagonalMode.EUCLIDEAN if ( type(self.domain) in (DenseCoordinateSpace, DenseVectorSpace, ElementwiseJordanSpace) ) and type(cast(Any, self.domain).geometry) is WeightedInnerProduct: return _DiagonalMode.WEIGHTED_FUSED return _DiagonalMode.GENERAL_METRIC @cached_property def A(self) -> DenseArray: """Dense tensor representation of this diagonal operator.""" return self.to_dense()
[docs] @checked_method(in_space="domain", out_space="codomain") def apply(self, x: DenseArray) -> DenseArray: """Apply the diagonal operator to ``x``.""" return self._apply_core(x)
[docs] @checked_method(in_space="codomain", out_space="domain") def rapply(self, y: DenseArray) -> DenseArray: """Apply the adjoint diagonal operator to ``y``.""" return self._rapply_core(y)
[docs] @checked_method(in_space="domain", in_batched=True) def vapply(self, xs: DenseArray) -> DenseArray: """Apply over a leading batch axis. Input must have shape ``(N,) + domain.shape``; use ``moveaxis`` for other layouts.""" return self._vapply_core(xs)
[docs] @checked_method(in_space="codomain", out_space="domain", in_batched=True, out_batched=True) def rvapply(self, ys: DenseArray) -> DenseArray: """Apply the adjoint over a leading batch axis. Input must have shape ``(N,) + codomain.shape``; use ``moveaxis`` for other layouts.""" return self._rvapply_core(ys)
[docs] def to_matrix(self) -> DenseArray: """Return the flattened dense diagonal matrix representation.""" return self.ops.diag(self._diag_flat)
[docs] def to_dense(self) -> DenseArray: """Return a dense tensor representation of this diagonal operator.""" matrix = self.to_matrix() return self.ops.reshape(matrix, tuple(self.codomain.shape) + tuple(self.domain.shape))
[docs] def is_hermitian(self) -> bool | None: """ Return whether this diagonal operator is structurally self-adjoint. Returns ------- bool or None ``True`` or ``False`` when the structure can be checked, otherwise ``None``. """ if not cast(Any, self.domain).is_euclidean: return _metric_is_hermitian_by_basis(self) try: return bool(self.ops.allclose(self.diagonal, self._diag_adjoint)) except Exception: return None
def __eq__(self, other: Any) -> bool: """Return whether another diagonal operator has the same space and values.""" if not self._eq_backend_compatible(other): # Tier 1: backend return NotImplemented if self.domain != other.domain: # Tier 2: space before allclose return False return bool(self.ops.allclose(self.diagonal, other.diagonal, equal_nan=True)) # Tier 3 def tree_flatten(self): """Flatten this operator for pytree registration.""" children = (self.diagonal,) aux = (self.domain, self.ctx) return children, aux @classmethod def tree_unflatten(cls, aux, children): """Rebuild this operator from pytree data.""" domain, ctx = aux return cls(children[0], domain, ctx) def _convert(self, new_ctx: Context) -> DiagonalLinOp: """Convert the stored diagonal and space to ``new_ctx``.""" return DiagonalLinOp( new_ctx.asarray(self.diagonal), self.domain.convert(new_ctx), new_ctx, )