Source code for spacecore.space.concrete._stacked

from __future__ import annotations

from typing import TYPE_CHECKING, Any, cast

from ..base import (
    CoordinateSpace,
    EuclideanJordanAlgebraSpace,
    InnerProductSpace,
    JordanAlgebraSpace,
    Space,
    StarSpace,
)
from ._tree_space import TreeSpace, _space_capabilities
from ..._checks import checked_method
from ..._contextual import resolve_context_priority
from ...backend import BackendOps, Context, jax_pytree_class
from ...types import DenseArray
from ..checks import BackendCheck, DTypeCheck, FieldCheck, ShapeCheck

_STACKED_FALLBACK_ERRORS = (TypeError, ValueError, AttributeError, IndexError)

_CAP_INNER = InnerProductSpace
_CAP_STAR = StarSpace
_CAP_JORDAN = JordanAlgebraSpace
_CAP_EUCLIDEAN_JORDAN = EuclideanJordanAlgebraSpace

_STACKED_REGISTRY: dict[frozenset[type], type[StackedSpace]] = {}


def _validate_stacked_base(base: Any, owner: str = "StackedSpace") -> CoordinateSpace:
    """Validate a stacked-space base before capability-specific access."""
    if not isinstance(base, CoordinateSpace):
        raise TypeError(
            f"{owner} requires base to be a CoordinateSpace; base is {type(base).__name__}."
        )
    if isinstance(base, TreeSpace):
        raise TypeError(
            "StackedSpace cannot wrap TreeSpace directly; use "
            "tree_space.stacked(count), which stacks each leaf."
        )
    return base


def _validate_count(count: int, owner: str = "StackedSpace") -> int:
    """Validate and normalize stacked copy count."""
    count = int(count)
    if count < 0:
        raise ValueError(f"{owner} count must be nonnegative.")
    return count


def _stacked_capabilities(base: Space) -> frozenset[type]:
    """Return capabilities copied from the stacked base space."""
    return _space_capabilities(base)


def _stacked_class_for(capabilities: frozenset[type]) -> type[StackedSpace]:
    """Return the deterministic concrete class for a stacked capability set."""
    return _STACKED_REGISTRY.get(capabilities, StackedSpace)


def _require_base(base: Space, capability: type, owner: str) -> None:
    """Raise if ``base`` lacks ``capability``."""
    if not isinstance(base, capability):
        raise TypeError(
            f"{owner} requires base to be a {capability.__name__}; base is {type(base).__name__}."
        )


