from __future__ import annotations
from math import prod
from typing import Any
from ._base import LinOp, Domain, Codomain
from ..space import VectorSpace
from ..types import DenseArray, SparseArray
from ..backend import jax_pytree_class, Context
from .._contextual.manager import ctx_manager
[docs]
@jax_pytree_class
class SparseLinOp(LinOp):
"""
Sparse linear operator implementing the tensor map A : dom -> cod where
conceptually A has shape cod.shape + dom.shape, but stored as a 2D sparse matrix:
A2.shape == (prod(cod.shape), prod(dom.shape))
apply: y = A ⋅ x (contract over dom axes)
rapply: x = A^* ⋅ y (contract over cod axes)
"""
def __init__(self,
A: SparseArray,
dom: Domain,
cod: Codomain,
ctx: Context | str | None = None
) -> None:
ctx = ctx_manager.resolve_context_priority(ctx, dom, cod)
ctx.assert_sparse(A) # Check if A is sparse array of ctx
super(SparseLinOp, self).__init__(dom, cod, ctx)
expected = (prod(self.cod.shape), prod(self.dom.shape))
if tuple(A.shape) != expected:
raise TypeError(f"Expected A.shape == (prod(cod.shape), prod(dom.shape)) == {expected}, got {A.shape}")
self.A = A # No dtype conversion
self._cod_size = expected[0]
self._dom_size = expected[1]
dtype = self.ops.get_dtype(self.A)
self._A_is_complex = getattr(dtype, "kind", None) == "c" or str(dtype).startswith("torch.complex")
self._AT = self.A.T
self._AH = self._AT.conj() if self._A_is_complex else self._AT
self._dom_is_flat = tuple(self.dom.shape) == (self._dom_size,)
self._cod_is_flat = tuple(self.cod.shape) == (self._cod_size,)
self._dom_vector_fast_path = type(self.dom) is VectorSpace
self._cod_vector_fast_path = type(self.cod) is VectorSpace
if not self._enable_checks:
self.apply = self._apply_unchecked
self.rapply = self._rapply_unchecked
[docs]
def apply(self, x: DenseArray) -> DenseArray:
"""
Forward action: y = A ⋅ x with y in cod.shape.
x must have shape dom.shape (dense).
"""
if self._enable_checks:
self.dom._check_member(x)
return self._apply_unchecked(x)
def _apply_unchecked(self, x: DenseArray) -> DenseArray:
x1 = x if self._dom_is_flat else x.reshape((self._dom_size,))
y1 = self.A @ x1 # (m,)
if self._cod_vector_fast_path:
return y1 if self._cod_is_flat else y1.reshape(self.cod.shape)
return self.cod.unflatten(y1)
[docs]
def rapply(self, y: DenseArray) -> DenseArray:
"""
Adjoint action: x = A^* ⋅ y with x in dom.shape.
y must have shape cod.shape (dense).
"""
if self._enable_checks:
self.cod._check_member(y)
return self._rapply_unchecked(y)
def _rapply_unchecked(self, y: DenseArray) -> DenseArray:
y1 = y if self._cod_is_flat else y.reshape((self._cod_size,))
x1 = self._AH @ y1
if self._dom_vector_fast_path:
return x1 if self._dom_is_flat else x1.reshape(self.dom.shape)
return self.dom.unflatten(x1)
def __eq__(self, x: Any) -> bool:
if type(x) is type(self):
return (self.dom == x.dom
and self.cod == x.cod
and self.ops.allclose_sparse(self.A, x.A)
)
return False
def tree_flatten(self):
aux = (self.dom, self.cod, self.ctx)
children = (self.A,)
return children, aux
@classmethod
def tree_unflatten(cls, aux, children):
dom, cod, ctx = aux
A = children[0]
return cls(A, dom, cod, ctx)
def _convert(self, new_ctx: Context) -> SparseLinOp:
new_dom = self.dom.convert(new_ctx)
new_cod = self.cod.convert(new_ctx)
new_A = new_ctx.assparse(self.A)
return SparseLinOp(new_A, new_dom, new_cod, new_ctx)