Source code for spacecore.backend.jax._ops

from __future__ import annotations

from typing import Any, Sequence, Literal, Tuple, Callable, Optional, Type, cast
from warnings import warn

from .._family import BackendFamily
from .._ops import BackendOps
from ..numpy import NumpyOps
from ...types import DenseArray, ArrayLike, SparseArray, DType, Index, X, T, Y, R, Carry


[docs] class JaxOps(BackendOps): """ BackendOps implementation for the JAX ecosystem. This backend uses JAX for dense array operations and JAX experimental sparse arrays for sparse operations. Dense arrays jax.Array Sparse arrays jax.experimental.sparse.BCOO jax.experimental.sparse.BCSR Methods ------- Most methods mirror the corresponding JAX public API signatures and delegate to `jax.numpy`, `jax.numpy.linalg`, `jax.scipy`, or `jax.experimental.sparse`. Backend-specific behavior, tracing rules, dtype canonicalization, device placement, sharding, and error modes therefore follow JAX semantics. Backend handles - jax : module JAX module stored on the class and available through instances as `ops.jax`. Advanced users may use it when SpaceCore's portable API does not expose a required JAX feature. - jnp : module `jax.numpy` module stored on the class and available through instances as `ops.jnp`. - jsparse : module `jax.experimental.sparse` module stored on the class and available through instances as `ops.jsparse`. Notes ----- Code intended to remain backend-portable should prefer `BackendOps` methods. Direct use of `ops.jax`, `ops.jnp`, or `ops.jsparse` is an explicit JAX-specific escape hatch. Some parameters are accepted for JAX signature compatibility even when JAX ignores them. Array-creation routines may expose `device` and `out_sharding` for explicit placement or sharding. """ import jax as _jax import jax.numpy as _jnp import jax.experimental.sparse as _jsparse # Concrete library handles exposed as ``Any`` so the portable protocols can # flow into typed JAX calls without per-boundary casts; mirrors the base # ``xp: ClassVar[Any]`` design. jax: Any = _jax jnp: Any = _jnp jsparse: Any = _jsparse xp = jnp _family = BackendFamily.jax.value.lower() _allow_sparse = True def __init__(self) -> None: super().__init__()
[docs] def sanitize_dtype(self, dtype: DType | None) -> DType: """ Normalize a dtype specifier using JAX. Input: dtype: Optional dtype requested by SpaceCore or the caller. Output: Backend dtype object accepted by array constructors. See: https://docs.jax.dev/en/latest/_autosummary/jax.numpy.dtype.html Backend-specific notes: SpaceCore rejects dtypes that JAX would silently canonicalize under the active x64 setting. """ x64_enabled = bool(self.jax.config.read("jax_enable_x64")) if dtype is None: if not x64_enabled: warn( "jax_enable_x64 is set to False, so default JAX dtype is set to float32. " "If you need float64, run `jax.config.update('jax_enable_x64', True)`.", UserWarning, ) return self.jnp.float32 return self.jnp.float64 try: dt = self.jnp.dtype(dtype) except Exception as e: raise TypeError(f"Invalid dtype specifier for JAX: {dtype!r}.") from e # Ensure dtype is actually usable on this backend/device try: self.jnp.empty((), dtype=dt) except Exception as e: raise TypeError( f"Dtype {dt!r} is not supported by the active JAX backend/device." ) from e # Forbid implicit coercion under current JAX configuration dt_canon = self.jax.dtypes.canonicalize_dtype(dt) if dt_canon != dt: raise TypeError( f"Dtype {dt} is not permitted under current JAX configuration: " f"it would be canonicalized to {dt_canon}. " f"(jax_enable_x64={x64_enabled!r})" ) return dt
@property def dense_array(self) -> Type[Any]: """ Dense array type using JAX. Returns ------- Concrete dense array class accepted by this backend. See: https://docs.jax.dev/en/latest/jax.Array.html """ return self.jax.Array @property def sparse_array(self) -> Tuple[Type[Any], ...]: """ Sparse array type tuple using JAX. Returns ------- Concrete sparse array classes accepted by this backend, or None. See: https://docs.jax.dev/en/latest/jax.experimental.sparse.html """ return (self.jsparse.BCOO, self.jsparse.BCSR)
[docs] def assparse( self, x: Any, *, format: Literal["bcoo", "bcsr"] = "bcoo", index_dtype: DType | None = None, nse: int | None = None, dtype: DType | None = None, ) -> SparseArray: """ Convert input to a sparse array using JAX. Input: x: Dense, sparse, or array-like input plus sparse-format options. Output: Sparse backend array. See: https://docs.jax.dev/en/latest/jax.experimental.sparse.html Backend-specific notes: Dense inputs are converted with JAX sparse BCOO/BCSR constructors; SciPy sparse inputs use from_scipy_sparse. """ import scipy.sparse as sps self._reject_complex_to_real(x, dtype, operation="assparse") if self.is_sparse(x): if dtype is not None and self.get_dtype(x) != self.sanitize_dtype(dtype): return x.astype(self.sanitize_dtype(dtype)) return x if sps.issparse(x): if format == "bcoo": kwargs = {} if index_dtype is not None: kwargs["index_dtype"] = index_dtype if nse is not None: kwargs["nse"] = nse out = self.jsparse.BCOO.from_scipy_sparse(x, **kwargs) return out.astype(self.sanitize_dtype(dtype)) if dtype is not None else out if format == "bcsr": if self.jsparse.BCSR is None: raise TypeError("BCSR is not available in this JAX version.") kwargs = {} if index_dtype is not None: kwargs["index_dtype"] = index_dtype if nse is not None: kwargs["nse"] = nse out = self.jsparse.BCSR.from_scipy_sparse(x, **kwargs) return out.astype(self.sanitize_dtype(dtype)) if dtype is not None else out raise ValueError(f"Unknown sparse format: {format!r}") x_arr = self.asarray(x, dtype=dtype) if format == "bcoo": kwargs = {} if index_dtype is not None: kwargs["index_dtype"] = index_dtype if nse is not None: kwargs["nse"] = nse return self.jsparse.BCOO.fromdense(x_arr, **kwargs) if format == "bcsr": if self.jsparse.BCSR is None: raise TypeError("BCSR is not available in this JAX version.") kwargs = {} if index_dtype is not None: kwargs["index_dtype"] = index_dtype if nse is not None: kwargs["nse"] = nse return self.jsparse.BCSR.fromdense(x_arr, **kwargs) raise ValueError(f"Unknown sparse format: {format!r}")
[docs] def sparse_matmul(self, a: SparseArray, b: DenseArray) -> DenseArray: """ Multiply sparse and dense arrays using JAX. Input: a: Sparse backend array; b: Dense backend array. Output: Dense backend array containing the product. See: https://docs.jax.dev/en/latest/jax.experimental.sparse.html Backend-specific notes: Uses JAX sparse matmul and returns a JAX array; sparse support remains experimental in JAX. """ return a @ b
[docs] def vmap( self, fn: Callable, in_axes: int | Sequence[int | None] | None = 0, out_axes: int | Sequence[int | None] | None = 0, ) -> Callable: """Vectorize a function using ``jax.vmap``.""" return self.jax.vmap(fn, in_axes=in_axes, out_axes=out_axes)
@property def has_native_vmap(self) -> bool: """Return ``True`` because JAX provides native ``vmap``.""" return True
[docs] def logsumexp( self, a: DenseArray, axis: int | Sequence[int] | None = None, b: DenseArray | None = None, keepdims: bool = False, return_sign: bool = False, where: DenseArray | None = None, ) -> DenseArray | Tuple[DenseArray, DenseArray]: """ Compute a stable log-sum-exp reduction using JAX. Input: a: Dense backend array; axis, weights, and sign options control the reduction. Output: Dense backend array or tuple containing log-sum-exp results. See: https://docs.jax.dev/en/latest/_autosummary/jax.scipy.special.logsumexp.html """ return self.jax.scipy.special.logsumexp( a, axis=axis, b=b, keepdims=keepdims, return_sign=return_sign, where=where )
[docs] def index_set(self, x: DenseArray, index: Index, values: ArrayLike, *, copy: bool = True): """ Set indexed values using JAX. Input: x: Dense backend array; index: Selection; values: Replacement values; copy controls mutation policy. Output: Dense backend array with indexed values set. See: https://docs.jax.dev/en/latest/_autosummary/jax.Array.at.html Backend-specific notes: JAX arrays are immutable; copy=False raises NotImplementedError. """ if not copy: raise NotImplementedError("JAX arrays are immutable; copy=False is not supported.") return cast(Any, x).at[index].set(values)
[docs] def ix_(self, *args: Any) -> Any: r""" Build open mesh index arrays using JAX. Input: args: One-dimensional index arrays or sequences. Output: Tuple of dense backend arrays usable for open-mesh indexing. See: https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ix\\_.html """ return self.jnp.ix_(*args)
[docs] def fori_loop( self, lower: int, upper: int, body_fun: Callable[[int, T], T], init_val: T, *, unroll: int | bool | None = None, ) -> T: """ Run a counted loop primitive using JAX. Input: lower, upper: Loop bounds; body_fun: Loop body; init_val: Initial carry value. Output: Final carry value after loop execution. See: https://docs.jax.dev/en/latest/_autosummary/jax.lax.fori_loop.html Backend-specific notes: Loop bounds and unroll behavior follow JAX tracing and compilation rules. """ return self.jax.lax.fori_loop(lower, upper, body_fun, init_val, unroll=unroll)
[docs] def while_loop( self, cond_fun: Callable[[T], bool], body_fun: Callable[[T], T], init_val: T, ) -> T: """ Run a while-loop primitive using JAX. Input: cond_fun: Loop condition; body_fun: Loop body; init_val: Initial carry value. Output: Final carry value after loop execution. See: https://docs.jax.dev/en/latest/_autosummary/jax.lax.while_loop.html Backend-specific notes: Condition and body are staged according to JAX lax control-flow semantics. """ return self.jax.lax.while_loop(cond_fun, body_fun, init_val)
[docs] def scan( self, f: Callable[[Carry, X], Tuple[Carry, Y]], init: Carry, xs: X, length: Optional[int] = None, reverse: bool = False, unroll: int = 1, _split_transpose: bool = False, ) -> Tuple[Carry, Y]: """ Run a scan primitive using JAX. Input: f: Scan body; init: Initial carry; xs: Per-step inputs plus scan options. Output: Tuple of final carry and stacked outputs. See: https://docs.jax.dev/en/latest/_autosummary/jax.lax.scan.html Backend-specific notes: Inputs and outputs may be pytrees and are staged according to JAX lax.scan semantics. """ return self.jax.lax.scan( f, init, xs, length=length, reverse=reverse, unroll=unroll, _split_transpose=_split_transpose, )
[docs] def cond( self, pred: bool, true_fun: Callable[[T], R], false_fun: Callable[[T], R], *operands: Any, ) -> R: """ Run conditional branch selection using JAX. Input: pred: Predicate; true_fun and false_fun: Branch functions; operands: Branch inputs. Output: Result returned by the selected branch. See: https://docs.jax.dev/en/latest/_autosummary/jax.lax.cond.html Backend-specific notes: Branches are staged according to JAX lax.cond semantics rather than Python eager branching. """ return self.jax.lax.cond(pred, true_fun, false_fun, *operands)
[docs] def index_add(self, x: DenseArray, index: Index, values: DenseArray, *, copy: bool = True): """ Add into indexed values using JAX. Input: x: Dense backend array; index: Selection; values: Values to add; copy controls mutation policy. Output: Dense backend array with indexed values incremented. See: https://docs.jax.dev/en/latest/_autosummary/jax.Array.at.html Backend-specific notes: JAX arrays are immutable; copy=False raises NotImplementedError and repeated indices follow JAX scatter-add semantics. """ if not copy: raise NotImplementedError("JAX arrays are immutable; copy=False is not supported.") return cast(Any, x).at[index].add(values)
[docs] def allclose_sparse( self, a: SparseArray, b: SparseArray, rtol: float = 1e-5, atol: float = 1e-8, ) -> bool: """ Compare sparse arrays elementwise within tolerances using JAX. Input: a, b: Sparse backend arrays; rtol and atol configure comparison. Output: Boolean indicating whether sparse arrays are close. See: https://docs.jax.dev/en/latest/jax.experimental.sparse.html Backend-specific notes: SpaceCore converts JAX sparse arrays through SciPy sparse arrays for comparison. """ self._require_two_sparse(a, b) np_ops = NumpyOps() a_sp = self._to_scipy_sparse(np_ops, a) b_sp = self._to_scipy_sparse(np_ops, b) return np_ops.allclose_sparse(a_sp, b_sp, rtol=rtol, atol=atol)
def _to_scipy_sparse(self, np_ops: NumpyOps, x: SparseArray): bcoo: Any = x if isinstance(x, self.jsparse.BCSR): bcoo = bcoo.to_bcoo() if isinstance(bcoo, self.jsparse.BCOO): bcoo = bcoo.sum_duplicates(remove_zeros=False) if bcoo.n_batch != 0 or bcoo.n_dense != 0 or bcoo.n_sparse != 2: raise NotImplementedError( "_to_scipy_sparse supports only 2D unbatched sparse arrays." ) row = bcoo.indices[:, 0] col = bcoo.indices[:, 1] data = bcoo.data return np_ops.sp.sparse.coo_array((data, (row, col)), shape=bcoo.shape) raise TypeError(f"Unsupported sparse type: {type(x)!r}")