from __future__ import annotations
from math import prod
from numbers import Number
from typing import Any, Callable, Sequence, cast
from ._base import LinOp, Domain, Codomain
from ._metric import _requires_euclidean_or_riesz, metric_rapply, metric_rvapply
from .._checks import checked_method
from .._contextual import resolve_context_priority
from .._contextual._bound import _same_math_context
from .._repr import summarize_value
from ..backend import Context, jax_pytree_class
from ..kernels import core_kernels
from ..kernels.core.algebra import (
batched_zeros as _batched_zeros,
compose_chain as _compose_chain,
conjugate_scalar as _conjugate_scalar,
leading_shape as _leading_shape,
)
def is_scalar_like(value: Any) -> bool:
"""Return whether ``value`` can be used as a scalar multiplier for a ``LinOp``."""
if isinstance(value, Number):
return True
shape = getattr(value, "shape", None)
if shape is not None:
return tuple(shape) == ()
ndim = getattr(value, "ndim", None)
return ndim == 0
def _scalar_eq(a: Any, b: Any) -> bool:
"""Return whether two scalar-likes are equal, NaN-reflexive.
Mirrors the ``equal_nan=True`` used for array values: two matching NaN
scalars compare equal so a NaN-scaled operator equals itself. Always returns
a real Python ``bool`` (a 0-d backend-array ``==`` would otherwise yield
``np.bool_``, which leaks through the ``and`` combinator of any container).
"""
if bool(a == b):
return True
try:
# ``x != x`` is True only for NaN (including a complex value with a NaN
# component), so this branch matches NaN against NaN.
return bool(a != a) and bool(b != b)
except Exception:
return False
def _require_same_context(ops: Sequence[LinOp]) -> Context:
"""Return the common context for algebra operands or raise."""
ctx = ops[0].ctx
for i, op in enumerate(ops[1:], start=1):
if not _same_math_context(ops[0].ctx, op.ctx):
raise ValueError(
"All LinOp operands in an algebraic expression must have the same ctx; "
f"operand 0 has ctx {ctx!r}, operand {i} has ctx {op.ctx!r}."
)
return ctx
def _same_space_for_algebra(left: Any, right: Any) -> bool:
"""Return whether two spaces are compatible for algebraic composition."""
if left == right:
return True
if type(left) is not type(right):
return False
if tuple(left.shape) != tuple(right.shape):
return False
if not _same_math_context(left.ctx, right.ctx):
return False
try:
return left.convert(right.ctx) == right
except Exception:
return False
def _require_linop(op: Any, name: str) -> LinOp:
"""Return ``op`` as a linear operator or raise a typed error."""
if not isinstance(op, LinOp):
raise TypeError(f"{name} must be a LinOp, got {type(op).__name__}.")
return op
def _scalar_equal(value: Any, target: Any) -> bool:
"""Return whether two scalar-like values compare equal."""
try:
return bool(value == target)
except Exception:
return False
def _is_zero_scalar(value: Any) -> bool:
"""Return whether ``value`` is scalar-like zero."""
return _scalar_equal(value, 0)
def _is_one_scalar(value: Any) -> bool:
"""Return whether ``value`` is scalar-like one."""
return _scalar_equal(value, 1)
def _flatten_sum_terms(ops: Sequence[LinOp]) -> tuple[LinOp, ...]:
"""Flatten nested lazy sums into a tuple of terms."""
terms: list[LinOp] = []
for i, op in enumerate(ops):
op = _require_linop(op, f"ops[{i}]")
if isinstance(op, SumLinOp):
terms.extend(_flatten_sum_terms(op.parts))
else:
terms.append(op)
return tuple(terms)
[docs]
def make_sum(ops: Sequence[LinOp]) -> LinOp:
"""
Return a locally simplified lazy sum of linear operators.
This factory performs only local algebraic canonicalization: nested
``SumLinOp`` nodes are flattened and ``ZeroLinOp`` terms are removed. It
does not collect like terms, reorder operands, or attempt full symbolic
optimization. All operands must have the same context, domain, and codomain
before a simplified operator is returned.
Parameters
----------
ops : sequence of LinOp
Nonempty sequence of operators with common domain and codomain.
Returns
-------
LinOp
Simplified lazy sum, a single operand, or a zero operator.
"""
if not ops:
raise ValueError("make_sum requires a nonempty sequence of LinOp operands.")
terms = _flatten_sum_terms(ops)
ctx = _require_same_context(terms)
domain = terms[0].domain
codomain = terms[0].codomain
for i, op in enumerate(terms[1:], start=1):
if not _same_space_for_algebra(op.domain, domain) or not _same_space_for_algebra(
op.codomain, codomain
):
raise ValueError(
"All SumLinOp operands must have the same domain and codomain; "
f"operand 0 maps {domain!r} -> {codomain!r}, "
f"operand {i} maps {op.domain!r} -> {op.codomain!r}."
)
nonzero_terms = tuple(op for op in terms if not isinstance(op, ZeroLinOp))
if not nonzero_terms:
return ZeroLinOp(domain, codomain, ctx)
if len(nonzero_terms) == 1:
return nonzero_terms[0]
return SumLinOp(nonzero_terms)
[docs]
def make_scaled(scalar: Any, op: LinOp) -> LinOp:
"""
Return a locally simplified scalar multiple of a linear operator.
This factory performs only local algebraic canonicalization: zero and unit
scalars are simplified, and nested ``ScaledLinOp`` nodes are folded into one
scalar. It does not distribute scaling over sums or perform full symbolic
optimization. Complex scalars retain the usual conjugated coefficient in
``rapply`` through ``ScaledLinOp``.
Parameters
----------
scalar : scalar-like
Scalar coefficient multiplying ``op``.
op : LinOp
Operator to scale.
Returns
-------
LinOp
Simplified scalar multiple.
"""
op = _require_linop(op, "op")
if not is_scalar_like(scalar):
raise TypeError(f"scalar must be scalar-like, got {type(scalar).__name__}.")
if _is_zero_scalar(scalar):
return ZeroLinOp(op.domain, op.codomain, op.ctx)
if _is_one_scalar(scalar):
return op
if isinstance(op, ZeroLinOp):
return op
if isinstance(op, ScaledLinOp):
return make_scaled(scalar * op.scalar, op.op)
return ScaledLinOp(scalar, op)
[docs]
def make_composed(left: LinOp, right: LinOp) -> LinOp:
"""
Return a locally simplified composition of two linear operators.
This factory performs only local algebraic canonicalization: identity
factors are removed and compositions with zero maps become zero maps. It
preserves the binary ``ComposedLinOp`` representation and does not flatten
multi-factor chains or attempt full symbolic optimization. Operands must
have the same context and compatible middle spaces before a simplified
operator is returned.
Parameters
----------
left : LinOp
Operator applied second.
right : LinOp
Operator applied first.
Returns
-------
LinOp
Simplified lazy composition representing ``left @ right``.
"""
left = _require_linop(left, "left")
right = _require_linop(right, "right")
_require_same_context((left, right))
if not _same_space_for_algebra(right.codomain, left.domain):
raise ValueError(
"ComposedLinOp requires right.codomain == left.domain; "
f"got {right.codomain!r} and {left.domain!r}."
)
if isinstance(right, IdentityLinOp):
return left
if isinstance(left, IdentityLinOp):
return right
if isinstance(left, ZeroLinOp):
return ZeroLinOp(right.domain, left.codomain, left.ctx)
if isinstance(right, ZeroLinOp):
return ZeroLinOp(right.domain, left.codomain, left.ctx)
return ComposedLinOp(left, right)
[docs]
@core_kernels("scaled")
@jax_pytree_class
class ScaledLinOp(LinOp[Domain, Codomain]):
r"""
Lazy scalar multiple of a linear operator.
``ScaledLinOp(alpha, A)`` represents the mathematical operator
``alpha * A``. Its context is exactly ``A.ctx``; its domain is ``A.domain``
and its codomain is ``A.codomain``. No dense matrix representation is
formed.
The forward action is ``apply(x) = alpha * A.apply(x)`` for
``x in A.domain``. The reverse action is
``rapply(y) = conj(alpha) * A.rapply(y)`` for ``y in A.codomain``, so
complex scalars use the conjugated coefficient.
Parameters
----------
scalar : scalar-like
Scalar multiplier.
op : LinOp
Operator being scaled.
Attributes
----------
scalar : scalar-like
Stored scalar multiplier.
op : LinOp
Stored operand.
"""
def __init__(self, scalar: Any, op: LinOp[Domain, Codomain]) -> None:
op = _require_linop(op, "op")
if not is_scalar_like(scalar):
raise TypeError(f"scalar must be scalar-like, got {type(scalar).__name__}.")
super().__init__(op.domain, op.codomain, op.ctx)
self.scalar = scalar
self.op = op
[docs]
@checked_method(in_space="domain", out_space="codomain")
def apply(self, x: Any) -> Any:
"""Return ``scalar * op.apply(x)``."""
return self._apply_core(x)
[docs]
@checked_method(in_space="codomain", out_space="domain")
def rapply(self, y: Any) -> Any:
"""Return ``conj(scalar) * op.rapply(y)``."""
return self._rapply_core(y)
[docs]
@checked_method(in_space="domain", out_space="codomain", in_batched=True, out_batched=True)
def vapply(self, xs: Any) -> Any:
"""Return ``scalar * op.vapply(xs)``."""
return self._vapply_core(xs)
[docs]
def rvapply(self, ys: Any) -> Any:
"""Return ``conj(scalar) * op.rvapply(ys)``."""
xs = self.op.rvapply(ys)
return self.domain.scale_batch(_conjugate_scalar(self.scalar), xs)
[docs]
def fuse(self, *, materialize: bool = False) -> LinOp:
r"""Fuse the operand and fold the scalar into a dense matrix (ADR-021).
When the fused operand is dense, replace ``c · A`` with one
:class:`DenseLinOp` holding ``c · M_A``; otherwise keep the scaling lazy,
so a matrix-free operand is never densified. Adjoint-consistent: the
fused operator's adjoint is ``conj(c) · M_A^*`` (with the metric), exactly
the lazy ``ScaledLinOp`` adjoint.
"""
from ._dense import DenseLinOp
op = self.op.fuse(materialize=materialize)
if isinstance(op, DenseLinOp):
matrix = self.scalar * op.to_matrix()
tensor = self.ops.reshape(
matrix, tuple(op.codomain.shape) + tuple(op.domain.shape)
)
return DenseLinOp(tensor, op.domain, op.codomain, self.ctx)
return make_scaled(self.scalar, op)
[docs]
def is_hermitian(self) -> bool | None:
"""
Return whether this scaled operator is structurally Hermitian.
For a real scalar ``s`` the adjoint satisfies ``(s A)* = s A*``, so
``s A`` is self-adjoint exactly when ``A`` is; the operand's verdict is
propagated faithfully. For a non-real scalar the relation becomes
``(s A)* = conj(s) A* != s A`` in general, so Hermiticity cannot be
decided cheaply and ``None`` is returned.
Returns
-------
bool | None
``self.op.is_hermitian()`` when ``self.scalar`` is real, otherwise
``None`` for unknown.
"""
if _conjugate_scalar(self.scalar) == self.scalar:
return self.op.is_hermitian()
return None
def __eq__(self, other: Any) -> bool:
"""Return whether another scaled operator has the same scalar and operand."""
if not self._eq_backend_compatible(other): # Tier 1: backend
return NotImplemented
# NaN-reflexive, returns a real Python bool (no np.bool_ leak).
if not _scalar_eq(self.scalar, other.scalar): # Tier 3: scalar value
return False
return self.op == other.op # operand (own gate)
def _repr_body(self) -> str:
return f"{summarize_value(self.scalar)} · {self.op._short_repr()}"
def tree_flatten(self):
"""Flatten this operator for pytree registration."""
children = (self.scalar, self.op)
aux = ()
return children, aux
@classmethod
def tree_unflatten(cls, aux, children):
"""Rebuild this operator from pytree data."""
scalar, op = children
return cls(scalar, op)
def _convert(self, new_ctx: Context) -> ScaledLinOp:
"""Convert the operand to ``new_ctx`` while preserving the scalar."""
return ScaledLinOp(self.scalar, self.op.convert(new_ctx))
[docs]
@core_kernels("sum")
@jax_pytree_class
class SumLinOp(LinOp[Domain, Codomain]):
r"""
Lazy finite sum of linear operators with common spaces.
``SumLinOp((A1, ..., Ak))`` represents ``A1 + ... + Ak`` for a nonempty
sequence of ``LinOp`` instances. All operands must have the same ``ctx``,
the same domain, and the same codomain before construction. The resulting
operator has that shared context, domain, and codomain.
The forward action is ``apply(x) = sum_i Ai.apply(x)`` for the shared
domain element ``x``. The reverse action is
``rapply(y) = sum_i Ai.rapply(y)`` for the shared codomain element ``y``.
Parameters
----------
ops : sequence of LinOp
Nonempty sequence of operators with common context, domain, and
codomain.
Attributes
----------
parts : tuple of LinOp
Stored operands in the lazy sum.
"""
def __init__(self, ops: Sequence[LinOp[Domain, Codomain]]) -> None:
if not ops:
raise ValueError("SumLinOp requires a nonempty sequence of LinOp operands.")
parts = tuple(_require_linop(op, f"ops[{i}]") for i, op in enumerate(ops))
ctx = _require_same_context(parts)
domain = parts[0].domain
codomain = parts[0].codomain
for i, op in enumerate(parts[1:], start=1):
if not _same_space_for_algebra(op.domain, domain) or not _same_space_for_algebra(
op.codomain, codomain
):
raise ValueError(
"All SumLinOp operands must have the same domain and codomain; "
f"operand 0 maps {domain!r} -> {codomain!r}, "
f"operand {i} maps {op.domain!r} -> {op.codomain!r}."
)
super().__init__(domain, codomain, ctx)
self.ops_tuple = parts
@property
def parts(self) -> tuple[LinOp[Domain, Codomain], ...]:
"""Operators in this lazy sum."""
return self.ops_tuple
[docs]
@checked_method(in_space="domain", out_space="codomain")
def apply(self, x: Any) -> Any:
"""Return ``sum_i ops[i].apply(x)``."""
return self._apply_core(x)
[docs]
@checked_method(in_space="codomain", out_space="domain")
def rapply(self, y: Any) -> Any:
"""Return ``sum_i ops[i].rapply(y)``."""
return self._rapply_core(y)
[docs]
@checked_method(in_space="domain", out_space="codomain", in_batched=True, out_batched=True)
def vapply(self, xs: Any) -> Any:
"""Return ``sum_i ops[i].vapply(xs)``."""
return self._vapply_core(xs)
[docs]
@checked_method(in_space="codomain", out_space="domain", in_batched=True, out_batched=True)
def rvapply(self, ys: Any) -> Any:
"""Return ``sum_i ops[i].rvapply(ys)``."""
add_batch = self.domain.add_batch
acc = self.ops_tuple[0].rvapply(ys)
for op in self.ops_tuple[1:]:
acc = add_batch(acc, op.rvapply(ys))
return acc
[docs]
def fuse(self, *, materialize: bool = False) -> LinOp:
r"""Fuse each term and combine the dense terms into one ``DenseLinOp`` (ADR-021).
Fuse every term, sum the matrices of the densely-fusible ones into a
single :class:`DenseLinOp`, and keep the remaining (matrix-free or
structured) terms as lazy summands — so a matrix-free term is never
densified. Adjoint-consistent and additive: ``(A + B)^* = A^* + B^*``.
Combining reassociates the term order, so equality holds up to rounding.
"""
from ._dense import DenseLinOp
fused = [p.fuse(materialize=materialize) for p in self.parts]
dense = [p for p in fused if isinstance(p, DenseLinOp)]
if len(dense) < 2:
return make_sum(fused)
matrix = dense[0].to_matrix()
for d in dense[1:]:
matrix = matrix + d.to_matrix()
ref = dense[0]
tensor = self.ops.reshape(
matrix, tuple(ref.codomain.shape) + tuple(ref.domain.shape)
)
combined = DenseLinOp(tensor, ref.domain, ref.codomain, self.ctx)
others = [p for p in fused if not isinstance(p, DenseLinOp)]
return make_sum([combined, *others])
[docs]
def is_hermitian(self) -> bool | None:
"""
Return whether this lazy sum is structurally Hermitian.
The adjoint is additive, ``(A1 + ... + Ak)* = A1* + ... + Ak*``, so a
sum is self-adjoint when every term is. If all operands are provably
Hermitian this returns ``True``; otherwise the verdict is not cheaply
decidable (a sum of non-Hermitian terms may still be Hermitian) and
``None`` is returned. ``False`` is never returned.
Returns
-------
bool | None
``True`` when every part is provably Hermitian, otherwise ``None``.
"""
if all(op.is_hermitian() is True for op in self.parts):
return True
return None
def __eq__(self, other: Any) -> bool:
"""Return whether another sum has the same operands, in order."""
if not self._eq_backend_compatible(other): # Tier 1: backend
return NotImplemented
if len(self.ops_tuple) != len(other.ops_tuple): # Tier 2: operand count before zip
return False
# Ordered, structural: A + B != B + A. Commutative equivalence is a
# separate concern (a future equiv()), not __eq__.
return all(a == b for a, b in zip(self.ops_tuple, other.ops_tuple))
def _repr_body(self) -> str:
from .._repr import truncated_join
return truncated_join((op._short_repr() for op in self.ops_tuple), " + ")
def tree_flatten(self):
"""Flatten this operator for pytree registration."""
children = self.ops_tuple
aux = ()
return children, aux
@classmethod
def tree_unflatten(cls, aux, children):
"""Rebuild this operator from pytree data."""
return cls(tuple(children))
def _convert(self, new_ctx: Context) -> SumLinOp:
"""Convert all operands to ``new_ctx``."""
return SumLinOp(tuple(op.convert(new_ctx) for op in self.ops_tuple))
[docs]
@core_kernels("composed")
@jax_pytree_class
class ComposedLinOp(LinOp[Domain, Codomain]):
r"""
Lazy composition of two linear operators.
``ComposedLinOp(A, B)`` represents ``A @ B = A circ B``. The operands must
have the same ``ctx`` before construction, and ``B.codomain`` must equal
``A.domain``. The resulting operator has domain ``B.domain`` and codomain
``A.codomain``.
The forward action is ``apply(x) = A.apply(B.apply(x))`` for
``x in B.domain``. The reverse action is ``rapply(z) = B.rapply(A.rapply(z))``
for ``z in A.codomain``.
Parameters
----------
left : LinOp
Operator applied second.
right : LinOp
Operator applied first.
Attributes
----------
left : LinOp
Left operand.
right : LinOp
Right operand.
"""
def __init__(self, left: LinOp, right: LinOp) -> None:
left = _require_linop(left, "left")
right = _require_linop(right, "right")
_require_same_context((left, right))
if not _same_space_for_algebra(right.codomain, left.domain):
raise ValueError(
"ComposedLinOp requires right.codomain == left.domain; "
f"got {right.codomain!r} and {left.domain!r}."
)
super().__init__(right.domain, left.codomain, left.ctx)
self.left = left
self.right = right
# Fuse the (possibly nested) composition into one flat chain of leaf
# operators in application order — right applied first, then left.
# Cached at construction so every apply runs a single check-free loop
# instead of re-walking the binary ComposedLinOp tree.
self._apply_chain = _compose_chain(right) + _compose_chain(left)
[docs]
@checked_method(in_space="domain", out_space="codomain")
def apply(self, x: Any) -> Any:
"""Return ``left.apply(right.apply(x))``."""
return self._apply_core(x)
[docs]
@checked_method(in_space="codomain", out_space="domain")
def rapply(self, z: Any) -> Any:
"""Return ``right.rapply(left.rapply(z))``."""
return self._rapply_core(z)
[docs]
@checked_method(in_space="domain", out_space="codomain", in_batched=True, out_batched=True)
def vapply(self, xs: Any) -> Any:
"""Return ``left.vapply(right.vapply(xs))``."""
return self._vapply_core(xs)
[docs]
def rvapply(self, zs: Any) -> Any:
"""Return ``right.rvapply(left.rvapply(zs))``."""
return self.right.rvapply(self.left.rvapply(zs))
[docs]
def fuse(self, *, materialize: bool = False) -> LinOp:
r"""Fuse a composition of dense operators into one ``DenseLinOp`` (ADR-021).
Fuse each operand first, then — when both fused operands are dense —
replace ``A @ B`` with a single :class:`DenseLinOp` holding the matrix
product :math:`M_A M_B`. Any operand that does not fuse to a dense
operator (matrix-free leaves, sparse, structured) keeps the composition
lazy, so a matrix-free operand is never densified. The fused matrix is
adjoint-consistent on any geometry: the shared middle-space Riesz maps
cancel, so the fused operator's metric adjoint equals
``B* @ A*`` up to floating-point rounding.
"""
from ._dense import DenseLinOp
left = self.left.fuse(materialize=materialize)
right = self.right.fuse(materialize=materialize)
if isinstance(left, DenseLinOp) and isinstance(right, DenseLinOp):
ops = self.ops
matrix = ops.matmul(left.to_matrix(), right.to_matrix())
tensor = ops.reshape(
matrix, tuple(left.codomain.shape) + tuple(right.domain.shape)
)
return DenseLinOp(tensor, right.domain, left.codomain, self.ctx)
return make_composed(left, right)
[docs]
def is_hermitian(self) -> bool | None:
"""
Return whether this composition is structurally Hermitian.
A Gram product ``R* @ R`` (equivalently ``L @ L*``) is self-adjoint in
any geometry, since ``<R* R x, y> = <R x, R y> = <x, R* R y>``. This is
detected structurally when ``self.left == self.right.H`` (the adjoint
view compares its wrapped operand, so this also matches ``L @ L*``).
Any other composition is not cheaply decidable and returns ``None``;
non-Hermiticity is never asserted.
Returns
-------
bool | None
``True`` for a Gram product, otherwise ``None``.
"""
if self.left == self.right.H:
return True
return None
def __eq__(self, other: Any) -> bool:
"""Return whether another composition has the same operands, in order."""
if not self._eq_backend_compatible(other): # Tier 1: backend
return NotImplemented
return self.left == other.left and self.right == other.right
def _repr_body(self) -> str:
return f"{self.left._short_repr()} ∘ {self.right._short_repr()}"
def tree_flatten(self):
"""Flatten this operator for pytree registration."""
children = (self.left, self.right)
aux = ()
return children, aux
@classmethod
def tree_unflatten(cls, aux, children):
"""Rebuild this operator from pytree data."""
left, right = children
return cls(left, right)
def _convert(self, new_ctx: Context) -> ComposedLinOp:
"""Convert both operands to ``new_ctx``."""
return ComposedLinOp(self.left.convert(new_ctx), self.right.convert(new_ctx))
[docs]
@core_kernels("zero")
@jax_pytree_class
class ZeroLinOp(LinOp[Domain, Codomain]):
r"""
Lazy zero map between two spaces.
``ZeroLinOp(X, Y)`` represents the linear map ``0 : X -> Y``. The context is
resolved from the optional ``ctx`` argument and the two spaces, then both
spaces are converted to that context. Its domain is ``X`` and its codomain
is ``Y`` in the resolved context.
The forward action is ``apply(x) = 0_Y`` for ``x in X``. The reverse action
is ``rapply(y) = 0_X`` for ``y in Y``.
Parameters
----------
dom : Space
Domain space.
cod : Space
Codomain space.
ctx : Context, str, or None, optional
Backend context specification. Default is resolved from the spaces.
"""
def __init__(
self,
dom: Domain,
cod: Codomain,
ctx: Context | str | None = None,
) -> None:
super().__init__(dom, cod, ctx)
[docs]
@checked_method(in_space="domain", out_space="codomain")
def apply(self, x: Any) -> Any:
"""Return the zero element of the codomain."""
return self._apply_core(x)
[docs]
@checked_method(in_space="codomain", out_space="domain")
def rapply(self, y: Any) -> Any:
"""Return the zero element of the domain."""
return self._rapply_core(y)
[docs]
@checked_method(in_space="domain", in_batched=True)
def vapply(self, xs: Any) -> Any:
"""Return the batched zero element of the codomain."""
return self._vapply_core(xs)
[docs]
@checked_method(in_space="codomain", in_batched=True)
def rvapply(self, ys: Any) -> Any:
"""Return the batched zero element of the domain."""
return _batched_zeros(self.domain, _leading_shape(self.codomain, ys))
[docs]
def to_dense(self) -> Any:
"""
Return the dense tensor representation of the zero map.
The returned array has shape ``self.codomain.shape + self.domain.shape``.
"""
return self.ops.zeros(
tuple(self.codomain.shape) + tuple(self.domain.shape), dtype=self.dtype
)
[docs]
def is_hermitian(self) -> bool:
"""
Return whether the zero map is Hermitian.
Returns
-------
bool
``True`` exactly when domain and codomain are the same space.
"""
return self.domain == self.codomain
def __eq__(self, other: Any) -> bool:
"""Return whether another zero map has the same spaces."""
if not self._eq_backend_compatible(other): # Tier 1: backend
return NotImplemented
return self.domain == other.domain and self.codomain == other.codomain # Tier 2
def tree_flatten(self):
"""Flatten this operator for pytree registration."""
children = ()
aux = (self.domain, self.codomain, self.ctx)
return children, aux
@classmethod
def tree_unflatten(cls, aux, children):
"""Rebuild this operator from pytree data."""
domain, codomain, ctx = aux
return cls(domain, codomain, ctx)
def _convert(self, new_ctx: Context) -> ZeroLinOp:
"""Convert domain and codomain spaces to ``new_ctx``."""
return ZeroLinOp(self.domain.convert(new_ctx), self.codomain.convert(new_ctx), new_ctx)
[docs]
@core_kernels("identity")
@jax_pytree_class
class IdentityLinOp(LinOp[Domain, Domain]):
r"""
Lazy identity map on a space.
``IdentityLinOp(X)`` represents the identity operator ``I_X : X -> X``. The
context is resolved from the optional ``ctx`` argument and the space, and the
resulting operator has domain and codomain equal to ``X`` in that context.
The forward action is ``apply(x) = x`` for ``x in X``. The reverse action is
``rapply(x) = x`` for ``x in X``.
Parameters
----------
space : Space
Domain and codomain space.
ctx : Context, str, or None, optional
Backend context specification. Default is resolved from ``space``.
"""
def __init__(self, space: Domain, ctx: Context | str | None = None) -> None:
super().__init__(space, space, ctx)
[docs]
@checked_method(in_space="domain", out_space="codomain")
def apply(self, x: Any) -> Any:
"""Return ``x`` after domain validation."""
return self._apply_core(x)
[docs]
@checked_method(in_space="codomain", out_space="domain")
def rapply(self, x: Any) -> Any:
"""Return ``x`` after codomain validation."""
return self._rapply_core(x)
[docs]
@checked_method(in_space="domain", in_batched=True)
def vapply(self, xs: Any) -> Any:
"""Return ``xs`` after batched domain validation."""
return xs
[docs]
@checked_method(in_space="codomain", in_batched=True)
def rvapply(self, xs: Any) -> Any:
"""Return ``xs`` after batched codomain validation."""
return xs
[docs]
def to_dense(self) -> Any:
"""
Return the dense tensor representation of this identity map.
The returned array has shape ``self.codomain.shape + self.domain.shape``.
"""
size = 1
for dim in self.domain.shape:
size *= dim
eye = self.ops.eye(size, dtype=self.dtype)
return self.ops.reshape(eye, tuple(self.codomain.shape) + tuple(self.domain.shape))
[docs]
def is_hermitian(self) -> bool:
"""
Return whether this identity operator is Hermitian.
Returns
-------
bool
Always ``True``.
"""
return True
def __eq__(self, other: Any) -> bool:
"""Return whether another identity map has the same space."""
if not self._eq_backend_compatible(other): # Tier 1: backend
return NotImplemented
return self.domain == other.domain # Tier 2 (square: cod == dom)
def _repr_body(self) -> str:
from .._repr import describe_space
return describe_space(self.domain)
def tree_flatten(self):
"""Flatten this operator for pytree registration."""
children = ()
aux = (self.domain, self.ctx)
return children, aux
@classmethod
def tree_unflatten(cls, aux, children):
"""Rebuild this operator from pytree data."""
domain, ctx = aux
return cls(domain, ctx)
def _convert(self, new_ctx: Context) -> IdentityLinOp:
"""Convert the identity space to ``new_ctx``."""
return IdentityLinOp(self.domain.convert(new_ctx), new_ctx)
[docs]
@core_kernels("matrixfree")
@jax_pytree_class
class MatrixFreeLinOp(LinOp[Domain, Codomain]):
"""
Linear operator defined by user-supplied forward and reverse callables.
``MatrixFreeLinOp(apply, rapply, X, Y)`` represents a matrix-free map
``A : X -> Y`` without storing or materializing a matrix. The context is
resolved from the optional ``ctx`` argument and the spaces, then the spaces
are converted to that context.
The forward action is ``apply(x) = apply_fn(x)`` for ``x in X``. The
reverse action is ``rapply(y) = rapply_fn(y)`` for ``y in Y``. The supplied
``rapply`` callable must already be the true adjoint with respect to the
declared domain and codomain inner products:
``<apply(x), y>_Y = <x, rapply(y)>_X``. It is not automatically corrected
with Riesz maps. If you only have a Euclidean coordinate adjoint in
non-Euclidean spaces, compute the metric adjoint outside SpaceCore and pass
that callable as ``rapply``. When checks are enabled, inputs and callable
outputs are validated against the corresponding domain and codomain, but
construction does not run adjoint dot-tests.
Parameters
----------
apply : callable
Callable with signature ``apply(x: Any) -> Any`` implementing the
forward map from ``dom`` to ``cod``.
rapply : callable
Callable with signature ``rapply(y: Any) -> Any`` implementing the
true space adjoint map from ``cod`` back to ``dom``. For
non-Euclidean spaces this is generally not the same as the Euclidean
coordinate adjoint.
dom : Space
Domain space containing valid inputs for ``apply`` and outputs from
``rapply``.
cod : Space
Codomain space containing outputs from ``apply`` and valid inputs for
``rapply``.
ctx : Context, str, or None, optional
Optional context specification. An explicit context wins over inferred
contexts from ``dom`` and ``cod``.
vapply : callable or None, optional
Optional callable with signature ``vapply(xs: Any) -> Any`` for batched
forward application. If omitted, backend ``vmap`` fallback is used.
rvapply : callable or None, optional
Optional callable with signature ``rvapply(ys: Any) -> Any`` for
batched adjoint application. If omitted, backend ``vmap`` fallback is
used.
Returns
-------
MatrixFreeLinOp
Operator using the supplied callables for forward, adjoint, and
optionally batched actions.
Notes
-----
See ``docs/dev/adr/009_metric_adjoint.md`` for the full design rationale
for metric adjoints and the distinction between direct matrix-free adjoints
and coordinate-adjoint wrapping.
"""
[docs]
def __init__(
self,
apply: Callable[[Any], Any],
rapply: Callable[[Any], Any],
dom: Domain,
cod: Codomain,
ctx: Context | str | None = None,
vapply: Callable[[Any], Any] | None = None,
rvapply: Callable[[Any], Any] | None = None,
) -> None:
"""
Initialize a matrix-free linear operator.
Parameters
----------
apply:
Callable ``apply(x)`` that accepts an element of ``dom`` and returns
an element of ``cod``.
rapply:
Callable ``rapply(y)`` that accepts an element of ``cod`` and
returns an element of ``dom``.
dom:
Domain space of the operator.
cod:
Codomain space of the operator.
ctx:
Optional context specification for the operator and converted
spaces.
vapply:
Optional callable for batched forward application over ``dom``
batches.
rvapply:
Optional callable for batched adjoint application over ``cod``
batches.
Returns
-------
None
The initializer stores the callables and converted spaces on
``self``.
"""
if not callable(apply):
raise TypeError(f"apply must be callable, got {type(apply).__name__}.")
if not callable(rapply):
raise TypeError(f"rapply must be callable, got {type(rapply).__name__}.")
if vapply is not None and not callable(vapply):
raise TypeError(f"vapply must be callable, got {type(vapply).__name__}.")
if rvapply is not None and not callable(rvapply):
raise TypeError(f"rvapply must be callable, got {type(rvapply).__name__}.")
super().__init__(dom, cod, ctx)
self.apply_fn = apply
self.rapply_fn = rapply
self.vapply_fn = vapply
self.rvapply_fn = rvapply
if self._checks_at_least("strict"):
self._check_adjoint_consistency()
def _check_adjoint_consistency(self) -> None:
"""Probe the declared adjoint identity on deterministic space elements."""
if not all(hasattr(space, "inner") for space in (self.domain, self.codomain)):
return
def probe(space: Any) -> Any:
if hasattr(space, "ones"):
return space.ones()
if hasattr(space, "unflatten") and hasattr(space, "shape"):
flat = self.ops.ones((prod(space.shape),), dtype=self.dtype)
return space.unflatten(flat)
raise TypeError(f"{type(space).__name__} cannot build a strict probe element.")
try:
x = probe(self.domain)
y = probe(self.codomain)
ax = self.apply_fn(x)
ahy = self.rapply_fn(y)
self.codomain._check_member(ax)
self.domain._check_member(ahy)
lhs = cast(Any, self.codomain).inner(ax, y)
rhs = cast(Any, self.domain).inner(x, ahy)
consistent = bool(self.ops.allclose(lhs, rhs))
except Exception as exc:
raise ValueError(
"Strict matrix-free adjoint consistency check could not be completed."
) from exc
if not consistent:
raise ValueError(
"Strict matrix-free adjoint consistency check failed: "
"<A x, y> != <x, A* y> for the deterministic probe."
)
[docs]
@classmethod
def from_coordinate_adjoint(
cls,
apply: Callable[[Any], Any],
coordinate_rapply: Callable[[Any], Any],
dom: Domain,
cod: Codomain,
ctx: Context | str | None = None,
vapply: Callable[[Any], Any] | None = None,
coordinate_rvapply: Callable[[Any], Any] | None = None,
) -> MatrixFreeLinOp:
r"""
Build a matrix-free operator from a Euclidean coordinate adjoint.
``coordinate_rapply`` is interpreted as the Euclidean coordinate
adjoint ``A^dagger`` of ``apply``. The stored ``rapply`` callable is the
metric adjoint
``A^sharp(y) = R_X^-1 A^dagger R_Y y``
for the declared domain ``X`` and codomain ``Y``. Euclidean spaces have
identity Riesz maps, so this degenerates to the supplied coordinate
adjoint. Non-Euclidean spaces must expose usable Riesz maps at
construction time; otherwise the constructor rejects the operator rather
than storing an incoherent adjoint.
When ``coordinate_rvapply`` is provided, it is treated as the batched
Euclidean coordinate adjoint and wrapped with batched Riesz maps. If
batched Riesz application is unavailable, the public metric-adjoint
helper falls back to vectorized scalar ``rapply`` consistently. When
``coordinate_rvapply`` is omitted, ``rvapply_fn`` remains ``None`` and
the normal ``rvapply`` fallback vectorizes the wrapped scalar adjoint.
See ``docs/dev/adr/009_metric_adjoint.md`` for the full design
rationale.
Parameters
----------
apply : callable
Forward coordinate action from ``dom`` to ``cod``.
coordinate_rapply : callable
Euclidean coordinate adjoint from ``cod`` coordinates to ``dom``
coordinates.
dom, cod : Space
Domain and codomain spaces.
ctx : Context, str, or None, optional
Optional context specification.
vapply : callable or None, optional
Optional batched forward application.
coordinate_rvapply : callable or None, optional
Optional batched Euclidean coordinate adjoint. If omitted, batched
adjoints use backend ``vmap`` over the wrapped scalar ``rapply``.
"""
if not callable(apply):
raise TypeError(f"apply must be callable, got {type(apply).__name__}.")
if not callable(coordinate_rapply):
raise TypeError(
f"coordinate_rapply must be callable, got {type(coordinate_rapply).__name__}."
)
if coordinate_rvapply is not None:
if not callable(coordinate_rvapply):
raise TypeError(
f"coordinate_rvapply must be callable, got {type(coordinate_rvapply).__name__}."
)
resolved_ctx = resolve_context_priority(ctx, dom, cod)
dom = dom.convert(resolved_ctx)
cod = cod.convert(resolved_ctx)
try:
_requires_euclidean_or_riesz(dom, cod, "MatrixFreeLinOp.from_coordinate_adjoint")
except TypeError as exc:
raise ValueError(str(exc)) from exc
def wrapped_rapply(y: Any) -> Any:
return metric_rapply(dom, cod, coordinate_rapply, y)
wrapped_rvapply: Callable[[Any], Any] | None = None
if coordinate_rvapply is not None:
def _wrapped_rvapply(ys: Any) -> Any:
return metric_rvapply(
dom,
cod,
coordinate_rapply,
coordinate_rvapply,
ys,
opname="MatrixFreeLinOp.from_coordinate_adjoint",
ops=resolved_ctx.ops,
)
wrapped_rvapply = _wrapped_rvapply
return cls(apply, wrapped_rapply, dom, cod, resolved_ctx, vapply, wrapped_rvapply)
[docs]
@checked_method(in_space="domain", out_space="codomain")
def apply(self, x: Any) -> Any:
"""
Apply the forward callable.
Parameters
----------
x:
Element of ``self.domain`` passed to ``apply_fn``.
Returns
-------
Any
Element of ``self.codomain`` returned by ``apply_fn``.
"""
return self._apply_core(x)
[docs]
@checked_method(in_space="codomain", out_space="domain")
def rapply(self, y: Any) -> Any:
"""
Apply the adjoint callable.
Parameters
----------
y:
Element of ``self.codomain`` passed to ``rapply_fn``.
Returns
-------
Any
Element of ``self.domain`` returned by ``rapply_fn``.
"""
return self._rapply_core(y)
[docs]
@checked_method(in_space="domain", out_space="codomain", in_batched=True, out_batched=True)
def vapply(self, xs: Any) -> Any:
"""
Apply this operator to a batch of domain elements.
Parameters
----------
xs:
Batched element of ``self.domain``.
Returns
-------
Any
Batched element of ``self.codomain`` produced by ``vapply_fn`` or
by the fallback batching implementation.
"""
if self.vapply_fn is None:
return super().vapply(xs)
return self.vapply_fn(xs)
[docs]
@checked_method(in_space="codomain", out_space="domain", in_batched=True, out_batched=True)
def rvapply(self, ys: Any) -> Any:
"""
Apply the adjoint operator to a batch of codomain elements.
Parameters
----------
ys:
Batched element of ``self.codomain``.
Returns
-------
Any
Batched element of ``self.domain`` produced by ``rvapply_fn`` or by
the fallback batching implementation.
"""
if self.rvapply_fn is None:
return super().rvapply(ys)
return self.rvapply_fn(ys)
[docs]
def fuse(self, *, materialize: bool = False) -> LinOp:
"""Stay matrix-free unless ``materialize=True`` is requested (ADR-021/ADR-008).
By default a matrix-free operator is never densified, so it returns
itself. With ``materialize=True`` the caller explicitly opts into
densification: the operator is probed into a dense tensor (via
``to_dense``, the basis sweep) and returned as a :class:`DenseLinOp`,
letting an enclosing expression collapse to a single dense operator.
"""
if not materialize:
return self
from ._dense import DenseLinOp
return DenseLinOp(self.to_dense(), self.domain, self.codomain, self.ctx)
def __eq__(self, other: Any) -> bool:
if not self._eq_backend_compatible(other): # Tier 1: backend
return NotImplemented
# Tier 2: spaces + callable identity. Extensional equality of callables
# is undecidable, so 'is' is the only sound comparison.
if self.domain != other.domain or self.codomain != other.codomain:
return False
return (
self.apply_fn is other.apply_fn
and self.vapply_fn is other.vapply_fn
and self.rapply_fn is other.rapply_fn
and self.rvapply_fn is other.rvapply_fn
)
def tree_flatten(self):
children = ()
aux = (
self.apply_fn,
self.rapply_fn,
self.domain,
self.codomain,
self.ctx,
self.vapply_fn,
self.rvapply_fn,
)
return children, aux
@classmethod
def tree_unflatten(cls, aux, children):
(
apply_fn,
rapply_fn,
domain,
codomain,
ctx,
vapply_fn,
rvapply_fn,
) = aux
return cls(apply_fn, rapply_fn, domain, codomain, ctx, vapply_fn, rvapply_fn)
def _convert(self, new_ctx: Context) -> MatrixFreeLinOp:
"""
Convert this matrix-free operator to ``new_ctx``.
Parameters
----------
new_ctx:
Concrete target context for converted domain and codomain spaces.
Returns
-------
MatrixFreeLinOp
Operator with converted spaces and the same user-supplied
callables.
"""
return MatrixFreeLinOp(
self.apply_fn,
self.rapply_fn,
self.domain.convert(new_ctx),
self.codomain.convert(new_ctx),
new_ctx,
self.vapply_fn,
self.rvapply_fn,
)
@core_kernels("adjoint")
@jax_pytree_class
class _AdjointViewLinOp(LinOp[Codomain, Domain]):
"""
Hermitian-adjoint view of a linear operator.
``A.H`` represents the adjoint view ``A*``. Its context is exactly
``A.ctx``; its domain is ``A.codomain`` and its codomain is ``A.domain``.
``A.H.H`` returns ``A`` rather than constructing another wrapper.
The forward action is ``apply(y) = A.rapply(y)`` for ``y in A.codomain``.
The reverse action is ``rapply(x) = A.apply(x)`` for ``x in A.domain``.
"""
def __init__(self, op: LinOp[Domain, Codomain]) -> None:
op = _require_linop(op, "op")
super().__init__(op.codomain, op.domain, op.ctx)
self.op = op
@checked_method(in_space="domain", out_space="codomain")
def apply(self, y: Any) -> Any:
"""Return ``op.rapply(y)``."""
return self._apply_core(y)
@checked_method(in_space="codomain", out_space="domain")
def rapply(self, x: Any) -> Any:
"""Return ``op.apply(x)``."""
return self._rapply_core(x)
def vapply(self, ys: Any) -> Any:
"""Return ``op.rvapply(ys)`` over a batch."""
return self.op.rvapply(ys)
def rvapply(self, xs: Any) -> Any:
"""Return ``op.vapply(xs)`` over a batch."""
return self.op.vapply(xs)
def fuse(self, *, materialize: bool = False) -> LinOp:
"""Fuse the wrapped operand; return the adjoint of the fused operator (ADR-021).
``A.H.fuse()`` is ``A.fuse().H`` — the inner expression is fused (e.g. a
composition collapses to one dense operator) and the adjoint view wraps
the result, which keeps the metric adjoint correct on any geometry. A
matrix-free operand stays matrix-free under its adjoint (or is densified
when ``materialize=True``).
"""
return self.op.fuse(materialize=materialize).H
@property
def H(self) -> LinOp[Domain, Codomain]:
"""Original operator viewed as the adjoint of this adjoint view."""
return self.op
def __eq__(self, other: Any) -> bool:
if not self._eq_backend_compatible(other): # Tier 1: backend
return NotImplemented
return self.op == other.op
def _repr_class_name(self) -> str:
"""Present a clean public label for the private adjoint-view class."""
return "AdjointLinOp"
def _repr_body(self) -> str:
return f"{self.op._short_repr()}.H"
def tree_flatten(self):
children = (self.op,)
aux = ()
return children, aux
@classmethod
def tree_unflatten(cls, aux, children):
return cls(children[0])
def _convert(self, new_ctx: Context) -> _AdjointViewLinOp:
return _AdjointViewLinOp(self.op.convert(new_ctx))