JAX integration#
JIT usage notes#
SpaceCore’s numerical kernels are written to run under jax.jit when values
live in a JAX-backed Context. The object model remains ordinary Python:
spaces, operators, and functionals are assembled before the numerical kernel is
traced, then passed into the jitted function.
Operator algebra such as A @ B and A + B executes Python-level
simplification rules at construction time. For maximum JIT efficiency:
construct operator expressions outside the JIT-decorated function;
pass the assembled operator as an argument to the jitted function;
avoid calling
make_sumormake_composedfrom inside ajax.jitbody.
This is a trace-time concern rather than a correctness concern. The algebra is
correct either way, but composing inside jax.jit means the simplification
runs once per trace. For repeatedly invoked code with stable operator
structure, build the expression once outside the jitted function.
Example:
import jax
import spacecore as sc
ctx = sc.Context(sc.JaxOps(), dtype="float32")
X = sc.DenseCoordinateSpace((128,), ctx)
A = build_operator(X)
B = build_preconditioner(X)
# Build algebra outside the JIT boundary.
system = B.H @ A @ B + 0.01 * sc.IdentityLinOp(X, ctx)
@jax.jit
def solve(op, rhs):
return sc.cg(op, rhs, maxiter=50).x
x = solve(system, rhs)