Source code for spacecore.functional._quadratic
from __future__ import annotations
from typing import Any
from ._base import Domain, Functional
from .._batching import (
_check_scalar_shape,
_leading_batch_size,
_warn_vmap_fallback_once,
)
from ._linear import LinearFunctional
from .._checks import checked_method
from .._contextual import resolve_context_priority
from ..backend import Context, jax_pytree_class
from ..kernels import core_kernels
from ..linop import LinOp
[docs]
class QuadraticForm(Functional[Domain]):
"""
Represent a scalar quadratic objective on a space.
Parameters
----------
dom : Space
Domain space.
ctx : Context, str, or None, optional
Backend context specification. Default is resolved from ``dom``.
"""
[docs]
def hess_apply(self, x: Any) -> Any:
"""Apply the Hessian action at ``x`` when available."""
raise NotImplementedError(f"{type(self).__name__} does not define hess_apply.")
[docs]
def grad(self, x: Any) -> Any:
"""Return the gradient with respect to ``domain.inner`` when available."""
raise NotImplementedError(f"{type(self).__name__} does not define grad.")
[docs]
@checked_method(in_space="domain", out_space="domain", in_batched=True, out_batched=True)
def vgrad(self, xs: Any) -> Any:
"""Evaluate ``grad`` independently over leading batch axes."""
_warn_vmap_fallback_once(self, "vgrad", _leading_batch_size(self.domain, xs))
return self.ops.vmap(self.grad, in_axes=0, out_axes=0)(xs)
[docs]
@core_kernels("linop-quadratic-form")
@jax_pytree_class
class LinOpQuadraticForm(QuadraticForm[Domain]):
r"""
Represent a quadratic form backed by a linear operator.
Assumption:
Q is Hermitian/self-adjoint. Under this assumption,
grad f(x) = Q x.
Non-Hermitian operators are not supported here. If users need the
Hermitian part, they must construct 0.5 * (Q + Q.H) explicitly.
The full objective is ``q(x) = 1/2 * <x, Qx> + linear(x) + a`` with
``Q : X -> X``. Structurally available dense and diagonal operators are
checked at construction. Matrix-free operators are not validated; correctness
is the caller's responsibility.
Parameters
----------
Q : LinOp
Hermitian operator from a space to itself.
linear : LinearFunctional or None, optional
Optional linear term on ``Q.domain``.
a : scalar-like, optional
Constant scalar offset. Default is 0.
ctx : Context, str, or None, optional
Backend context specification. Default is resolved from ``Q`` and
``linear``.
Attributes
----------
Q : LinOp
Stored Hermitian operator.
linear : LinearFunctional or None
Stored linear term.
a : scalar-like
Stored scalar offset.
"""
def __init__(
self,
Q: LinOp[Domain, Domain],
linear: LinearFunctional[Domain] | None = None,
a: Any = 0,
ctx: Context | str | None = None,
) -> None:
if not isinstance(Q, LinOp):
raise TypeError(f"Q must be a LinOp, got {type(Q).__name__}.")
if linear is not None and not isinstance(linear, LinearFunctional):
raise TypeError(
f"linear must be a LinearFunctional or None, got {type(linear).__name__}."
)
resolved_ctx = resolve_context_priority(ctx, Q.domain, Q, linear)
Q = Q.convert(resolved_ctx)
if Q.domain != Q.codomain:
raise ValueError("LinOpQuadraticForm requires Q.domain == Q.codomain.")
self._check_hermitian_structure(Q)
if linear is not None:
linear = linear.convert(resolved_ctx)
if linear.domain != Q.domain:
raise ValueError("linear.domain must match Q.domain.")
super().__init__(Q.domain, resolved_ctx)
self.Q = Q
self.linear = linear
self.a = self.ctx.asarray(a)
_check_scalar_shape(self.a, ())
@staticmethod
def _check_hermitian_structure(Q: LinOp[Domain, Domain]) -> None:
"""Raise when ``Q`` is structurally known to be non-Hermitian."""
result = Q.is_hermitian()
if result is False:
raise ValueError("LinOpQuadraticForm requires Q to be Hermitian/self-adjoint.")
[docs]
@checked_method(in_space="domain")
def value(self, x: Any) -> Any:
"""Return ``1/2 * <x, Qx> + linear(x) + a``."""
return self._value_core(x)
[docs]
@checked_method(in_space="domain", out_space="domain")
def grad(self, x: Any) -> Any:
"""
Return the gradient with respect to ``domain.inner``.
This is the Riesz gradient: for Euclidean geometry it is the ordinary
coordinate gradient, while for non-Euclidean geometry it is corrected
by the domain inner product.
``LinOpQuadraticForm`` assumes ``Q`` is Hermitian/self-adjoint, so the
quadratic contribution is exactly ``Q.apply(x)``.
"""
return self._grad_core(x)
[docs]
@checked_method(in_space="domain", out_space="domain")
def hess_apply(self, x: Any) -> Any:
"""Return the Hessian action ``Q x`` under the Hermitian assumption."""
return self.Q.apply(x)
[docs]
@checked_method(in_space="domain", in_batched=True)
def vvalue(self, xs: Any) -> Any:
"""Evaluate the quadratic objective over a leading batch axis."""
values = self._vvalue_core(xs)
if self._checks_at_least("standard"):
_check_scalar_shape(values, (_leading_batch_size(self.domain, xs),))
return values
[docs]
@checked_method(in_space="domain", out_space="domain", in_batched=True, out_batched=True)
def vgrad(self, xs: Any) -> Any:
"""Evaluate the Riesz gradient over a leading batch axis."""
return self._vgrad_core(xs)
def __eq__(self, other: Any) -> bool:
"""Return whether another quadratic form has the same stored terms."""
if not self._eq_backend_compatible(other): # Tier 1: backend
return NotImplemented
if self.Q != other.Q: # Tier 2/3: operator (own gate)
return False
if self.linear != other.linear: # Tier 2/3: linear term (None-safe)
return False
return bool(self.ops.allclose(self.a, other.a, equal_nan=True)) # Tier 3: scalar offset
def tree_flatten(self):
"""Flatten this quadratic form for pytree registration."""
children = (self.Q, self.linear, self.a)
aux = ()
return children, aux
@classmethod
def tree_unflatten(cls, aux, children):
"""Rebuild this quadratic form from pytree data."""
Q, linear, a = children
return cls(Q, linear, a, Q.ctx)
def _convert(self, new_ctx: Context) -> LinOpQuadraticForm:
"""Convert stored terms to ``new_ctx``."""
linear = None if self.linear is None else self.linear.convert(new_ctx)
return LinOpQuadraticForm(self.Q.convert(new_ctx), linear, self.a, new_ctx)