from __future__ import annotations
from typing import Any, NamedTuple
from ..linop import LinOp
from ._utils import DEFAULT_CONVERGENCE_CHECK_INTERVAL, SpaceCoreOps, check_interval
from ._utils import check_maxiter, is_converged, require_linop, resolve_apply, resolve_rapply
from ._utils import safe_inverse_nonneg, should_check_iteration, result_repr, threshold
[docs]
class LSQRResult(NamedTuple):
"""
Store the result returned by :func:`lsqr`.
Parameters
----------
x : array-like
Approximate least-squares solution in ``A.domain``.
converged : bool-like
Whether the normal-equation residual satisfied the requested tolerance.
num_iters : int-like
Number of LSQR iterations executed.
residual_norm : scalar
Norm of ``A x - b`` in ``A.codomain`` in exact mode, or the LSQR
recurrence estimate in recurrence mode.
normal_residual_norm : scalar
Norm of ``A.H @ (A x - b)`` in ``A.domain`` in exact mode, or the LSQR
recurrence estimate in recurrence mode.
"""
x: Any
converged: Any
num_iters: Any
residual_norm: Any
normal_residual_norm: Any
def __repr__(self) -> str:
"""Return a compact summary without printing the full solution array."""
return result_repr(
"LSQRResult",
{
"converged": self.converged,
"num_iters": self.num_iters,
"residual_norm": self.residual_norm,
"normal_residual_norm": self.normal_residual_norm,
"x": self.x,
},
)
[docs]
def lsqr(
A: LinOp,
b: Any,
*,
x0: Any | None = None,
tol: float = 1e-6,
atol: float = 0.0,
maxiter: int | None = None,
check_every: int = DEFAULT_CONVERGENCE_CHECK_INTERVAL,
residual_mode: str = "exact",
) -> LSQRResult:
r"""
Solve :math:`\min_x \|A x - b\|` by LSQR.
Allow ``A`` to map between distinct ``domain`` and ``codomain`` spaces.
The method uses :meth:`LinOp.apply` for forward products and ``A.H.apply``
for adjoint products, so the normal equations are represented implicitly
and no dense matrix is formed.
Parameters
----------
A : LinOp
Linear operator with possibly distinct ``domain`` and ``codomain``.
For square ``A`` (``A.domain == A.codomain``), :func:`cg` is usually
preferred when ``A`` is also Hermitian positive-definite.
b : array-like
Right-hand side in ``A.codomain``.
x0 : array-like or None, optional
Initial guess in ``A.domain``. Default is the zero vector.
tol : float, optional
Relative tolerance for the normal-equation residual
``norm(A.H @ (A @ x - b))``. ``result.converged`` is ``True`` when that
residual is below ``atol + tol * norm(b)``. Default is 1e-6.
atol : float, optional
Absolute tolerance for the normal-equation residual. Default is 0.0.
maxiter : int or None, optional
Maximum number of iterations. Default is ``prod(A.domain.shape)``.
check_every : int, optional
Refresh residual diagnostics every this many iterations and always on
the final iteration. Default is
``DEFAULT_CONVERGENCE_CHECK_INTERVAL``.
residual_mode : {"exact", "recurrence"}, optional
Residual diagnostic mode. ``"exact"`` preserves the historical
behavior: every diagnostic refresh recomputes ``A @ x - b`` and
``A.H @ (A @ x - b)``. This costs one additional forward application
and one additional adjoint application on each check iteration, so
small ``check_every`` values, especially ``check_every=1``, can
substantially increase runtime for expensive operators. Use larger
values such as ``check_every=10`` or ``check_every=20`` when exact
diagnostics are needed for matrix-free, PDE, neural-network, GPU, or
JAX workloads. ``"recurrence"`` uses LSQR scalar recurrences for both
returned residual diagnostics and avoids those extra applications.
Returns
-------
LSQRResult
Named tuple with fields:
- ``x``: approximate least-squares solution in ``A.domain``
- ``converged``: whether the requested tolerance was met
- ``num_iters``: number of iterations executed
- ``residual_norm``: final residual norm or recurrence estimate
- ``normal_residual_norm``: final normal-equation residual norm or
recurrence estimate
Raises
------
TypeError
If ``A`` is not a :class:`LinOp`.
ValueError
If iteration parameters are invalid or ``residual_mode`` is unknown.
See Also
--------
cg : Solve square Hermitian positive-definite systems.
power_iteration : Estimate a dominant eigenpair.
Notes
-----
Convergence is tested using
:math:`\|A^*(A x - b)\| < \text{atol} + \text{tol}\|b\|`.
In ``residual_mode="exact"``, exact residual diagnostics are refreshed only
every ``check_every`` iterations, and always on the final iteration. Each
refresh performs one additional forward product and one additional adjoint
product beyond the LSQR recurrence itself.
In ``residual_mode="recurrence"``, ``residual_norm`` is the standard LSQR
estimate ``abs(phi_bar)`` and ``normal_residual_norm`` is the LSQR scalar
estimate ``alpha * abs(tau)`` with ``tau = s * phi``. These estimates avoid
extra operator applications during checks, including the final check. This
function is JIT-compatible on the JAX backend when ``maxiter``,
``check_every``, and ``residual_mode`` are static arguments.
The normal-equation residual can be much smaller than the solution error
for ill-conditioned ``A``. For ill-conditioned problems, use a tighter
``tol`` or check the residual and solution quality directly.
Works on real and complex operators. For complex operators, ``A.H`` uses
the conjugate adjoint.
Inner products and norms use ``A.domain.inner`` / ``A.domain.norm`` for
domain-space quantities and ``A.codomain.norm`` for least-squares
residuals. The method is therefore correct on non-Euclidean geometries
when the spaces provide Riesz maps and ``A.rapply`` is the true metric
adjoint.
References
----------
Paige, C. C. and Saunders, M. A., "LSQR: An Algorithm for Sparse
Linear Equations and Sparse Least Squares," ACM Trans. Math. Soft.,
8 (1982), 43-71.
Examples
--------
Solve a small overdetermined least-squares problem.
>>> import numpy as np
>>> import spacecore as sc
>>> ctx = sc.Context(sc.NumpyOps(), dtype=np.float64)
>>> X = sc.DenseCoordinateSpace((2,), ctx)
>>> Y = sc.DenseCoordinateSpace((3,), ctx)
>>> M = ctx.asarray([[1.0, 0.0], [0.0, 1.0], [1.0, 1.0]])
>>> A = sc.DenseLinOp(M, X, Y, ctx)
>>> b = ctx.asarray([1.0, 2.0, 3.0])
>>> result = sc.lsqr(A, b, tol=1e-10)
>>> np.allclose(result.x, [1.0, 2.0])
True
"""
A = require_linop(A)
A.codomain.check_member(b)
maxiter = check_maxiter(maxiter, A)
check_every = check_interval(check_every)
if residual_mode not in {"exact", "recurrence"}:
raise ValueError("residual_mode must be 'exact' or 'recurrence'.")
# Resolve check-free cores once; the hot loop then skips per-iteration
# validation while honoring any custom geometry the spaces define.
apply = resolve_apply(A)
radj = resolve_rapply(A)
dom = SpaceCoreOps(A.domain)
cod = SpaceCoreOps(A.codomain)
x = A.domain.zeros() if x0 is None else x0
A.domain.check_member(x)
residual = cod.add(b, cod.scale(-1.0, apply(x)))
beta = cod.norm(residual)
u = residual
u = cod.scale(safe_inverse_nonneg(A.ops, beta), u)
v = radj(u)
alpha = dom.norm(v)
if residual_mode == "exact":
normal_residual_norm = dom.norm(radj(residual))
else:
normal_residual_norm = beta * alpha
v = dom.scale(safe_inverse_nonneg(A.ops, alpha), v)
w = v
phi_bar = beta
rho_bar = alpha
residual_norm = beta
threshold_value = threshold(A.codomain.norm(b), tol, atol)
def cond_fun(carry: tuple[Any, ...]) -> Any:
_x, _u, _v, _w, _alpha, _beta, _rho_bar, _phi_bar, _res_norm, norm_res, k = carry
return (k < maxiter) & (norm_res > threshold_value)
def body_fun(carry: tuple[Any, ...]) -> tuple[Any, ...]:
x, u, v, w, alpha, _beta, rho_bar, phi_bar, _residual_norm, _normal_residual, k = carry
u_next = cod.axpy(-alpha, u, apply(v))
beta_next = cod.norm(u_next)
u_next = cod.scale(safe_inverse_nonneg(A.ops, beta_next), u_next)
v_next = dom.axpy(-beta_next, v, radj(u_next))
alpha_next = dom.norm(v_next)
v_next = dom.scale(safe_inverse_nonneg(A.ops, alpha_next), v_next)
rho = A.ops.sqrt(rho_bar * rho_bar + beta_next * beta_next)
inv_rho = safe_inverse_nonneg(A.ops, rho)
c = rho_bar * inv_rho
s = beta_next * inv_rho
theta = s * alpha_next
rho_bar_next = -c * alpha_next
phi = c * phi_bar
phi_bar_next = s * phi_bar
recurrence_residual_norm = A.ops.abs(phi_bar_next)
recurrence_normal_residual_norm = alpha_next * A.ops.abs(s * phi)
x_next = dom.axpy(phi * inv_rho, w, x)
w_next = dom.axpy(-(theta * inv_rho), w, v_next)
k_next = k + 1
if residual_mode == "exact":
def refresh_residuals(payload: tuple[Any, Any, Any]) -> tuple[Any, Any]:
x_candidate, _old_residual_norm, _old_normal_residual = payload
residual_next = cod.add(apply(x_candidate), cod.scale(-1.0, b))
return cod.norm(residual_next), dom.norm(radj(residual_next))
else:
def refresh_residuals(payload: tuple[Any, Any, Any]) -> tuple[Any, Any]:
_x_candidate, _old_residual_norm, _old_normal_residual = payload
return recurrence_residual_norm, recurrence_normal_residual_norm
def keep_residuals(payload: tuple[Any, Any, Any]) -> tuple[Any, Any]:
_x_candidate, old_residual_norm, old_normal_residual = payload
return old_residual_norm, old_normal_residual
residual_norm_next, normal_residual_norm_next = A.ops.cond(
should_check_iteration(k_next, maxiter, check_every),
refresh_residuals,
keep_residuals,
(x_next, _residual_norm, _normal_residual),
)
return (
x_next,
u_next,
v_next,
w_next,
alpha_next,
beta_next,
rho_bar_next,
phi_bar_next,
residual_norm_next,
normal_residual_norm_next,
k_next,
)
x, *_rest, residual_norm, normal_residual_norm, num_iters = A.ops.while_loop(
cond_fun,
body_fun,
(x, u, v, w, alpha, beta, rho_bar, phi_bar, residual_norm, normal_residual_norm, 0),
)
return LSQRResult(
x,
is_converged(normal_residual_norm, threshold_value),
num_iters,
residual_norm,
normal_residual_norm,
)