from __future__ import annotations
from functools import cached_property
from math import prod
from typing import Any, cast
from ._base import Codomain, Domain, LinOp
from .._checks import checked_method
from ._metric import _metric_is_hermitian_by_basis, _requires_euclidean_or_riesz
from ..space import (
DenseCoordinateSpace,
DenseVectorSpace,
ElementwiseJordanSpace,
WeightedInnerProduct,
)
from ..types import DenseArray
from ..backend import jax_pytree_class, Context
from .._contextual import resolve_context_priority
from ..kernels import core_kernels
from ..kernels.core.dense import _DenseMode
[docs]
@core_kernels("dense")
@jax_pytree_class
class DenseLinOp(LinOp[Domain, Codomain]):
r"""
Represent a dense coordinate tensor-backed linear operator.
``DenseLinOp(A, dom, cod)`` represents a linear map
:math:`A \colon X \to Y` where the stored dense array has shape
``cod.shape + dom.shape``. Forward application is the raw coordinate
matrix action. Adjoint application is metric-aware: Euclidean spaces use
the conjugate transpose fast path, while non-Euclidean spaces use their
Riesz maps as ``R_X^{-1} A^dagger R_Y``.
DenseLinOp does not copy or cast the input array. The caller is responsible
for passing an array compatible with `ctx`. This avoids duplicating large dense
operators in memory.
Parameters
----------
A : DenseArray
Dense backend array with shape ``cod.shape + dom.shape``.
dom : Space
Domain space.
cod : Space or None, optional
Codomain space. If omitted, it is inferred from the leading axes of
``A``.
ctx : Context, str, or None, optional
Backend context specification. Default is resolved from the spaces.
Attributes
----------
A : DenseArray
Stored dense operator tensor.
Examples
--------
>>> import numpy as np
>>> import spacecore as sc
>>> ctx = sc.Context(sc.NumpyOps(), dtype=np.float64)
>>> X = sc.DenseCoordinateSpace((2,), ctx)
>>> A = sc.DenseLinOp(ctx.asarray([[2.0, 0.0], [0.0, 3.0]]), X, X, ctx)
>>> A.apply(ctx.asarray([1.0, 2.0]))
array([2., 6.])
"""
def __init__(
self,
A: DenseArray,
dom: Domain,
cod: Codomain | None = None,
ctx: Context | str | None = None,
) -> None:
ctx = resolve_context_priority(ctx, dom, cod)
ctx.assert_dense(A) # Check if A is ndarray of ctx
if cod is None:
cod_shape_len = len(A.shape) - len(dom.shape)
cod = cast(Codomain, DenseCoordinateSpace(tuple(A.shape[:cod_shape_len]), ctx))
_requires_euclidean_or_riesz(dom, cod, "DenseLinOp")
super(DenseLinOp, self).__init__(dom, cod, ctx)
expected = tuple(self.cod.shape) + tuple(self.dom.shape)
if tuple(A.shape) != expected:
raise TypeError(
f"Expected A.shape == cod.shape + dom.shape == {expected}, got {A.shape}"
)
self._A = A # Intentionally no dtype conversion to avoid extra memory use.
self._cod_size = prod(self.cod.shape)
self._dom_size = prod(self.dom.shape)
self._matrix_shape = (self._cod_size, self._dom_size)
self._A2 = self.A.reshape(self._matrix_shape)
dtype = self.ops.get_dtype(self.A)
is_complex = self.ops.is_complex_dtype(dtype)
self._A2T = self._A2.T
self._A2H = self._A2.T.conj() if is_complex else self._A2.T
self._dom_is_flat = tuple(self.dom.shape) == (self._dom_size,)
self._cod_is_flat = tuple(self.cod.shape) == (self._cod_size,)
self._mode = self._select_mode()
if self._mode is _DenseMode.WEIGHTED_FUSED:
self._dom_weights = self.dom.geometry.weights
self._cod_weights = self.cod.geometry.weights
self._weighted_A2H = (
self._A2H * self._cod_weights.reshape((1, self._cod_size))
) / self._dom_weights.reshape((self._dom_size, 1))
def _select_mode(self) -> _DenseMode:
"""Select the dense computation mode once for this operator."""
vector_dom = type(self.dom) in (
DenseCoordinateSpace,
DenseVectorSpace,
ElementwiseJordanSpace,
)
vector_cod = type(self.cod) in (
DenseCoordinateSpace,
DenseVectorSpace,
ElementwiseJordanSpace,
)
if (
vector_dom
and vector_cod
and self._dom_is_flat
and self._cod_is_flat
and type(self.dom.geometry) is WeightedInnerProduct
and type(self.cod.geometry) is WeightedInnerProduct
):
return _DenseMode.WEIGHTED_FUSED
if (
vector_dom
and vector_cod
and cast(Any, self.domain).is_euclidean
and cast(Any, self.codomain).is_euclidean
):
if self._dom_is_flat and self._cod_is_flat:
return _DenseMode.EUCLIDEAN_FLAT
return _DenseMode.EUCLIDEAN_TENSOR
return _DenseMode.GENERAL_METRIC
@cached_property
def A(self) -> DenseArray:
"""
Stored dense tensor representation of this operator.
The returned array has shape ``self.codomain.shape + self.domain.shape``
and is the same object supplied at construction.
"""
return self._A
[docs]
@checked_method(in_space="domain", out_space="codomain")
def apply(self, x: DenseArray) -> DenseArray:
"""Apply the dense operator to ``x``."""
return self._apply_core(x)
[docs]
@checked_method(in_space="codomain", out_space="domain")
def rapply(self, y: DenseArray) -> DenseArray:
r"""Apply the adjoint dense operator to ``y``.
Euclidean spaces use the conjugate transpose of the flattened matrix.
Non-Euclidean spaces apply the codomain and domain Riesz maps around
that Euclidean adjoint.
"""
return self._rapply_core(y)
[docs]
@checked_method(in_space="domain", in_batched=True)
def vapply(self, xs: DenseArray) -> DenseArray:
"""Apply over a leading batch axis. Input must have shape ``(N,) + domain.shape``; use ``moveaxis`` for other layouts."""
return self._vapply_core(xs)
[docs]
@checked_method(in_space="codomain", out_space="domain", in_batched=True, out_batched=True)
def rvapply(self, ys: DenseArray) -> DenseArray:
"""Apply the adjoint over a leading batch axis. Input must have shape ``(N,) + codomain.shape``; use ``moveaxis`` for other layouts."""
return self._rvapply_core(ys)
[docs]
def to_dense(self) -> DenseArray:
"""
Return the stored dense tensor representation of this operator.
The returned array has shape ``self.codomain.shape + self.domain.shape``.
"""
return self.A
[docs]
def to_matrix(self) -> DenseArray:
"""
Return the flattened dense matrix representation.
The returned array has shape
``(prod(self.codomain.shape), prod(self.domain.shape))``.
It is a reshape/view of the stored dense tensor when the backend permits.
"""
return self._A2
[docs]
def is_hermitian(self) -> bool | None:
"""
Return whether this dense operator is structurally self-adjoint.
Returns
-------
bool or None
``True`` or ``False`` when the structure can be checked, otherwise
``None``.
"""
if self.dom != self.cod:
return False
if not (
cast(Any, self.domain).is_euclidean
and cast(Any, self.codomain).is_euclidean
):
return _metric_is_hermitian_by_basis(self)
try:
return bool(self.ops.allclose(self._A2, self._A2H))
except Exception:
return None
def __eq__(self, other: Any) -> bool:
"""Return whether another dense operator has the same spaces and values."""
if not self._eq_backend_compatible(other): # Tier 1: backend
return NotImplemented
if self.dom != other.dom or self.cod != other.cod: # Tier 2: spaces before allclose
return False
return bool(self.ops.allclose(self.A, other.A, equal_nan=True)) # Tier 3: values
def tree_flatten(self):
"""Flatten this operator for pytree registration."""
aux = (self.dom, self.cod, self.ctx)
children = (self.A,)
return children, aux
@classmethod
def tree_unflatten(cls, aux, children):
"""Rebuild this operator from pytree data."""
dom, cod, ctx = aux
A = children[0]
return cls(A, dom, cod, ctx)
def _convert(self, new_ctx: Context) -> DenseLinOp:
"""Convert spaces and stored dense tensor to ``new_ctx``."""
new_dom = self.dom.convert(new_ctx)
new_cod = self.cod.convert(new_ctx)
new_A = new_ctx.asarray(self.A)
return DenseLinOp(new_A, new_dom, new_cod, new_ctx)