[docs] @jax_pytree_class class StackedSpace(CoordinateSpace): """ Leading-axis copies of a coordinate leaf space. Baseline ``StackedSpace`` exposes only coordinate-space operations. Direct construction dispatches to a more specific class that preserves the base space's inner-product, star, Jordan, and Euclidean-Jordan capabilities. Parameters ---------- base : Space Coordinate leaf space to repeat along a leading axis. count : int Number of leading-axis copies. ctx : Context, str, or None, optional Context specification. If omitted, the context is resolved from ``base``. """ def __new__(cls, base: Space, count: int, ctx: Context | str | None = None): if cls is StackedSpace: base = _validate_stacked_base(base) _validate_count(count) resolved_ctx = resolve_context_priority(ctx, base) cls = _stacked_class_for(_stacked_capabilities(base.convert(resolved_ctx))) return super(StackedSpace, cls).__new__(cls) def __init__(self, base: Space, count: int, ctx: Context | str | None = None) -> None: base = _validate_stacked_base(base, type(self).__name__) count = _validate_count(count, type(self).__name__) ctx = resolve_context_priority(ctx, base) self.base = base.convert(ctx) self.count = count super().__init__((self.count,) + tuple(self.base.shape), ctx) def _eq_algebra(self, other: Any) -> bool: # Tier 2: count + base. ``base == other.base`` is load-bearing — the # __new__ capability dispatch can map different bases onto the same # private subclass, so the type-identity gate alone is not sufficient. return super()._eq_algebra(other) and self.count == other.count and self.base == other.base def _repr_class_name(self) -> str: """Present the public ``StackedSpace`` label, not the private dispatch subclass.""" return "StackedSpace" def _space_descriptor(self) -> str: """Return ``count×<base descriptor>`` (e.g. ``8×ℝ^3``).""" from ..._repr import describe_space return f"{self.count}×{describe_space(self.base)}" def _local_checks(self): """Return membership checks local to stacked dense coordinate spaces.""" return BackendCheck(), ShapeCheck(), FieldCheck(), DTypeCheck()
[docs] def zeros(self) -> DenseArray: """Return the stacked zero element.""" return self.ops.zeros(self.shape, dtype=self.dtype)
[docs] def ones(self) -> DenseArray: """Return the stacked all-ones element.""" return self.ops.ones(self.shape, dtype=self.dtype)
[docs] @checked_method(in_space="self", arg_positions=(0, 1)) def add(self, x: Any, y: Any) -> DenseArray: """Return the stacked sum ``x + y``.""" return x + y
[docs] def add_batch(self, x: Any, y: Any) -> DenseArray: """Return the leading-axis batch sum of stacked elements.""" return x + y
[docs] @checked_method(in_space="self", arg_positions=(1,)) def scale(self, a: Any, x: Any) -> DenseArray: """Return the stacked scalar product ``a * x``.""" return a * x
[docs] def scale_batch(self, a: Any, x: Any) -> DenseArray: """Return the leading-axis batch scalar product of stacked elements.""" return a * x
[docs] @checked_method(in_space="self") def flatten(self, x: Any) -> DenseArray: """Flatten the whole stacked element to one coordinate vector.""" return x.reshape((-1,))
[docs] def unflatten(self, v: DenseArray) -> DenseArray: """Unflatten one coordinate vector to the stacked element shape.""" v = self._coerce_dense(v) return v.reshape(self.shape)
[docs] def flatten_batch(self, xs: DenseArray) -> DenseArray: """Flatten a batch of stacked elements to ``(N, count * base.size)``.""" xs = self._coerce_dense(xs) return xs.reshape((xs.shape[0], -1))
[docs] def unflatten_batch(self, vs: DenseArray) -> DenseArray: """Unflatten rows to a batch of stacked elements.""" vs = self._coerce_dense(vs) return vs.reshape((vs.shape[0],) + self.shape)
def _convert(self, new_ctx: Context) -> StackedSpace: """Convert the base space and rebuild the stacked space.""" return StackedSpace(self.base.convert(new_ctx), self.count, new_ctx)
[docs] def stacked(self, count: int) -> StackedSpace: """Return a flattened stack of this stack: ``base.stacked(count * k)``.""" return StackedSpace(self.base, self.count * int(count), self.ctx)
def tree_flatten(self): """Flatten this space for JAX pytree registration.""" return (), (self.base, self.count, self.ctx) @classmethod def tree_unflatten(cls, aux, children): """Rebuild this space from pytree aux data.""" base, count, ctx = aux return cls(base, count, ctx)
class _StackedInnerProductMixin: """Inner-product operations for stacks whose base supports them.""" if TYPE_CHECKING: # Provided by the StackedSpace host this mixin is combined with; narrowed # to the relevant capability per method (see ``cast`` calls below). @property def base(self) -> Space: ... @property def ops(self) -> BackendOps: ... @checked_method(in_space="self", arg_positions=(0, 1)) def inner(self, x: Any, y: Any) -> Any: """Return ``sum_i base.inner(x[i], y[i])`` as a scalar.""" base = cast(InnerProductSpace, self.base) if base.is_euclidean: return self.ops.vdot(x, y) try: y_dual = base.riesz(y) return self.ops.vdot(x, y_dual) except _STACKED_FALLBACK_ERRORS: values = self.ops.vmap(base.inner, in_axes=(0, 0), out_axes=0)(x, y) return self.ops.sum(values) def riesz(self, x: Any) -> Any: """Apply the base Riesz map to every stacked copy.""" base = cast(InnerProductSpace, self.base) if base.is_euclidean: return x try: return base.riesz(x) except _STACKED_FALLBACK_ERRORS: return self.ops.vmap(base.riesz, in_axes=0, out_axes=0)(x) def riesz_inverse(self, x: Any) -> Any: """Apply the base inverse Riesz map to every stacked copy.""" base = cast(InnerProductSpace, self.base) if base.is_euclidean: return x try: return base.riesz_inverse(x) except _STACKED_FALLBACK_ERRORS: return self.ops.vmap(base.riesz_inverse, in_axes=0, out_axes=0)(x) @property def is_euclidean(self) -> bool: """Return whether the base geometry is Euclidean.""" return cast(InnerProductSpace, self.base).is_euclidean class _StackedStarMixin: """Star operation for stacks whose base supports it.""" if TYPE_CHECKING: # Provided by the StackedSpace host this mixin is combined with. @property def base(self) -> Space: ... @property def ops(self) -> BackendOps: ... def star(self, x: Any) -> Any: """Return the base-space star operation for each stacked copy.""" base = cast(StarSpace, self.base) try: return base.star(x) except _STACKED_FALLBACK_ERRORS: return self.ops.vmap(base.star, in_axes=0, out_axes=0)(x) class _StackedJordanMixin: """Jordan operations for stacks whose base supports them.""" if TYPE_CHECKING: # Provided by the StackedSpace host this mixin is combined with. @property def base(self) -> Space: ... @property def ops(self) -> BackendOps: ... @checked_method(in_space="self", arg_positions=(0, 1)) def jordan(self, x: Any, y: Any) -> Any: """Return the base-space Jordan product for each stacked copy.""" base = cast(JordanAlgebraSpace, self.base) try: return base.jordan(x, y) except _STACKED_FALLBACK_ERRORS: return self.ops.vmap(base.jordan, in_axes=(0, 0), out_axes=0)(x, y) def spectrum(self, x: Any) -> Any: """Return spectra for each leading-axis copy of the base space.""" base = cast(JordanAlgebraSpace, self.base) try: return base.spectrum(x) except _STACKED_FALLBACK_ERRORS: return self.ops.vmap(base.spectrum, in_axes=0, out_axes=0)(x) def spectral_decompose(self, x: Any) -> Any: """Return spectral decompositions for each leading-axis copy.""" base = cast(JordanAlgebraSpace, self.base) try: return base.spectral_decompose(x) except _STACKED_FALLBACK_ERRORS: return self.ops.vmap(base.spectral_decompose, in_axes=0, out_axes=0)(x) def from_spectrum(self, eigvals: Any, frame: Any) -> Any: """Reconstruct stacked elements from base spectral data.""" base = cast(JordanAlgebraSpace, self.base) try: return base.from_spectrum(eigvals, frame) except _STACKED_FALLBACK_ERRORS: return self.ops.vmap(base.from_spectrum, in_axes=(0, 0), out_axes=0)( eigvals, frame ) def spectral_apply(self, x: Any, f: Any) -> Any: """Apply the base-space spectral calculus to each stacked copy.""" base = cast(JordanAlgebraSpace, self.base) try: return base.spectral_apply(x, f) except _STACKED_FALLBACK_ERRORS: return self.ops.vmap(lambda xi: base.spectral_apply(xi, f), in_axes=0, out_axes=0)( x ) def trace(self, x: Any) -> Any: """Return the direct-sum trace: the sum of the per-copy base traces.""" base = cast(JordanAlgebraSpace, self.base) try: traces = base.trace(x) except _STACKED_FALLBACK_ERRORS: traces = self.ops.vmap(base.trace, in_axes=0, out_axes=0)(x) # Reduce only the copy axis; leading batch axes are preserved. return self.ops.sum(traces, axis=-1) def determinant(self, x: Any) -> Any: """Return the direct-sum determinant: the product of the per-copy base determinants.""" base = cast(JordanAlgebraSpace, self.base) try: dets = base.determinant(x) except _STACKED_FALLBACK_ERRORS: dets = self.ops.vmap(base.determinant, in_axes=0, out_axes=0)(x) # Reduce only the copy axis; leading batch axes are preserved. return self.ops.prod(dets, axis=-1) def unit(self) -> Any: """Return the stacked identity: ``base.unit()`` replicated across every copy.""" base = cast(JordanAlgebraSpace, self.base) return self.ops.broadcast_to(base.unit(), self.shape) @jax_pytree_class class _StackedInnerProductSpace(_StackedInnerProductMixin, StackedSpace, InnerProductSpace): """Stacked space whose base supports an inner product.""" def __init__(self, base, count, ctx=None): base = _validate_stacked_base(base, type(self).__name__) _require_base(base, InnerProductSpace, type(self).__name__) super().__init__(base, count, ctx) @jax_pytree_class class _StackedStarSpace(_StackedStarMixin, StackedSpace, StarSpace): """Stacked space whose base supports a star operation.""" def __init__(self, base, count, ctx=None): base = _validate_stacked_base(base, type(self).__name__) _require_base(base, StarSpace, type(self).__name__) super().__init__(base, count, ctx) @jax_pytree_class class _StackedJordanAlgebraSpace(_StackedJordanMixin, StackedSpace, JordanAlgebraSpace): """Stacked space whose base supports Jordan algebra operations.""" def __init__(self, base, count, ctx=None): base = _validate_stacked_base(base, type(self).__name__) _require_base(base, JordanAlgebraSpace, type(self).__name__) super().__init__(base, count, ctx) @jax_pytree_class class _StackedEuclideanJordanAlgebraSpace( _StackedInnerProductMixin, _StackedJordanMixin, StackedSpace, EuclideanJordanAlgebraSpace, ): """Stacked space whose base supports Euclidean Jordan algebra operations.""" def __init__(self, base, count, ctx=None): base = _validate_stacked_base(base, type(self).__name__) _require_base(base, EuclideanJordanAlgebraSpace, type(self).__name__) super().__init__(base, count, ctx) _require_base(self.base, EuclideanJordanAlgebraSpace, type(self).__name__) @jax_pytree_class class _StackedInnerProductStarSpace( _StackedInnerProductMixin, _StackedStarMixin, StackedSpace, InnerProductSpace, StarSpace, ): """Stacked implementation for inner-product plus star capability.""" @jax_pytree_class class _StackedInnerProductJordanSpace( _StackedInnerProductMixin, _StackedJordanMixin, StackedSpace, InnerProductSpace, JordanAlgebraSpace, ): """Stacked implementation for inner-product plus Jordan capability.""" @jax_pytree_class class _StackedStarJordanSpace( _StackedStarMixin, _StackedJordanMixin, StackedSpace, StarSpace, JordanAlgebraSpace, ): """Stacked implementation for star plus Jordan capability.""" @jax_pytree_class class _StackedInnerProductStarJordanSpace( _StackedInnerProductMixin, _StackedStarMixin, _StackedJordanMixin, StackedSpace, InnerProductSpace, StarSpace, JordanAlgebraSpace, ): """Stacked implementation for inner-product, star, and Jordan capability.""" @jax_pytree_class class _StackedEuclideanJordanStarSpace( _StackedStarMixin, _StackedEuclideanJordanAlgebraSpace, StarSpace, ): """Stacked implementation for Euclidean-Jordan plus star capability.""" _STACKED_REGISTRY.update( { frozenset(): StackedSpace, frozenset({_CAP_INNER}): _StackedInnerProductSpace, frozenset({_CAP_STAR}): _StackedStarSpace, frozenset({_CAP_JORDAN}): _StackedJordanAlgebraSpace, frozenset({_CAP_INNER, _CAP_STAR}): _StackedInnerProductStarSpace, frozenset({_CAP_INNER, _CAP_JORDAN}): _StackedInnerProductJordanSpace, frozenset({_CAP_STAR, _CAP_JORDAN}): _StackedStarJordanSpace, frozenset({_CAP_INNER, _CAP_STAR, _CAP_JORDAN}): _StackedInnerProductStarJordanSpace, frozenset( {_CAP_INNER, _CAP_JORDAN, _CAP_EUCLIDEAN_JORDAN} ): _StackedEuclideanJordanAlgebraSpace, frozenset( {_CAP_INNER, _CAP_STAR, _CAP_JORDAN, _CAP_EUCLIDEAN_JORDAN} ): _StackedEuclideanJordanStarSpace, } ) __all__ = [ "StackedSpace", "_stacked_capabilities", ]