from __future__ import annotations
from typing import TYPE_CHECKING, Dict, Any, Iterable, Tuple
from warnings import warn
from .._check_policy import (
CheckLevel,
minimum_check_level,
normalize_check_level,
require_mutually_exclusive,
)
from ..types import DType
from ..backend._family import BackendFamily
from ..backend._ops import BackendOps
from ..backend.numpy import NumpyOps
from ._policies import (
ContextConflictError,
ContextInferenceError,
UnknownBackendError,
)
try:
from ..backend.jax import JaxOps
except ImportError:
JaxOps = None
try:
from ..backend.cupy import CuPyOps
except ImportError:
CuPyOps = None
try:
from ..backend.torch import TorchOps
except ImportError:
TorchOps = None
if TYPE_CHECKING:
from ..backend._context import Context
def _context_type() -> type[Context]:
from ..backend._context import Context
return Context
class Contextual:
"""Resolve contexts and backend registrations."""
_default_ctx: Context
_available_ops: Dict[str, type[BackendOps]]
_default_dtype: DType | None = None
_default_check_level: CheckLevel = "none"
def __init__(self) -> None:
Context = _context_type()
ops = NumpyOps()
self.default_ctx = Context(
ops=ops,
dtype=ops.sanitize_dtype(self._default_dtype),
check_level=self._default_check_level,
)
self._available_ops = {
self._backend_key(NumpyOps): NumpyOps,
}
if JaxOps is not None:
self._available_ops[self._backend_key(JaxOps)] = JaxOps
if CuPyOps is not None:
self._available_ops[self._backend_key(CuPyOps)] = CuPyOps
if TorchOps is not None:
self._available_ops[self._backend_key(TorchOps)] = TorchOps
def normalize_context(
self,
ctx: Context | BackendFamily | str | None = None,
dtype: Any = None,
enable_checks: bool | None = None,
*,
check_level: CheckLevel | None = None,
) -> Context:
Context = _context_type()
require_mutually_exclusive("check_level", check_level, "enable_checks", enable_checks)
if ctx is None:
if dtype is not None or enable_checks is not None or check_level is not None:
warn(
"Provided context is None; dtype and check policy parameters are ignored.",
UserWarning,
)
return self.default_ctx
if isinstance(ctx, Context):
if dtype is not None or enable_checks is not None or check_level is not None:
warn(
"Provided concrete context; dtype and check policy parameters are ignored.",
UserWarning,
)
return Context(
ops=ctx.ops,
dtype=ctx.ops.sanitize_dtype(ctx.dtype),
check_level=ctx.check_level,
)
if isinstance(ctx, (str, BackendFamily)):
ctx = self._backend_key(ctx)
ops = self.get_ops(ctx)
return self.ctx_from_ops(
ops,
dtype=dtype,
enable_checks=enable_checks,
check_level=check_level,
)
else:
raise TypeError(f"Expected Context, BackendFamily, str, or None, got {type(ctx)}.")
def ctx_from_ops(
self,
ops: BackendOps,
dtype: DType | None = None,
enable_checks: bool | None = None,
*,
check_level: CheckLevel | None = None,
) -> Context:
Context = _context_type()
dtype = ops.sanitize_dtype(dtype)
level = normalize_check_level(
check_level,
enable_checks=enable_checks,
default=self._default_check_level,
warn_legacy=enable_checks is not None,
)
return Context(ops=ops, dtype=dtype, check_level=level)
@property
def default_ctx(self) -> Context:
return self._default_ctx
@default_ctx.setter
def default_ctx(self, ctx: Context | BackendFamily | str | None = None) -> None:
ctx = self.normalize_context(ctx)
self._default_ctx = ctx
def get_ops(
self, name: str | BackendFamily | BackendOps | type[BackendOps] | Context
) -> BackendOps:
name = self._backend_key(name)
if name not in self.available_ops:
allowed = ", ".join(k for k in self.available_ops.keys())
raise UnknownBackendError(f"Unknown backend: {name!r}. Expected one of: {allowed}")
return self.available_ops[name]()
@property
def available_ops(self) -> Dict[str, type[BackendOps]]:
return self._available_ops
def register_ops(self, ops: type[BackendOps]) -> type[BackendOps]:
if not isinstance(ops, type) or not issubclass(ops, BackendOps):
raise TypeError(f"Expected type[BackendOps], got {type(ops)!r}")
else:
family = self._backend_key(ops)
if family in self.available_ops.keys():
raise ContextConflictError(f"BackendOps {family} is already registered.")
self._available_ops[family] = ops
return ops
def infer_context(
self,
x: Any,
enable_checks: bool | None = None,
*,
check_level: CheckLevel | None = None,
) -> Context | None:
"""Infer context from `.ctx` first, then registered backend arrays."""
Context = _context_type()
if isinstance(x, Context):
return x
ctx = getattr(x, "ctx", None)
if isinstance(ctx, Context):
return ctx
matched: list[BackendOps] = []
for name, ops in self.available_ops.items():
try:
ops = ops()
if ops.is_array(x):
matched.append(ops)
except Exception:
# Keep inference conservative.
continue
if not matched:
return None
if len(matched) > 1:
raise ContextInferenceError(
f"Ambiguous backend inference for object of type {type(x)!r}: {matched!r}."
)
ops = matched[0]
try:
dtype = ops.get_dtype(x)
except Exception:
dtype = getattr(x, "dtype", self.default_ctx.dtype)
return self.ctx_from_ops(
ops,
dtype,
enable_checks,
check_level=check_level,
)
def infer_contexts(self, values: Iterable[Any]) -> Tuple[Context, ...]:
out: list[Context] = []
for x in values:
ctx = self.infer_context(x)
if ctx is not None:
out.append(ctx)
return tuple(out)
def are_compatible_contexts(self, *ctxs: Context) -> bool:
if len(ctxs) < 2:
return True
first = ctxs[0]
return all(ctx.ops.family == first.ops.family for ctx in ctxs[1:])
def are_compatible_values(self, *values: Any) -> bool:
return self.are_compatible_contexts(*self.infer_contexts(values))
def are_compatible_ops(self, *ops: BackendOps) -> bool:
if not ops:
return True
first = ops[0]
return all(op.family == first.family for op in ops)
def enforce_convert_policy(
self, x: Any, to: Context | BackendFamily | str | None = None
) -> Tuple[Any, Context]:
"""Resolve the target context for ``x``."""
self.infer_context(x)
ctx = self.normalize_context(to)
return x, ctx
def _backend_key(self, x: str | BackendFamily | BackendOps | type[BackendOps] | Context) -> str:
Context = _context_type()
if isinstance(x, Context):
return self._backend_key(x.ops)
if isinstance(x, BackendOps):
return self._backend_key(x.family)
if isinstance(x, type) and issubclass(x, BackendOps):
return self._backend_key(x._family)
if isinstance(x, BackendFamily):
return x.value.lower()
if isinstance(x, str):
key = x.lower()
return "torch" if key == "pytorch" else key
raise TypeError(f"Unsupported backend key source: {type(x)!r}")
def resolve_context_priority(
self,
priority_ctx: Context | BackendFamily | str | None = None,
*other_ctx: object,
) -> Context:
"""Resolve explicit context first, then compatible inferred contexts."""
if priority_ctx is not None:
return self.normalize_context(priority_ctx)
inferred = self.infer_contexts(other_ctx)
if not inferred:
return self.default_ctx
if not self.are_compatible_contexts(*inferred):
fams = tuple(ctx.ops.family for ctx in inferred)
raise ValueError(f"Incompatible inferred contexts: {fams!r}")
first = inferred[0]
ops = type(first.ops)()
dtype = self._join_dtypes(ops, *(ctx.dtype for ctx in inferred))
return self.ctx_from_ops(
ops=ops,
dtype=dtype,
check_level=minimum_check_level(tuple(ctx.check_level for ctx in inferred)),
)
def _join_dtypes(self, ops: BackendOps, *dtypes: DType | None) -> DType | None:
clean = [ops.sanitize_dtype(dt) for dt in dtypes if dt is not None]
if not clean:
return ops.sanitize_dtype(None)
# Promote through the operands' OWN backend namespace. NumPy's
# ``result_type`` cannot interpret a torch/jax dtype, so joining the
# inferred contexts of a non-NumPy operator (for example a
# ``BlockDiagonalLinOp`` built with ``from_operators``, which infers a
# ``TreeSpace`` and joins its leaf dtypes) would otherwise raise
# ``TypeError: Cannot interpret 'torch.float64' as a data type``.
joined = ops.xp.result_type(*clean)
return ops.sanitize_dtype(joined)
_contextual: Contextual | None = None
def _state() -> Contextual:
"""Return the process-wide contextual singleton."""
global _contextual
if _contextual is None:
_contextual = Contextual()
return _contextual
[docs]
def set_context(
ctx: Context | BackendFamily | str | None = None,
dtype: Any = None,
enable_checks: bool | None = None,
*,
check_level: CheckLevel | None = None,
) -> None:
"""
Set the process-wide default SpaceCore context.
Parameters
----------
ctx : Context, BackendFamily, str, or None, optional
Context or backend specification.
dtype : Any, optional
Default dtype override.
enable_checks : bool or None, optional
Deprecated Boolean validation override.
check_level : CheckLevel or None, optional
Validation policy override for backend-name contexts.
"""
state = _state()
state.default_ctx = state.normalize_context(
ctx,
dtype=dtype,
enable_checks=enable_checks,
check_level=check_level,
)
[docs]
def get_context() -> Context:
"""
Return the current process-wide default SpaceCore context.
Returns
-------
Context
Active process-wide default context.
"""
return _state().default_ctx
[docs]
def resolve_context_priority(
priority_ctx: Context | BackendFamily | str | None = None,
*other_ctx: object,
) -> Context:
"""
Resolve the context assigned to a newly created object.
Parameters
----------
priority_ctx : Context, BackendFamily, str, or None, optional
Explicit context that takes precedence when provided.
*other_ctx : object
Objects or contexts used as fallback context sources.
Returns
-------
Context
Resolved context.
"""
return _state().resolve_context_priority(priority_ctx, *other_ctx)
[docs]
def register_ops(ops: type[BackendOps]) -> type[BackendOps]:
"""
Register a backend operations implementation.
Parameters
----------
ops : type of BackendOps
Backend operations class to register.
Returns
-------
type of BackendOps
Registered backend operations class.
"""
return _state().register_ops(ops)
[docs]
def normalize_context(
ctx: Context | BackendFamily | str | None = None,
dtype: Any = None,
enable_checks: bool | None = None,
*,
check_level: CheckLevel | None = None,
) -> Context:
"""
Normalize a context specification through the process-wide state.
Parameters
----------
ctx : Context, BackendFamily, str, or None, optional
Context or backend specification.
dtype : Any, optional
Default dtype override.
enable_checks : bool or None, optional
Deprecated Boolean validation override.
check_level : CheckLevel or None, optional
Validation policy override for backend-name contexts.
Returns
-------
Context
Normalized context.
"""
return _state().normalize_context(
ctx,
dtype=dtype,
enable_checks=enable_checks,
check_level=check_level,
)
[docs]
def normalize_ops(ops: str | BackendFamily | BackendOps | type[BackendOps] | Context) -> BackendOps:
"""
Normalize backend operations through the process-wide state.
Parameters
----------
ops : str, BackendFamily, BackendOps, type of BackendOps, or Context
Backend operations specification.
Returns
-------
BackendOps
Normalized backend operations singleton.
"""
if isinstance(ops, BackendOps):
return ops
return _state().get_ops(ops)
def enforce_convert_policy(
x: Any,
to: Context | BackendFamily | str | None = None,
) -> tuple[Any, Context]:
"""Resolve a conversion target context."""
return _state().enforce_convert_policy(x, to)