Source code for spacecore.linalg._expm

from __future__ import annotations

from typing import Any, NamedTuple

from ..linop import LinOp
from ._lanczos import _check_lanczos_max_iter, _lanczos_basis_and_tridiag
from ._utils import require_linop, require_square, result_repr


[docs] class ExpmMultiplyResult(NamedTuple): """ Store the result returned by :func:`expm_multiply`. Parameters ---------- result : array-like Vector in the domain of the input operator approximating ``exp(t * A) @ v``. krylov_dim : int-like Actual Krylov dimension reached before breakdown or ``max_iter``. residual_estimate : scalar Projected exponential residual estimate ``abs(beta[m] * phi[m - 1])``. converged : bool-like Boolean indicating whether ``residual_estimate < tol``. """ result: Any krylov_dim: Any residual_estimate: Any converged: Any def __repr__(self) -> str: """Return a compact summary without printing the full vector.""" return result_repr( "ExpmMultiplyResult", { "converged": self.converged, "krylov_dim": self.krylov_dim, "residual_estimate": self.residual_estimate, "result": self.result, }, )
[docs] def expm_multiply( A: LinOp, v: Any, t: float | complex = 1.0, *, max_iter: int = 30, tol: float = 1e-10, ) -> ExpmMultiplyResult: r""" Compute :math:`\exp(t A) v` by Krylov projection. Require ``A`` to be square in the SpaceCore sense (``A.domain == A.codomain``) and Hermitian with respect to ``A.domain.inner``. The method builds a Lanczos basis and applies the exponential of the small tridiagonal projection, avoiding dense materialization of ``A``. Parameters ---------- A : LinOp Linear operator that must be Hermitian/self-adjoint with respect to ``A.domain.inner``. ``A.domain`` must equal ``A.codomain``, including the underlying space type and inner-product geometry. Operators with structurally unknown Hermiticity (``A.is_hermitian()`` returns ``None``) are accepted on trust; the caller is responsible for ensuring Hermiticity. Non-Hermitian inputs produce undefined results. v : array-like Initial vector in ``A.domain``. t : float or complex, optional Scalar multiplier on ``A``. Complex values require a complex-valued ``ctx.dtype`` such as ``complex64`` or ``complex128``. Using a complex ``t`` with a real-valued context produces backend-dependent results. Default is 1.0. max_iter : int, optional Maximum Krylov dimension. Values around 20-50 are usually sufficient when :math:`|t|\|A\|` is moderate. Must be a Python ``int`` rather than a traced JAX scalar; under ``jax.jit`` it is treated as a static argument and changing it triggers retracing. Default is 30. tol : float, optional Tolerance used both for Lanczos breakdown and for the convergence flag: ``result.converged`` is ``True`` when the projected exponential residual estimate is below ``tol``. Default is 1e-10. Returns ------- ExpmMultiplyResult Result vector in ``A.domain``, the Krylov dimension used, the standard estimate ``abs(beta[m] * phi[m - 1])``, and a convergence flag. Raises ------ TypeError If ``A`` is not a :class:`LinOp`. ValueError If ``A`` is not square, is known to be non-Hermitian, or if ``max_iter`` is invalid. See Also -------- lanczos_smallest : Build the related Hermitian Krylov projection. power_iteration : Estimate a dominant eigenpair. Notes ----- The projected exponential is computed as :math:`\exp(t T) e_0` using an eigendecomposition of the small real symmetric tridiagonal matrix ``T``. This is JIT-compatible on the JAX backend when ``max_iter`` is static. Hermiticity is enforced only when it can be structurally verified: known non-Hermitian operators raise ``ValueError``. Operators with unknown structure, such as many matrix-free operators and operators on custom spaces, are trusted. The returned residual estimate is :math:`|\beta_m \phi_{m-1}|`, where ``phi`` is the projected exponential vector. Callers that need the true residual can perform one additional operator application. Examples -------- Apply the exponential of a diagonal operator. >>> import numpy as np >>> import spacecore as sc >>> ctx = sc.Context(sc.NumpyOps(), dtype=np.float64) >>> X = sc.DenseCoordinateSpace((2,), ctx) >>> A = sc.DiagonalLinOp(ctx.asarray([0.0, 1.0]), X, ctx) >>> v = ctx.asarray([2.0, 3.0]) >>> result = sc.expm_multiply(A, v, t=0.5, max_iter=5) >>> np.allclose(result.result, [2.0, 3.0 * np.exp(0.5)], atol=1e-10) True """ A = require_linop(A) require_square(A, "expm_multiply") if A.is_hermitian() is False: raise ValueError("expm_multiply requires A to be Hermitian/self-adjoint.") max_iter = _check_lanczos_max_iter(max_iter) A.domain.check_member(v) ops = A.ops ctx = A.ctx real_dtype = ops.real_dtype(ctx.dtype) basis = _lanczos_basis_and_tridiag(A, v, max_iter, tol, real_dtype, check_every=1) m = basis.krylov_dim idx = ops.arange(max_iter) active_mask = idx < m active_matrix_mask = active_mask[:, None] & active_mask[None, :] T_safe = ops.where(active_matrix_mask, basis.T, ops.zeros_like(basis.T)) eigvals, eigvecs = ops.eigh(T_safe) exp_eigs = ops.exp(t * eigvals) expT_e1 = eigvecs @ (exp_eigs * eigvecs[0, :]) expT_e1 = ops.where(active_mask, expT_e1, ops.zeros_like(expT_e1)) V_reduced = basis.V[:max_iter, :] result_flat = basis.initial_norm * ops.einsum("j,jn->n", expT_e1, V_reduced) result = A.domain.unflatten(result_flat) last_coeff = ops.abs(expT_e1[m - 1]) residual_estimate = basis.betas[m] * last_coeff converged = residual_estimate < basis.tol return ExpmMultiplyResult(result, m, residual_estimate, converged)