Source code for spacecore.space.checks._base

from __future__ import annotations

from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, ClassVar

from ..._check_policy import CheckLevel, check_level_at_least, normalize_check_level


[docs] class SpaceValidationError(ValueError, TypeError): """Raised when an object is not a member of a space."""
def _shape_of(space: Any, x: Any) -> tuple[int, ...] | None: """Return the backend-visible shape of ``x`` when available.""" try: return tuple(space.ops.shape(x)) except Exception: maybe_shape = getattr(x, "shape", None) return tuple(maybe_shape) if maybe_shape is not None else None def _dtype_of(space: Any, x: Any) -> Any: """Return the backend-visible dtype of ``x`` when available.""" try: return space.ops.get_dtype(x) except Exception: return getattr(x, "dtype", None)
[docs] @dataclass(frozen=True) class SpaceCheck(ABC): """ Define a membership check for :class:`Space` objects. Parameters ---------- name : str Human-readable check name used in diagnostics. """ name: str core_rank: ClassVar[int] = 0 enforce_core_shape: ClassVar[bool] = False minimum_level: ClassVar[CheckLevel] = "standard" def __call__(self, space: Any, x: Any) -> None: """Raise :class:`SpaceValidationError` when ``x`` is invalid.""" if not self.validate(space, x, allow_leading=False): raise SpaceValidationError(self.validation_message(space, x, allow_leading=False))
[docs] def core_shape(self, space: Any) -> tuple[int, ...]: """Return the trailing shape that defines one element for this check.""" if self.core_rank == 0: return () return tuple(space.shape)[-self.core_rank :]
[docs] def leading_dims(self, x: Any, space: Any) -> tuple[int, ...] | None: """Return leading batch dimensions before this check's core axes.""" shape = _shape_of(space, x) if shape is None: return None core_shape = self.core_shape(space) core_rank = len(core_shape) if core_rank == 0: return shape if len(shape) < core_rank: return None return shape[:-core_rank]
[docs] def validate(self, space: Any, x: Any, *, allow_leading: bool) -> bool: """Return whether ``x`` is valid under member or batched shape policy.""" if self.enforce_core_shape: shape = _shape_of(space, x) core_shape = self.core_shape(space) core_rank = len(core_shape) trailing_matches = core_rank == 0 or ( shape is not None and len(shape) >= core_rank and shape[-core_rank:] == core_shape ) if shape is None or not trailing_matches: return False if not allow_leading and self.leading_dims(x, space) != (): return False return self.is_valid(space, x)
[docs] def validation_message(self, space: Any, x: Any, *, allow_leading: bool) -> str: """Return a diagnostic for an invalid validation result.""" return self.error_message(space, x)
[docs] @abstractmethod def is_valid(self, space: Any, x: Any) -> bool: """Return whether ``x`` is valid for ``space``.""" ...
[docs] @abstractmethod def error_message(self, space: Any, x: Any) -> str: """Return a diagnostic for an invalid ``x``.""" ...
[docs] @dataclass(frozen=True) class BackendCheck(SpaceCheck): """ Check that a value is a dense array for a space backend. Parameters ---------- name : str, optional Check name. Default is ``"backend"``. """ name: str = "backend" core_rank: ClassVar[int] = 0 minimum_level: ClassVar[CheckLevel] = "cheap"
[docs] def is_valid(self, space: Any, x: Any) -> bool: return bool(space.ops.is_dense(x))
[docs] def error_message(self, space: Any, x: Any) -> str: return f"Expected dense array for {space.ops.family}, got {type(x).__name__}"
[docs] @dataclass(frozen=True) class ShapeCheck(SpaceCheck): """ Check that a value has the canonical shape of a space. Parameters ---------- name : str, optional Check name. Default is ``"shape"``. """ name: str = "shape" enforce_core_shape: ClassVar[bool] = True minimum_level: ClassVar[CheckLevel] = "cheap"
[docs] def core_shape(self, space: Any) -> tuple[int, ...]: """Return the whole canonical shape as the trailing element shape.""" return tuple(space.shape)
[docs] def is_valid(self, space: Any, x: Any) -> bool: shape = _shape_of(space, x) if shape is None: return False core_shape = self.core_shape(space) core_rank = len(core_shape) if core_rank == 0: return True return len(shape) >= core_rank and shape[-core_rank:] == core_shape
[docs] def error_message(self, space: Any, x: Any) -> str: return f"Expected shape {tuple(space.shape)}, got {_shape_of(space, x)}"
[docs] def validation_message(self, space: Any, x: Any, *, allow_leading: bool) -> str: if allow_leading: return ( f"Batched value trailing shape must be {tuple(space.shape)}, " f"got {_shape_of(space, x)}." ) return self.error_message(space, x)
[docs] @dataclass(frozen=True) class DTypeCheck(SpaceCheck): """ Check that a value has the dtype required by a space context. Parameters ---------- name : str, optional Check name. Default is ``"dtype"``. """ name: str = "dtype" core_rank: ClassVar[int] = 0 minimum_level: ClassVar[CheckLevel] = "cheap"
[docs] def is_valid(self, space: Any, x: Any) -> bool: return _dtype_of(space, x) == space.dtype
[docs] def error_message(self, space: Any, x: Any) -> str: return f"Expected dtype {space.dtype}, got {_dtype_of(space, x)}"
[docs] @dataclass(frozen=True) class FieldCheck(SpaceCheck): """ Check that a value is compatible with a space's mathematical field. Parameters ---------- name : str, optional Identifier for this check. Default is ``"field"``. """ name: str = "field" core_rank: ClassVar[int] = 0 minimum_level: ClassVar[CheckLevel] = "cheap"
[docs] def is_valid(self, space: Any, x: Any) -> bool: dtype = _dtype_of(space, x) if dtype is None: return False return space.field == "complex" or not space.ops.is_complex_dtype(dtype)
[docs] def error_message(self, space: Any, x: Any) -> str: return ( f"Expected an element compatible with the {space.field} scalar field, " f"got dtype {_dtype_of(space, x)}" )
[docs] @dataclass(frozen=True) class SquareMatrixCheck(SpaceCheck): """ Check that a value has square trailing matrix axes. Parameters ---------- name : str, optional Check name. Default is ``"square_matrix"``. """ name: str = "square_matrix" core_rank: ClassVar[int] = 2 enforce_core_shape: ClassVar[bool] = True minimum_level: ClassVar[CheckLevel] = "cheap"
[docs] def is_valid(self, space: Any, x: Any) -> bool: shape = _shape_of(space, x) return shape is not None and len(shape) >= 2 and shape[-1] == shape[-2]
[docs] def error_message(self, space: Any, x: Any) -> str: return f"Expected square matrix, got shape {_shape_of(space, x)}"
[docs] @dataclass(frozen=True) class HermitianCheck(SpaceCheck): """ Check that a value is Hermitian within tolerances. Parameters ---------- name : str, optional Check name. Default is ``"hermitian"``. atol : float, optional Absolute tolerance for Hermitian comparison. rtol : float, optional Relative tolerance for Hermitian comparison. enforce : bool, optional Whether to enforce the Hermitian comparison. """ name: str = "hermitian" core_rank: ClassVar[int] = 2 enforce_core_shape: ClassVar[bool] = True minimum_level: ClassVar[CheckLevel] = "standard" atol: float = 1e-8 rtol: float = 1e-8 enforce: bool = True
[docs] def is_valid(self, space: Any, x: Any) -> bool: if not self.enforce: return True ops = space.ops x_adj = ops.conj(ops.swapaxes(x, -1, -2)) return bool(ops.allclose(x, x_adj, atol=self.atol, rtol=self.rtol))
[docs] def error_message(self, space: Any, x: Any) -> str: return ( "Expected Hermitian matrix; input is not Hermitian. " "Expected x satisfying x = x.conj().T " f"within atol={self.atol} and " f"rtol={self.rtol}. " f"Got shape {_shape_of(space, x)}." )
def _run_checks(space: Any, x: Any, *, allow_leading: bool) -> None: """Run all membership checks with a shared member/batched shape policy.""" level = normalize_check_level(getattr(space, "check_level", "standard")) for check in space.member_checks(): if not check_level_at_least(level, check.minimum_level): continue if not check.validate(space, x, allow_leading=allow_leading): raise SpaceValidationError( check.validation_message(space, x, allow_leading=allow_leading) )