Source code for spacecore.functional._linear

from __future__ import annotations

from abc import abstractmethod
from typing import Any, Callable

from ._base import Domain, Functional
from .._batching import _check_scalar_shape, _leading_batch_size
from .._checks import checked_method
from ..backend import Context, jax_pytree_class
from ..kernels import core_kernels
from ..space import Space, TreeElement, TreeSpace


def _convert_space_element(space: Space, value: Any) -> Any:
    """Convert a value recursively into a possibly tree-valued space."""
    if isinstance(space, TreeSpace):
        if isinstance(value, TreeElement):
            source_spaces = value.space.leaf_spaces
            leaves = value.leaves
            converted = tuple(
                target.unflatten(space.ctx.asarray(source.flatten(leaf)))
                for source, target, leaf in zip(source_spaces, space.leaf_spaces, leaves)
            )
        else:
            leaves = space.flatten_tree(value)
            converted = tuple(
                leaf_space.ctx.asarray(leaf)
                for leaf_space, leaf in zip(space.leaf_spaces, leaves)
            )
        return space.unflatten_tree(converted)
    return space.ctx.asarray(value)


[docs] @core_kernels("functional-linear") class LinearFunctional(Functional[Domain]): r""" Represent a linear scalar-valued map. Parameters ---------- dom : Space Domain space. ctx : Context, str, or None, optional Backend context specification. Default is resolved from ``dom``. """ @property @abstractmethod def representer(self) -> Any: """ Riesz representer of this functional when one is explicitly available. Matrix-free functionals may not have a stored representer and should raise ``NotImplementedError``. """
[docs] @checked_method(in_space="domain", out_space="domain") def grad(self, x: Any) -> Any: """ Return the constant Riesz gradient of this linear functional. For ``ell(x) = <c, x>_X``, the gradient is the space element ``c``. Matrix-free functionals without a stored representer inherit the ``NotImplementedError`` raised by :attr:`representer`. """ return self._grad_core(x)
[docs] @checked_method(in_space="domain", out_space="domain", in_batched=True, out_batched=True) def vgrad(self, xs: Any) -> Any: """Return the constant Riesz gradient over a leading batch axis.""" return self._vgrad_core(xs)
[docs] @core_kernels("inner-product-functional") @jax_pytree_class class InnerProductFunctional(LinearFunctional[Domain]): r""" Linear functional represented by a domain element. ``InnerProductFunctional(c, X)`` evaluates :math:`\ell_c(x) = \langle c, x\rangle_X`. Parameters ---------- c : array-like Riesz representer in ``dom``. dom : Space Domain space. ctx : Context, str, or None, optional Backend context specification. Default is resolved from ``dom``. Attributes ---------- representer : array-like Stored domain element ``c``. """ def __init__( self, c: Any, dom: Domain, ctx: Context | str | None = None, ) -> None: super().__init__(dom, ctx) self._c = _convert_space_element(self.domain, c) if self._checks_at_least("standard"): self.domain._check_member(self._c) @property def representer(self) -> Any: """Stored domain element ``c`` defining ``ell_c(x) = <c, x>``.""" return self._c
[docs] @checked_method(in_space="domain") def value(self, x: Any) -> Any: """Return ``domain.inner(representer, x)``.""" return self._value_core(x)
[docs] @checked_method(in_space="domain", in_batched=True) def vvalue(self, xs: Any) -> Any: """Evaluate ``domain.inner(representer, xs[i])`` without a Python loop.""" values = self._vvalue_core(xs) if self._checks_at_least("standard"): _check_scalar_shape(values, (_leading_batch_size(self.domain, xs),)) return values
def __eq__(self, other: Any) -> bool: """Return whether another inner-product functional has the same representer.""" if not self._eq_backend_compatible(other): # Tier 1: backend return NotImplemented if self.domain != other.domain: # Tier 2: domain before allclose return False return bool(self.ops.allclose( # Tier 3: representer self.domain.flatten(self._c), other.domain.flatten(other._c), equal_nan=True, )) def tree_flatten(self): """Flatten this functional for pytree registration.""" children = (self._c,) aux = (self.domain, self.ctx) return children, aux @classmethod def tree_unflatten(cls, aux, children): """Rebuild this functional from pytree data.""" domain, ctx = aux c = children[0] return cls(c, domain, ctx) def _convert(self, new_ctx: Context) -> InnerProductFunctional: """Convert the domain and representer to ``new_ctx``.""" return InnerProductFunctional(self._c, self.domain.convert(new_ctx), new_ctx)
[docs] @core_kernels("matrixfree-linear-functional") @jax_pytree_class class MatrixFreeLinearFunctional(LinearFunctional[Domain]): """ Linear functional defined by user-supplied evaluation callables. ``MatrixFreeLinearFunctional(value, X)`` represents a linear scalar-valued map on ``X`` without storing or materializing a Riesz representer. Parameters ---------- value : callable Callable with signature ``value(x: Any) -> Any`` accepting an element of ``dom`` and returning a scalar-like backend value. dom : Space Domain space of the functional. ctx : Context, str, or None, optional Optional context specification. An explicit context wins over inferred and default contexts. vvalue : callable or None, optional Optional callable with signature ``vvalue(xs: Any) -> Any`` for batched evaluation. If omitted, backend ``vmap`` fallback is used. Returns ------- MatrixFreeLinearFunctional Functional using the supplied callable for scalar evaluation and, optionally, batched scalar evaluation. """
[docs] def __init__( self, value: Callable[[Any], Any], dom: Domain, ctx: Context | str | None = None, vvalue: Callable[[Any], Any] | None = None, ) -> None: """ Initialize a matrix-free linear functional. Parameters ---------- value: Callable ``value(x)`` accepting an element of ``dom`` and returning a scalar-like value. dom: Domain space of the functional. ctx: Optional context specification for the functional and converted domain. vvalue: Optional callable ``vvalue(xs)`` accepting a batch of domain elements and returning a batch of scalar-like values. Returns ------- None The initializer stores the callables and converted domain on ``self``. """ if not callable(value): raise TypeError(f"value must be callable, got {type(value).__name__}.") if vvalue is not None and not callable(vvalue): raise TypeError(f"vvalue must be callable, got {type(vvalue).__name__}.") super().__init__(dom, ctx) self.value_fn = value self.vvalue_fn = vvalue
@property def representer(self) -> Any: """ Raise because matrix-free functionals do not store a representer. Parameters ---------- None Returns ------- Any This property never returns; it raises ``NotImplementedError``. """ raise NotImplementedError(f"{type(self).__name__} does not store a Riesz representer.")
[docs] @checked_method(in_space="domain") def value(self, x: Any) -> Any: """ Evaluate the scalar functional. Parameters ---------- x: Element of ``self.domain`` passed to ``value_fn``. Returns ------- Any Scalar-like backend value returned by ``value_fn``. """ y = self._value_core(x) if self._checks_at_least("standard"): _check_scalar_shape(y, ()) return y
[docs] @checked_method(in_space="domain", in_batched=True) def vvalue(self, xs: Any) -> Any: """ Evaluate the scalar functional over a batch of domain elements. Parameters ---------- xs: Batched element of ``self.domain``. Returns ------- Any Backend array of scalar-like values with shape matching the leading batch shape. """ if self.vvalue_fn is None: return super().vvalue(xs) values = self.vvalue_fn(xs) if self._checks_at_least("standard"): shape = tuple(getattr(xs, "shape", ())) base = tuple(self.domain.shape) leading = shape if not base else shape[: len(shape) - len(base)] _check_scalar_shape(values, leading) return values
def __eq__(self, other: Any) -> bool: """Return whether another matrix-free functional uses the same callables.""" if not self._eq_backend_compatible(other): # Tier 1: backend return NotImplemented if self.domain != other.domain: # Tier 2: domain return False # Callable identity: extensional equality of callables is undecidable. return self.value_fn is other.value_fn and self.vvalue_fn is other.vvalue_fn def tree_flatten(self): """Flatten this functional for pytree registration.""" children = () aux = (self.value_fn, self.domain, self.ctx, self.vvalue_fn) return children, aux @classmethod def tree_unflatten(cls, aux, children): """Rebuild this functional from pytree data.""" value_fn, domain, ctx, vvalue_fn = aux return cls(value_fn, domain, ctx, vvalue_fn) def _convert(self, new_ctx: Context) -> MatrixFreeLinearFunctional: """ Convert this functional to ``new_ctx``. Parameters ---------- new_ctx: Concrete target context for the converted domain. Returns ------- MatrixFreeLinearFunctional Functional with converted domain and the same user-supplied callables. """ return MatrixFreeLinearFunctional( self.value_fn, self.domain.convert(new_ctx), new_ctx, self.vvalue_fn, )