Source code for spacecore.linop._algebra

from __future__ import annotations

from math import prod
from numbers import Number
from typing import Any, Callable, Sequence, cast

from ._base import LinOp, Domain, Codomain
from ._metric import _requires_euclidean_or_riesz, metric_rapply, metric_rvapply
from .._checks import checked_method
from .._contextual import resolve_context_priority
from .._contextual._bound import _same_math_context
from .._repr import summarize_value
from ..backend import Context, jax_pytree_class
from ..kernels import core_kernels
from ..kernels.core.algebra import (
    batched_zeros as _batched_zeros,
    compose_chain as _compose_chain,
    conjugate_scalar as _conjugate_scalar,
    leading_shape as _leading_shape,
)


def is_scalar_like(value: Any) -> bool:
    """Return whether ``value`` can be used as a scalar multiplier for a ``LinOp``."""
    if isinstance(value, Number):
        return True
    shape = getattr(value, "shape", None)
    if shape is not None:
        return tuple(shape) == ()
    ndim = getattr(value, "ndim", None)
    return ndim == 0


def _scalar_eq(a: Any, b: Any) -> bool:
    """Return whether two scalar-likes are equal, NaN-reflexive.

    Mirrors the ``equal_nan=True`` used for array values: two matching NaN
    scalars compare equal so a NaN-scaled operator equals itself. Always returns
    a real Python ``bool`` (a 0-d backend-array ``==`` would otherwise yield
    ``np.bool_``, which leaks through the ``and`` combinator of any container).
    """
    if bool(a == b):
        return True
    try:
        # ``x != x`` is True only for NaN (including a complex value with a NaN
        # component), so this branch matches NaN against NaN.
        return bool(a != a) and bool(b != b)
    except Exception:
        return False


def _require_same_context(ops: Sequence[LinOp]) -> Context:
    """Return the common context for algebra operands or raise."""
    ctx = ops[0].ctx
    for i, op in enumerate(ops[1:], start=1):
        if not _same_math_context(ops[0].ctx, op.ctx):
            raise ValueError(
                "All LinOp operands in an algebraic expression must have the same ctx; "
                f"operand 0 has ctx {ctx!r}, operand {i} has ctx {op.ctx!r}."
            )
    return ctx


def _same_space_for_algebra(left: Any, right: Any) -> bool:
    """Return whether two spaces are compatible for algebraic composition."""
    if left == right:
        return True
    if type(left) is not type(right):
        return False
    if tuple(left.shape) != tuple(right.shape):
        return False
    if not _same_math_context(left.ctx, right.ctx):
        return False
    try:
        return left.convert(right.ctx) == right
    except Exception:
        return False


def _require_linop(op: Any, name: str) -> LinOp:
    """Return ``op`` as a linear operator or raise a typed error."""
    if not isinstance(op, LinOp):
        raise TypeError(f"{name} must be a LinOp, got {type(op).__name__}.")
    return op


def _scalar_equal(value: Any, target: Any) -> bool:
    """Return whether two scalar-like values compare equal."""
    try:
        return bool(value == target)
    except Exception:
        return False


def _is_zero_scalar(value: Any) -> bool:
    """Return whether ``value`` is scalar-like zero."""
    return _scalar_equal(value, 0)


def _is_one_scalar(value: Any) -> bool:
    """Return whether ``value`` is scalar-like one."""
    return _scalar_equal(value, 1)


def _flatten_sum_terms(ops: Sequence[LinOp]) -> tuple[LinOp, ...]:
    """Flatten nested lazy sums into a tuple of terms."""
    terms: list[LinOp] = []
    for i, op in enumerate(ops):
        op = _require_linop(op, f"ops[{i}]")
        if isinstance(op, SumLinOp):
            terms.extend(_flatten_sum_terms(op.parts))
        else:
            terms.append(op)
    return tuple(terms)


