Source code for spacecore.functional._composed
from __future__ import annotations
from typing import Any, cast
from ._base import Functional
from ._linear import InnerProductFunctional
from ._quadratic import LinOpQuadraticForm
from .._checks import checked_method
from ..backend import Context, jax_pytree_class
from ..kernels import core_kernels
from ..linop import LinOp
def _require_composable(F: Functional, A: LinOp) -> None:
"""Raise unless ``F`` can be composed with ``A``."""
if not isinstance(F, Functional):
raise TypeError(f"F must be a Functional, got {type(F).__name__}.")
if not isinstance(A, LinOp):
raise TypeError(f"A must be a LinOp, got {type(A).__name__}.")
if A.codomain != F.domain:
raise ValueError(
"Functional composition requires A.codomain == F.domain; "
f"got {A.codomain!r} and {F.domain!r}."
)
[docs]
def make_functional_composed(F: Functional, A: LinOp) -> Functional:
"""
Return the pull-back ``F o A`` with local specializations.
Parameters
----------
F : Functional
Functional defined on ``A.codomain``.
A : LinOp
Linear operator whose codomain is ``F.domain``.
Returns
-------
Functional
Specialized pull-back when available, otherwise
:class:`ComposedFunctional`.
"""
_require_composable(F, A)
if isinstance(F, InnerProductFunctional):
return InnerProductFunctional(A.H.apply(F.representer), A.domain, A.ctx)
if isinstance(F, LinOpQuadraticForm):
Q = A.H @ F.Q @ A
linear = None if F.linear is None else cast(Any, F.linear.compose(A))
return LinOpQuadraticForm(Q, linear, F.a, A.ctx)
return ComposedFunctional(F, A)
[docs]
@core_kernels("composed-functional")
@jax_pytree_class
class ComposedFunctional(Functional):
"""
Generic pull-back of a functional through a linear operator.
``ComposedFunctional(F, A)`` represents ``x -> F(A x)`` on ``A.domain``.
Parameters
----------
F : Functional
Functional defined on ``A.codomain``.
A : LinOp
Linear operator whose codomain is ``F.domain``.
"""
def __init__(self, F: Functional, A: LinOp) -> None:
_require_composable(F, A)
super().__init__(A.domain, A.ctx)
self.F = F.convert(A.ctx)
self.A = A
[docs]
@checked_method(in_space="domain")
def value(self, x: Any) -> Any:
"""
Evaluate ``F(A x)``.
Parameters
----------
x:
Element of ``A.domain``.
Returns
-------
Any
Scalar-like value returned by the composed functional.
"""
return self._value_core(x)
def __eq__(self, other: Any) -> bool:
"""Return whether another composed functional has the same operands."""
if not self._eq_backend_compatible(other): # Tier 1: backend
return NotImplemented
return self.F == other.F and self.A == other.A
def _repr_body(self) -> str:
return f"{self.F._short_repr()} ∘ {self.A._short_repr()}"
def tree_flatten(self):
"""Flatten this functional for pytree registration."""
children = (self.F, self.A)
aux = ()
return children, aux
@classmethod
def tree_unflatten(cls, aux, children):
"""Rebuild this functional from pytree data."""
F, A = children
return cls(F, A)
def _convert(self, new_ctx: Context) -> ComposedFunctional:
"""Convert the composed functional and operator to ``new_ctx``."""
return ComposedFunctional(self.F.convert(new_ctx), self.A.convert(new_ctx))