from __future__ import annotations
from typing import Any, Sequence, Tuple, cast
from ._base import TreeLinOp
from .._base import LinOp, Domain
from ..._checks import checked_method
from ...kernels import CachedStackParts, dispatch, should_consult_dispatch
from ...space import DenseCoordinateSpace, DenseVectorSpace, ElementwiseJordanSpace, TreeSpace
from ...backend import jax_pytree_class, Context
# ADR-016 dispatch call site: a StackedLinOp applies one shared input through
# every component. The per-component loop is the ``generic`` fallback; the
# ``stacked-uniform-dense-batched-apply`` spec routes here when dispatch is on
# and the components are uniform flat-dense. ``parts[i]._apply_core`` is exactly
# the bound core in ``self._apply_parts[i]``, so the generic stays byte-identical.
_STACKED_APPLY_KEY = "linop.stacked.apply"
def _stacked_apply(parts: Any, x: Any) -> tuple[Any, ...]:
"""Apply the shared input through each component core (generic stacked apply)."""
return tuple(p._apply_core(x) for p in parts)
[docs]
@jax_pytree_class
class StackedLinOp(TreeLinOp[Domain, TreeSpace]):
r"""
Represent operators from one domain as a tree-valued map.
If ``dom = X`` and ``cod = Y1 x ... x Yk``, component ``parts[i]`` maps
``X`` to ``Yi``. Forward application returns a value with ``cod.treedef``;
adjoint application sums component adjoints in ``X``.
Parameters
----------
dom : Space
Shared component domain.
cod : TreeSpace
Tree-structured codomain.
parts : sequence of LinOp
Operators from ``dom`` to each component of ``cod``.
ctx : Context, str, or None, optional
Backend context specification.
"""
def __init__(
self,
dom: Domain,
cod: TreeSpace,
parts: Sequence[LinOp],
ctx: Context | str | None = None,
) -> None:
super().__init__(dom, cod, parts, ctx)
# ADR-022: memoize the stacked component matrices for the stacked.apply
# broadcast fold (built once on first optimized use, NumPy-only).
self.parts = CachedStackParts(self.parts)
self._flat_dense_rapply_mats = self._make_flat_dense_rapply_mats()
def _make_flat_dense_rapply_mats(self):
"""Return dense adjoint matrices for the exact flat-vector fast path."""
if (
type(self.dom) not in (DenseCoordinateSpace, DenseVectorSpace, ElementwiseJordanSpace)
or not self.dom.is_euclidean
):
return None
if tuple(self.dom.shape) != (self.dom._size,):
return None
mats = []
for op in self.parts:
if (
type(op.cod) not in (DenseCoordinateSpace, DenseVectorSpace, ElementwiseJordanSpace)
or not op.cod.is_euclidean
or tuple(op.cod.shape) != (op.cod._size,)
or not hasattr(op, "_A2H")
):
return None
mats.append(cast(Any, op)._A2H)
return tuple(mats)
def _check_layout(self) -> None:
"""Check that every component maps the shared domain to one codomain part."""
if not isinstance(self.cod, TreeSpace):
raise TypeError("StackedLinOp expects cod to be TreeSpace.")
if len(self.parts) != self.cod.arity:
raise ValueError("Number of ops must match codomain tree arity.")
for i, A in enumerate(self.parts):
if A.dom == self.dom and A.cod == self.cod.leaf_spaces[i]:
continue
else:
raise TypeError(f"Component op {i} must map dom -> cod.leaf_spaces[{i}].")
[docs]
@checked_method(in_space="domain", out_space="codomain")
def apply(self, x: Any) -> Any:
"""Apply each component operator and return a codomain product element."""
return self._apply_unchecked(x)
def _apply_unchecked(self, x: Any) -> Any:
"""Apply component operators without checks and rebuild codomain representation."""
if should_consult_dispatch(self.ctx):
y_parts = dispatch(
_STACKED_APPLY_KEY,
self.parts,
x,
generic=_stacked_apply,
ctx=self.ctx,
)
elif self._num_parts == 2:
y_parts = (self._apply_parts[0](x), self._apply_parts[1](x))
else:
y_parts = tuple(apply(x) for apply in self._apply_parts)
return self.cod._from_components(y_parts)
[docs]
@checked_method(in_space="codomain", out_space="domain")
def rapply(self, y: Any) -> Any:
"""Apply component adjoints from a codomain product element and sum them."""
return self._rapply_unchecked(y)
def _rapply_unchecked(self, y: Any) -> Any:
"""Apply component adjoints without membership checks."""
y_parts = self.cod._components(y)
mats = self._flat_dense_rapply_mats
if mats is not None:
if self._num_parts == 2:
return mats[0] @ y_parts[0] + mats[1] @ y_parts[1]
acc = mats[0] @ y_parts[0]
for mat, yi in zip(mats[1:], y_parts[1:]):
acc = acc + mat @ yi
return acc
if self._num_parts == 2:
x0 = self._rapply_parts[0](y_parts[0])
x1 = self._rapply_parts[1](y_parts[1])
if type(self.dom) in (DenseCoordinateSpace, DenseVectorSpace, ElementwiseJordanSpace):
return x0 + x1
return self.dom.add(x0, x1)
acc = None
for rapply, yi in zip(self._rapply_parts, y_parts):
xi = rapply(yi)
if acc is None:
acc = xi
elif type(self.dom) in (DenseCoordinateSpace, DenseVectorSpace, ElementwiseJordanSpace):
acc = acc + xi
else:
acc = self.dom.add(acc, xi)
return acc
[docs]
@checked_method(in_space="domain", out_space="codomain", in_batched=True, out_batched=True)
def vapply(self, x: Any) -> Any:
"""Apply this stacked operator over a batch and preserve codomain structure."""
return self._vapply_unchecked(x)
def _vapply_unchecked(self, x: Any) -> Any:
"""Apply over a batch without checks and rebuild codomain representation."""
y_parts = tuple(op.vapply(x) for op in self.parts)
return self.cod._from_components(y_parts)
[docs]
@checked_method(in_space="codomain", out_space="domain", in_batched=True, out_batched=True)
def rvapply(self, y: Any) -> Any:
"""Apply the adjoint stacked operator over a structured product batch."""
return self._rvapply_unchecked(y)
def _rvapply_unchecked(self, y: Any) -> Any:
"""Apply the adjoint over a product batch without membership checks."""
y_parts = self.cod._components(y)
mats = self._flat_dense_rapply_mats
if mats is not None:
if self._num_parts == 2:
acc = y_parts[0] @ mats[0].T + y_parts[1] @ mats[1].T
else:
acc = y_parts[0] @ mats[0].T
for mat, yi in zip(mats[1:], y_parts[1:]):
acc = acc + yi @ mat.T
return acc
acc = None
for op, yi in zip(self.parts, y_parts):
xi = op.rvapply(yi)
if acc is None:
acc = xi
elif type(self.domain) in (
DenseCoordinateSpace,
DenseVectorSpace,
ElementwiseJordanSpace,
):
acc = acc + xi
else:
acc = self.domain.add_batch(acc, xi)
return acc
[docs]
def fuse(self, *, materialize: bool = False) -> StackedLinOp:
"""Fuse each component operator (ADR-021), preserving dom/cod and context."""
return StackedLinOp(
self.dom,
self.cod,
tuple(op.fuse(materialize=materialize) for op in self.parts),
self.ctx,
)
[docs]
@classmethod
def from_operators(cls, parts: Tuple[LinOp, ...]) -> StackedLinOp:
"""Build a stacked operator from component operators."""
if not parts:
raise ValueError("Parts must be non-empty.")
cod = TreeSpace(tuple(range(len(parts))), tuple(op.cod for op in parts))
dom = parts[0].dom
return cls(dom, cod, parts)
def _convert(self, new_ctx: Context) -> StackedLinOp:
"""Convert spaces and component operators to ``new_ctx``."""
new_dom = self.dom.convert(new_ctx)
new_cod = self.cod.convert(new_ctx)
new_parts = [op.convert(new_ctx) for op in self.parts]
return StackedLinOp(new_dom, new_cod, new_parts, new_ctx)