Source code for spacecore.linop._dense

from __future__ import annotations

from functools import cached_property
from math import prod
from typing import Any, cast

from ._base import Codomain, Domain, LinOp
from .._checks import checked_method
from ._metric import _metric_is_hermitian_by_basis, _requires_euclidean_or_riesz
from ..space import (
    DenseCoordinateSpace,
    DenseVectorSpace,
    ElementwiseJordanSpace,
    WeightedInnerProduct,
)
from ..types import DenseArray
from ..backend import jax_pytree_class, Context
from .._contextual import resolve_context_priority
from ..kernels import core_kernels
from ..kernels.core.dense import _DenseMode


[docs] @core_kernels("dense") @jax_pytree_class class DenseLinOp(LinOp[Domain, Codomain]): r""" Represent a dense coordinate tensor-backed linear operator. ``DenseLinOp(A, dom, cod)`` represents a linear map :math:`A \colon X \to Y` where the stored dense array has shape ``cod.shape + dom.shape``. Forward application is the raw coordinate matrix action. Adjoint application is metric-aware: Euclidean spaces use the conjugate transpose fast path, while non-Euclidean spaces use their Riesz maps as ``R_X^{-1} A^dagger R_Y``. DenseLinOp does not copy or cast the input array. The caller is responsible for passing an array compatible with `ctx`. This avoids duplicating large dense operators in memory. Parameters ---------- A : DenseArray Dense backend array with shape ``cod.shape + dom.shape``. dom : Space Domain space. cod : Space or None, optional Codomain space. If omitted, it is inferred from the leading axes of ``A``. ctx : Context, str, or None, optional Backend context specification. Default is resolved from the spaces. Attributes ---------- A : DenseArray Stored dense operator tensor. Examples -------- >>> import numpy as np >>> import spacecore as sc >>> ctx = sc.Context(sc.NumpyOps(), dtype=np.float64) >>> X = sc.DenseCoordinateSpace((2,), ctx) >>> A = sc.DenseLinOp(ctx.asarray([[2.0, 0.0], [0.0, 3.0]]), X, X, ctx) >>> A.apply(ctx.asarray([1.0, 2.0])) array([2., 6.]) """ def __init__( self, A: DenseArray, dom: Domain, cod: Codomain | None = None, ctx: Context | str | None = None, ) -> None: ctx = resolve_context_priority(ctx, dom, cod) ctx.assert_dense(A) # Check if A is ndarray of ctx if cod is None: cod_shape_len = len(A.shape) - len(dom.shape) cod = cast(Codomain, DenseCoordinateSpace(tuple(A.shape[:cod_shape_len]), ctx)) _requires_euclidean_or_riesz(dom, cod, "DenseLinOp") super(DenseLinOp, self).__init__(dom, cod, ctx) expected = tuple(self.cod.shape) + tuple(self.dom.shape) if tuple(A.shape) != expected: raise TypeError( f"Expected A.shape == cod.shape + dom.shape == {expected}, got {A.shape}" ) self._A = A # Intentionally no dtype conversion to avoid extra memory use. self._cod_size = prod(self.cod.shape) self._dom_size = prod(self.dom.shape) self._matrix_shape = (self._cod_size, self._dom_size) self._A2 = self.A.reshape(self._matrix_shape) dtype = self.ops.get_dtype(self.A) is_complex = self.ops.is_complex_dtype(dtype) self._A2T = self._A2.T self._A2H = self._A2.T.conj() if is_complex else self._A2.T self._dom_is_flat = tuple(self.dom.shape) == (self._dom_size,) self._cod_is_flat = tuple(self.cod.shape) == (self._cod_size,) self._mode = self._select_mode() if self._mode is _DenseMode.WEIGHTED_FUSED: self._dom_weights = self.dom.geometry.weights self._cod_weights = self.cod.geometry.weights self._weighted_A2H = ( self._A2H * self._cod_weights.reshape((1, self._cod_size)) ) / self._dom_weights.reshape((self._dom_size, 1)) def _select_mode(self) -> _DenseMode: """Select the dense computation mode once for this operator.""" vector_dom = type(self.dom) in ( DenseCoordinateSpace, DenseVectorSpace, ElementwiseJordanSpace, ) vector_cod = type(self.cod) in ( DenseCoordinateSpace, DenseVectorSpace, ElementwiseJordanSpace, ) if ( vector_dom and vector_cod and self._dom_is_flat and self._cod_is_flat and type(self.dom.geometry) is WeightedInnerProduct and type(self.cod.geometry) is WeightedInnerProduct ): return _DenseMode.WEIGHTED_FUSED if ( vector_dom and vector_cod and cast(Any, self.domain).is_euclidean and cast(Any, self.codomain).is_euclidean ): if self._dom_is_flat and self._cod_is_flat: return _DenseMode.EUCLIDEAN_FLAT return _DenseMode.EUCLIDEAN_TENSOR return _DenseMode.GENERAL_METRIC @cached_property def A(self) -> DenseArray: """ Stored dense tensor representation of this operator. The returned array has shape ``self.codomain.shape + self.domain.shape`` and is the same object supplied at construction. """ return self._A
[docs] @checked_method(in_space="domain", out_space="codomain") def apply(self, x: DenseArray) -> DenseArray: """Apply the dense operator to ``x``.""" return self._apply_core(x)
[docs] @checked_method(in_space="codomain", out_space="domain") def rapply(self, y: DenseArray) -> DenseArray: r"""Apply the adjoint dense operator to ``y``. Euclidean spaces use the conjugate transpose of the flattened matrix. Non-Euclidean spaces apply the codomain and domain Riesz maps around that Euclidean adjoint. """ return self._rapply_core(y)
[docs] @checked_method(in_space="domain", in_batched=True) def vapply(self, xs: DenseArray) -> DenseArray: """Apply over a leading batch axis. Input must have shape ``(N,) + domain.shape``; use ``moveaxis`` for other layouts.""" return self._vapply_core(xs)
[docs] @checked_method(in_space="codomain", out_space="domain", in_batched=True, out_batched=True) def rvapply(self, ys: DenseArray) -> DenseArray: """Apply the adjoint over a leading batch axis. Input must have shape ``(N,) + codomain.shape``; use ``moveaxis`` for other layouts.""" return self._rvapply_core(ys)
[docs] def to_dense(self) -> DenseArray: """ Return the stored dense tensor representation of this operator. The returned array has shape ``self.codomain.shape + self.domain.shape``. """ return self.A
[docs] def to_matrix(self) -> DenseArray: """ Return the flattened dense matrix representation. The returned array has shape ``(prod(self.codomain.shape), prod(self.domain.shape))``. It is a reshape/view of the stored dense tensor when the backend permits. """ return self._A2
[docs] def is_hermitian(self) -> bool | None: """ Return whether this dense operator is structurally self-adjoint. Returns ------- bool or None ``True`` or ``False`` when the structure can be checked, otherwise ``None``. """ if self.dom != self.cod: return False if not ( cast(Any, self.domain).is_euclidean and cast(Any, self.codomain).is_euclidean ): return _metric_is_hermitian_by_basis(self) try: return bool(self.ops.allclose(self._A2, self._A2H)) except Exception: return None
def __eq__(self, other: Any) -> bool: """Return whether another dense operator has the same spaces and values.""" if not self._eq_backend_compatible(other): # Tier 1: backend return NotImplemented if self.dom != other.dom or self.cod != other.cod: # Tier 2: spaces before allclose return False return bool(self.ops.allclose(self.A, other.A, equal_nan=True)) # Tier 3: values def tree_flatten(self): """Flatten this operator for pytree registration.""" aux = (self.dom, self.cod, self.ctx) children = (self.A,) return children, aux @classmethod def tree_unflatten(cls, aux, children): """Rebuild this operator from pytree data.""" dom, cod, ctx = aux A = children[0] return cls(A, dom, cod, ctx) def _convert(self, new_ctx: Context) -> DenseLinOp: """Convert spaces and stored dense tensor to ``new_ctx``.""" new_dom = self.dom.convert(new_ctx) new_cod = self.cod.convert(new_ctx) new_A = new_ctx.asarray(self.A) return DenseLinOp(new_A, new_dom, new_cod, new_ctx)