Source code for spacecore.functional.tools._huber
"""The separable Huber loss functional (ADR-019)."""
from __future__ import annotations
import math
from typing import Any
from .._base import Domain
from ...backend import Context, jax_pytree_class
from ..._checks import checked_method
from ._coordinate import _CoordinateFunctional
[docs]
@jax_pytree_class
class HuberFunctional(_CoordinateFunctional[Domain]):
r"""
Separable Huber loss ``F(x) = sum_i h_delta(x_i)``.
The per-coordinate loss is quadratic near the origin and linear in the tails:
``h_delta(r) = 1/2 r^2`` for ``|r| <= delta`` and
``delta (|r| - delta/2)`` otherwise. It is everywhere differentiable, with
gradient ``r`` in the quadratic region and ``delta sign(r)`` in the tails.
Parameters
----------
dom : Space
Domain space ``X``.
delta : float
Transition threshold; must be finite and ``> 0``.
ctx : Context, str, or None, optional
Backend context specification. Default is resolved from ``dom``.
Examples
--------
>>> import numpy as np
>>> import spacecore as sc
>>> ctx = sc.Context(sc.NumpyOps(), dtype=np.float64)
>>> X = sc.DenseCoordinateSpace((2,), ctx)
>>> f = sc.HuberFunctional(X, 1.0)
>>> float(f.value(ctx.asarray([0.5, 3.0]))) # 0.125 + (3 - 0.5)
2.625
"""
def __init__(self, dom: Domain, delta: Any, ctx: Context | str | None = None) -> None:
super().__init__(dom, ctx)
delta = float(delta)
if not math.isfinite(delta) or delta <= 0.0:
raise ValueError(f"HuberFunctional requires a finite delta > 0, got {delta}.")
self.delta = delta
[docs]
@checked_method(in_space="domain")
def value(self, x: Any) -> Any:
"""Return ``sum_i h_delta(x_i)``."""
o = self.ops
d = self.delta
a = o.abs(x)
quadratic = 0.5 * a * a
linear = d * (a - 0.5 * d)
return o.sum(o.where(a <= d, quadratic, linear))
def _coordinate_grad(self, x: Any) -> Any:
"""Euclidean coordinate gradient: ``x`` (quadratic) or ``delta sign(x)`` (tail)."""
o = self.ops
d = self.delta
a = o.abs(x)
return o.where(a <= d, x, d * o.sign(x))
def tree_flatten(self):
"""Flatten this functional for pytree registration."""
return (), (self.domain, self.delta, self.ctx)
@classmethod
def tree_unflatten(cls, aux, children):
"""Rebuild this functional from pytree data."""
domain, delta, ctx = aux
return cls(domain, delta, ctx)
def _convert(self, new_ctx: Context) -> "HuberFunctional":
"""Convert this functional to ``new_ctx``."""
return HuberFunctional(self.domain.convert(new_ctx), self.delta, new_ctx)