Source code for spacecore.functional._base

from __future__ import annotations

from abc import abstractmethod
from typing import TYPE_CHECKING, Any, Generic, Self, TypeVar

from .._batching import _leading_batch_size, _warn_vmap_fallback_once

# Re-exported for backward compatibility; these helpers now live in
# :mod:`spacecore._batching` and are shared with :class:`~spacecore.linop.LinOp`.
from .._batching import (  # noqa: F401
    _VMAP_FALLBACK_WARN_BATCH,
    _VMAP_FALLBACK_WARNED,
    _check_scalar_shape,
)
from .._checks import checked_method
from .._repr import describe_space, field_symbol
from .._contextual import ContextBound
from ..backend import Context
from ..space import CoordinateSpace

if TYPE_CHECKING:
    from ..linop import LinOp


Domain = TypeVar("Domain", bound=CoordinateSpace)


[docs] class Functional(ContextBound, Generic[Domain]): r""" Scalar-valued map on a space. ``Functional`` represents a map ``F : X -> K`` without assuming any storage model. It mirrors the minimal ``LinOp`` contract: the domain is converted into the resolved context, value checks follow ``ctx.check_level``, and batched evaluation is implemented by a backend ``vmap`` fallback. Parameters ---------- dom : Space Domain space ``X``. ctx : Context, str, or None, optional Backend context specification. Default is resolved from ``dom``. Attributes ---------- dom : Space Domain space converted to ``ctx``. ctx : Context Resolved backend context. """ def __init__(self, dom: Domain, ctx: Context | str | None = None) -> None: (self.dom,) = self._bind_context(ctx, dom) @property def domain(self) -> Domain: """Domain space of this scalar-valued map.""" return self.dom
[docs] @abstractmethod def value(self, x: Any) -> Any: """Evaluate this functional at an element of ``self.domain``."""
[docs] def grad(self, x: Any) -> Any: """Gradient at an element of ``self.domain``. Override in subclasses that support differentiation; the base raises :class:`NotImplementedError`. """ raise NotImplementedError(f"{type(self).__name__} does not implement grad().")
[docs] def vgrad(self, xs: Any) -> Any: """Gradient over a leading batch axis. Override in subclasses that support differentiation; the base raises :class:`NotImplementedError`. """ raise NotImplementedError(f"{type(self).__name__} does not implement vgrad().")
def _value_core(self, x: Any) -> Any: """Check-free value core; the base falls back to the checked ``value``.""" return self.value(x) def _grad_core(self, x: Any) -> Any: """Check-free gradient core; the base falls back to the checked ``grad``.""" return self.grad(x) def _vvalue_core(self, xs: Any) -> Any: """Check-free batched-value core; the base falls back to ``vvalue``.""" return self.vvalue(xs) def _vgrad_core(self, xs: Any) -> Any: """Check-free batched-gradient core; the base falls back to ``vgrad``.""" return self.vgrad(xs) def __call__(self, x: Any) -> Any: """Evaluate this functional at ``x``.""" return self.value(x)
[docs] def compose(self, A: "LinOp") -> "Functional": """ Return the pull-back ``self o A``. Parameters ---------- A : LinOp Linear operator whose codomain matches this functional's domain. Returns ------- Functional Functional on ``A.domain`` evaluating ``self.value(A.apply(x))``. """ from ._composed import make_functional_composed return make_functional_composed(self, A)
[docs] @checked_method(in_space="domain", in_batched=True) def vvalue(self, xs: Any) -> Any: """Evaluate over a leading batch axis. Input must have shape ``(N,) + domain.shape``; use ``moveaxis`` for other layouts.""" _warn_vmap_fallback_once(self, "vvalue", _leading_batch_size(self.domain, xs)) return self.ops.vmap(self.value, in_axes=0, out_axes=0)(xs)
[docs] def assert_domain(self, x: Any) -> None: """Raise if ``x`` is not in the domain.""" self.dom.check_member(x)
def __eq__(self, other: Any) -> bool: """Return structural equality when implemented by a subclass. Mirrors :class:`~spacecore.linop.LinOp`: the base provides no algebraic equality and returns ``NotImplemented`` so Python tries the reflected comparison and falls back to identity symmetrically. """ return NotImplemented def _arrow(self) -> str: """Return the ``domain → scalar-field`` descriptor for this functional.""" try: codomain = field_symbol(self.dom.field) except Exception: codomain = "?" return f"{describe_space(self.dom)}{codomain}" def _repr_body(self) -> str: return self._arrow() def _short_repr(self) -> str: """Return a bounded ``ClassName(domain → field)`` form for nesting.""" return f"{type(self).__name__}({self._arrow()})" @abstractmethod def tree_flatten(self) -> tuple[tuple[Any, ...], Any]: """Flatten this functional for pytree registration.""" ... @classmethod @abstractmethod def tree_unflatten(cls, aux: Any, children: Any) -> Self: """Rebuild this functional from pytree data.""" ...