Source code for spacecore.linop._dense

from __future__ import annotations

from math import prod
from typing import Any

from ._base import LinOp, Domain, Codomain
from ..space import VectorSpace
from ..types import DenseArray
from ..backend import jax_pytree_class, Context
from .._contextual.manager import ctx_manager


[docs] @jax_pytree_class class DenseLinOp(LinOp[VectorSpace, VectorSpace]): """ Dense linear operator defined by an array A with shape: A.shape == cod.shape + dom.shape apply: y = A ⋅ x (contract over dom axes) rapply: x = A^* ⋅ y (contract over cod axes) """ def __init__(self, A: DenseArray, dom: Domain, cod: Codomain | None = None, ctx: Context | str | None = None ) -> None: ctx = ctx_manager.resolve_context_priority(ctx, dom, cod) ctx.assert_dense(A) # Check if A is ndarray of ctx if cod is None: cod_shape_len = len(A.shape) - len(dom.shape) cod = VectorSpace(A.shape[:cod_shape_len], ctx) super(DenseLinOp, self).__init__(dom, cod, ctx) expected = tuple(self.cod.shape) + tuple(self.dom.shape) if tuple(A.shape) != expected: raise TypeError(f"Expected A.shape == cod.shape + dom.shape == {expected}, got {A.shape}") self.A = A # No dtype conversion self._cod_size = prod(self.cod.shape) self._dom_size = prod(self.dom.shape) self._matrix_shape = (self._cod_size, self._dom_size) self._A2 = self.A.reshape(self._matrix_shape) dtype = self.ops.get_dtype(self.A) is_complex = getattr(dtype, "kind", None) == "c" or str(dtype).startswith("torch.complex") self._A2H = self._A2.T.conj() if is_complex else self._A2.T self._dom_is_flat = tuple(self.dom.shape) == (self._dom_size,) self._cod_is_flat = tuple(self.cod.shape) == (self._cod_size,) self._dom_vector_fast_path = type(self.dom) is VectorSpace self._cod_vector_fast_path = type(self.cod) is VectorSpace if not self._enable_checks: self.apply = self._apply_unchecked self.rapply = self._rapply_unchecked
[docs] def apply(self, x: DenseArray) -> DenseArray: """ Forward action: y = A ⋅ x with y in cod.shape. """ if self._enable_checks: self.dom._check_member(x) return self._apply_unchecked(x)
def _apply_unchecked(self, x: DenseArray) -> DenseArray: x1 = x if self._dom_is_flat else x.reshape((self._dom_size,)) y1 = self._A2 @ x1 if self._cod_vector_fast_path: return y1 if self._cod_is_flat else y1.reshape(self.cod.shape) return self.cod.unflatten(y1)
[docs] def rapply(self, y: DenseArray) -> DenseArray: """ Adjoint action: x = A^* ⋅ y with x in dom.shape. For complex A, uses conjugate-transpose of the 2D reshaped matrix. """ if self._enable_checks: self.cod._check_member(y) return self._rapply_unchecked(y)
def _rapply_unchecked(self, y: DenseArray) -> DenseArray: y1 = y if self._cod_is_flat else y.reshape((self._cod_size,)) x1 = self._A2H @ y1 if self._dom_vector_fast_path: return x1 if self._dom_is_flat else x1.reshape(self.dom.shape) return self.dom.unflatten(x1) def __eq__(self, x: Any) -> bool: if type(x) is type(self): return (self.dom == x.dom and self.cod == x.cod and self.ops.allclose(self.A, x.A) ) return False def tree_flatten(self): aux = (self.dom, self.cod, self.ctx) children = (self.A,) return children, aux @classmethod def tree_unflatten(cls, aux, children): dom, cod, ctx = aux A = children[0] return cls(A, dom, cod, ctx) def _convert(self, new_ctx: Context) -> DenseLinOp: new_dom = self.dom.convert(new_ctx) new_cod = self.cod.convert(new_ctx) new_A = new_ctx.asarray(self.A) return DenseLinOp(new_A, new_dom, new_cod, new_ctx)