from __future__ import annotations
from typing import Any, Callable, Literal, Sequence, Tuple, Type, cast
import numpy as np
from .._eager import EagerControlFlowMixin
from .._family import BackendFamily
from .._ops import BackendOps, LazyNamespace
from ...types import ArrayLike, DenseArray, DType, Index, SparseArray
[docs]
class TorchOps(EagerControlFlowMixin, BackendOps):
"""
BackendOps implementation for PyTorch tensors.
This backend uses PyTorch for dense and sparse tensor operations.
Dense arrays
torch.Tensor with strided layout
Sparse arrays
torch.Tensor with a PyTorch sparse layout
Methods
-------
Most methods mirror the corresponding PyTorch public API signatures and
delegate to ``torch`` or ``torch.linalg``. Backend-specific behavior,
dtype promotion, broadcasting, device placement, autograd tracking, and
error modes therefore follow PyTorch semantics.
Backend handles
- torch : module
PyTorch module stored on the class and available through instances
as ``ops.torch``. Advanced users may use it when SpaceCore's
portable API does not expose a required PyTorch feature.
Notes
-----
Code intended to remain backend-portable should prefer ``BackendOps``
methods. Direct use of ``ops.torch`` is an explicit PyTorch-specific
escape hatch.
``TorchOps`` follows PyTorch dtype defaults. When no dtype is provided,
``sanitize_dtype(None)`` returns ``torch.get_default_dtype()``. Python
``complex`` maps to ``torch.complex64`` or ``torch.complex128`` based
on the active default floating dtype, and NumPy dtype specifiers are
mapped to their corresponding PyTorch dtypes when supported.
Array creation and conversion methods may accept a backend-specific
``device=`` keyword. Existing tensors stay on their device unless an
explicit device conversion is requested. Dense conversion and ordinary
math operations do not detach tensors; autograd metadata is preserved
according to normal PyTorch rules.
"""
import torch as _torch
# Concrete library handle exposed as ``Any`` so the portable protocols
# (DenseArray/SparseArray) can be passed to typed PyTorch calls without a
# cast at every boundary; mirrors the base ``xp: ClassVar[Any]`` design.
torch: Any = _torch
xp = LazyNamespace("array_api_compat.torch")
_family = BackendFamily.torch.value.lower()
_allow_sparse = True
_sparse_layouts = (
torch.sparse_coo,
torch.sparse_csr,
torch.sparse_csc,
torch.sparse_bsr,
torch.sparse_bsc,
)
def __init__(self) -> None:
super().__init__()
@staticmethod
def _defined_kwargs(**kwargs: Any) -> dict[str, Any]:
return {key: value for key, value in kwargs.items() if value is not None}
@property
def dense_array(self) -> Type[Any]:
"""
Dense array type using PyTorch.
Returns
-------
Concrete dense tensor class accepted by this backend.
See:
https://docs.pytorch.org/docs/stable/tensors.html
"""
return self.torch.Tensor
@property
def sparse_array(self) -> Tuple[Type[Any], ...]:
"""
Sparse array type tuple using PyTorch.
Returns
-------
Tensor class accepted by this backend for sparse tensor layouts.
See:
https://docs.pytorch.org/docs/stable/sparse.html
"""
return (self.torch.Tensor,)
[docs]
def is_dense(self, x: Any) -> bool:
"""
Check whether an object is a dense PyTorch tensor.
Input:
x: Object to inspect.
Output:
Boolean indicating whether x is a strided PyTorch tensor.
See:
https://docs.pytorch.org/docs/stable/tensor_attributes.html#torch-layout
"""
return isinstance(x, self.torch.Tensor) and x.layout == self.torch.strided
[docs]
def is_sparse(self, x: Any) -> bool:
"""
Check whether an object is a sparse PyTorch tensor.
Input:
x: Object to inspect.
Output:
Boolean indicating whether x is a PyTorch tensor with a sparse layout.
See:
https://docs.pytorch.org/docs/stable/sparse.html
"""
return isinstance(x, self.torch.Tensor) and x.layout in self._sparse_layouts
[docs]
def sanitize_dtype(self, dtype: DType | None) -> DType:
"""
Normalize a dtype specifier using PyTorch.
Input:
dtype: Optional dtype requested by SpaceCore or the caller.
Output:
Backend dtype object accepted by PyTorch tensor constructors.
See:
https://docs.pytorch.org/docs/stable/tensor_attributes.html#torch-dtype
Backend-specific notes:
``None`` follows ``torch.get_default_dtype()``. NumPy dtype
specifiers are mapped to equivalent PyTorch dtypes when supported.
"""
if dtype is None:
return self.torch.get_default_dtype()
if isinstance(dtype, self.torch.dtype):
return dtype
if dtype is float:
return self.torch.get_default_dtype()
if dtype is complex:
return (
self.torch.complex128
if self.torch.get_default_dtype() == self.torch.float64
else self.torch.complex64
)
if dtype is int:
return self.torch.int64
if dtype is bool:
return self.torch.bool
try:
np_dtype = np.dtype(dtype)
except Exception as e:
raise TypeError(f"Invalid dtype specifier for PyTorch: {dtype!r}.") from e
mapping = {
np.dtype("bool"): self.torch.bool,
np.dtype("uint8"): self.torch.uint8,
np.dtype("int8"): self.torch.int8,
np.dtype("int16"): self.torch.int16,
np.dtype("int32"): self.torch.int32,
np.dtype("int64"): self.torch.int64,
np.dtype("float16"): self.torch.float16,
np.dtype("float32"): self.torch.float32,
np.dtype("float64"): self.torch.float64,
np.dtype("complex64"): self.torch.complex64,
np.dtype("complex128"): self.torch.complex128,
}
if np_dtype in mapping:
return mapping[np_dtype]
raise TypeError(f"Dtype {np_dtype!r} is not supported by PyTorch.")
[docs]
def assparse(
self,
x: Any,
*,
format: Literal["coo", "csr", "csc"] = "coo",
dtype: DType | None = None,
device: Any | None = None,
) -> SparseArray:
"""
Convert input to a sparse tensor using PyTorch.
Input:
x: Dense, sparse, or SciPy sparse input plus sparse format, dtype, and device.
Output:
Sparse backend tensor in COO, CSR, or CSC format.
See:
https://docs.pytorch.org/docs/stable/sparse.html
Backend-specific notes:
SciPy sparse inputs are converted through COO indices and values.
Dense inputs are converted through PyTorch's sparse COO conversion.
"""
self._reject_complex_to_real(x, dtype, operation="assparse")
dtype = self.sanitize_dtype(dtype) if dtype is not None else None
if self.is_sparse(x):
y = x.to(dtype=dtype, device=device) if dtype is not None or device is not None else x
if format == "coo":
return y.to_sparse_coo()
if format == "csr":
return y.to_sparse_csr()
if format == "csc":
return y.to_sparse_csc()
raise ValueError(f"Unknown sparse format: {format!r}")
try:
import scipy.sparse as sps
except Exception:
sps = None
if sps is not None and sps.issparse(x):
coo = cast(Any, x).tocoo()
indices = self.torch.as_tensor(
np.vstack((coo.row, coo.col)),
dtype=self.torch.int64,
device=device,
)
values = self.torch.as_tensor(coo.data, dtype=dtype, device=device)
out = self.torch.sparse_coo_tensor(indices, values, coo.shape, device=device)
else:
out = cast(Any, self.asarray(x, dtype=dtype, device=device)).to_sparse_coo()
if format == "coo":
return out.coalesce()
if format == "csr":
return out.to_sparse_csr()
if format == "csc":
return out.to_sparse_csc()
raise ValueError(f"Unknown sparse format: {format!r}")
[docs]
def asarray(
self,
x: Any,
dtype: DType | None = None,
*,
device: Any | None = None,
copy: bool | None = None,
backend_kwargs: dict[str, Any] | None = None,
**extra_kwargs: Any,
) -> DenseArray:
self._reject_complex_to_real(x, dtype, operation="asarray")
kwargs = {} if backend_kwargs is None else dict(backend_kwargs)
kwargs.update(extra_kwargs)
if device is not None:
kwargs["device"] = device
dtype = self.sanitize_dtype(dtype) if dtype is not None else None
if self.is_sparse(x):
x = x.to_dense()
out = self.torch.as_tensor(x, dtype=dtype, **kwargs)
return out.clone() if copy else out
[docs]
def astype(
self,
x: DenseArray,
dtype: DType | None,
*,
copy: bool = True,
non_blocking: bool = False,
memory_format: Any | None = None,
backend_kwargs: dict[str, Any] | None = None,
**extra_kwargs: Any,
) -> DenseArray:
if dtype is None:
return x
self._reject_complex_to_real(x, dtype, operation="astype")
kwargs = {} if backend_kwargs is None else dict(backend_kwargs)
kwargs.update(extra_kwargs)
kwargs.update(self._defined_kwargs(memory_format=memory_format))
return cast(Any, x).to(
dtype=self.sanitize_dtype(dtype),
non_blocking=non_blocking,
copy=copy,
**kwargs,
)
[docs]
def empty(
self,
shape: Tuple[int, ...],
dtype: DType | None = None,
*,
out: DenseArray | None = None,
layout: Any | None = None,
device: Any | None = None,
requires_grad: bool = False,
pin_memory: bool = False,
memory_format: Any | None = None,
) -> DenseArray:
return self.torch.empty(
shape,
out=out,
dtype=self.sanitize_dtype(dtype) if dtype is not None else None,
requires_grad=requires_grad,
pin_memory=pin_memory,
**self._defined_kwargs(layout=layout, device=device, memory_format=memory_format),
)
[docs]
def zeros(
self,
shape: Tuple[int, ...],
dtype: DType | None = None,
*,
out: DenseArray | None = None,
layout: Any | None = None,
device: Any | None = None,
requires_grad: bool = False,
) -> DenseArray:
return self.torch.zeros(
shape,
out=out,
dtype=self.sanitize_dtype(dtype) if dtype is not None else None,
requires_grad=requires_grad,
**self._defined_kwargs(layout=layout, device=device),
)
[docs]
def zeros_like(
self,
x: DenseArray,
dtype: DType | None = None,
*,
layout: Any | None = None,
device: Any | None = None,
requires_grad: bool = False,
memory_format: Any | None = None,
) -> DenseArray:
return self.torch.zeros_like(
x,
dtype=self.sanitize_dtype(dtype) if dtype is not None else None,
requires_grad=requires_grad,
**self._defined_kwargs(layout=layout, device=device, memory_format=memory_format),
)
[docs]
def arange(
self,
start: int,
stop: int | None = None,
step: int | None = None,
dtype: DType | None = None,
*,
out: DenseArray | None = None,
layout: Any | None = None,
device: Any | None = None,
requires_grad: bool = False,
) -> DenseArray:
kwargs = self._defined_kwargs(out=out, layout=layout, device=device)
kwargs["requires_grad"] = requires_grad
dtype = self.sanitize_dtype(dtype) if dtype is not None else None
if stop is None:
return self.torch.arange(start, dtype=dtype, **kwargs)
if step is None:
return self.torch.arange(start, stop, dtype=dtype, **kwargs)
return self.torch.arange(start, stop, step, dtype=dtype, **kwargs)
[docs]
def sum(
self,
x: DenseArray,
axis: int | Sequence[int] | None = None,
keepdims: bool = False,
dtype: DType | None = None,
*,
out: DenseArray | None = None,
) -> DenseArray:
kwargs = {"dim": self._to_axis_tuple(axis), "keepdim": keepdims}
if dtype is not None:
kwargs["dtype"] = self.sanitize_dtype(dtype)
if out is not None:
kwargs["out"] = out
return self.torch.sum(x, **kwargs)
[docs]
def matmul(
self,
a: DenseArray,
b: DenseArray,
backend_kwargs: dict[str, Any] | None = None,
*,
out: DenseArray | None = None,
) -> DenseArray:
kwargs = {} if backend_kwargs is None else dict(backend_kwargs)
if out is not None:
kwargs["out"] = out
return self.torch.matmul(a, b, **kwargs)
[docs]
def sparse_matmul(
self,
a: SparseArray,
b: DenseArray,
*,
reduce: Literal["sum", "mean", "amax", "amin"] = "sum",
) -> DenseArray:
"""
Matrix-multiply a sparse tensor by a dense tensor using PyTorch.
Input:
a: Sparse backend tensor; b: Dense backend tensor or vector.
Output:
Dense backend tensor.
See:
https://docs.pytorch.org/docs/stable/generated/torch.sparse.mm.html
"""
kwargs = {"reduce": reduce} if reduce != "sum" else {}
if b.ndim == 1:
return self.torch.sparse.mm(a, b[:, None], **kwargs)[:, 0]
return self.torch.sparse.mm(a, b, **kwargs)
[docs]
def vmap(
self,
fn: Callable,
in_axes: int | Sequence[int | None] | None = 0,
out_axes: int | Sequence[int | None] | None = 0,
) -> Callable:
"""Vectorize a function using PyTorch's native vmap when available."""
vmap = getattr(self.torch, "vmap", None)
if vmap is None and hasattr(self.torch, "func"):
vmap = getattr(self.torch.func, "vmap", None)
if vmap is None:
return super().vmap(fn, in_axes=in_axes, out_axes=out_axes)
return vmap(fn, in_dims=in_axes, out_dims=out_axes)
@property
def has_native_vmap(self) -> bool:
"""Return ``True`` because supported PyTorch versions provide native ``vmap``."""
return True
[docs]
def eigh(
self,
x: DenseArray,
backend_kwargs: dict[str, Any] | None = None,
UPLO: Literal["L", "U"] = "L",
*,
out: tuple[DenseArray, DenseArray] | None = None,
) -> tuple[DenseArray, DenseArray]:
if self.is_sparse(x):
raise TypeError("eigh requires a dense array; sparse input is not supported.")
kwargs = {} if backend_kwargs is None else dict(backend_kwargs)
kwargs.update(self._defined_kwargs(out=out))
return self.torch.linalg.eigh(x, UPLO=UPLO, **kwargs)
[docs]
def norm(
self,
x: DenseArray,
ord: int | str | None = None,
axis: int | Sequence[int] | None = None,
keepdims: bool = False,
*,
dtype: DType | None = None,
out: DenseArray | None = None,
) -> DenseArray:
return self.torch.linalg.norm(
x,
ord=ord,
dim=axis,
keepdim=keepdims,
dtype=self.sanitize_dtype(dtype) if dtype is not None else None,
out=out,
)
[docs]
def solve(
self,
A: DenseArray,
b: DenseArray,
backend_kwargs: dict[str, Any] | None = None,
*,
left: bool = True,
out: DenseArray | None = None,
) -> DenseArray:
kwargs = {} if backend_kwargs is None else dict(backend_kwargs)
kwargs.update(self._defined_kwargs(out=out))
return self.torch.linalg.solve(A, b, left=left, **kwargs)
[docs]
def svd(
self,
A: DenseArray,
full_matrices: bool = True,
backend_kwargs: dict[str, Any] | None = None,
*,
driver: str | None = None,
out: DenseArray | tuple[DenseArray, DenseArray, DenseArray] | None = None,
) -> tuple[DenseArray, DenseArray, DenseArray]:
kwargs = {} if backend_kwargs is None else dict(backend_kwargs)
kwargs.update(self._defined_kwargs(driver=driver, out=out))
return self.torch.linalg.svd(A, full_matrices=full_matrices, **kwargs)
[docs]
def cholesky(
self,
A: DenseArray,
backend_kwargs: dict[str, Any] | None = None,
*,
upper: bool = False,
out: DenseArray | None = None,
) -> DenseArray:
kwargs = {} if backend_kwargs is None else dict(backend_kwargs)
kwargs.update(self._defined_kwargs(out=out))
return self.torch.linalg.cholesky(A, upper=upper, **kwargs)
[docs]
def logsumexp(
self,
a: DenseArray,
axis: int | Sequence[int] | None = None,
b: DenseArray | None = None,
keepdims: bool = False,
return_sign: bool = False,
*,
out: DenseArray | None = None,
) -> DenseArray | tuple[DenseArray, DenseArray]:
"""
Compute log-sum-exp using PyTorch.
Input:
a: Dense backend tensor; axis, b, keepdims, return_sign: Reduction controls.
Output:
Dense backend tensor, or ``(value, sign)`` when ``return_sign`` is true.
See:
https://docs.pytorch.org/docs/stable/generated/torch.logsumexp.html
Backend-specific notes:
Weighted and signed variants are implemented in SpaceCore because
PyTorch's public ``logsumexp`` does not expose SciPy-style ``b`` or
``return_sign`` parameters.
"""
dim = tuple(range(a.ndim)) if axis is None else axis
if b is None and not return_sign:
return self.torch.logsumexp(a, dim=dim, keepdim=keepdims, out=out)
weights = self.ones_like(a) if b is None else b
m = self.torch.amax(a, dim=dim, keepdim=True)
total = self.sum(weights * self.torch.exp(a - m), axis=dim, keepdims=True)
sign = self.torch.sign(total)
result = self.torch.log(self.torch.abs(total)) + m
if not keepdims:
result = self.squeeze(result, axis)
sign = self.squeeze(sign, axis)
if return_sign:
return result, sign
if out is not None:
cast(Any, out).copy_(result)
return out
return result
[docs]
def where(
self,
condition: DenseArray | bool,
x: ArrayLike,
y: ArrayLike,
*,
out: DenseArray | None = None,
) -> DenseArray:
if out is None:
return self.torch.where(condition, x, y)
return self.torch.where(condition, x, y, out=out)
[docs]
def concatenate(
self,
arrays: Sequence[DenseArray],
axis: int = 0,
dtype: DType | None = None,
*,
out: DenseArray | None = None,
) -> DenseArray:
if out is None:
result = self.torch.cat(tuple(arrays), dim=axis)
else:
result = self.torch.cat(tuple(arrays), dim=axis, out=out)
return self.astype(result, dtype) if dtype is not None else result
def _copy(self, x: DenseArray) -> DenseArray:
"""Return a PyTorch clone of ``x`` (mutation primitive for index ops)."""
return cast(Any, x).clone()
def _scatter_add_inplace(self, y: DenseArray, index: Index, values: ArrayLike) -> None:
"""Add ``values`` into ``y`` at ``index`` in place.
Unlike NumPy's ``add.at``, plain indexed assignment does not accumulate
repeated indices; this matches PyTorch's prior ``index_add`` behavior.
"""
y[index] = y[index] + values
[docs]
def ix_(self, *args: Any) -> Any:
"""
Construct open mesh indices using PyTorch.
Input:
args: One-dimensional index arrays or array-like objects.
Output:
Tuple of broadcastable index tensors.
See:
https://docs.pytorch.org/docs/stable/generated/torch.meshgrid.html
"""
tensors = tuple(
arg if isinstance(arg, self.torch.Tensor) else self.asarray(arg) for arg in args
)
return self.torch.meshgrid(*tensors, indexing="ij")
[docs]
def allclose_sparse(
self, a: SparseArray, b: SparseArray, rtol: float = 1e-5, atol: float = 1e-8
) -> bool:
"""
Compare sparse tensors elementwise within tolerances using PyTorch.
Input:
a, b: Sparse backend tensors; rtol and atol: Comparison controls.
Output:
Boolean indicating whether sparse tensors are close.
See:
https://docs.pytorch.org/docs/stable/sparse.html
Backend-specific notes:
Sparse tensors are compared by converting both operands to dense
tensors before calling ``allclose``.
"""
self._require_two_sparse(a, b, noun="sparse tensors")
return self.allclose(cast(Any, a).to_dense(), cast(Any, b).to_dense(), rtol=rtol, atol=atol)