Source code for spacecore.optimize._optax

"""optax optimizer adapter (ADR-018).

Drives an `optax <https://optax.readthedocs.io>`_ gradient-transformation
optimizer from a SpaceCore :class:`~spacecore.Functional`. optax operates on JAX
pytrees, and SpaceCore elements already *are* pytrees (a dense array is a leaf; a
``TreeSpace`` element is a registered pytree), so the adapter passes the element
representation straight through with no flatten/unflatten marshalling. As with
the SciPy adapters, the gradient handed to optax is the coordinate gradient
``X.riesz(F.grad(x))`` -- optax applies coordinate updates
(``params - lr * grad``), so the metric-to-coordinate conversion is mandatory on
a weighted space and the identity on a Euclidean one.

``optax`` is an optional dependency; it is imported lazily so that
``import spacecore`` (and doctest collection) does not require it.

Information lost at the optax boundary
--------------------------------------
* **Geometry.** optax updates are taken in the flat coordinate metric; the
  domain geometry survives only through the ``X.riesz`` gradient conversion.
* **Representation.** ``x0`` may be a raw array or a tuple/list/dict of arrays;
  a bound ``TreeElement`` is normalized to its raw pytree at entry so that
  ``optax.apply_updates`` matches it against the converted gradient's structure.
* **Field.** Unlike the SciPy adapters, ``minimize_optax`` does not reject a
  complex domain -- optax's updates are complex-capable -- so the correctness of
  complex-valued descent is the caller's responsibility.
* **Tracing.** Eager evaluation is correct but unfused. To ``jax.jit`` the step,
  wrap it yourself; jitted external solves may require a domain context built with
  ``check_level="none"`` (ADR-018, ergonomics run-02 S14).
"""
from __future__ import annotations

from typing import Any, Callable

from ..space import TreeSpace
from ._common import coordinate_gradient, domain_with_geometry, require_functional


[docs] def minimize_optax( F: Any, x0: Any, optimizer: Any, *, steps: int, callback: Callable[[int, Any], None] | None = None, ) -> Any: r""" Run an optax optimizer on a SpaceCore functional with pytree pass-through. Executes the canonical optax update loop for ``steps`` iterations:: grad = X.riesz(F.grad(params)) # metric -> coordinate gradient updates, opt_state = optimizer.update(grad, opt_state, params) params = optax.apply_updates(params, updates) The optax optimizer owns the update rule (learning rate, momentum, preconditioning); this adapter only supplies the coordinate gradient and runs the fixed-length loop. There is no convergence test -- optax is a gradient-transformation library, so iteration count is the caller's contract. Parameters ---------- F : Functional Objective with an inner-product domain ``X = F.domain`` and an implemented ``F.grad``. x0 : pytree Initial parameters, an element of ``F.domain``. A raw array or tuple/list/dict of arrays is used as-is; a bound ``TreeElement`` is normalized to its raw pytree. optimizer : optax.GradientTransformation An optax optimizer such as ``optax.adam(1e-2)``. steps : int Number of update steps to run. Must be non-negative. callback : callable, optional Called as ``callback(step, params)`` after each update, e.g. to record a loss trajectory. Returns ------- pytree The final parameters, an element of ``F.domain``. Raises ------ TypeError If ``F`` is not a :class:`~spacecore.Functional`, its domain has no inner-product geometry, or its domain is not JAX-backed (optax produces JAX arrays). ValueError If ``steps`` is negative. ImportError If ``optax`` is not installed (``pip install spacecore[optax]``). See Also -------- spacecore.optimize.minimize_scipy : SciPy ``minimize`` with the same handoff. Notes ----- On a Euclidean space ``X.riesz`` is the identity. On a weighted space it multiplies the metric gradient by the diagonal weights, which is exactly the coordinate gradient optax's Euclidean update expects (ADR-018). Examples -------- .. code-block:: python import numpy as np import optax import spacecore as sc ctx = sc.Context(sc.JaxOps(), dtype=np.float32, check_level="none") X = sc.DenseCoordinateSpace((2,), ctx) Q = sc.DenseLinOp(ctx.asarray([[3.0, 0.0], [0.0, 1.0]]), X, X, ctx) linear = sc.InnerProductFunctional(ctx.asarray([-3.0, -2.0]), X) F = sc.LinOpQuadraticForm(Q, linear) x_star = sc.minimize_optax(F, X.zeros(), optax.adam(1e-1), steps=500) # x_star is approximately (1.0, 2.0) """ F = require_functional(F, "minimize_optax") X = domain_with_geometry(F, "minimize_optax") if getattr(X.ops, "family", None) != "jax": raise TypeError( "minimize_optax requires a JAX-backed domain (optax produces JAX " f"arrays via optax.apply_updates); F.domain uses the " f"{getattr(X.ops, 'family', '?')!r} backend. Build the domain with a " "JaxOps context, or use minimize_scipy for the NumPy backend." ) steps = int(steps) if steps < 0: raise ValueError(f"steps must be non-negative, got {steps}.") try: import optax except ImportError as exc: # pragma: no cover - exercised only without optax raise ImportError( "minimize_optax requires the optional 'optax' dependency; " "install it with `pip install spacecore[optax]`." ) from exc # Normalize a tree element to the raw pytree representation that # ``X.riesz`` (and thus the gradient) returns, so ``optax.apply_updates`` # sees one consistent pytree structure across params, grad, and updates. A # bound ``TreeElement`` is its own registered pytree and would otherwise # collide with the raw tuple/dict the gradient carries. params = X.unflatten_tree(X.flatten_tree(x0)) if isinstance(X, TreeSpace) else x0 opt_state = optimizer.init(params) for step in range(steps): grad = coordinate_gradient(F, X, params) updates, opt_state = optimizer.update(grad, opt_state, params) params = optax.apply_updates(params, updates) if callback is not None: callback(step, params) return params