Source code for spacecore.linop._base

from __future__ import annotations

import warnings
from abc import abstractmethod
from functools import cached_property
from math import prod
from numbers import Number
from typing import Any, Generic, Self, TypeVar

from .._batching import _leading_batch_size, _warn_vmap_fallback_once
from .._checks import checked_method
from .._repr import describe_space
from ..space import CoordinateSpace
from ..backend import Context
from .._contextual import ContextBound

Domain = TypeVar("Domain", bound=CoordinateSpace)
Codomain = TypeVar("Codomain", bound=CoordinateSpace)


[docs] class LinOp(ContextBound, Generic[Domain, Codomain]): r""" Represent a linear map between two spaces. This class is intentionally small. It defines no storage assumptions and requires subclasses to provide forward and adjoint actions. The adjoint :math:`A^*` satisfies :math:`\langle A x, y\rangle_Y = \langle x, A^* y\rangle_X` for :math:`x \in X` and :math:`y \in Y`. For complex operators this is the conjugate adjoint. Parameters ---------- dom : Space Domain space ``X``. cod : Space Codomain space ``Y``. ctx : Context, str, or None, optional Backend context specification. Default is resolved from ``dom`` and ``cod``. Attributes ---------- dom : Space Domain space converted to ``ctx``. cod : Space Codomain space converted to ``ctx``. ctx : Context Resolved backend context. Examples -------- Use a concrete dense operator as a :class:`LinOp`. >>> import numpy as np >>> import spacecore as sc >>> ctx = sc.Context(sc.NumpyOps(), dtype=np.float64) >>> X = sc.DenseCoordinateSpace((2,), ctx) >>> A = sc.DenseLinOp(ctx.asarray([[1.0, 0.0], [0.0, 2.0]]), X, X, ctx) >>> A.apply(ctx.asarray([3.0, 4.0])) array([3., 8.]) """ def __init__(self, dom: Domain, cod: Codomain, ctx: Context | str | None = None): self.dom, self.cod = self._bind_context(ctx, dom, cod) @property def domain(self) -> Domain: """Domain space of this linear operator.""" return self.dom @property def codomain(self) -> Codomain: """Codomain space of this linear operator.""" return self.cod @cached_property def A(self) -> Any: """ Native numerical representation of this operator. Concrete subclasses may choose the representation that best matches their storage model: for example, dense operators return a dense array while sparse operators return their sparse matrix. Matrix-free or lazy operators generally do not have such a representation and should leave this property unimplemented. Use :meth:`to_dense` when a dense tensor materialization is explicitly required. """ raise NotImplementedError( f"{type(self).__name__} does not define a native numerical representation." )
[docs] @abstractmethod def apply(self, x: Any) -> Any: """Apply the forward map to an element of ``self.domain``."""
[docs] @abstractmethod def rapply(self, y: Any) -> Any: """Apply the adjoint map to an element of ``self.codomain``."""
def _apply_core(self, x: Any) -> Any: """Apply without adding validation beyond the concrete implementation.""" return self.apply(x) def _rapply_core(self, y: Any) -> Any: """Apply the adjoint without adding validation beyond the implementation.""" return self.rapply(y) def _vapply_core(self, xs: Any) -> Any: """Apply to a batch without adding validation beyond the implementation.""" return self.vapply(xs) def _rvapply_core(self, ys: Any) -> Any: """Apply the adjoint to a batch without adding validation beyond the implementation.""" return self.rvapply(ys) def __call__(self, x: Any) -> Any: """Apply this linear operator to ``x``.""" return self.apply(x)
[docs] def adjoint_apply(self, y: Any) -> Any: """Apply the adjoint of this linear operator to ``y``.""" return self.rapply(y)
[docs] def is_hermitian(self) -> bool | None: """ Return whether this operator is structurally Hermitian when known. Returns ------- bool | None ``True`` or ``False`` when the subclass can verify the structure cheaply, otherwise ``None`` for unknown or matrix-free operators. """ return None
[docs] @checked_method(in_space="domain", in_batched=True) def vapply(self, xs: Any) -> Any: """Apply over a leading batch axis. Input must have shape ``(N,) + domain.shape``; use ``moveaxis`` for other layouts.""" _warn_vmap_fallback_once(self, "vapply", _leading_batch_size(self.domain, xs)) return self.ops.vmap(self.apply, in_axes=0, out_axes=0)(xs)
[docs] @checked_method(in_space="codomain", in_batched=True) def rvapply(self, ys: Any) -> Any: """Apply the adjoint over a leading batch axis. Input must have shape ``(N,) + codomain.shape``; use ``moveaxis`` for other layouts.""" _warn_vmap_fallback_once(self, "rvapply", _leading_batch_size(self.codomain, ys)) return self.ops.vmap(self.rapply, in_axes=0, out_axes=0)(ys)
@property def H(self) -> LinOp: r"""Hermitian-adjoint view of this linear operator. Returns ------- LinOp Adjoint view satisfying :math:`\langle A x, y\rangle_Y = \langle x, A^* y\rangle_X`. """ from ._algebra import _AdjointViewLinOp view = getattr(self, "_adjoint_view", None) if view is None: view = _AdjointViewLinOp(self) self._adjoint_view = view return view def __add__(self, other: Any) -> LinOp: """Return the lazy sum ``self + other`` of two compatible operators.""" from ._algebra import make_sum if not isinstance(other, LinOp): return NotImplemented return make_sum((self, other)) def __radd__(self, other: Any) -> LinOp: """Return the lazy sum ``other + self`` of two compatible operators.""" from ._algebra import make_sum if isinstance(other, Number) and other == 0: return self if not isinstance(other, LinOp): return NotImplemented return make_sum((other, self)) def __neg__(self) -> LinOp: """Return the lazy negation ``-self``.""" from ._algebra import make_scaled return make_scaled(-1, self) def __sub__(self, other: Any) -> LinOp: """Return the lazy difference ``self - other`` of two compatible operators.""" from ._algebra import make_scaled, make_sum if not isinstance(other, LinOp): return NotImplemented return make_sum((self, make_scaled(-1, other))) def __rsub__(self, other: Any) -> LinOp: """Return the lazy difference ``other - self`` of two compatible operators.""" from ._algebra import make_scaled, make_sum if isinstance(other, Number) and other == 0: return make_scaled(-1, self) if not isinstance(other, LinOp): return NotImplemented return make_sum((other, make_scaled(-1, self))) def __mul__(self, scalar: Any) -> LinOp: """Return the lazy right scalar multiple ``self * scalar``.""" from ._algebra import is_scalar_like, make_scaled if not is_scalar_like(scalar): return NotImplemented return make_scaled(scalar, self) def __rmul__(self, scalar: Any) -> LinOp: """Return the lazy left scalar multiple ``scalar * self``.""" from ._algebra import is_scalar_like, make_scaled if not is_scalar_like(scalar): return NotImplemented return make_scaled(scalar, self) def __matmul__(self, other: Any) -> LinOp: """Return the lazy composition ``self @ other`` of two compatible operators.""" from ._algebra import make_composed if not isinstance(other, LinOp): return NotImplemented return make_composed(self, other)
[docs] def adjoint(self) -> LinOp: """Return the Hermitian-adjoint view of this linear operator.""" return self.H
[docs] def fuse(self, *, materialize: bool = False) -> LinOp: r""" Return an equivalent operator with fusible sub-expressions multiplied out. Tier-2 lazy-algebra simplification ([ADR-021](021_lazy_operator_algebra_and_simplification.md)): collapse each maximal subtree of densely-fusible operators into a single materialized operator — for example, a composition of dense operators becomes one :class:`DenseLinOp` holding the matrix product :math:`M_A M_B` — while leaving matrix-free and other non-materializable leaves intact. This is an **explicit, opt-in materialization**. The result is mathematically equal to ``self`` but only *within floating-point rounding*: fusing reassociates the arithmetic (multiplying matrices then applying differs from applying in sequence at the ulp level), so equality holds up to tolerance, not bit-for-bit. The fused operator preserves the domain, codomain, context, and scalar-field/dtype identity. A leaf operator returns itself. Parameters ---------- materialize : bool, optional With the default ``False``, a matrix-free operand ([ADR-008](008_linop_subclasses.md)) is **never** densified: it remains a lazy leaf and only breaks a fusible run. With ``True`` the caller explicitly accepts giving up the matrix-free contract: a matrix-free operand is densified into a :class:`DenseLinOp` (via its ``to_dense`` basis probe, which may be expensive), allowing the enclosing expression to collapse to a single dense operator. Returns ------- LinOp A fused operator with the same action as ``self`` (up to rounding). """ return self
[docs] def to_dense(self) -> Any: """ Materialize this operator as a dense backend array. The returned array has shape ``self.codomain.shape + self.domain.shape``. The default implementation is intended for small problems, debugging, and tests. It materializes the full coordinate matrix, so subclasses that already store a dense or sparse matrix should override this method for efficiency. """ return self.ops.reshape( self.to_matrix(), tuple(self.codomain.shape) + tuple(self.domain.shape) )
[docs] def to_sparse(self): raise NotImplementedError(f"{type(self).__name__} does not define sparse materialization.")
[docs] def to_matrix(self) -> Any: """ Materialize this operator as a 2D dense coordinate matrix. The returned array has shape ``(prod(self.codomain.shape), prod(self.domain.shape))``. The default implementation builds a batch of standard basis vectors and calls :meth:`vapply` once. If a space cannot batch-flatten or batch-unflatten its representation, it falls back to a safe Python loop. This method is for small/testing use; concrete storage-backed subclasses should override it when they can expose a matrix directly. """ domain_size = prod(self.domain.shape) codomain_size = prod(self.codomain.shape) eye = self.ops.eye(domain_size, dtype=self.dtype) try: xs = self.domain.unflatten_batch(eye) ys = self.vapply(xs) ys_flat = self.codomain.flatten_batch(ys) matrix = self.ops.transpose(ys_flat, (1, 0)) return self.ops.reshape(matrix, (codomain_size, domain_size)) except (AttributeError, NotImplementedError, TypeError) as exc: warnings.warn( ( f"{type(self).__name__}.to_matrix() could not use the batched " f"materialization path and is falling back to a Python loop. " f"This is slower and not JIT-friendly. Original error: " f"{type(exc).__name__}: {exc}" ), RuntimeWarning, stacklevel=2, ) columns = [] for i in range(domain_size): basis_vector = eye[:, i] x = self.domain.unflatten(basis_vector) y = self.apply(x) columns.append(self.codomain.flatten(y)) return self.ops.stack(tuple(columns), axis=1)
[docs] def assert_domain(self, x: Any) -> None: """Raise if ``x`` is not in the domain.""" self.dom.check_member(x)
[docs] def assert_codomain(self, y: Any) -> None: """Raise if ``y`` is not in the codomain.""" self.cod.check_member(y)
def __eq__(self, other: Any) -> bool: """Return structural equality when implemented by a subclass.""" return NotImplemented def _arrow(self) -> str: """Return the ``domain → codomain`` descriptor for this operator.""" return f"{describe_space(self.dom)}{describe_space(self.cod)}" def _repr_body(self) -> str: return self._arrow() def _short_repr(self) -> str: """Return a bounded ``ClassName(domain → codomain)`` form for nesting. Algebra operators show their operands in the full :meth:`__repr__` but collapse to this arrow form when nested, so deep trees never explode. """ return f"{type(self).__name__}({self._arrow()})" @abstractmethod def tree_flatten(self) -> tuple[tuple[Any, ...], Any]: """Flatten this operator for backend pytree registration.""" ... @classmethod @abstractmethod def tree_unflatten(cls, aux: Any, children: Any) -> Self: """Rebuild this operator from backend pytree data.""" ...