Source code for spacecore.linop.product._block

from __future__ import annotations

from typing import Any, Tuple

from ._base import ProductLinOp
from .._base import LinOp
from ... import Context
from ...space import ProductSpace
from ...backend import jax_pytree_class


[docs] @jax_pytree_class class BlockDiagonalLinOp(ProductLinOp[ProductSpace, ProductSpace]): """ Block-diagonal operator between product spaces. dom = X1 × ... × Xk cod = Y1 × ... × Yk ops[i] : Xi -> Yi """ def _check_layout(self) -> None: if not isinstance(self.dom, ProductSpace) or not isinstance(self.cod, ProductSpace): raise TypeError("BlockDiagonalLinOp expects dom and cod to be ProductSpace.") if len(self.parts) != len(self.dom.spaces) or len(self.parts) != len(self.cod.spaces): raise ValueError("Number of component ops must match product arity.") for i, A in enumerate(self.parts): if A.dom == self.dom.spaces[i] and A.cod == self.cod.spaces[i]: continue else: raise TypeError(f"Component op {i} has incompatible dom/cod spaces.")
[docs] def apply(self, x: Any) -> Any: if self._enable_checks: self.dom._check_member(x) return self._apply_unchecked(x)
def _apply_unchecked(self, x: Any) -> Any: if self._num_parts == 2: return self._apply_parts[0](x[0]), self._apply_parts[1](x[1]) return tuple(apply(xi) for apply, xi in zip(self._apply_parts, x))
[docs] def rapply(self, y: Any) -> Any: if self._enable_checks: self.cod._check_member(y) return self._rapply_unchecked(y)
def _rapply_unchecked(self, y: Any) -> Any: if self._num_parts == 2: return self._rapply_parts[0](y[0]), self._rapply_parts[1](y[1]) return tuple(rapply(yi) for rapply, yi in zip(self._rapply_parts, y))
[docs] @classmethod def from_operators(cls, parts: Tuple[LinOp, ...]) -> BlockDiagonalLinOp: if not parts: raise ValueError("Parts must be non-empty.") dom = ProductSpace(tuple(op.dom for op in parts)) cod = ProductSpace(tuple(op.cod for op in parts)) return cls(dom, cod, parts)
def _convert(self, new_ctx: Context) -> BlockDiagonalLinOp: new_dom = self.dom.convert(new_ctx) new_cod = self.cod.convert(new_ctx) new_parts = [op.convert(new_ctx) for op in self.parts] return BlockDiagonalLinOp(new_dom, new_cod, new_parts)