Source code for spacecore.functional.tools._norms

"""Norm functionals: squared L2 energy and coordinate ``p``-norms (ADR-019)."""
from __future__ import annotations

import math
from typing import Any, cast

from .._base import Domain
from ...backend import Context, jax_pytree_class
from ..._checks import checked_method
from ._coordinate import _CoordinateFunctional, _inner_core, lp_coordinate_grad, lp_value


[docs] @jax_pytree_class class SquaredL2NormFunctional(_CoordinateFunctional[Domain]): r""" Half the squared space norm ``F(x) = 1/2 ||x||_X^2 = 1/2 <x, x>_X``. This is the smooth quadratic energy whose Riesz gradient is ``x`` and whose proximal operator is the clean shrinkage ``v / (1 + t)`` (see :func:`~spacecore.functional.prox_l2sq`). It is intentionally distinct from ``LpNormFunctional(X, 2)``, which is the *un-squared* coordinate 2-norm ``(sum_i |x_i|^2)^{1/2}``. Parameters ---------- dom : Space Domain space ``X``. ctx : Context, str, or None, optional Backend context specification. Default is resolved from ``dom``. Examples -------- >>> import numpy as np >>> import spacecore as sc >>> ctx = sc.Context(sc.NumpyOps(), dtype=np.float64) >>> X = sc.DenseCoordinateSpace((2,), ctx) >>> f = sc.SquaredL2NormFunctional(X) >>> float(f.value(ctx.asarray([3.0, 4.0]))) 12.5 >>> np.asarray(f.grad(ctx.asarray([3.0, 4.0]))) array([3., 4.]) """ def __init__(self, dom: Domain, ctx: Context | str | None = None) -> None: super().__init__(dom, ctx)
[docs] @checked_method(in_space="domain") def value(self, x: Any) -> Any: """Return ``1/2 <x, x>_X`` as a real scalar.""" return 0.5 * self.ops.real(_inner_core(self.domain, x, x))
[docs] @checked_method(in_space="domain", out_space="domain") def grad(self, x: Any) -> Any: """Return the Riesz gradient ``x``.""" return x
def _coordinate_grad(self, x: Any) -> Any: """Euclidean coordinate gradient ``G x`` of ``1/2 <x, x>_X``.""" return cast(Any, self.domain).riesz(x) def tree_flatten(self): """Flatten this functional for pytree registration.""" return (), (self.domain, self.ctx) @classmethod def tree_unflatten(cls, aux, children): """Rebuild this functional from pytree data.""" domain, ctx = aux return cls(domain, ctx) def _convert(self, new_ctx: Context) -> "SquaredL2NormFunctional": """Convert this functional to ``new_ctx``.""" return SquaredL2NormFunctional(self.domain.convert(new_ctx), new_ctx)
[docs] @jax_pytree_class class LpNormFunctional(_CoordinateFunctional[Domain]): r""" Coordinate ``p``-norm ``F(x) = (sum_i |x_i|^p)^{1/p}`` for ``p >= 1``. Parameters ---------- dom : Space Domain space ``X``. p : float Norm order; must be finite and ``>= 1``. ctx : Context, str, or None, optional Backend context specification. Default is resolved from ``dom``. Notes ----- The gradient at ``x != 0`` is ``d/dx_i ||x||_p = sign(x_i) |x_i|^{p-1} / ||x||_p^{p-1}`` (Riesz-corrected). At the origin the function is not differentiable; this returns the zero subgradient there. For ``p = 1`` the gradient is ``sign(x)``. Examples -------- >>> import numpy as np >>> import spacecore as sc >>> ctx = sc.Context(sc.NumpyOps(), dtype=np.float64) >>> X = sc.DenseCoordinateSpace((3,), ctx) >>> f = sc.LpNormFunctional(X, 1) >>> float(f.value(ctx.asarray([1.0, -2.0, 3.0]))) 6.0 """ def __init__(self, dom: Domain, p: Any, ctx: Context | str | None = None) -> None: super().__init__(dom, ctx) p = float(p) if not math.isfinite(p) or p < 1.0: raise ValueError(f"LpNormFunctional requires a finite p >= 1, got {p}.") self.p = p
[docs] @checked_method(in_space="domain") def value(self, x: Any) -> Any: """Return ``(sum_i |x_i|^p)^{1/p}``.""" return lp_value(self.ops, x, self.p)
def _coordinate_grad(self, x: Any) -> Any: """Euclidean coordinate gradient of the ``p``-norm (zero at the origin).""" return lp_coordinate_grad(self.ops, x, self.p) def tree_flatten(self): """Flatten this functional for pytree registration.""" return (), (self.domain, self.p, self.ctx) @classmethod def tree_unflatten(cls, aux, children): """Rebuild this functional from pytree data.""" domain, p, ctx = aux return cls(domain, p, ctx) def _convert(self, new_ctx: Context) -> "LpNormFunctional": """Convert this functional to ``new_ctx``.""" return LpNormFunctional(self.domain.convert(new_ctx), self.p, new_ctx)
[docs] def L1NormFunctional( dom: Domain, ctx: Context | str | None = None ) -> "LpNormFunctional[Domain]": r""" Coordinate 1-norm ``||x||_1`` -- a thin wrapper for ``LpNormFunctional(X, 1)``. Parameters ---------- dom : Space Domain space ``X``. ctx : Context, str, or None, optional Backend context specification. Default is resolved from ``dom``. Returns ------- LpNormFunctional The ``p = 1`` instance of :class:`LpNormFunctional`. """ return LpNormFunctional(dom, 1.0, ctx)