Source code for spacecore.linop.tree._to_single

from __future__ import annotations

from typing import Any, Sequence, Tuple, cast

from ._base import TreeLinOp
from .._base import LinOp, Codomain
from ..._checks import checked_method
from ...kernels import CachedStackParts, dispatch, should_consult_dispatch
from ...space import DenseCoordinateSpace, DenseVectorSpace, ElementwiseJordanSpace, TreeSpace
from ...backend import jax_pytree_class, Context

# ADR-016 dispatch call site: a SumToSingleLinOp applies one shared input through
# every component adjoint. The per-component loop is the ``generic`` fallback;
# the ``sum-to-single-uniform-dense-batched-rapply`` spec routes here when
# dispatch is on and the components are uniform flat-dense. ``parts[i]._rapply_core``
# is exactly ``self._rapply_parts[i]``, so the generic stays byte-identical.
_SUM_TO_SINGLE_RAPPLY_KEY = "linop.sum_to_single.rapply"


def _sum_to_single_rapply(parts: Any, y: Any) -> tuple[Any, ...]:
    """Apply the shared input through each component adjoint core (generic rapply)."""
    return tuple(p._rapply_core(y) for p in parts)


[docs] @jax_pytree_class class SumToSingleLinOp(TreeLinOp[TreeSpace, Codomain]): r""" Represent a sum of leaf operators from a tree domain. If ``dom = X1 x ... x Xk`` and ``cod = Y``, component ``parts[i]`` maps ``Xi`` to ``Y``. Forward application sums component outputs in ``Y``; adjoint application returns a value with ``dom.treedef``. Parameters ---------- dom : TreeSpace Tree-structured domain. cod : Space Shared codomain. parts : sequence of LinOp Operators from each product component to ``cod``. ctx : Context, str, or None, optional Backend context specification. """ def __init__( self, dom: TreeSpace, cod: Codomain, parts: Sequence[LinOp], ctx: Context | str | None = None, ) -> None: super().__init__(dom, cod, parts, ctx) # ADR-022: memoize the stacked adjoint matrices for the sum_to_single.rapply # broadcast fold (built once on first optimized use, NumPy-only). self.parts = CachedStackParts(self.parts) self._flat_dense_apply_mats = self._make_flat_dense_apply_mats() def _make_flat_dense_apply_mats(self): """Return dense matrices for the exact flat-vector fast path.""" if ( type(self.cod) not in (DenseCoordinateSpace, DenseVectorSpace, ElementwiseJordanSpace) or not self.cod.is_euclidean ): return None if tuple(self.cod.shape) != (self.cod._size,): return None mats = [] for op in self.parts: if ( type(op.dom) not in (DenseCoordinateSpace, DenseVectorSpace, ElementwiseJordanSpace) or not op.dom.is_euclidean or tuple(op.dom.shape) != (op.dom._size,) or not hasattr(op, "_A2") ): return None mats.append(cast(Any, op)._A2) return tuple(mats) def _check_layout(self) -> None: """Check that every component maps one product part to the shared codomain.""" if not isinstance(self.dom, TreeSpace): raise TypeError("SumToSingleLinOp expects dom to be TreeSpace.") if len(self.parts) != self.dom.arity: raise ValueError("Number of ops must match domain tree arity.") for i, A in enumerate(self.parts): if A.dom == self.dom.leaf_spaces[i] and A.cod == self.cod: continue else: raise TypeError(f"Component op {i} must map dom.leaf_spaces[{i}] -> cod.")
[docs] @checked_method(in_space="domain", out_space="codomain") def apply(self, x: Any) -> Any: """Apply operators to components of a domain product element and sum.""" return self._apply_unchecked(x)
def _apply_unchecked(self, x: Any) -> Any: """Apply component operators without membership checks.""" x_parts = self.dom._components(x) mats = self._flat_dense_apply_mats if mats is not None: if self._num_parts == 2: return mats[0] @ x_parts[0] + mats[1] @ x_parts[1] acc = mats[0] @ x_parts[0] for mat, xi in zip(mats[1:], x_parts[1:]): acc = acc + mat @ xi return acc if self._num_parts == 2: y0 = self._apply_parts[0](x_parts[0]) y1 = self._apply_parts[1](x_parts[1]) if type(self.cod) in (DenseCoordinateSpace, DenseVectorSpace, ElementwiseJordanSpace): return y0 + y1 return self.cod.add(y0, y1) acc = None for apply, xi in zip(self._apply_parts, x_parts): yi = apply(xi) if acc is None: acc = yi elif type(self.cod) in (DenseCoordinateSpace, DenseVectorSpace, ElementwiseJordanSpace): acc = acc + yi else: acc = self.cod.add(acc, yi) return acc
[docs] @checked_method(in_space="codomain", out_space="domain") def rapply(self, y: Any) -> Any: """Apply component adjoints and return a domain product element.""" return self._rapply_unchecked(y)
def _rapply_unchecked(self, y: Any) -> Any: """Apply component adjoints without checks and rebuild domain representation.""" if should_consult_dispatch(self.ctx): x_parts = dispatch( _SUM_TO_SINGLE_RAPPLY_KEY, self.parts, y, generic=_sum_to_single_rapply, ctx=self.ctx, ) elif self._num_parts == 2: x_parts = (self._rapply_parts[0](y), self._rapply_parts[1](y)) else: x_parts = tuple(rapply(y) for rapply in self._rapply_parts) return self.dom._from_components(x_parts)
[docs] @checked_method(in_space="domain", out_space="codomain", in_batched=True, out_batched=True) def vapply(self, x: Any) -> Any: """Apply this sum-to-single operator over a structured product batch.""" return self._vapply_unchecked(x)
def _vapply_unchecked(self, x: Any) -> Any: """Apply over a product batch without membership checks.""" x_parts = self.dom._components(x) mats = self._flat_dense_apply_mats if mats is not None: if self._num_parts == 2: acc = x_parts[0] @ mats[0].T + x_parts[1] @ mats[1].T else: acc = x_parts[0] @ mats[0].T for mat, xi in zip(mats[1:], x_parts[1:]): acc = acc + xi @ mat.T return acc acc = None for op, xi in zip(self.parts, x_parts): yi = op.vapply(xi) if acc is None: acc = yi elif type(self.codomain) in ( DenseCoordinateSpace, DenseVectorSpace, ElementwiseJordanSpace, ): acc = acc + yi else: acc = self.codomain.add_batch(acc, yi) return acc
[docs] @checked_method(in_space="codomain", out_space="domain", in_batched=True, out_batched=True) def rvapply(self, y: Any) -> Any: """Apply the adjoint over a codomain batch and preserve domain structure.""" return self._rvapply_unchecked(y)
def _rvapply_unchecked(self, y: Any) -> Any: """Apply the adjoint over a codomain batch without checks and rebuild domain representation.""" x_parts = tuple(op.rvapply(y) for op in self.parts) return self.dom._from_components(x_parts)
[docs] def fuse(self, *, materialize: bool = False) -> SumToSingleLinOp: """Fuse each component operator (ADR-021), preserving dom/cod and context.""" return SumToSingleLinOp( self.dom, self.cod, tuple(op.fuse(materialize=materialize) for op in self.parts), self.ctx, )
[docs] @classmethod def from_operators(cls, parts: Tuple[LinOp, ...]) -> SumToSingleLinOp: """Build a sum-to-single operator from component operators.""" if not parts: raise ValueError("Parts must be non-empty.") dom = TreeSpace(tuple(range(len(parts))), tuple(op.dom for op in parts)) cod = parts[0].cod return cls(dom, cod, parts)
def _convert(self, new_ctx: Context) -> SumToSingleLinOp: """Convert spaces and component operators to ``new_ctx``.""" new_dom = self.dom.convert(new_ctx) new_cod = self.cod.convert(new_ctx) new_parts = [op.convert(new_ctx) for op in self.parts] return SumToSingleLinOp(new_dom, new_cod, new_parts, new_ctx)