Source code for spacecore.linop.product._base

from __future__ import annotations

from abc import abstractmethod
from typing import Tuple, Sequence, Any

from .._base import LinOp, Domain, Codomain
from ...backend import jax_pytree_class, Context


[docs] @jax_pytree_class class ProductLinOp(LinOp[Domain, Codomain]): """ Base class for linear operators assembled from component operators. """ parts: Tuple[LinOp, ...] def __init__(self, dom: Domain, cod: Codomain, parts: Sequence[LinOp], ctx: Context | str | None = None ) -> None: if not parts: raise ValueError("Parts must be non-empty.") super().__init__(dom, cod, ctx) self.parts = tuple(op.convert(self.ctx) for op in parts) self._num_parts = len(self.parts) self._apply_parts = tuple(getattr(op, "_apply_unchecked", op.apply) for op in self.parts) self._rapply_parts = tuple(getattr(op, "_rapply_unchecked", op.rapply) for op in self.parts) self._check_layout() unchecked_apply = getattr(self, "_apply_unchecked", None) unchecked_rapply = getattr(self, "_rapply_unchecked", None) if not self._enable_checks and unchecked_apply is not None and unchecked_rapply is not None: self.apply = unchecked_apply self.rapply = unchecked_rapply @abstractmethod def _check_layout(self) -> None: """ Check incidence compatibility between self.parts and self.dom/self.cod. """ raise NotImplementedError
[docs] @classmethod @abstractmethod def from_operators(cls, parts: Tuple[LinOp, ...]) -> ProductLinOp: ...
def __eq__(self, x: Any) -> bool: if type(x) is type(self): return (self.dom == x.dom and self.cod == x.cod and len(self.parts) == len(x.parts) and all([op1 == op2 for op1, op2 in zip(self.parts, x.parts)]) ) return False def tree_flatten(self): children = self.parts aux = (self.dom, self.cod, self.ctx) return children, aux @classmethod def tree_unflatten(cls, aux, children): dom, cod, ctx = aux return cls(dom, cod, tuple(children), ctx)