Source code for spacecore.functional.tools._entropy

"""Entropy objectives: negative entropy and KL divergence (ADR-019)."""
from __future__ import annotations

from typing import Any, cast

from .._base import Domain
from .._linear import _convert_space_element
from ...backend import Context, jax_pytree_class
from ..._checks import checked_method
from ._coordinate import _CoordinateFunctional


[docs] @jax_pytree_class class NegativeEntropyFunctional(_CoordinateFunctional[Domain]): r""" Negative (Shannon) entropy ``F(x) = sum_i x_i log x_i``. The natural domain is the positive orthant ``x_i > 0``; the value uses the convention ``0 log 0 = 0`` so the origin and zero coordinates evaluate cleanly, but the gradient ``log x_i + 1`` is only defined for ``x_i > 0``. Parameters ---------- dom : Space Domain space ``X``. ctx : Context, str, or None, optional Backend context specification. Default is resolved from ``dom``. Examples -------- >>> import numpy as np >>> import spacecore as sc >>> ctx = sc.Context(sc.NumpyOps(), dtype=np.float64) >>> X = sc.DenseCoordinateSpace((2,), ctx) >>> f = sc.NegativeEntropyFunctional(X) >>> float(f.value(ctx.asarray([1.0, 1.0]))) 0.0 >>> np.asarray(f.grad(ctx.asarray([1.0, 1.0]))) # log(x) + 1 array([1., 1.]) """ def __init__(self, dom: Domain, ctx: Context | str | None = None) -> None: super().__init__(dom, ctx)
[docs] @checked_method(in_space="domain") def value(self, x: Any) -> Any: """Return ``sum_i x_i log x_i`` with ``0 log 0 = 0``.""" o = self.ops positive = x > 0 safe = o.where(positive, x, cast(Any, 1.0)) return o.sum(o.where(positive, x * o.log(safe), cast(Any, 0.0)))
def _coordinate_grad(self, x: Any) -> Any: """Euclidean coordinate gradient ``log x_i + 1`` (defined for ``x_i > 0``).""" return self.ops.log(x) + 1.0 def tree_flatten(self): """Flatten this functional for pytree registration.""" return (), (self.domain, self.ctx) @classmethod def tree_unflatten(cls, aux, children): """Rebuild this functional from pytree data.""" domain, ctx = aux return cls(domain, ctx) def _convert(self, new_ctx: Context) -> "NegativeEntropyFunctional": """Convert this functional to ``new_ctx``.""" return NegativeEntropyFunctional(self.domain.convert(new_ctx), new_ctx)
[docs] @jax_pytree_class class KLDivergenceFunctional(_CoordinateFunctional[Domain]): r""" Kullback--Leibler divergence to a fixed positive ``target``. ``F(x) = sum_i x_i log(x_i / t_i)`` against a strictly positive target ``t``. The natural domain is ``x_i >= 0`` (with ``0 log 0 = 0``) and ``t_i > 0``. With ``target`` equal to the all-ones element this reduces exactly to :class:`NegativeEntropyFunctional`, and the gradient ``log(x_i / t_i) + 1`` reduces accordingly. Parameters ---------- target : array-like Reference element ``t`` in ``dom`` with strictly positive coordinates. dom : Space Domain space ``X``. ADR-019 writes ``KLDivergenceFunctional(target)``; the explicit space follows the ``(data, dom, ctx)`` constructor shape used by :class:`~spacecore.functional.InnerProductFunctional`, because a bare element does not carry its space. ctx : Context, str, or None, optional Backend context specification. Default is resolved from ``dom``. Examples -------- >>> import numpy as np >>> import spacecore as sc >>> ctx = sc.Context(sc.NumpyOps(), dtype=np.float64) >>> X = sc.DenseCoordinateSpace((2,), ctx) >>> f = sc.KLDivergenceFunctional(ctx.asarray([1.0, 1.0]), X) >>> float(f.value(ctx.asarray([1.0, 1.0]))) # zero divergence at x == target 0.0 >>> np.asarray(f.grad(ctx.asarray([1.0, 1.0]))) # log(x / t) + 1 array([1., 1.]) """ def __init__( self, target: Any, dom: Domain, ctx: Context | str | None = None, ) -> None: super().__init__(dom, ctx) self._target = _convert_space_element(self.domain, target) if self._checks_at_least("standard"): self.domain._check_member(self._target) @property def target(self) -> Any: """Stored reference element ``t``.""" return self._target
[docs] @checked_method(in_space="domain") def value(self, x: Any) -> Any: """Return ``sum_i x_i log(x_i / t_i)`` with ``0 log 0 = 0``.""" o = self.ops positive = x > 0 ratio = o.where(positive, x / self._target, cast(Any, 1.0)) return o.sum(o.where(positive, x * o.log(ratio), cast(Any, 0.0)))
def _coordinate_grad(self, x: Any) -> Any: """Euclidean coordinate gradient ``log(x_i / t_i) + 1`` (for ``x_i > 0``).""" return self.ops.log(x / self._target) + 1.0 def tree_flatten(self): """Flatten this functional for pytree registration.""" return (self._target,), (self.domain, self.ctx) @classmethod def tree_unflatten(cls, aux, children): """Rebuild this functional from pytree data.""" domain, ctx = aux return cls(children[0], domain, ctx) def _convert(self, new_ctx: Context) -> "KLDivergenceFunctional": """Convert this functional to ``new_ctx``.""" return KLDivergenceFunctional(self._target, self.domain.convert(new_ctx), new_ctx)