Source code for spacecore.linalg._cg

from __future__ import annotations

from typing import Any, NamedTuple

from ..linop import LinOp
from ._utils import DEFAULT_CONVERGENCE_CHECK_INTERVAL, check_interval, check_maxiter
from ._utils import SpaceCoreOps, is_converged, resolve_apply
from ._utils import require_linop, require_square
from ._utils import (
    require_strict_cg_preconditions,
    result_repr,
    safe_inverse_nonneg,
    should_check_iteration,
    threshold,
)


[docs] class CGResult(NamedTuple): """ Store the result returned by :func:`cg`. Parameters ---------- x : array-like Approximate solution in ``A.domain``. converged : bool-like Whether the final residual norm satisfied the requested tolerance. num_iters : int-like Number of conjugate-gradient iterations executed. residual_norm : scalar Norm of the final residual in ``A.codomain``. """ x: Any converged: Any num_iters: Any residual_norm: Any def __repr__(self) -> str: """Return a compact summary without printing the full solution array.""" return result_repr( "CGResult", { "converged": self.converged, "num_iters": self.num_iters, "residual_norm": self.residual_norm, "x": self.x, }, )
[docs] def cg( A: LinOp, b: Any, *, x0: Any | None = None, tol: float = 1e-6, atol: float = 0.0, maxiter: int | None = None, check_every: int = DEFAULT_CONVERGENCE_CHECK_INTERVAL, ) -> CGResult: r""" Solve :math:`A x = b` by conjugate gradients. Require ``A`` to be square in the SpaceCore sense (``A.domain == A.codomain``), Hermitian, and positive-definite with respect to ``A.domain.inner``. The implementation uses only :meth:`LinOp.apply` and the domain-space inner product; it never materializes a dense matrix. Parameters ---------- A : LinOp Linear operator that must be Hermitian positive-definite with respect to ``A.domain.inner``. ``A.domain`` must equal ``A.codomain``, including the underlying space type and inner-product geometry. An operator that is *provably* non-self-adjoint in this geometry (``A.is_hermitian() is False``) is rejected at entry with a ``ValueError``. Operators whose Hermiticity is unknown (``A.is_hermitian() is None``, e.g. matrix-free operators) are accepted unchecked; positive-definiteness is likewise not validated, so indefinite or otherwise unsuitable operators can still diverge or produce NaN outputs without an explicit error. b : array-like Right-hand side in ``A.codomain``. x0 : array-like or None, optional Initial guess in ``A.domain``. Default is the zero vector. tol : float, optional Relative tolerance on the linear-system residual. ``result.converged`` is ``True`` when the residual norm is below ``atol + tol * norm(b)``. Default is 1e-6. atol : float, optional Absolute residual tolerance. Default is 0.0. maxiter : int or None, optional Maximum number of iterations. Default is ``prod(A.domain.shape)``. check_every : int, optional Refresh convergence diagnostics every this many iterations and always on the final iteration. Default is ``DEFAULT_CONVERGENCE_CHECK_INTERVAL``. Returns ------- CGResult Named tuple with fields: - ``x``: approximate solution in ``A.domain`` - ``converged``: whether the requested tolerance was met - ``num_iters``: number of iterations executed - ``residual_norm``: final residual norm Raises ------ TypeError If ``A`` is not a :class:`LinOp`. ValueError If ``A`` is not square or if iteration parameters are invalid. See Also -------- lsqr : Solve least-squares systems for rectangular operators. lanczos_smallest : Approximate the smallest eigenpair of a Hermitian operator. Notes ----- The residual norm is compared with :math:`\text{atol} + \text{tol} \| b \|` only every ``check_every`` iterations, and always on the final iteration. This keeps convergence checks out of the hot loop while remaining compatible with JAX JIT control flow. ``maxiter`` and ``check_every`` should be treated as static JAX arguments. Iteration also stops when no numerically useful CG update remains: either the squared residual is at machine-precision scale or the curvature ``inner(p, A p)`` is nonpositive/tiny relative to the residual scale. The residual is refreshed before this early exit, so ``converged`` still reflects the returned iterate. For complex operators, residual norms and step sizes are computed from the real part of ``A.domain.inner(x, y)``. SpaceCore's complex inner-product convention conjugates the first argument; custom :class:`Space` subclasses must follow that convention for CG to converge correctly. Inner products and norms use ``A.domain.inner`` and ``A.domain.norm``. The method is correct on non-Euclidean geometries when the space supplies Riesz maps and ``A`` is Hermitian positive-definite in that geometry. References ---------- Hestenes, M. R. and Stiefel, E., "Methods of Conjugate Gradients for Solving Linear Systems," J. Res. Natl. Bur. Stand., 49 (1952), 409-436. Examples -------- Solve a small positive-definite system. >>> import numpy as np >>> import spacecore as sc >>> ctx = sc.Context(sc.NumpyOps(), dtype=np.float64) >>> X = sc.DenseCoordinateSpace((3,), ctx) >>> M = ctx.asarray([[4.0, 1.0, 0.0], [1.0, 3.0, 1.0], [0.0, 1.0, 2.0]]) >>> A = sc.DenseLinOp(M, X, X, ctx) >>> b = ctx.asarray([1.0, 2.0, 3.0]) >>> result = sc.cg(A, b, tol=1e-10) >>> bool(result.converged) True >>> np.allclose(A.apply(result.x), b) True """ A = require_linop(A) require_square(A, "cg") if A.is_hermitian() is False: raise ValueError( "cg requires A to be Hermitian/self-adjoint with respect to " "A.domain.inner, but A was determined to be provably " "non-self-adjoint in this geometry. For example, a symmetric " "matrix wrapped as a DenseLinOp on a weighted (non-Euclidean) " "inner-product space is not self-adjoint under the weighted inner " "product. Build a valid operator instead, e.g. a normal operator " "A.H @ A + lam * Identity (which is self-adjoint in any geometry), " "or use a Euclidean space where the symmetric matrix is " "self-adjoint." ) require_strict_cg_preconditions(A) A.codomain.check_member(b) maxiter = check_maxiter(maxiter, A) check_every = check_interval(check_every) # Resolve check-free cores once; the hot loop then skips per-iteration # validation while honoring any custom geometry the spaces define. apply = resolve_apply(A) dom = SpaceCoreOps(A.domain) cod = SpaceCoreOps(A.codomain) x = A.domain.zeros() if x0 is None else x0 A.domain.check_member(x) r = cod.add(b, cod.scale(-1.0, apply(x))) p = r rs = dom.real_inner(r, r) residual_norm = dom.norm(r) threshold_value = threshold(cod.norm(b), tol, atol) eps = A.ops.asarray(A.ops.eps(A.dtype), dtype=A.dtype) eps2 = eps * eps def cond_fun(carry: tuple[Any, Any, Any, Any, Any, int, Any]) -> Any: _x, _r, _p, _rs, res_norm, k, active = carry return (k < maxiter) & (res_norm > threshold_value) & active def body_fun( carry: tuple[Any, Any, Any, Any, Any, int, Any], ) -> tuple[Any, Any, Any, Any, Any, int, Any]: x, r, p, rs, _residual_norm, k, _active = carry Ap = apply(p) pAp = dom.real_inner(p, Ap) active = (rs > eps2) & (pAp > eps * rs) alpha = A.ops.where(active, rs * safe_inverse_nonneg(A.ops, pAp), A.ops.zeros_like(rs)) x_next = dom.axpy(alpha, p, x) r_next = cod.axpy(-alpha, Ap, r) rs_next = dom.real_inner(r_next, r_next) beta = A.ops.where( active, rs_next * safe_inverse_nonneg(A.ops, rs), A.ops.zeros_like(rs_next) ) p_next = dom.axpy(beta, p, r_next) k_next = k + 1 should_refresh_residual = should_check_iteration(k_next, maxiter, check_every) | (~active) residual_norm_next = A.ops.cond( should_refresh_residual, lambda _: A.ops.sqrt(rs_next), lambda _: _residual_norm, A.ops.asarray(0.0, dtype=A.dtype), ) return x_next, r_next, p_next, rs_next, residual_norm_next, k_next, active x, _r, _p, _rs, residual_norm, num_iters, _active = A.ops.while_loop( cond_fun, body_fun, (x, r, p, rs, residual_norm, 0, A.ops.asarray(True)), ) return CGResult(x, is_converged(residual_norm, threshold_value), num_iters, residual_norm)