from __future__ import annotations
from functools import cached_property
from math import prod
from typing import Any, cast
from ._base import LinOp
from ._metric import _metric_is_hermitian_by_basis, _requires_euclidean_or_riesz
from .._checks import checked_method
from ..space import (
CoordinateSpace,
DenseCoordinateSpace,
DenseVectorSpace,
ElementwiseJordanSpace,
WeightedInnerProduct,
)
from ..types import DenseArray, SparseArray
from ..backend import jax_pytree_class, Context
from .._contextual import resolve_context_priority
from ..kernels import core_kernels
from ..kernels.core.sparse import _SparseMode
_VECTOR_SPACE_ONLY = (
"SparseLinOp is only for coordinate sparse matrices acting between "
"CoordinateSpace objects. Non-vector or exotic spaces should use "
"MatrixFreeLinOp with explicit forward and adjoint callbacks."
)
[docs]
@core_kernels("sparse")
@jax_pytree_class
class SparseLinOp(LinOp[CoordinateSpace, CoordinateSpace]):
r"""
Represent a sparse coordinate matrix-backed linear operator.
``SparseLinOp(A, dom, cod)`` represents a sparse coordinate matrix between
vector spaces. Subclasses of :class:`VectorSpace` are supported, but product
spaces and other non-vector spaces are intentionally rejected. The
conceptual operator tensor has shape ``cod.shape + dom.shape`` while
storage uses a two-dimensional sparse matrix with shape
``(prod(cod.shape), prod(dom.shape))``.
Forward application is the raw coordinate matrix action. Adjoint
application is metric-aware: Euclidean spaces use the conjugate transpose
fast path, while non-Euclidean spaces use their Riesz maps as
``R_X^{-1} A^dagger R_Y``.
Parameters
----------
A : SparseArray
Sparse backend matrix with shape ``(prod(cod.shape), prod(dom.shape))``.
dom : CoordinateSpace
Domain vector space, or a subclass of ``VectorSpace``.
cod : CoordinateSpace
Codomain vector space, or a subclass of ``VectorSpace``.
ctx : Context, str, or None, optional
Backend context specification. Default is resolved from the spaces.
Attributes
----------
A : SparseArray
Stored sparse matrix representation. The constructor keeps this object
without sparse conversion or copying; explicit conversion happens only
through :meth:`_convert`.
Examples
--------
>>> import numpy as np
>>> import scipy.sparse as sps
>>> import spacecore as sc
>>> ctx = sc.Context(sc.NumpyOps(), dtype=np.float64)
>>> X = sc.DenseCoordinateSpace((2,), ctx)
>>> A = sc.SparseLinOp(ctx.assparse(sps.eye(2)), X, X, ctx)
>>> A.apply(ctx.asarray([1.0, 2.0]))
array([1., 2.])
"""
def __init__(
self,
A: SparseArray,
dom: CoordinateSpace,
cod: CoordinateSpace,
ctx: Context | str | None = None,
) -> None:
ctx = resolve_context_priority(ctx, dom, cod)
ctx.assert_sparse(A) # Check if A is sparse array of ctx
if not isinstance(dom, CoordinateSpace) or not isinstance(cod, CoordinateSpace):
raise TypeError(_VECTOR_SPACE_ONLY)
_requires_euclidean_or_riesz(dom, cod, "SparseLinOp")
super(SparseLinOp, self).__init__(dom, cod, ctx)
expected = (prod(self.cod.shape), prod(self.dom.shape))
if tuple(A.shape) != expected:
raise TypeError(
f"Expected A.shape == (prod(cod.shape), prod(dom.shape)) == {expected}, got {A.shape}"
)
self._A = A # No dtype conversion
self._cod_size = expected[0]
self._dom_size = expected[1]
dtype = self.ops.get_dtype(self.A)
self._A_is_complex = self.ops.is_complex_dtype(dtype)
self._AT = self.A.T
self._AH = self._sparse_conj(self._AT) if self._A_is_complex else self._AT
self._dom_dense_array = type(self.dom) in (
DenseCoordinateSpace,
DenseVectorSpace,
ElementwiseJordanSpace,
)
self._cod_dense_array = type(self.cod) in (
DenseCoordinateSpace,
DenseVectorSpace,
ElementwiseJordanSpace,
)
self._dom_is_flat = self._dom_dense_array and tuple(self.dom.shape) == (self._dom_size,)
self._cod_is_flat = self._cod_dense_array and tuple(self.cod.shape) == (self._cod_size,)
self._mode = self._select_mode()
if self._mode is _SparseMode.WEIGHTED_FUSED:
self._dom_weights = self.dom.geometry.weights
self._cod_weights = self.cod.geometry.weights
def _select_mode(self) -> _SparseMode:
"""Select the sparse computation mode once for this operator."""
if (
self._dom_is_flat
and self._cod_is_flat
and type(getattr(self.dom, "geometry", None)) is WeightedInnerProduct
and type(getattr(self.cod, "geometry", None)) is WeightedInnerProduct
):
return _SparseMode.WEIGHTED_FUSED
if (
cast(Any, self.domain).is_euclidean
and cast(Any, self.codomain).is_euclidean
and self._dom_dense_array
and self._cod_dense_array
):
if self._dom_is_flat and self._cod_is_flat:
return _SparseMode.EUCLIDEAN_FLAT
return _SparseMode.EUCLIDEAN_TENSOR
return _SparseMode.GENERAL_METRIC
def _sparse_conj(self, A: SparseArray) -> SparseArray:
"""Return the complex conjugate of a backend sparse array."""
if hasattr(A, "conj"):
return A.conj()
if hasattr(A, "conjugate"):
return cast(Any, A).conjugate()
if hasattr(A, "data") and hasattr(A, "indices"):
A_any = cast(Any, A)
kwargs = {
"shape": A.shape,
"indices_sorted": getattr(A, "indices_sorted", False),
"unique_indices": getattr(A, "unique_indices", False),
}
return cast(Any, type(A))((self.ops.conj(A_any.data), A_any.indices), **kwargs)
raise TypeError(f"Cannot conjugate sparse array of type {type(A).__name__}.")
@cached_property
def A(self) -> SparseArray:
"""
Stored sparse matrix representation of this operator.
The returned sparse matrix has shape
``(prod(self.codomain.shape), prod(self.domain.shape))`` and is the
same object supplied at construction.
"""
return self._A
[docs]
@checked_method(in_space="domain", out_space="codomain")
def apply(self, x: DenseArray) -> DenseArray:
"""
Forward action ``y = A @ x`` in Euclidean coordinates.
x must have shape dom.shape (dense).
"""
return self._apply_core(x)
[docs]
@checked_method(in_space="codomain", out_space="domain")
def rapply(self, y: DenseArray) -> DenseArray:
"""
Metric-aware adjoint action.
y must have shape cod.shape (dense).
"""
return self._rapply_core(y)
[docs]
@checked_method(in_space="domain", in_batched=True)
def vapply(self, xs: DenseArray) -> DenseArray:
return self._vapply_core(xs)
[docs]
@checked_method(in_space="codomain", out_space="domain", in_batched=True, out_batched=True)
def rvapply(self, ys: DenseArray) -> DenseArray:
return self._rvapply_core(ys)
[docs]
def to_sparse(self) -> SparseArray:
"""
Return the stored sparse matrix representation without copying.
The returned object is exactly the sparse array supplied at construction.
"""
return self.A
[docs]
def to_matrix(self) -> DenseArray:
"""
Materialize the stored sparse matrix as a dense 2D coordinate matrix.
Use :meth:`to_sparse` when sparse storage should be preserved.
"""
if hasattr(self.A, "toarray"):
dense = cast(Any, self.A).toarray()
elif hasattr(self.A, "todense"):
dense = cast(Any, self.A).todense()
elif hasattr(self.A, "to_dense"):
dense = cast(Any, self.A).to_dense()
else:
dense = super().to_matrix()
return self.ops.reshape(self.ctx.asarray(dense), (self._cod_size, self._dom_size))
[docs]
def to_dense(self) -> DenseArray:
"""
Materialize the stored sparse matrix as a dense operator tensor.
The returned array has shape ``self.codomain.shape + self.domain.shape``.
"""
return self.ops.reshape(
self.to_matrix(), tuple(self.codomain.shape) + tuple(self.domain.shape)
)
[docs]
def is_hermitian(self) -> bool | None:
"""
Return whether this sparse operator is structurally self-adjoint.
Returns
-------
bool or None
``True`` or ``False`` when the structure can be checked, otherwise
``None``.
"""
if self.dom != self.cod:
return False
if not (
cast(Any, self.domain).is_euclidean
and cast(Any, self.codomain).is_euclidean
):
return _metric_is_hermitian_by_basis(self)
try:
return bool(self.ops.allclose_sparse(self.A, self._AH))
except Exception:
return None
def __eq__(self, other: Any) -> bool:
if not self._eq_backend_compatible(other): # Tier 1: backend
return NotImplemented
if self.dom != other.dom or self.cod != other.cod: # Tier 2: spaces before allclose
return False
# Tier 3: sparse values. Note: allclose_sparse has no equal_nan, so a
# structural NaN compares unequal to itself (rare for sparse storage).
return bool(self.ops.allclose_sparse(self.A, other.A))
def _repr_body(self) -> str:
"""Lead with the arrow; add the (cheap) stored non-zero count when known."""
nnz = self._stored_nnz()
return self._arrow() if nnz is None else f"{self._arrow()}, nnz={nnz}"
def _stored_nnz(self) -> int | None:
"""Return the stored non-zero count across backends, or ``None``.
Tries the common spellings: ``nnz`` (scipy/cupy), ``nse`` (jax BCOO),
and ``_nnz()`` (torch). Returns ``None`` when none is cheaply available.
"""
A = self._A
for attr in ("nnz", "nse"):
value = getattr(A, attr, None)
if isinstance(value, int):
return value
method = getattr(A, "_nnz", None)
if callable(method):
try:
return int(cast(Any, method()))
except Exception:
return None
return None
def tree_flatten(self):
aux = (self.dom, self.cod, self.ctx)
children = (self.A,)
return children, aux
@classmethod
def tree_unflatten(cls, aux, children):
dom, cod, ctx = aux
A = children[0]
return cls(A, dom, cod, ctx)
def _convert(self, new_ctx: Context) -> SparseLinOp:
new_dom = self.dom.convert(new_ctx)
new_cod = self.cod.convert(new_ctx)
new_A = new_ctx.assparse(self.A)
if new_ctx.ops.get_dtype(new_A) != new_ctx.dtype:
if hasattr(new_A, "astype"):
new_A = cast(Any, new_A).astype(new_ctx.dtype)
elif hasattr(new_A, "to"):
new_A = cast(Any, new_A).to(dtype=new_ctx.dtype)
else:
new_A = new_ctx.assparse(self.to_matrix())
return SparseLinOp(new_A, new_dom, new_cod, new_ctx)