[docs] def make_sum(ops: Sequence[LinOp]) -> LinOp: """ Return a locally simplified lazy sum of linear operators. This factory performs only local algebraic canonicalization: nested ``SumLinOp`` nodes are flattened and ``ZeroLinOp`` terms are removed. It does not collect like terms, reorder operands, or attempt full symbolic optimization. All operands must have the same context, domain, and codomain before a simplified operator is returned. Parameters ---------- ops : sequence of LinOp Nonempty sequence of operators with common domain and codomain. Returns ------- LinOp Simplified lazy sum, a single operand, or a zero operator. """ if not ops: raise ValueError("make_sum requires a nonempty sequence of LinOp operands.") terms = _flatten_sum_terms(ops) ctx = _require_same_context(terms) domain = terms[0].domain codomain = terms[0].codomain for i, op in enumerate(terms[1:], start=1): if not _same_space_for_algebra(op.domain, domain) or not _same_space_for_algebra( op.codomain, codomain ): raise ValueError( "All SumLinOp operands must have the same domain and codomain; " f"operand 0 maps {domain!r} -> {codomain!r}, " f"operand {i} maps {op.domain!r} -> {op.codomain!r}." ) nonzero_terms = tuple(op for op in terms if not isinstance(op, ZeroLinOp)) if not nonzero_terms: return ZeroLinOp(domain, codomain, ctx) if len(nonzero_terms) == 1: return nonzero_terms[0] return SumLinOp(nonzero_terms)
[docs] def make_scaled(scalar: Any, op: LinOp) -> LinOp: """ Return a locally simplified scalar multiple of a linear operator. This factory performs only local algebraic canonicalization: zero and unit scalars are simplified, and nested ``ScaledLinOp`` nodes are folded into one scalar. It does not distribute scaling over sums or perform full symbolic optimization. Complex scalars retain the usual conjugated coefficient in ``rapply`` through ``ScaledLinOp``. Parameters ---------- scalar : scalar-like Scalar coefficient multiplying ``op``. op : LinOp Operator to scale. Returns ------- LinOp Simplified scalar multiple. """ op = _require_linop(op, "op") if not is_scalar_like(scalar): raise TypeError(f"scalar must be scalar-like, got {type(scalar).__name__}.") if _is_zero_scalar(scalar): return ZeroLinOp(op.domain, op.codomain, op.ctx) if _is_one_scalar(scalar): return op if isinstance(op, ZeroLinOp): return op if isinstance(op, ScaledLinOp): return make_scaled(scalar * op.scalar, op.op) return ScaledLinOp(scalar, op)
[docs] def make_composed(left: LinOp, right: LinOp) -> LinOp: """ Return a locally simplified composition of two linear operators. This factory performs only local algebraic canonicalization: identity factors are removed and compositions with zero maps become zero maps. It preserves the binary ``ComposedLinOp`` representation and does not flatten multi-factor chains or attempt full symbolic optimization. Operands must have the same context and compatible middle spaces before a simplified operator is returned. Parameters ---------- left : LinOp Operator applied second. right : LinOp Operator applied first. Returns ------- LinOp Simplified lazy composition representing ``left @ right``. """ left = _require_linop(left, "left") right = _require_linop(right, "right") _require_same_context((left, right)) if not _same_space_for_algebra(right.codomain, left.domain): raise ValueError( "ComposedLinOp requires right.codomain == left.domain; " f"got {right.codomain!r} and {left.domain!r}." ) if isinstance(right, IdentityLinOp): return left if isinstance(left, IdentityLinOp): return right if isinstance(left, ZeroLinOp): return ZeroLinOp(right.domain, left.codomain, left.ctx) if isinstance(right, ZeroLinOp): return ZeroLinOp(right.domain, left.codomain, left.ctx) return ComposedLinOp(left, right)
[docs] @core_kernels("scaled") @jax_pytree_class class ScaledLinOp(LinOp[Domain, Codomain]): r""" Lazy scalar multiple of a linear operator. ``ScaledLinOp(alpha, A)`` represents the mathematical operator ``alpha * A``. Its context is exactly ``A.ctx``; its domain is ``A.domain`` and its codomain is ``A.codomain``. No dense matrix representation is formed. The forward action is ``apply(x) = alpha * A.apply(x)`` for ``x in A.domain``. The reverse action is ``rapply(y) = conj(alpha) * A.rapply(y)`` for ``y in A.codomain``, so complex scalars use the conjugated coefficient. Parameters ---------- scalar : scalar-like Scalar multiplier. op : LinOp Operator being scaled. Attributes ---------- scalar : scalar-like Stored scalar multiplier. op : LinOp Stored operand. """ def __init__(self, scalar: Any, op: LinOp[Domain, Codomain]) -> None: op = _require_linop(op, "op") if not is_scalar_like(scalar): raise TypeError(f"scalar must be scalar-like, got {type(scalar).__name__}.") super().__init__(op.domain, op.codomain, op.ctx) self.scalar = scalar self.op = op
[docs] @checked_method(in_space="domain", out_space="codomain") def apply(self, x: Any) -> Any: """Return ``scalar * op.apply(x)``.""" return self._apply_core(x)
[docs] @checked_method(in_space="codomain", out_space="domain") def rapply(self, y: Any) -> Any: """Return ``conj(scalar) * op.rapply(y)``.""" return self._rapply_core(y)
[docs] @checked_method(in_space="domain", out_space="codomain", in_batched=True, out_batched=True) def vapply(self, xs: Any) -> Any: """Return ``scalar * op.vapply(xs)``.""" return self._vapply_core(xs)
[docs] def rvapply(self, ys: Any) -> Any: """Return ``conj(scalar) * op.rvapply(ys)``.""" xs = self.op.rvapply(ys) return self.domain.scale_batch(_conjugate_scalar(self.scalar), xs)
[docs] def fuse(self, *, materialize: bool = False) -> LinOp: r"""Fuse the operand and fold the scalar into a dense matrix (ADR-021). When the fused operand is dense, replace ``c · A`` with one :class:`DenseLinOp` holding ``c · M_A``; otherwise keep the scaling lazy, so a matrix-free operand is never densified. Adjoint-consistent: the fused operator's adjoint is ``conj(c) · M_A^*`` (with the metric), exactly the lazy ``ScaledLinOp`` adjoint. """ from ._dense import DenseLinOp op = self.op.fuse(materialize=materialize) if isinstance(op, DenseLinOp): matrix = self.scalar * op.to_matrix() tensor = self.ops.reshape( matrix, tuple(op.codomain.shape) + tuple(op.domain.shape) ) return DenseLinOp(tensor, op.domain, op.codomain, self.ctx) return make_scaled(self.scalar, op)
[docs] def is_hermitian(self) -> bool | None: """ Return whether this scaled operator is structurally Hermitian. For a real scalar ``s`` the adjoint satisfies ``(s A)* = s A*``, so ``s A`` is self-adjoint exactly when ``A`` is; the operand's verdict is propagated faithfully. For a non-real scalar the relation becomes ``(s A)* = conj(s) A* != s A`` in general, so Hermiticity cannot be decided cheaply and ``None`` is returned. Returns ------- bool | None ``self.op.is_hermitian()`` when ``self.scalar`` is real, otherwise ``None`` for unknown. """ if _conjugate_scalar(self.scalar) == self.scalar: return self.op.is_hermitian() return None
def __eq__(self, other: Any) -> bool: """Return whether another scaled operator has the same scalar and operand.""" if not self._eq_backend_compatible(other): # Tier 1: backend return NotImplemented # NaN-reflexive, returns a real Python bool (no np.bool_ leak). if not _scalar_eq(self.scalar, other.scalar): # Tier 3: scalar value return False return self.op == other.op # operand (own gate) def _repr_body(self) -> str: return f"{summarize_value(self.scalar)} · {self.op._short_repr()}" def tree_flatten(self): """Flatten this operator for pytree registration.""" children = (self.scalar, self.op) aux = () return children, aux @classmethod def tree_unflatten(cls, aux, children): """Rebuild this operator from pytree data.""" scalar, op = children return cls(scalar, op) def _convert(self, new_ctx: Context) -> ScaledLinOp: """Convert the operand to ``new_ctx`` while preserving the scalar.""" return ScaledLinOp(self.scalar, self.op.convert(new_ctx))
[docs] @core_kernels("sum") @jax_pytree_class class SumLinOp(LinOp[Domain, Codomain]): r""" Lazy finite sum of linear operators with common spaces. ``SumLinOp((A1, ..., Ak))`` represents ``A1 + ... + Ak`` for a nonempty sequence of ``LinOp`` instances. All operands must have the same ``ctx``, the same domain, and the same codomain before construction. The resulting operator has that shared context, domain, and codomain. The forward action is ``apply(x) = sum_i Ai.apply(x)`` for the shared domain element ``x``. The reverse action is ``rapply(y) = sum_i Ai.rapply(y)`` for the shared codomain element ``y``. Parameters ---------- ops : sequence of LinOp Nonempty sequence of operators with common context, domain, and codomain. Attributes ---------- parts : tuple of LinOp Stored operands in the lazy sum. """ def __init__(self, ops: Sequence[LinOp[Domain, Codomain]]) -> None: if not ops: raise ValueError("SumLinOp requires a nonempty sequence of LinOp operands.") parts = tuple(_require_linop(op, f"ops[{i}]") for i, op in enumerate(ops)) ctx = _require_same_context(parts) domain = parts[0].domain codomain = parts[0].codomain for i, op in enumerate(parts[1:], start=1): if not _same_space_for_algebra(op.domain, domain) or not _same_space_for_algebra( op.codomain, codomain ): raise ValueError( "All SumLinOp operands must have the same domain and codomain; " f"operand 0 maps {domain!r} -> {codomain!r}, " f"operand {i} maps {op.domain!r} -> {op.codomain!r}." ) super().__init__(domain, codomain, ctx) self.ops_tuple = parts @property def parts(self) -> tuple[LinOp[Domain, Codomain], ...]: """Operators in this lazy sum.""" return self.ops_tuple
[docs] @checked_method(in_space="domain", out_space="codomain") def apply(self, x: Any) -> Any: """Return ``sum_i ops[i].apply(x)``.""" return self._apply_core(x)
[docs] @checked_method(in_space="codomain", out_space="domain") def rapply(self, y: Any) -> Any: """Return ``sum_i ops[i].rapply(y)``.""" return self._rapply_core(y)
[docs] @checked_method(in_space="domain", out_space="codomain", in_batched=True, out_batched=True) def vapply(self, xs: Any) -> Any: """Return ``sum_i ops[i].vapply(xs)``.""" return self._vapply_core(xs)
[docs] @checked_method(in_space="codomain", out_space="domain", in_batched=True, out_batched=True) def rvapply(self, ys: Any) -> Any: """Return ``sum_i ops[i].rvapply(ys)``.""" add_batch = self.domain.add_batch acc = self.ops_tuple[0].rvapply(ys) for op in self.ops_tuple[1:]: acc = add_batch(acc, op.rvapply(ys)) return acc
[docs] def fuse(self, *, materialize: bool = False) -> LinOp: r"""Fuse each term and combine the dense terms into one ``DenseLinOp`` (ADR-021). Fuse every term, sum the matrices of the densely-fusible ones into a single :class:`DenseLinOp`, and keep the remaining (matrix-free or structured) terms as lazy summands — so a matrix-free term is never densified. Adjoint-consistent and additive: ``(A + B)^* = A^* + B^*``. Combining reassociates the term order, so equality holds up to rounding. """ from ._dense import DenseLinOp fused = [p.fuse(materialize=materialize) for p in self.parts] dense = [p for p in fused if isinstance(p, DenseLinOp)] if len(dense) < 2: return make_sum(fused) matrix = dense[0].to_matrix() for d in dense[1:]: matrix = matrix + d.to_matrix() ref = dense[0] tensor = self.ops.reshape( matrix, tuple(ref.codomain.shape) + tuple(ref.domain.shape) ) combined = DenseLinOp(tensor, ref.domain, ref.codomain, self.ctx) others = [p for p in fused if not isinstance(p, DenseLinOp)] return make_sum([combined, *others])
[docs] def is_hermitian(self) -> bool | None: """ Return whether this lazy sum is structurally Hermitian. The adjoint is additive, ``(A1 + ... + Ak)* = A1* + ... + Ak*``, so a sum is self-adjoint when every term is. If all operands are provably Hermitian this returns ``True``; otherwise the verdict is not cheaply decidable (a sum of non-Hermitian terms may still be Hermitian) and ``None`` is returned. ``False`` is never returned. Returns ------- bool | None ``True`` when every part is provably Hermitian, otherwise ``None``. """ if all(op.is_hermitian() is True for op in self.parts): return True return None
def __eq__(self, other: Any) -> bool: """Return whether another sum has the same operands, in order.""" if not self._eq_backend_compatible(other): # Tier 1: backend return NotImplemented if len(self.ops_tuple) != len(other.ops_tuple): # Tier 2: operand count before zip return False # Ordered, structural: A + B != B + A. Commutative equivalence is a # separate concern (a future equiv()), not __eq__. return all(a == b for a, b in zip(self.ops_tuple, other.ops_tuple)) def _repr_body(self) -> str: from .._repr import truncated_join return truncated_join((op._short_repr() for op in self.ops_tuple), " + ") def tree_flatten(self): """Flatten this operator for pytree registration.""" children = self.ops_tuple aux = () return children, aux @classmethod def tree_unflatten(cls, aux, children): """Rebuild this operator from pytree data.""" return cls(tuple(children)) def _convert(self, new_ctx: Context) -> SumLinOp: """Convert all operands to ``new_ctx``.""" return SumLinOp(tuple(op.convert(new_ctx) for op in self.ops_tuple))
[docs] @core_kernels("composed") @jax_pytree_class class ComposedLinOp(LinOp[Domain, Codomain]): r""" Lazy composition of two linear operators. ``ComposedLinOp(A, B)`` represents ``A @ B = A circ B``. The operands must have the same ``ctx`` before construction, and ``B.codomain`` must equal ``A.domain``. The resulting operator has domain ``B.domain`` and codomain ``A.codomain``. The forward action is ``apply(x) = A.apply(B.apply(x))`` for ``x in B.domain``. The reverse action is ``rapply(z) = B.rapply(A.rapply(z))`` for ``z in A.codomain``. Parameters ---------- left : LinOp Operator applied second. right : LinOp Operator applied first. Attributes ---------- left : LinOp Left operand. right : LinOp Right operand. """ def __init__(self, left: LinOp, right: LinOp) -> None: left = _require_linop(left, "left") right = _require_linop(right, "right") _require_same_context((left, right)) if not _same_space_for_algebra(right.codomain, left.domain): raise ValueError( "ComposedLinOp requires right.codomain == left.domain; " f"got {right.codomain!r} and {left.domain!r}." ) super().__init__(right.domain, left.codomain, left.ctx) self.left = left self.right = right # Fuse the (possibly nested) composition into one flat chain of leaf # operators in application order — right applied first, then left. # Cached at construction so every apply runs a single check-free loop # instead of re-walking the binary ComposedLinOp tree. self._apply_chain = _compose_chain(right) + _compose_chain(left)
[docs] @checked_method(in_space="domain", out_space="codomain") def apply(self, x: Any) -> Any: """Return ``left.apply(right.apply(x))``.""" return self._apply_core(x)
[docs] @checked_method(in_space="codomain", out_space="domain") def rapply(self, z: Any) -> Any: """Return ``right.rapply(left.rapply(z))``.""" return self._rapply_core(z)
[docs] @checked_method(in_space="domain", out_space="codomain", in_batched=True, out_batched=True) def vapply(self, xs: Any) -> Any: """Return ``left.vapply(right.vapply(xs))``.""" return self._vapply_core(xs)
[docs] def rvapply(self, zs: Any) -> Any: """Return ``right.rvapply(left.rvapply(zs))``.""" return self.right.rvapply(self.left.rvapply(zs))
[docs] def fuse(self, *, materialize: bool = False) -> LinOp: r"""Fuse a composition of dense operators into one ``DenseLinOp`` (ADR-021). Fuse each operand first, then — when both fused operands are dense — replace ``A @ B`` with a single :class:`DenseLinOp` holding the matrix product :math:`M_A M_B`. Any operand that does not fuse to a dense operator (matrix-free leaves, sparse, structured) keeps the composition lazy, so a matrix-free operand is never densified. The fused matrix is adjoint-consistent on any geometry: the shared middle-space Riesz maps cancel, so the fused operator's metric adjoint equals ``B* @ A*`` up to floating-point rounding. """ from ._dense import DenseLinOp left = self.left.fuse(materialize=materialize) right = self.right.fuse(materialize=materialize) if isinstance(left, DenseLinOp) and isinstance(right, DenseLinOp): ops = self.ops matrix = ops.matmul(left.to_matrix(), right.to_matrix()) tensor = ops.reshape( matrix, tuple(left.codomain.shape) + tuple(right.domain.shape) ) return DenseLinOp(tensor, right.domain, left.codomain, self.ctx) return make_composed(left, right)
[docs] def is_hermitian(self) -> bool | None: """ Return whether this composition is structurally Hermitian. A Gram product ``R* @ R`` (equivalently ``L @ L*``) is self-adjoint in any geometry, since ``<R* R x, y> = <R x, R y> = <x, R* R y>``. This is detected structurally when ``self.left == self.right.H`` (the adjoint view compares its wrapped operand, so this also matches ``L @ L*``). Any other composition is not cheaply decidable and returns ``None``; non-Hermiticity is never asserted. Returns ------- bool | None ``True`` for a Gram product, otherwise ``None``. """ if self.left == self.right.H: return True return None
def __eq__(self, other: Any) -> bool: """Return whether another composition has the same operands, in order.""" if not self._eq_backend_compatible(other): # Tier 1: backend return NotImplemented return self.left == other.left and self.right == other.right def _repr_body(self) -> str: return f"{self.left._short_repr()}{self.right._short_repr()}" def tree_flatten(self): """Flatten this operator for pytree registration.""" children = (self.left, self.right) aux = () return children, aux @classmethod def tree_unflatten(cls, aux, children): """Rebuild this operator from pytree data.""" left, right = children return cls(left, right) def _convert(self, new_ctx: Context) -> ComposedLinOp: """Convert both operands to ``new_ctx``.""" return ComposedLinOp(self.left.convert(new_ctx), self.right.convert(new_ctx))
[docs] @core_kernels("zero") @jax_pytree_class class ZeroLinOp(LinOp[Domain, Codomain]): r""" Lazy zero map between two spaces. ``ZeroLinOp(X, Y)`` represents the linear map ``0 : X -> Y``. The context is resolved from the optional ``ctx`` argument and the two spaces, then both spaces are converted to that context. Its domain is ``X`` and its codomain is ``Y`` in the resolved context. The forward action is ``apply(x) = 0_Y`` for ``x in X``. The reverse action is ``rapply(y) = 0_X`` for ``y in Y``. Parameters ---------- dom : Space Domain space. cod : Space Codomain space. ctx : Context, str, or None, optional Backend context specification. Default is resolved from the spaces. """ def __init__( self, dom: Domain, cod: Codomain, ctx: Context | str | None = None, ) -> None: super().__init__(dom, cod, ctx)
[docs] @checked_method(in_space="domain", out_space="codomain") def apply(self, x: Any) -> Any: """Return the zero element of the codomain.""" return self._apply_core(x)
[docs] @checked_method(in_space="codomain", out_space="domain") def rapply(self, y: Any) -> Any: """Return the zero element of the domain.""" return self._rapply_core(y)
[docs] @checked_method(in_space="domain", in_batched=True) def vapply(self, xs: Any) -> Any: """Return the batched zero element of the codomain.""" return self._vapply_core(xs)
[docs] @checked_method(in_space="codomain", in_batched=True) def rvapply(self, ys: Any) -> Any: """Return the batched zero element of the domain.""" return _batched_zeros(self.domain, _leading_shape(self.codomain, ys))
[docs] def to_dense(self) -> Any: """ Return the dense tensor representation of the zero map. The returned array has shape ``self.codomain.shape + self.domain.shape``. """ return self.ops.zeros( tuple(self.codomain.shape) + tuple(self.domain.shape), dtype=self.dtype )
[docs] def is_hermitian(self) -> bool: """ Return whether the zero map is Hermitian. Returns ------- bool ``True`` exactly when domain and codomain are the same space. """ return self.domain == self.codomain
def __eq__(self, other: Any) -> bool: """Return whether another zero map has the same spaces.""" if not self._eq_backend_compatible(other): # Tier 1: backend return NotImplemented return self.domain == other.domain and self.codomain == other.codomain # Tier 2 def tree_flatten(self): """Flatten this operator for pytree registration.""" children = () aux = (self.domain, self.codomain, self.ctx) return children, aux @classmethod def tree_unflatten(cls, aux, children): """Rebuild this operator from pytree data.""" domain, codomain, ctx = aux return cls(domain, codomain, ctx) def _convert(self, new_ctx: Context) -> ZeroLinOp: """Convert domain and codomain spaces to ``new_ctx``.""" return ZeroLinOp(self.domain.convert(new_ctx), self.codomain.convert(new_ctx), new_ctx)
[docs] @core_kernels("identity") @jax_pytree_class class IdentityLinOp(LinOp[Domain, Domain]): r""" Lazy identity map on a space. ``IdentityLinOp(X)`` represents the identity operator ``I_X : X -> X``. The context is resolved from the optional ``ctx`` argument and the space, and the resulting operator has domain and codomain equal to ``X`` in that context. The forward action is ``apply(x) = x`` for ``x in X``. The reverse action is ``rapply(x) = x`` for ``x in X``. Parameters ---------- space : Space Domain and codomain space. ctx : Context, str, or None, optional Backend context specification. Default is resolved from ``space``. """ def __init__(self, space: Domain, ctx: Context | str | None = None) -> None: super().__init__(space, space, ctx)
[docs] @checked_method(in_space="domain", out_space="codomain") def apply(self, x: Any) -> Any: """Return ``x`` after domain validation.""" return self._apply_core(x)
[docs] @checked_method(in_space="codomain", out_space="domain") def rapply(self, x: Any) -> Any: """Return ``x`` after codomain validation.""" return self._rapply_core(x)
[docs] @checked_method(in_space="domain", in_batched=True) def vapply(self, xs: Any) -> Any: """Return ``xs`` after batched domain validation.""" return xs
[docs] @checked_method(in_space="codomain", in_batched=True) def rvapply(self, xs: Any) -> Any: """Return ``xs`` after batched codomain validation.""" return xs
[docs] def to_dense(self) -> Any: """ Return the dense tensor representation of this identity map. The returned array has shape ``self.codomain.shape + self.domain.shape``. """ size = 1 for dim in self.domain.shape: size *= dim eye = self.ops.eye(size, dtype=self.dtype) return self.ops.reshape(eye, tuple(self.codomain.shape) + tuple(self.domain.shape))
[docs] def is_hermitian(self) -> bool: """ Return whether this identity operator is Hermitian. Returns ------- bool Always ``True``. """ return True
def __eq__(self, other: Any) -> bool: """Return whether another identity map has the same space.""" if not self._eq_backend_compatible(other): # Tier 1: backend return NotImplemented return self.domain == other.domain # Tier 2 (square: cod == dom) def _repr_body(self) -> str: from .._repr import describe_space return describe_space(self.domain) def tree_flatten(self): """Flatten this operator for pytree registration.""" children = () aux = (self.domain, self.ctx) return children, aux @classmethod def tree_unflatten(cls, aux, children): """Rebuild this operator from pytree data.""" domain, ctx = aux return cls(domain, ctx) def _convert(self, new_ctx: Context) -> IdentityLinOp: """Convert the identity space to ``new_ctx``.""" return IdentityLinOp(self.domain.convert(new_ctx), new_ctx)
[docs] @core_kernels("matrixfree") @jax_pytree_class class MatrixFreeLinOp(LinOp[Domain, Codomain]): """ Linear operator defined by user-supplied forward and reverse callables. ``MatrixFreeLinOp(apply, rapply, X, Y)`` represents a matrix-free map ``A : X -> Y`` without storing or materializing a matrix. The context is resolved from the optional ``ctx`` argument and the spaces, then the spaces are converted to that context. The forward action is ``apply(x) = apply_fn(x)`` for ``x in X``. The reverse action is ``rapply(y) = rapply_fn(y)`` for ``y in Y``. The supplied ``rapply`` callable must already be the true adjoint with respect to the declared domain and codomain inner products: ``<apply(x), y>_Y = <x, rapply(y)>_X``. It is not automatically corrected with Riesz maps. If you only have a Euclidean coordinate adjoint in non-Euclidean spaces, compute the metric adjoint outside SpaceCore and pass that callable as ``rapply``. When checks are enabled, inputs and callable outputs are validated against the corresponding domain and codomain, but construction does not run adjoint dot-tests. Parameters ---------- apply : callable Callable with signature ``apply(x: Any) -> Any`` implementing the forward map from ``dom`` to ``cod``. rapply : callable Callable with signature ``rapply(y: Any) -> Any`` implementing the true space adjoint map from ``cod`` back to ``dom``. For non-Euclidean spaces this is generally not the same as the Euclidean coordinate adjoint. dom : Space Domain space containing valid inputs for ``apply`` and outputs from ``rapply``. cod : Space Codomain space containing outputs from ``apply`` and valid inputs for ``rapply``. ctx : Context, str, or None, optional Optional context specification. An explicit context wins over inferred contexts from ``dom`` and ``cod``. vapply : callable or None, optional Optional callable with signature ``vapply(xs: Any) -> Any`` for batched forward application. If omitted, backend ``vmap`` fallback is used. rvapply : callable or None, optional Optional callable with signature ``rvapply(ys: Any) -> Any`` for batched adjoint application. If omitted, backend ``vmap`` fallback is used. Returns ------- MatrixFreeLinOp Operator using the supplied callables for forward, adjoint, and optionally batched actions. Notes ----- See ``docs/dev/adr/009_metric_adjoint.md`` for the full design rationale for metric adjoints and the distinction between direct matrix-free adjoints and coordinate-adjoint wrapping. """
[docs] def __init__( self, apply: Callable[[Any], Any], rapply: Callable[[Any], Any], dom: Domain, cod: Codomain, ctx: Context | str | None = None, vapply: Callable[[Any], Any] | None = None, rvapply: Callable[[Any], Any] | None = None, ) -> None: """ Initialize a matrix-free linear operator. Parameters ---------- apply: Callable ``apply(x)`` that accepts an element of ``dom`` and returns an element of ``cod``. rapply: Callable ``rapply(y)`` that accepts an element of ``cod`` and returns an element of ``dom``. dom: Domain space of the operator. cod: Codomain space of the operator. ctx: Optional context specification for the operator and converted spaces. vapply: Optional callable for batched forward application over ``dom`` batches. rvapply: Optional callable for batched adjoint application over ``cod`` batches. Returns ------- None The initializer stores the callables and converted spaces on ``self``. """ if not callable(apply): raise TypeError(f"apply must be callable, got {type(apply).__name__}.") if not callable(rapply): raise TypeError(f"rapply must be callable, got {type(rapply).__name__}.") if vapply is not None and not callable(vapply): raise TypeError(f"vapply must be callable, got {type(vapply).__name__}.") if rvapply is not None and not callable(rvapply): raise TypeError(f"rvapply must be callable, got {type(rvapply).__name__}.") super().__init__(dom, cod, ctx) self.apply_fn = apply self.rapply_fn = rapply self.vapply_fn = vapply self.rvapply_fn = rvapply if self._checks_at_least("strict"): self._check_adjoint_consistency()
def _check_adjoint_consistency(self) -> None: """Probe the declared adjoint identity on deterministic space elements.""" if not all(hasattr(space, "inner") for space in (self.domain, self.codomain)): return def probe(space: Any) -> Any: if hasattr(space, "ones"): return space.ones() if hasattr(space, "unflatten") and hasattr(space, "shape"): flat = self.ops.ones((prod(space.shape),), dtype=self.dtype) return space.unflatten(flat) raise TypeError(f"{type(space).__name__} cannot build a strict probe element.") try: x = probe(self.domain) y = probe(self.codomain) ax = self.apply_fn(x) ahy = self.rapply_fn(y) self.codomain._check_member(ax) self.domain._check_member(ahy) lhs = cast(Any, self.codomain).inner(ax, y) rhs = cast(Any, self.domain).inner(x, ahy) consistent = bool(self.ops.allclose(lhs, rhs)) except Exception as exc: raise ValueError( "Strict matrix-free adjoint consistency check could not be completed." ) from exc if not consistent: raise ValueError( "Strict matrix-free adjoint consistency check failed: " "<A x, y> != <x, A* y> for the deterministic probe." )
[docs] @classmethod def from_coordinate_adjoint( cls, apply: Callable[[Any], Any], coordinate_rapply: Callable[[Any], Any], dom: Domain, cod: Codomain, ctx: Context | str | None = None, vapply: Callable[[Any], Any] | None = None, coordinate_rvapply: Callable[[Any], Any] | None = None, ) -> MatrixFreeLinOp: r""" Build a matrix-free operator from a Euclidean coordinate adjoint. ``coordinate_rapply`` is interpreted as the Euclidean coordinate adjoint ``A^dagger`` of ``apply``. The stored ``rapply`` callable is the metric adjoint ``A^sharp(y) = R_X^-1 A^dagger R_Y y`` for the declared domain ``X`` and codomain ``Y``. Euclidean spaces have identity Riesz maps, so this degenerates to the supplied coordinate adjoint. Non-Euclidean spaces must expose usable Riesz maps at construction time; otherwise the constructor rejects the operator rather than storing an incoherent adjoint. When ``coordinate_rvapply`` is provided, it is treated as the batched Euclidean coordinate adjoint and wrapped with batched Riesz maps. If batched Riesz application is unavailable, the public metric-adjoint helper falls back to vectorized scalar ``rapply`` consistently. When ``coordinate_rvapply`` is omitted, ``rvapply_fn`` remains ``None`` and the normal ``rvapply`` fallback vectorizes the wrapped scalar adjoint. See ``docs/dev/adr/009_metric_adjoint.md`` for the full design rationale. Parameters ---------- apply : callable Forward coordinate action from ``dom`` to ``cod``. coordinate_rapply : callable Euclidean coordinate adjoint from ``cod`` coordinates to ``dom`` coordinates. dom, cod : Space Domain and codomain spaces. ctx : Context, str, or None, optional Optional context specification. vapply : callable or None, optional Optional batched forward application. coordinate_rvapply : callable or None, optional Optional batched Euclidean coordinate adjoint. If omitted, batched adjoints use backend ``vmap`` over the wrapped scalar ``rapply``. """ if not callable(apply): raise TypeError(f"apply must be callable, got {type(apply).__name__}.") if not callable(coordinate_rapply): raise TypeError( f"coordinate_rapply must be callable, got {type(coordinate_rapply).__name__}." ) if coordinate_rvapply is not None: if not callable(coordinate_rvapply): raise TypeError( f"coordinate_rvapply must be callable, got {type(coordinate_rvapply).__name__}." ) resolved_ctx = resolve_context_priority(ctx, dom, cod) dom = dom.convert(resolved_ctx) cod = cod.convert(resolved_ctx) try: _requires_euclidean_or_riesz(dom, cod, "MatrixFreeLinOp.from_coordinate_adjoint") except TypeError as exc: raise ValueError(str(exc)) from exc def wrapped_rapply(y: Any) -> Any: return metric_rapply(dom, cod, coordinate_rapply, y) wrapped_rvapply: Callable[[Any], Any] | None = None if coordinate_rvapply is not None: def _wrapped_rvapply(ys: Any) -> Any: return metric_rvapply( dom, cod, coordinate_rapply, coordinate_rvapply, ys, opname="MatrixFreeLinOp.from_coordinate_adjoint", ops=resolved_ctx.ops, ) wrapped_rvapply = _wrapped_rvapply return cls(apply, wrapped_rapply, dom, cod, resolved_ctx, vapply, wrapped_rvapply)
[docs] @checked_method(in_space="domain", out_space="codomain") def apply(self, x: Any) -> Any: """ Apply the forward callable. Parameters ---------- x: Element of ``self.domain`` passed to ``apply_fn``. Returns ------- Any Element of ``self.codomain`` returned by ``apply_fn``. """ return self._apply_core(x)
[docs] @checked_method(in_space="codomain", out_space="domain") def rapply(self, y: Any) -> Any: """ Apply the adjoint callable. Parameters ---------- y: Element of ``self.codomain`` passed to ``rapply_fn``. Returns ------- Any Element of ``self.domain`` returned by ``rapply_fn``. """ return self._rapply_core(y)
[docs] @checked_method(in_space="domain", out_space="codomain", in_batched=True, out_batched=True) def vapply(self, xs: Any) -> Any: """ Apply this operator to a batch of domain elements. Parameters ---------- xs: Batched element of ``self.domain``. Returns ------- Any Batched element of ``self.codomain`` produced by ``vapply_fn`` or by the fallback batching implementation. """ if self.vapply_fn is None: return super().vapply(xs) return self.vapply_fn(xs)
[docs] @checked_method(in_space="codomain", out_space="domain", in_batched=True, out_batched=True) def rvapply(self, ys: Any) -> Any: """ Apply the adjoint operator to a batch of codomain elements. Parameters ---------- ys: Batched element of ``self.codomain``. Returns ------- Any Batched element of ``self.domain`` produced by ``rvapply_fn`` or by the fallback batching implementation. """ if self.rvapply_fn is None: return super().rvapply(ys) return self.rvapply_fn(ys)
[docs] def fuse(self, *, materialize: bool = False) -> LinOp: """Stay matrix-free unless ``materialize=True`` is requested (ADR-021/ADR-008). By default a matrix-free operator is never densified, so it returns itself. With ``materialize=True`` the caller explicitly opts into densification: the operator is probed into a dense tensor (via ``to_dense``, the basis sweep) and returned as a :class:`DenseLinOp`, letting an enclosing expression collapse to a single dense operator. """ if not materialize: return self from ._dense import DenseLinOp return DenseLinOp(self.to_dense(), self.domain, self.codomain, self.ctx)
def __eq__(self, other: Any) -> bool: if not self._eq_backend_compatible(other): # Tier 1: backend return NotImplemented # Tier 2: spaces + callable identity. Extensional equality of callables # is undecidable, so 'is' is the only sound comparison. if self.domain != other.domain or self.codomain != other.codomain: return False return ( self.apply_fn is other.apply_fn and self.vapply_fn is other.vapply_fn and self.rapply_fn is other.rapply_fn and self.rvapply_fn is other.rvapply_fn ) def tree_flatten(self): children = () aux = ( self.apply_fn, self.rapply_fn, self.domain, self.codomain, self.ctx, self.vapply_fn, self.rvapply_fn, ) return children, aux @classmethod def tree_unflatten(cls, aux, children): ( apply_fn, rapply_fn, domain, codomain, ctx, vapply_fn, rvapply_fn, ) = aux return cls(apply_fn, rapply_fn, domain, codomain, ctx, vapply_fn, rvapply_fn) def _convert(self, new_ctx: Context) -> MatrixFreeLinOp: """ Convert this matrix-free operator to ``new_ctx``. Parameters ---------- new_ctx: Concrete target context for converted domain and codomain spaces. Returns ------- MatrixFreeLinOp Operator with converted spaces and the same user-supplied callables. """ return MatrixFreeLinOp( self.apply_fn, self.rapply_fn, self.domain.convert(new_ctx), self.codomain.convert(new_ctx), new_ctx, self.vapply_fn, self.rvapply_fn, )
@core_kernels("adjoint") @jax_pytree_class class _AdjointViewLinOp(LinOp[Codomain, Domain]): """ Hermitian-adjoint view of a linear operator. ``A.H`` represents the adjoint view ``A*``. Its context is exactly ``A.ctx``; its domain is ``A.codomain`` and its codomain is ``A.domain``. ``A.H.H`` returns ``A`` rather than constructing another wrapper. The forward action is ``apply(y) = A.rapply(y)`` for ``y in A.codomain``. The reverse action is ``rapply(x) = A.apply(x)`` for ``x in A.domain``. """ def __init__(self, op: LinOp[Domain, Codomain]) -> None: op = _require_linop(op, "op") super().__init__(op.codomain, op.domain, op.ctx) self.op = op @checked_method(in_space="domain", out_space="codomain") def apply(self, y: Any) -> Any: """Return ``op.rapply(y)``.""" return self._apply_core(y) @checked_method(in_space="codomain", out_space="domain") def rapply(self, x: Any) -> Any: """Return ``op.apply(x)``.""" return self._rapply_core(x) def vapply(self, ys: Any) -> Any: """Return ``op.rvapply(ys)`` over a batch.""" return self.op.rvapply(ys) def rvapply(self, xs: Any) -> Any: """Return ``op.vapply(xs)`` over a batch.""" return self.op.vapply(xs) def fuse(self, *, materialize: bool = False) -> LinOp: """Fuse the wrapped operand; return the adjoint of the fused operator (ADR-021). ``A.H.fuse()`` is ``A.fuse().H`` — the inner expression is fused (e.g. a composition collapses to one dense operator) and the adjoint view wraps the result, which keeps the metric adjoint correct on any geometry. A matrix-free operand stays matrix-free under its adjoint (or is densified when ``materialize=True``). """ return self.op.fuse(materialize=materialize).H @property def H(self) -> LinOp[Domain, Codomain]: """Original operator viewed as the adjoint of this adjoint view.""" return self.op def __eq__(self, other: Any) -> bool: if not self._eq_backend_compatible(other): # Tier 1: backend return NotImplemented return self.op == other.op def _repr_class_name(self) -> str: """Present a clean public label for the private adjoint-view class.""" return "AdjointLinOp" def _repr_body(self) -> str: return f"{self.op._short_repr()}.H" def tree_flatten(self): children = (self.op,) aux = () return children, aux @classmethod def tree_unflatten(cls, aux, children): return cls(children[0]) def _convert(self, new_ctx: Context) -> _AdjointViewLinOp: return _AdjointViewLinOp(self.op.convert(new_ctx))