Source code for spacecore.linop._dense
from __future__ import annotations
from math import prod
from typing import Any
from ._base import LinOp, Domain, Codomain
from ..space import VectorSpace
from ..types import DenseArray
from ..backend import jax_pytree_class, Context
from .._contextual.manager import ctx_manager
[docs]
@jax_pytree_class
class DenseLinOp(LinOp[VectorSpace, VectorSpace]):
"""
Dense linear operator defined by an array A with shape:
A.shape == cod.shape + dom.shape
apply: y = A ⋅ x (contract over dom axes)
rapply: x = A^* ⋅ y (contract over cod axes)
"""
def __init__(self,
A: DenseArray,
dom: Domain,
cod: Codomain | None = None,
ctx: Context | str | None = None
) -> None:
ctx = ctx_manager.resolve_context_priority(ctx, dom, cod)
ctx.assert_dense(A) # Check if A is ndarray of ctx
if cod is None:
cod_shape_len = len(A.shape) - len(dom.shape)
cod = VectorSpace(A.shape[:cod_shape_len], ctx)
super(DenseLinOp, self).__init__(dom, cod, ctx)
expected = tuple(self.cod.shape) + tuple(self.dom.shape)
if tuple(A.shape) != expected:
raise TypeError(f"Expected A.shape == cod.shape + dom.shape == {expected}, got {A.shape}")
self.A = A # No dtype conversion
self._cod_size = prod(self.cod.shape)
self._dom_size = prod(self.dom.shape)
self._matrix_shape = (self._cod_size, self._dom_size)
self._A2 = self.A.reshape(self._matrix_shape)
dtype = self.ops.get_dtype(self.A)
is_complex = getattr(dtype, "kind", None) == "c" or str(dtype).startswith("torch.complex")
self._A2H = self._A2.T.conj() if is_complex else self._A2.T
self._dom_is_flat = tuple(self.dom.shape) == (self._dom_size,)
self._cod_is_flat = tuple(self.cod.shape) == (self._cod_size,)
self._dom_vector_fast_path = type(self.dom) is VectorSpace
self._cod_vector_fast_path = type(self.cod) is VectorSpace
if not self._enable_checks:
self.apply = self._apply_unchecked
self.rapply = self._rapply_unchecked
[docs]
def apply(self, x: DenseArray) -> DenseArray:
"""
Forward action: y = A ⋅ x with y in cod.shape.
"""
if self._enable_checks:
self.dom._check_member(x)
return self._apply_unchecked(x)
def _apply_unchecked(self, x: DenseArray) -> DenseArray:
x1 = x if self._dom_is_flat else x.reshape((self._dom_size,))
y1 = self._A2 @ x1
if self._cod_vector_fast_path:
return y1 if self._cod_is_flat else y1.reshape(self.cod.shape)
return self.cod.unflatten(y1)
[docs]
def rapply(self, y: DenseArray) -> DenseArray:
"""
Adjoint action: x = A^* ⋅ y with x in dom.shape.
For complex A, uses conjugate-transpose of the 2D reshaped matrix.
"""
if self._enable_checks:
self.cod._check_member(y)
return self._rapply_unchecked(y)
def _rapply_unchecked(self, y: DenseArray) -> DenseArray:
y1 = y if self._cod_is_flat else y.reshape((self._cod_size,))
x1 = self._A2H @ y1
if self._dom_vector_fast_path:
return x1 if self._dom_is_flat else x1.reshape(self.dom.shape)
return self.dom.unflatten(x1)
def __eq__(self, x: Any) -> bool:
if type(x) is type(self):
return (self.dom == x.dom
and self.cod == x.cod
and self.ops.allclose(self.A, x.A)
)
return False
def tree_flatten(self):
aux = (self.dom, self.cod, self.ctx)
children = (self.A,)
return children, aux
@classmethod
def tree_unflatten(cls, aux, children):
dom, cod, ctx = aux
A = children[0]
return cls(A, dom, cod, ctx)
def _convert(self, new_ctx: Context) -> DenseLinOp:
new_dom = self.dom.convert(new_ctx)
new_cod = self.cod.convert(new_ctx)
new_A = new_ctx.asarray(self.A)
return DenseLinOp(new_A, new_dom, new_cod, new_ctx)