Source code for spacecore.linalg._lanczos

from __future__ import annotations

from typing import Any, NamedTuple


from ..linop import LinOp
from ..space import DenseCoordinateSpace, DenseVectorSpace, ElementwiseJordanSpace
from ..types import DenseArray
from ._utils import DEFAULT_CONVERGENCE_CHECK_INTERVAL, SpaceCoreOps, check_interval
from ._utils import require_linop, require_square, resolve_apply, safe_inverse_nonneg
from ._utils import result_repr, should_check_iteration


[docs] class LanczosResult(NamedTuple): """ Store the result returned by :func:`lanczos_smallest`. Parameters ---------- eigenvalue : scalar Ritz approximation to the smallest eigenvalue. eigenvector : array-like Ritz vector in ``A.domain``. residual_norm : scalar Standard Ritz residual estimate. krylov_dim : int-like Krylov dimension reached before breakdown or ``max_iter``. converged : bool-like Whether ``residual_norm < tol``. """ eigenvalue: Any eigenvector: Any residual_norm: Any krylov_dim: Any converged: Any def __repr__(self) -> str: """Return a compact summary without printing the full eigenvector.""" return result_repr( "LanczosResult", { "eigenvalue": self.eigenvalue, "eigenvector": self.eigenvector, "residual_norm": self.residual_norm, "krylov_dim": self.krylov_dim, "converged": self.converged, }, )
class _LanczosBasisResult(NamedTuple): """Store fixed-size Lanczos basis data and tridiagonal projection.""" V: DenseArray T: DenseArray alphas: DenseArray betas: DenseArray krylov_dim: Any initial_norm: Any tol: Any e0_unit: DenseArray def _check_lanczos_max_iter(max_iter: int) -> int: """Validate and normalize the maximum Lanczos iteration count.""" max_iter = int(max_iter) if max_iter < 1: raise ValueError("max_iter must be positive.") return max_iter def _build_tridiagonal( ops: Any, alphas: DenseArray, betas: DenseArray, max_iter: int, m: Any, real_dtype: Any, ) -> DenseArray: """Build the fixed-size tridiagonal Lanczos projection.""" idx = ops.arange(max_iter) full_indices = ops.arange(max_iter + 1) mask_alpha = idx < m inactive_sentinel = ( ops.max(ops.abs(alphas)) + 2.0 * ops.max(ops.abs(betas)) + ops.asarray(1.0, dtype=real_dtype) ) alphas_full = ops.where(mask_alpha, alphas, inactive_sentinel) betas_full = ops.where(full_indices == m, ops.asarray(0.0, dtype=real_dtype), betas) T = ops.zeros((max_iter, max_iter), dtype=real_dtype) def fill_diag(ii: int, T_in: DenseArray) -> DenseArray: return ops.index_set(T_in, (ii, ii), alphas_full[ii], copy=True) T = ops.fori_loop(0, max_iter, fill_diag, T) def fill_off(ii: int, T_in: DenseArray) -> DenseArray: b = betas_full[ii + 1] T_in = ops.index_set(T_in, (ii, ii + 1), b, copy=True) T_in = ops.index_set(T_in, (ii + 1, ii), b, copy=True) return T_in return ops.fori_loop(0, max_iter - 1, fill_off, T) def _lanczos_basis_and_tridiag( A: LinOp, initial_vector: Any, max_iter: int, tol: float, real_dtype: Any, check_every: int, ) -> _LanczosBasisResult: """Build a Lanczos basis and tridiagonal projection.""" ops = A.ops ctx = A.ctx apply = resolve_apply(A) dom = SpaceCoreOps(A.domain) use_euclidean_reorth = ( type(A.domain) in (DenseCoordinateSpace, DenseVectorSpace, ElementwiseJordanSpace) and A.domain.is_euclidean ) v0 = A.domain.flatten(initial_vector) v0 = ctx.assert_dense(v0) n = v0.shape[0] V = ops.zeros((max_iter + 1, n), dtype=ctx.dtype) alphas = ops.zeros((max_iter,), dtype=real_dtype) betas = ops.zeros((max_iter + 1,), dtype=real_dtype) tol_s = ops.asarray(tol, dtype=real_dtype) eps_s = ops.asarray(1e-12, dtype=real_dtype) v0_norm = A.domain.norm(initial_vector) e0 = ops.zeros((n,), dtype=ctx.dtype) e0 = ops.index_set(e0, (0,), ctx.asarray(1.0), copy=True) e0_member = A.domain.unflatten(e0) e0_norm = A.domain.norm(e0_member) e0_unit = A.domain.flatten(A.domain.scale(safe_inverse_nonneg(ops, e0_norm), e0_member)) v0_unit = ops.cond( v0_norm > eps_s, lambda _: A.domain.flatten( A.domain.scale(safe_inverse_nonneg(ops, v0_norm), initial_vector) ), lambda _: e0_unit, ops.asarray(0.0, dtype=real_dtype), ) V = ops.index_set(V, (0, slice(None)), v0_unit, copy=True) beta0 = ops.asarray(1.0, dtype=real_dtype) i0 = 0 m_true0 = ops.asarray(max_iter) keep_going0 = ops.asarray(True) full_indices = ops.arange(max_iter + 1) coeffs_zero = ops.zeros((max_iter + 1,), dtype=ctx.dtype) def cond_fun(state: tuple[Any, Any, Any, Any, Any, Any, Any]) -> Any: i, _V, _alphas, _betas, _beta, m_true, keep_going = state return (i < max_iter) & keep_going def body_fun( state: tuple[Any, Any, Any, Any, Any, Any, Any], ) -> tuple[Any, Any, Any, Any, Any, Any, Any]: i, V_, alphas_, betas_, beta, m_true, keep_going = state v_i = V_[i] v_i_member = A.domain.unflatten(v_i) w_member = apply(v_i_member) w = A.codomain.flatten(w_member) w = ctx.assert_dense(w) alpha = dom.real_inner(v_i_member, w_member) alphas_ = ops.index_set(alphas_, (i,), alpha, copy=True) w = ops.cond( i == 0, lambda w_in: w_in - alpha * v_i, lambda w_in: w_in - alpha * v_i - betas_[i] * V_[i - 1], w, ) w_member = A.domain.unflatten(w) valid = full_indices < (i + 1) mask = ops.where( valid, ops.asarray(1.0, dtype=real_dtype), ops.asarray(0.0, dtype=real_dtype), ) mask = ops.astype(mask, ctx.dtype) if use_euclidean_reorth: coeffs_full = ops.einsum("jn,n->j", ops.conj(V_), w) else: coeffs_full = coeffs_zero def fill_coeff(j: int, coeffs_in: DenseArray) -> DenseArray: v_j_member = A.domain.unflatten(V_[j]) coeff = dom.inner(v_j_member, w_member) return ops.index_set(coeffs_in, (j,), coeff, copy=True) coeffs_full = ops.fori_loop(0, max_iter + 1, fill_coeff, coeffs_full) coeffs_valid = coeffs_full * mask proj = ops.sum(coeffs_valid[:, None] * V_, axis=0) w = w - proj w_member = A.domain.unflatten(w) beta_new = dom.norm(w_member) betas_ = ops.index_set(betas_, (i + 1,), beta_new, copy=True) breakdown = beta_new < tol_s m_true_next = ops.where(breakdown & (m_true == max_iter), i + 1, m_true) def set_next(V_in: DenseArray) -> DenseArray: w_unit = A.domain.flatten(dom.scale(safe_inverse_nonneg(ops, beta_new), w_member)) return ops.index_set(V_in, (i + 1, slice(None)), w_unit, copy=True) V_ = ops.cond(beta_new >= tol_s, set_next, lambda V_in: V_in, V_) i_next = i + 1 keep_going_next = ops.cond( should_check_iteration(i_next, max_iter, check_every), lambda _: beta_new >= tol_s, lambda _: keep_going, ops.asarray(0.0, dtype=real_dtype), ) return i_next, V_, alphas_, betas_, beta_new, m_true_next, keep_going_next i_final, V, alphas, betas, _beta_final, m_true, _keep_going = ops.while_loop( cond_fun, body_fun, (i0, V, alphas, betas, beta0, m_true0, keep_going0) ) m = ops.minimum(m_true, ops.asarray(i_final)) T = _build_tridiagonal(ops, alphas, betas, max_iter, m, real_dtype) return _LanczosBasisResult(V, T, alphas, betas, m, v0_norm, tol_s, e0_unit)
[docs] def lanczos_smallest( A: LinOp, initial_vector: Any, *, max_iter: int = 100, tol: float = 1e-6, check_every: int = DEFAULT_CONVERGENCE_CHECK_INTERVAL, ) -> LanczosResult: r""" Approximate the smallest eigenpair of a Hermitian operator. The operator is supplied as a square ``LinOp`` in the SpaceCore sense (``A.domain == A.codomain``), and ``initial_vector`` is an element of ``A.domain``. The implementation keeps fixed-size coordinate arrays for JAX compatibility, safely handles zero initial vectors, and refines the returned eigenvalue with the Rayleigh quotient of the reconstructed Ritz vector in the original space. Mathematically, Lanczos builds an orthonormal Krylov basis ``V`` for ``span{v, A v, A^2 v, ...}`` and a tridiagonal projection :math:`T_k = V^* A V`. The returned vector is the Ritz vector reconstructed in the original coordinates, and the returned scalar is the Rayleigh quotient :math:`\langle x, A x \rangle_X / \langle x, x \rangle_X`. Parameters ---------- A : LinOp Linear operator that must be Hermitian/self-adjoint with respect to ``A.domain.inner``. ``A.domain`` must equal ``A.codomain``, including the underlying space type and inner-product geometry. Operators with structurally unknown Hermiticity (``A.is_hermitian()`` returns ``None``) are accepted on trust; the caller is responsible for ensuring Hermiticity. Non-Hermitian inputs produce undefined results. initial_vector : array-like Starting vector in ``A.domain``. If it is numerically zero, the algorithm falls back to a deterministic coordinate vector. max_iter : int, optional Maximum Krylov dimension. Must be a Python ``int`` rather than a traced JAX scalar; under ``jax.jit`` it is treated as a static argument and changing it triggers retracing. Default is 100. tol : float, optional Tolerance used for two purposes. Iteration stops at a check point when the off-diagonal Lanczos coefficient falls below ``tol``; the returned ``converged`` flag is ``True`` when the Ritz residual estimate is below ``tol``. Default is 1e-6. check_every : int, optional Refresh the breakdown-based stopping decision every this many iterations and always on the final iteration. Default is ``DEFAULT_CONVERGENCE_CHECK_INTERVAL``. Returns ------- LanczosResult Named tuple with fields: - ``eigenvalue``: smallest Ritz eigenvalue estimate - ``eigenvector``: associated Ritz vector in ``A.domain`` - ``residual_norm``: standard Ritz residual estimate - ``krylov_dim``: actual Krylov dimension reached - ``converged``: whether ``residual_norm < tol`` Raises ------ TypeError If ``A`` is not a :class:`LinOp`. ValueError If ``A`` is not square, is known to be non-Hermitian, or if ``max_iter`` is invalid. See Also -------- power_iteration : Estimate the dominant eigenpair. expm_multiply : Apply a matrix exponential using the Lanczos basis. Notes ----- The residual estimate is computed from the tridiagonal recurrence as :math:`\beta_m |y_{m-1}|`. Callers that need the true residual can evaluate ``A.apply(eigenvector) - eigenvalue * eigenvector`` once more in the original space. The "smallest Ritz value" is the smallest eigenvalue of the projected tridiagonal matrix, not necessarily a good approximation of the smallest eigenvalue of ``A``. Convergence to the actual smallest eigenvalue requires the bottom of the spectrum to be separated and the initial vector to have nonzero projection onto the corresponding eigenspace. For clustered low eigenvalues, increase ``max_iter`` or use multiple initial vectors. Hermiticity is enforced only when it can be structurally verified: known non-Hermitian operators raise ``ValueError``. Operators with unknown structure, such as many matrix-free operators and operators on custom spaces, are trusted. This function is JIT-compatible on the JAX backend when ``max_iter`` and ``check_every`` are static arguments. For plain :class:`VectorSpace` domains, Euclidean reorthogonalization is vectorized; custom spaces use :meth:`Space.inner` to preserve the declared geometry. Inner products and norms use ``A.domain.inner`` and ``A.domain.norm``. The method is correct on non-Euclidean geometries when the space supplies Riesz maps and ``A`` is self-adjoint in that geometry. References ---------- Lanczos, C., "An Iteration Method for the Solution of the Eigenvalue Problem of Linear Differential and Integral Operators," J. Res. Natl. Bur. Stand., 45 (1950), 255-282. Examples -------- Approximate the smallest eigenpair of a diagonal operator. >>> import numpy as np >>> import spacecore as sc >>> ctx = sc.Context(sc.NumpyOps(), dtype=np.float64) >>> X = sc.DenseCoordinateSpace((3,), ctx) >>> A = sc.DiagonalLinOp(ctx.asarray([1.0, 2.0, 4.0]), X, ctx) >>> result = sc.lanczos_smallest(A, ctx.asarray([1.0, 1.0, 1.0]), max_iter=3) >>> np.allclose(result.eigenvalue, 1.0) True """ A = require_linop(A) require_square(A, "lanczos_smallest") if A.is_hermitian() is False: raise ValueError("lanczos_smallest requires A to be Hermitian/self-adjoint.") max_iter = _check_lanczos_max_iter(max_iter) check_every = check_interval(check_every) A.domain.check_member(initial_vector) ops = A.ops ctx = A.ctx real_dtype = ops.real_dtype(ctx.dtype) idx = ops.arange(max_iter) basis = _lanczos_basis_and_tridiag(A, initial_vector, max_iter, tol, real_dtype, check_every) m = basis.krylov_dim _eigvals, eigvecs = ops.eigh(basis.T) y_full = eigvecs[:, 0] residual_norm = basis.betas[m] * ops.abs(y_full[m - 1]) converged = residual_norm < basis.tol mask_y = ops.where( idx < m, ops.asarray(1.0, dtype=real_dtype), ops.asarray(0.0, dtype=real_dtype), ) mask_y = ops.astype(mask_y, y_full.dtype) y_valid = y_full * mask_y V_reduced = basis.V[:max_iter, :] x_flat = ops.einsum("j,jn->n", y_valid, V_reduced) x_member = A.domain.unflatten(x_flat) x_norm = A.domain.norm(x_member) x_flat = ops.cond( x_norm > ops.asarray(1e-12, dtype=real_dtype), lambda _: A.domain.flatten(A.domain.scale(safe_inverse_nonneg(ops, x_norm), x_member)), lambda _: basis.e0_unit, ops.asarray(0.0, dtype=real_dtype), ) x = A.domain.unflatten(x_flat) Ax = A.apply(x) num = ops.real(A.domain.inner(x, Ax)) den = ops.real(A.domain.inner(x, x)) lam = num / den return LanczosResult(lam, x, residual_norm, m, converged)