Source code for spacecore.backend.torch._ops

from __future__ import annotations

from typing import Any, Callable, Literal, Optional, Sequence, Tuple, Type

import numpy as np

from .._family import BackendFamily
from .._ops import BackendOps
from ...types import ArrayLike, DenseArray, DType, Index, SparseArray, T, X, Y, R, Carry


[docs] class TorchOps(BackendOps): """ BackendOps implementation for PyTorch tensors. This backend uses PyTorch for dense and sparse tensor operations. Dense arrays torch.Tensor with strided layout Sparse arrays torch.Tensor with a PyTorch sparse layout Methods Most methods mirror the corresponding PyTorch public API signatures and delegate to ``torch`` or ``torch.linalg``. Backend-specific behavior, dtype promotion, broadcasting, device placement, autograd tracking, and error modes therefore follow PyTorch semantics. Backend handles - torch : module PyTorch module stored on the class and available through instances as ``ops.torch``. Advanced users may use it when SpaceCore's portable API does not expose a required PyTorch feature. Notes Code intended to remain backend-portable should prefer ``BackendOps`` methods. Direct use of ``ops.torch`` is an explicit PyTorch-specific escape hatch. ``TorchOps`` follows PyTorch dtype defaults. When no dtype is provided, ``sanitize_dtype(None)`` returns ``torch.get_default_dtype()``. Python ``complex`` maps to ``torch.complex64`` or ``torch.complex128`` based on the active default floating dtype, and NumPy dtype specifiers are mapped to their corresponding PyTorch dtypes when supported. Array creation and conversion methods may accept a backend-specific ``device=`` keyword. Existing tensors stay on their device unless an explicit device conversion is requested. Dense conversion and ordinary math operations do not detach tensors; autograd metadata is preserved according to normal PyTorch rules. """ import torch _family = BackendFamily.torch.value.lower() _allow_sparse = True _sparse_layouts = ( torch.sparse_coo, torch.sparse_csr, torch.sparse_csc, torch.sparse_bsr, torch.sparse_bsc, ) @staticmethod def _defined_kwargs(**kwargs: Any) -> dict[str, Any]: return {key: value for key, value in kwargs.items() if value is not None} @property def dense_array(self) -> Type[Any]: """ Dense array type using PyTorch. Returns: Concrete dense tensor class accepted by this backend. See: https://docs.pytorch.org/docs/stable/tensors.html """ return self.torch.Tensor @property def sparse_array(self) -> Tuple[Type[Any], ...]: """ Sparse array type tuple using PyTorch. Returns: Tensor class accepted by this backend for sparse tensor layouts. See: https://docs.pytorch.org/docs/stable/sparse.html """ return (self.torch.Tensor,)
[docs] def is_dense(self, x: Any) -> bool: """ Check whether an object is a dense PyTorch tensor. Input: x: Object to inspect. Output: Boolean indicating whether x is a strided PyTorch tensor. See: https://docs.pytorch.org/docs/stable/tensor_attributes.html#torch-layout """ return isinstance(x, self.torch.Tensor) and x.layout == self.torch.strided
[docs] def is_sparse(self, x: Any) -> bool: """ Check whether an object is a sparse PyTorch tensor. Input: x: Object to inspect. Output: Boolean indicating whether x is a PyTorch tensor with a sparse layout. See: https://docs.pytorch.org/docs/stable/sparse.html """ return isinstance(x, self.torch.Tensor) and x.layout in self._sparse_layouts
[docs] def sanitize_dtype(self, dtype: DType | None) -> DType: """ Normalize a dtype specifier using PyTorch. Input: dtype: Optional dtype requested by SpaceCore or the caller. Output: Backend dtype object accepted by PyTorch tensor constructors. See: https://docs.pytorch.org/docs/stable/tensor_attributes.html#torch-dtype Backend-specific notes: ``None`` follows ``torch.get_default_dtype()``. NumPy dtype specifiers are mapped to equivalent PyTorch dtypes when supported. """ if dtype is None: return self.torch.get_default_dtype() if isinstance(dtype, self.torch.dtype): return dtype if dtype is float: return self.torch.get_default_dtype() if dtype is complex: return ( self.torch.complex128 if self.torch.get_default_dtype() == self.torch.float64 else self.torch.complex64 ) if dtype is int: return self.torch.int64 if dtype is bool: return self.torch.bool try: np_dtype = np.dtype(dtype) except Exception as e: raise TypeError(f"Invalid dtype specifier for PyTorch: {dtype!r}.") from e mapping = { np.dtype("bool"): self.torch.bool, np.dtype("uint8"): self.torch.uint8, np.dtype("int8"): self.torch.int8, np.dtype("int16"): self.torch.int16, np.dtype("int32"): self.torch.int32, np.dtype("int64"): self.torch.int64, np.dtype("float16"): self.torch.float16, np.dtype("float32"): self.torch.float32, np.dtype("float64"): self.torch.float64, np.dtype("complex64"): self.torch.complex64, np.dtype("complex128"): self.torch.complex128, } if np_dtype in mapping: return mapping[np_dtype] raise TypeError(f"Dtype {np_dtype!r} is not supported by PyTorch.")
[docs] def get_dtype(self, x: Any) -> DType: """ Return a tensor dtype using PyTorch. Input: x: Dense or sparse backend tensor. Output: Backend dtype associated with x. See: https://docs.pytorch.org/docs/stable/tensor_attributes.html#torch-dtype """ if self.is_array(x): return x.dtype raise TypeError(f"Expected PyTorch tensor, got {type(x)}.")
[docs] def shape(self, x: Any) -> tuple[int, ...]: """ Return tensor shape metadata using PyTorch. Input: x: Dense or sparse backend tensor. Output: Tuple describing the logical shape of x. See: https://docs.pytorch.org/docs/stable/generated/torch.Tensor.shape.html """ return tuple(x.shape)
[docs] def ndim(self, x: Any) -> int: """ Return tensor rank metadata using PyTorch. Input: x: Dense or sparse backend tensor. Output: Number of dimensions in x. See: https://docs.pytorch.org/docs/stable/generated/torch.Tensor.ndim.html """ return int(x.ndim)
[docs] def size(self, x: Any) -> int: """ Return logical element count using PyTorch. Input: x: Dense or sparse backend tensor. Output: Total number of logical dense elements. See: https://docs.pytorch.org/docs/stable/generated/torch.numel.html """ return int(x.numel())
@property def inf(self): """ Positive infinity scalar using PyTorch. Returns: Backend tensor scalar representing positive infinity. See: https://docs.pytorch.org/docs/stable/generated/torch.tensor.html """ return self.torch.tensor(float("inf")) @property def nan(self): """ NaN scalar using PyTorch. Returns: Backend tensor scalar representing NaN. See: https://docs.pytorch.org/docs/stable/generated/torch.tensor.html """ return self.torch.tensor(float("nan")) @property def pi(self): """ Pi scalar using PyTorch. Returns: Backend tensor scalar representing pi. See: https://docs.pytorch.org/docs/stable/generated/torch.tensor.html """ return self.torch.tensor(np.pi) @property def e(self): """ Euler number scalar using PyTorch. Returns: Backend tensor scalar representing Euler's number. See: https://docs.pytorch.org/docs/stable/generated/torch.tensor.html """ return self.torch.tensor(np.e) @property def eps(self): """ Machine epsilon scalar using PyTorch. Returns: Backend tensor scalar for float64 machine epsilon. See: https://docs.pytorch.org/docs/stable/type_info.html#torch.finfo """ return self.torch.tensor(self.torch.finfo(self.torch.float64).eps)
[docs] def asarray( self, x: Any, dtype: DType | None = None, *, device: Any | None = None, copy: bool | None = None, ) -> DenseArray: """ Convert input to a dense tensor using PyTorch. Input: x/a: Array-like input and optional dtype, device, or copy controls. Output: Dense backend tensor. See: https://docs.pytorch.org/docs/stable/generated/torch.as_tensor.html Backend-specific notes: Sparse tensors are densified. Existing tensors keep autograd metadata according to normal PyTorch conversion rules. """ dtype = self.sanitize_dtype(dtype) if dtype is not None else None if self.is_sparse(x): x = x.to_dense() out = self.torch.as_tensor(x, dtype=dtype, device=device) if copy: out = out.clone() return out
[docs] def astype( self, x: DenseArray, dtype: DType, copy: bool = True, *, non_blocking: bool = False, memory_format: Any | None = None, ) -> DenseArray: """ Cast a tensor to a dtype using PyTorch. Input: x: Dense backend tensor; dtype: Target dtype; copy: Whether to force a copy. Output: Tensor with the requested dtype. See: https://docs.pytorch.org/docs/stable/generated/torch.Tensor.to.html """ return x.to( dtype=self.sanitize_dtype(dtype), non_blocking=non_blocking, copy=copy, **self._defined_kwargs(memory_format=memory_format), )
[docs] def assparse( self, x: Any, *, format: Literal["coo", "csr", "csc"] = "coo", dtype: DType | None = None, device: Any | None = None, ) -> SparseArray: """ Convert input to a sparse tensor using PyTorch. Input: x: Dense, sparse, or SciPy sparse input plus sparse format, dtype, and device. Output: Sparse backend tensor in COO, CSR, or CSC format. See: https://docs.pytorch.org/docs/stable/sparse.html Backend-specific notes: SciPy sparse inputs are converted through COO indices and values. Dense inputs are converted through PyTorch's sparse COO conversion. """ dtype = self.sanitize_dtype(dtype) if dtype is not None else None if self.is_sparse(x): y = x.to(dtype=dtype, device=device) if dtype is not None or device is not None else x if format == "coo": return y.to_sparse_coo() if format == "csr": return y.to_sparse_csr() if format == "csc": return y.to_sparse_csc() raise ValueError(f"Unknown sparse format: {format!r}") try: import scipy.sparse as sps except Exception: sps = None if sps is not None and sps.issparse(x): coo = x.tocoo() indices = self.torch.as_tensor( np.vstack((coo.row, coo.col)), dtype=self.torch.int64, device=device, ) values = self.torch.as_tensor(coo.data, dtype=dtype, device=device) out = self.torch.sparse_coo_tensor(indices, values, coo.shape, device=device) else: out = self.asarray(x, dtype=dtype, device=device).to_sparse_coo() if format == "coo": return out.coalesce() if format == "csr": return out.to_sparse_csr() if format == "csc": return out.to_sparse_csc() raise ValueError(f"Unknown sparse format: {format!r}")
[docs] def empty( self, shape: int | Tuple[int, ...], dtype: DType | None = None, *, out: DenseArray | None = None, layout: Any | None = None, device: Any | None = None, requires_grad: bool = False, pin_memory: bool = False, memory_format: Any | None = None, ) -> DenseArray: """ Create an uninitialized dense tensor using PyTorch. Input: shape: Output shape; dtype and device: Optional construction parameters. Output: Dense backend tensor with uninitialized values. See: https://docs.pytorch.org/docs/stable/generated/torch.empty.html """ return self.torch.empty( shape, out=out, dtype=self.sanitize_dtype(dtype) if dtype is not None else None, requires_grad=requires_grad, pin_memory=pin_memory, **self._defined_kwargs(layout=layout, device=device, memory_format=memory_format), )
[docs] def zeros( self, shape: int | Tuple[int, ...], dtype: DType | None = None, *, out: DenseArray | None = None, layout: Any | None = None, device: Any | None = None, requires_grad: bool = False, ) -> DenseArray: """ Create a dense tensor filled with zeros using PyTorch. Input: shape: Output shape; dtype and device: Optional construction parameters. Output: Dense backend tensor. See: https://docs.pytorch.org/docs/stable/generated/torch.zeros.html """ return self.torch.zeros( shape, out=out, dtype=self.sanitize_dtype(dtype) if dtype is not None else None, requires_grad=requires_grad, **self._defined_kwargs(layout=layout, device=device), )
[docs] def ones( self, shape: int | Tuple[int, ...], dtype: DType | None = None, *, out: DenseArray | None = None, layout: Any | None = None, device: Any | None = None, requires_grad: bool = False, ) -> DenseArray: """ Create a dense tensor filled with ones using PyTorch. Input: shape: Output shape; dtype and device: Optional construction parameters. Output: Dense backend tensor. See: https://docs.pytorch.org/docs/stable/generated/torch.ones.html """ return self.torch.ones( shape, out=out, dtype=self.sanitize_dtype(dtype) if dtype is not None else None, requires_grad=requires_grad, **self._defined_kwargs(layout=layout, device=device), )
[docs] def zeros_like( self, x: DenseArray, dtype: DType | None = None, *, layout: Any | None = None, device: Any | None = None, requires_grad: bool = False, memory_format: Any | None = None, ) -> DenseArray: """ Create a zero tensor matching another tensor using PyTorch. Input: x: Reference tensor; dtype and device: Optional overrides. Output: Dense backend tensor with shape matching x. See: https://docs.pytorch.org/docs/stable/generated/torch.zeros_like.html """ return self.torch.zeros_like( x, dtype=self.sanitize_dtype(dtype) if dtype is not None else None, requires_grad=requires_grad, **self._defined_kwargs(layout=layout, device=device, memory_format=memory_format), )
[docs] def ones_like( self, x: DenseArray, dtype: DType | None = None, *, layout: Any | None = None, device: Any | None = None, requires_grad: bool = False, memory_format: Any | None = None, ) -> DenseArray: """ Create a one tensor matching another tensor using PyTorch. Input: x: Reference tensor; dtype and device: Optional overrides. Output: Dense backend tensor with shape matching x. See: https://docs.pytorch.org/docs/stable/generated/torch.ones_like.html """ return self.torch.ones_like( x, dtype=self.sanitize_dtype(dtype) if dtype is not None else None, requires_grad=requires_grad, **self._defined_kwargs(layout=layout, device=device, memory_format=memory_format), )
[docs] def full_like( self, x: DenseArray, value: Any, dtype: DType | None = None, *, layout: Any | None = None, device: Any | None = None, requires_grad: bool = False, memory_format: Any | None = None, ) -> DenseArray: """ Create a filled tensor matching another tensor using PyTorch. Input: x: Reference tensor; value: Fill value; dtype and device: Optional overrides. Output: Dense backend tensor with shape matching x. See: https://docs.pytorch.org/docs/stable/generated/torch.full_like.html """ return self.torch.full_like( x, value, dtype=self.sanitize_dtype(dtype) if dtype is not None else None, requires_grad=requires_grad, **self._defined_kwargs(layout=layout, device=device, memory_format=memory_format), )
[docs] def arange( self, start: int | float = 0, stop: int | float | None = None, step: int | float | None = None, dtype: DType | None = None, *, out: DenseArray | None = None, layout: Any | None = None, device: Any | None = None, requires_grad: bool = False, ) -> DenseArray: """ Create a range tensor using PyTorch. Input: start, stop, step: Range parameters; dtype and device: Optional construction parameters. Output: Dense backend tensor containing evenly spaced values. See: https://docs.pytorch.org/docs/stable/generated/torch.arange.html """ dtype = self.sanitize_dtype(dtype) if dtype is not None else None kwargs = self._defined_kwargs(out=out, layout=layout, device=device) kwargs["requires_grad"] = requires_grad if stop is None: return self.torch.arange(start, dtype=dtype, **kwargs) if step is None: return self.torch.arange(start, stop, dtype=dtype, **kwargs) return self.torch.arange(start, stop, step, dtype=dtype, **kwargs)
[docs] def full( self, shape: int | Tuple[int, ...], fill_value: Any, dtype: DType | None = None, *, out: DenseArray | None = None, layout: Any | None = None, device: Any | None = None, requires_grad: bool = False, ) -> DenseArray: """ Create a filled dense tensor using PyTorch. Input: shape: Output shape; fill_value: Fill value; dtype and device: Optional parameters. Output: Dense backend tensor. See: https://docs.pytorch.org/docs/stable/generated/torch.full.html """ return self.torch.full( shape, fill_value, out=out, dtype=self.sanitize_dtype(dtype) if dtype is not None else None, requires_grad=requires_grad, **self._defined_kwargs(layout=layout, device=device), )
[docs] def eye( self, n: int, m: int | None = None, k: int = 0, dtype: DType | None = None, *, out: DenseArray | None = None, layout: Any | None = None, device: Any | None = None, requires_grad: bool = False, ) -> DenseArray: """ Create a two-dimensional identity-like tensor using PyTorch. Input: n, m: Matrix dimensions; k: Diagonal offset; dtype and device: Optional parameters. Output: Dense backend tensor with ones on the requested diagonal. See: https://docs.pytorch.org/docs/stable/generated/torch.eye.html Backend-specific notes: PyTorch ``torch.eye`` has no diagonal offset parameter, so SpaceCore constructs the offset diagonal explicitly. """ m = n if m is None else m dtype = self.sanitize_dtype(dtype) if dtype is not None else None if k == 0: return self.torch.eye( n, m, out=out, dtype=dtype, requires_grad=requires_grad, **self._defined_kwargs(layout=layout, device=device), ) out = self.torch.zeros( (n, m), out=out, dtype=dtype, requires_grad=False, **self._defined_kwargs(layout=layout, device=device), ) diag_len = min(n, m - k) if k > 0 else min(n + k, m) if diag_len <= 0: return out rows = self.torch.arange(diag_len, device=device) cols = rows + k if k < 0: rows = rows - k cols = self.torch.arange(diag_len, device=device) out[rows, cols] = 1 if requires_grad: out.requires_grad_() return out
[docs] def ravel(self, x: DenseArray) -> DenseArray: """ Flatten a tensor using PyTorch. Input: x: Dense backend tensor. Output: One-dimensional tensor view or copy following PyTorch semantics. See: https://docs.pytorch.org/docs/stable/generated/torch.ravel.html """ return self.torch.ravel(x)
[docs] def reshape(self, x: DenseArray, shape: int | Tuple[int, ...], *, copy: bool | None = None) -> DenseArray: """ Reshape a tensor using PyTorch. Input: x: Dense backend tensor; shape: Target shape; copy: Whether to clone first. Output: Reshaped dense backend tensor. See: https://docs.pytorch.org/docs/stable/generated/torch.reshape.html """ if copy: x = x.clone() return self.torch.reshape(x, shape if isinstance(shape, tuple) else (shape,))
[docs] def transpose(self, x: DenseArray, axes: Sequence[int] | None = None) -> DenseArray: """ Permute tensor axes using PyTorch. Input: x: Dense backend tensor; axes: Optional axis order. Output: Tensor with permuted axes. See: https://docs.pytorch.org/docs/stable/generated/torch.permute.html """ if axes is None: axes = tuple(reversed(range(x.ndim))) return x.permute(tuple(axes))
[docs] def swapaxes(self, x: DenseArray, axis1: int, axis2: int) -> DenseArray: """ Swap two tensor axes using PyTorch. Input: x: Dense backend tensor; axis1, axis2: Axes to swap. Output: Tensor with the requested axes swapped. See: https://docs.pytorch.org/docs/stable/generated/torch.swapaxes.html """ return self.torch.swapaxes(x, axis1, axis2)
[docs] def broadcast_to(self, x: DenseArray, shape: int | Tuple[int, ...]) -> DenseArray: """ Broadcast a tensor to a shape using PyTorch. Input: x: Dense backend tensor; shape: Target broadcast shape. Output: Broadcasted tensor view following PyTorch broadcasting rules. See: https://docs.pytorch.org/docs/stable/generated/torch.broadcast_to.html """ return self.torch.broadcast_to(x, shape)
[docs] def expand_dims(self, x: DenseArray, axis: int | Sequence[int]) -> DenseArray: """ Insert singleton dimensions using PyTorch. Input: x: Dense backend tensor; axis: Axis or axes where dimensions are inserted. Output: Tensor with inserted singleton dimensions. See: https://docs.pytorch.org/docs/stable/generated/torch.unsqueeze.html """ if isinstance(axis, int): return self.torch.unsqueeze(x, axis) ndim = x.ndim + len(axis) axes = sorted(a + ndim if a < 0 else a for a in axis) out = x for ax in axes: out = self.torch.unsqueeze(out, ax) return out
[docs] def squeeze(self, x: DenseArray, axis: int | Sequence[int] | None = None) -> DenseArray: """ Remove singleton dimensions using PyTorch. Input: x: Dense backend tensor; axis: Optional axis or axes to squeeze. Output: Tensor with singleton dimensions removed. See: https://docs.pytorch.org/docs/stable/generated/torch.squeeze.html """ if axis is None: return self.torch.squeeze(x) if isinstance(axis, int): return self.torch.squeeze(x, dim=axis) out = x for ax in sorted(axis, reverse=True): out = self.torch.squeeze(out, dim=ax) return out
[docs] def moveaxis(self, x: DenseArray, source: int | Sequence[int], destination: int | Sequence[int]) -> DenseArray: """ Move tensor axes to new positions using PyTorch. Input: x: Dense backend tensor; source and destination: Axis positions. Output: Tensor with axes moved. See: https://docs.pytorch.org/docs/stable/generated/torch.moveaxis.html """ return self.torch.moveaxis(x, source, destination)
[docs] def stack( self, arrays: Sequence[DenseArray], axis: int = 0, *, out: DenseArray | None = None, ) -> DenseArray: """ Stack tensors along a new axis using PyTorch. Input: arrays: Sequence of tensors; axis: New axis; out: Optional output tensor. Output: Dense backend tensor. See: https://docs.pytorch.org/docs/stable/generated/torch.stack.html """ arrays = tuple(arrays) if out is None: return self.torch.stack(arrays, dim=axis) return self.torch.stack(arrays, dim=axis, out=out)
[docs] def conj(self, x: DenseArray) -> DenseArray: """ Return the complex conjugate using PyTorch. Input: x: Dense backend tensor. Output: Tensor containing complex conjugates. See: https://docs.pytorch.org/docs/stable/generated/torch.conj.html """ return self.torch.conj(x)
[docs] def real(self, x: DenseArray) -> DenseArray: """ Return the real part of a tensor using PyTorch. Input: x: Dense backend tensor. Output: Tensor view or value containing real components. See: https://docs.pytorch.org/docs/stable/generated/torch.real.html """ return self.torch.real(x)
[docs] def imag(self, x: DenseArray) -> DenseArray: """ Return the imaginary part of a tensor using PyTorch. Input: x: Dense backend tensor. Output: Tensor view or value containing imaginary components. See: https://docs.pytorch.org/docs/stable/generated/torch.imag.html """ return self.torch.imag(x)
[docs] def abs( self, x: DenseArray, *, out: DenseArray | None = None, ) -> DenseArray: """ Compute elementwise absolute value using PyTorch. Input: x: Dense backend tensor. Output: Dense backend tensor. See: https://docs.pytorch.org/docs/stable/generated/torch.abs.html """ if out is None: return self.torch.abs(x) return self.torch.abs(x, out=out)
[docs] def sign( self, x: DenseArray, *, out: DenseArray | None = None, ) -> DenseArray: """ Compute elementwise sign using PyTorch. Input: x: Dense backend tensor. Output: Dense backend tensor. See: https://docs.pytorch.org/docs/stable/generated/torch.sign.html """ if out is None: return self.torch.sign(x) return self.torch.sign(x, out=out)
[docs] def sqrt( self, x: DenseArray, *, out: DenseArray | None = None, ) -> DenseArray: """ Compute elementwise square root using PyTorch. Input: x: Dense backend tensor. Output: Dense backend tensor. See: https://docs.pytorch.org/docs/stable/generated/torch.sqrt.html """ if out is None: return self.torch.sqrt(x) return self.torch.sqrt(x, out=out)
[docs] def sum( self, x: DenseArray, axis: int | Sequence[int] | None = None, dtype: DType | None = None, keepdims: bool = False, *, out: DenseArray | None = None, ) -> DenseArray: """ Sum tensor elements using PyTorch. Input: x: Dense backend tensor; axis, dtype, keepdims: Reduction controls. Output: Dense backend tensor or scalar tensor. See: https://docs.pytorch.org/docs/stable/generated/torch.sum.html """ if dtype is None: if out is None: if axis is None and not keepdims: return self.torch.sum(x) return self.torch.sum(x, dim=axis, keepdim=keepdims) return self.torch.sum(x, dim=axis, keepdim=keepdims, out=out) dtype = self.sanitize_dtype(dtype) if out is None: if axis is None and not keepdims: return self.torch.sum(x, dtype=dtype) return self.torch.sum(x, dim=axis, keepdim=keepdims, dtype=dtype) return self.torch.sum( x, dim=axis, keepdim=keepdims, dtype=dtype, out=out, )
[docs] def mean( self, x: DenseArray, axis: int | Sequence[int] | None = None, dtype: DType | None = None, keepdims: bool = False, *, out: DenseArray | None = None, ) -> DenseArray: """ Average tensor elements using PyTorch. Input: x: Dense backend tensor; axis, dtype, keepdims: Reduction controls. Output: Dense backend tensor or scalar tensor. See: https://docs.pytorch.org/docs/stable/generated/torch.mean.html """ if dtype is None: if out is None: if axis is None and not keepdims: return self.torch.mean(x) return self.torch.mean(x, dim=axis, keepdim=keepdims) return self.torch.mean(x, dim=axis, keepdim=keepdims, out=out) dtype = self.sanitize_dtype(dtype) if out is None: if axis is None and not keepdims: return self.torch.mean(x, dtype=dtype) return self.torch.mean(x, dim=axis, keepdim=keepdims, dtype=dtype) return self.torch.mean( x, dim=axis, keepdim=keepdims, dtype=dtype, out=out, )
[docs] def min( self, x: DenseArray, axis: int | Sequence[int] | None = None, keepdims: bool = False, *, out: DenseArray | None = None, ) -> DenseArray: """ Compute minimum values using PyTorch. Input: x: Dense backend tensor; axis and keepdims: Reduction controls. Output: Dense backend tensor or scalar tensor. See: https://docs.pytorch.org/docs/stable/generated/torch.amin.html """ if out is None: if axis is None and not keepdims: return self.torch.amin(x) return self.torch.amin(x, dim=axis, keepdim=keepdims) return self.torch.amin(x, dim=axis, keepdim=keepdims, out=out)
[docs] def max( self, x: DenseArray, axis: int | Sequence[int] | None = None, keepdims: bool = False, *, out: DenseArray | None = None, ) -> DenseArray: """ Compute maximum values using PyTorch. Input: x: Dense backend tensor; axis and keepdims: Reduction controls. Output: Dense backend tensor or scalar tensor. See: https://docs.pytorch.org/docs/stable/generated/torch.amax.html """ if out is None: if axis is None and not keepdims: return self.torch.amax(x) return self.torch.amax(x, dim=axis, keepdim=keepdims) return self.torch.amax(x, dim=axis, keepdim=keepdims, out=out)
[docs] def prod( self, x: DenseArray, axis: int | Sequence[int] | None = None, dtype: DType | None = None, keepdims: bool = False, *, out: DenseArray | None = None, ) -> DenseArray: """ Multiply tensor elements using PyTorch. Input: x: Dense backend tensor; axis, dtype, keepdims: Reduction controls. Output: Dense backend tensor or scalar tensor. See: https://docs.pytorch.org/docs/stable/generated/torch.prod.html Backend-specific notes: Multiple-axis products are applied one axis at a time because PyTorch's ``torch.prod`` reduces a single dimension per call. """ dtype = self.sanitize_dtype(dtype) if dtype is not None else None if axis is None: result = self.torch.prod(x) if dtype is None else self.torch.prod(x, dtype=dtype) if out is not None: out.copy_(result) return out return result if isinstance(axis, int): if out is None: if dtype is None: return self.torch.prod(x, dim=axis, keepdim=keepdims) return self.torch.prod(x, dim=axis, dtype=dtype, keepdim=keepdims) return self.torch.prod(x, dim=axis, dtype=dtype, keepdim=keepdims, out=out) result = x for ax in sorted(axis, reverse=True): result = self.torch.prod(result, dim=ax, dtype=dtype, keepdim=keepdims) if out is not None: out.copy_(result) return out return result
[docs] def trace( self, x: DenseArray, offset: int = 0, axis1: int = 0, axis2: int = 1, dtype: DType | None = None, ) -> DenseArray: """ Sum diagonal entries using PyTorch. Input: x: Dense backend tensor; offset, axis1, axis2, dtype: Diagonal controls. Output: Dense backend tensor or scalar tensor. See: https://docs.pytorch.org/docs/stable/generated/torch.diagonal.html """ return self.sum(self.diagonal(x, offset=offset, axis1=axis1, axis2=axis2), dtype=dtype)
[docs] def argsort( self, x: DenseArray, axis: int = -1, stable: bool = False, descending: bool = False, ) -> DenseArray: """ Return sorting indices using PyTorch. Input: x: Dense backend tensor; axis, stable, descending: Sorting controls. Output: Integer tensor of indices. See: https://docs.pytorch.org/docs/stable/generated/torch.argsort.html """ return self.torch.argsort(x, dim=axis, stable=stable, descending=descending)
[docs] def sort( self, x: DenseArray, axis: int = -1, stable: bool = False, descending: bool = False, *, out: tuple[DenseArray, DenseArray] | None = None, ) -> DenseArray: """ Sort tensor values using PyTorch. Input: x: Dense backend tensor; axis, stable, descending: Sorting controls. Output: Dense backend tensor of sorted values. See: https://docs.pytorch.org/docs/stable/generated/torch.sort.html """ return self.torch.sort(x, dim=axis, stable=stable, descending=descending, out=out).values
[docs] def argmin(self, x: DenseArray, axis: int | None = None, keepdims: bool = False) -> DenseArray: """ Return indices of minimum values using PyTorch. Input: x: Dense backend tensor; axis and keepdims: Reduction controls. Output: Integer tensor of indices. See: https://docs.pytorch.org/docs/stable/generated/torch.argmin.html """ return self.torch.argmin(x, dim=axis, keepdim=keepdims)
[docs] def argmax(self, x: DenseArray, axis: int | None = None, keepdims: bool = False) -> DenseArray: """ Return indices of maximum values using PyTorch. Input: x: Dense backend tensor; axis and keepdims: Reduction controls. Output: Integer tensor of indices. See: https://docs.pytorch.org/docs/stable/generated/torch.argmax.html """ return self.torch.argmax(x, dim=axis, keepdim=keepdims)
[docs] def vdot( self, x: DenseArray, y: DenseArray, *, out: DenseArray | None = None, ) -> DenseArray: """ Compute conjugating vector dot product using PyTorch. Input: x, y: Dense backend tensors. Output: Scalar tensor containing the vector dot product. See: https://docs.pytorch.org/docs/stable/generated/torch.vdot.html """ x1 = x if x.ndim == 1 else self.torch.ravel(x) y1 = y if y.ndim == 1 else self.torch.ravel(y) if out is None: return self.torch.vdot(x1, y1) return self.torch.vdot(x1, y1, out=out)
[docs] def matmul( self, a: DenseArray, b: DenseArray, *, out: DenseArray | None = None, ) -> DenseArray: """ Matrix-multiply tensors using PyTorch. Input: a, b: Dense backend tensors. Output: Dense backend tensor. See: https://docs.pytorch.org/docs/stable/generated/torch.matmul.html """ if out is None: return self.torch.matmul(a, b) return self.torch.matmul(a, b, out=out)
[docs] def sparse_matmul( self, a: SparseArray, b: DenseArray, *, reduce: Literal["sum", "mean", "amax", "amin"] = "sum", ) -> DenseArray: """ Matrix-multiply a sparse tensor by a dense tensor using PyTorch. Input: a: Sparse backend tensor; b: Dense backend tensor or vector. Output: Dense backend tensor. See: https://docs.pytorch.org/docs/stable/generated/torch.sparse.mm.html """ kwargs = {"reduce": reduce} if reduce != "sum" else {} if b.ndim == 1: return self.torch.sparse.mm(a, b[:, None], **kwargs)[:, 0] return self.torch.sparse.mm(a, b, **kwargs)
[docs] def kron( self, a: DenseArray, b: DenseArray, *, out: DenseArray | None = None, ) -> DenseArray: """ Compute the Kronecker product using PyTorch. Input: a, b: Dense backend tensors. Output: Dense backend tensor. See: https://docs.pytorch.org/docs/stable/generated/torch.kron.html """ return self.torch.kron(a, b, out=out)
[docs] def einsum(self, subscripts: str, *operands: DenseArray) -> DenseArray: """ Evaluate an Einstein summation using PyTorch. Input: subscripts: Einsum expression; operands: Dense backend tensors. Output: Dense backend tensor. See: https://docs.pytorch.org/docs/stable/generated/torch.einsum.html """ return self.torch.einsum(subscripts, *operands)
[docs] def eigh( self, x: DenseArray, UPLO: Literal["L", "U"] = "L", *, out: tuple[DenseArray, DenseArray] | None = None, ) -> tuple[DenseArray, DenseArray]: """ Compute Hermitian eigenvalues and eigenvectors using PyTorch. Input: x: Dense Hermitian or symmetric backend tensor. Output: Tuple of eigenvalues and eigenvectors. See: https://docs.pytorch.org/docs/stable/generated/torch.linalg.eigh.html """ if self.is_sparse(x): raise TypeError("eigh requires a dense array; sparse input is not supported.") return self.torch.linalg.eigh(x, UPLO=UPLO, out=out)
[docs] def norm( self, x: DenseArray, ord: int | str | None = None, axis: int | Sequence[int] | None = None, keepdims: bool = False, *, dtype: DType | None = None, out: DenseArray | None = None, ) -> DenseArray: """ Compute vector or matrix norms using PyTorch. Input: x: Dense backend tensor; ord, axis, keepdims: Norm controls. Output: Dense backend tensor or scalar tensor. See: https://docs.pytorch.org/docs/stable/generated/torch.linalg.norm.html """ return self.torch.linalg.norm( x, ord=ord, dim=axis, keepdim=keepdims, dtype=self.sanitize_dtype(dtype) if dtype is not None else None, out=out, )
[docs] def solve( self, A: DenseArray, b: DenseArray, *, left: bool = True, out: DenseArray | None = None, ) -> DenseArray: """ Solve a linear system using PyTorch. Input: A: Coefficient tensor; b: Right-hand side tensor. Output: Dense backend tensor solving ``A @ x = b``. See: https://docs.pytorch.org/docs/stable/generated/torch.linalg.solve.html """ return self.torch.linalg.solve(A, b, left=left, out=out)
[docs] def eigvalsh( self, A: DenseArray, UPLO: Literal["L", "U"] = "L", *, out: DenseArray | None = None, ) -> DenseArray: """ Compute Hermitian eigenvalues using PyTorch. Input: A: Dense Hermitian or symmetric backend tensor. Output: Dense backend tensor of eigenvalues. See: https://docs.pytorch.org/docs/stable/generated/torch.linalg.eigvalsh.html """ return self.torch.linalg.eigvalsh(A, UPLO=UPLO, out=out)
[docs] def svd( self, A: DenseArray, full_matrices: bool = True, compute_uv: bool = True, hermitian: bool = False, *, driver: str | None = None, out: DenseArray | tuple[DenseArray, DenseArray, DenseArray] | None = None, ) -> DenseArray | tuple[DenseArray, DenseArray, DenseArray]: """ Compute singular value decomposition using PyTorch. Input: A: Dense backend tensor; full_matrices, compute_uv, hermitian: SVD controls. Output: Singular values or tuple ``(U, S, Vh)``. See: https://docs.pytorch.org/docs/stable/generated/torch.linalg.svd.html Backend-specific notes: PyTorch does not expose a ``hermitian`` option for SVD. When ``compute_uv`` is false, this delegates to ``torch.linalg.svdvals``. """ if hermitian: raise NotImplementedError("PyTorch svd does not expose a hermitian option.") if not compute_uv: return self.torch.linalg.svdvals(A, driver=driver, out=out) return self.torch.linalg.svd(A, full_matrices=full_matrices, driver=driver, out=out)
[docs] def cholesky( self, A: DenseArray, *, upper: bool = False, out: DenseArray | None = None, ) -> DenseArray: """ Compute a Cholesky factorization using PyTorch. Input: A: Positive-definite dense backend tensor. Output: Dense backend tensor containing the Cholesky factor. See: https://docs.pytorch.org/docs/stable/generated/torch.linalg.cholesky.html """ return self.torch.linalg.cholesky(A, upper=upper, out=out)
[docs] def logsumexp( self, a: DenseArray, axis: int | Sequence[int] | None = None, b: DenseArray | None = None, keepdims: bool = False, return_sign: bool = False, *, out: DenseArray | None = None, ) -> DenseArray | tuple[DenseArray, DenseArray]: """ Compute log-sum-exp using PyTorch. Input: a: Dense backend tensor; axis, b, keepdims, return_sign: Reduction controls. Output: Dense backend tensor, or ``(value, sign)`` when ``return_sign`` is true. See: https://docs.pytorch.org/docs/stable/generated/torch.logsumexp.html Backend-specific notes: Weighted and signed variants are implemented in SpaceCore because PyTorch's public ``logsumexp`` does not expose SciPy-style ``b`` or ``return_sign`` parameters. """ dim = tuple(range(a.ndim)) if axis is None else axis if b is None and not return_sign: return self.torch.logsumexp(a, dim=dim, keepdim=keepdims, out=out) weights = self.ones_like(a) if b is None else b m = self.torch.amax(a, dim=dim, keepdim=True) total = self.sum(weights * self.torch.exp(a - m), axis=dim, keepdims=True) sign = self.torch.sign(total) result = self.torch.log(self.torch.abs(total)) + m if not keepdims: result = self.squeeze(result, axis) sign = self.squeeze(sign, axis) if return_sign: return result, sign if out is not None: out.copy_(result) return out return result
[docs] def exp( self, x: DenseArray, *, out: DenseArray | None = None, ) -> DenseArray: """ Compute elementwise exponential using PyTorch. Input: x: Dense backend tensor. Output: Dense backend tensor. See: https://docs.pytorch.org/docs/stable/generated/torch.exp.html """ if out is None: return self.torch.exp(x) return self.torch.exp(x, out=out)
[docs] def log( self, x: DenseArray, *, out: DenseArray | None = None, ) -> DenseArray: """ Compute elementwise natural logarithm using PyTorch. Input: x: Dense backend tensor. Output: Dense backend tensor. See: https://docs.pytorch.org/docs/stable/generated/torch.log.html """ if out is None: return self.torch.log(x) return self.torch.log(x, out=out)
[docs] def where( self, condition: DenseArray | bool, x: ArrayLike | None = None, y: ArrayLike | None = None, *, out: DenseArray | None = None, ) -> DenseArray: """ Select values conditionally using PyTorch. Input: condition: Boolean tensor or scalar; x, y: Values to select. Output: Dense backend tensor. See: https://docs.pytorch.org/docs/stable/generated/torch.where.html """ if x is None and y is None: return self.torch.where(condition) if x is None or y is None: raise TypeError("where requires both x and y when either is provided.") if out is None: return self.torch.where(condition, x, y) return self.torch.where(condition, x, y, out=out)
[docs] def maximum( self, x: ArrayLike, y: ArrayLike, *, out: DenseArray | None = None, ) -> DenseArray: """ Compute elementwise maximum using PyTorch. Input: x, y: Array-like operands. Output: Dense backend tensor. See: https://docs.pytorch.org/docs/stable/generated/torch.maximum.html """ y = y if isinstance(y, self.torch.Tensor) else self.asarray(y, dtype=x.dtype, device=x.device) if out is None: return self.torch.maximum(x, y) return self.torch.maximum( x, y, out=out, )
[docs] def minimum( self, x: ArrayLike, y: ArrayLike, *, out: DenseArray | None = None, ) -> DenseArray: """ Compute elementwise minimum using PyTorch. Input: x, y: Array-like operands. Output: Dense backend tensor. See: https://docs.pytorch.org/docs/stable/generated/torch.minimum.html """ y = y if isinstance(y, self.torch.Tensor) else self.asarray(y, dtype=x.dtype, device=x.device) if out is None: return self.torch.minimum(x, y) return self.torch.minimum( x, y, out=out, )
[docs] def clip( self, x: DenseArray, a_min: ArrayLike | None = None, a_max: ArrayLike | None = None, *, out: DenseArray | None = None, ) -> DenseArray: """ Clip tensor values using PyTorch. Input: x: Dense backend tensor; a_min, a_max: Lower and upper bounds. Output: Dense backend tensor. See: https://docs.pytorch.org/docs/stable/generated/torch.clamp.html """ if out is None: return self.torch.clamp(x, min=a_min, max=a_max) return self.torch.clamp(x, min=a_min, max=a_max, out=out)
[docs] def isfinite(self, x: DenseArray) -> DenseArray: """ Test finiteness elementwise using PyTorch. Input: x: Dense backend tensor. Output: Boolean dense backend tensor. See: https://docs.pytorch.org/docs/stable/generated/torch.isfinite.html """ return self.torch.isfinite(x)
[docs] def isnan(self, x: DenseArray) -> DenseArray: """ Test NaN values elementwise using PyTorch. Input: x: Dense backend tensor. Output: Boolean dense backend tensor. See: https://docs.pytorch.org/docs/stable/generated/torch.isnan.html """ return self.torch.isnan(x)
[docs] def concatenate( self, arrays: Sequence[DenseArray], axis: int = 0, dtype: DType | None = None, *, out: DenseArray | None = None, ) -> DenseArray: """ Concatenate tensors using PyTorch. Input: arrays: Sequence of dense backend tensors; axis: Concatenation axis; dtype: Optional cast. Output: Dense backend tensor. See: https://docs.pytorch.org/docs/stable/generated/torch.cat.html """ arrays = tuple(arrays) if out is None: result = self.torch.cat(arrays, dim=axis) else: result = self.torch.cat(arrays, dim=axis, out=out) return self.astype(result, dtype) if dtype is not None else result
[docs] def take( self, x: DenseArray, indices: DenseArray, axis: int | None = None, *, out: DenseArray | None = None, ) -> DenseArray: """ Take tensor elements by index using PyTorch. Input: x: Dense backend tensor; indices: Integer indices; axis: Optional selection axis. Output: Dense backend tensor. See: https://docs.pytorch.org/docs/stable/generated/torch.take.html """ if axis is None: result = self.torch.take(x, indices) if out is not None: out.copy_(result) return out return result return self.torch.index_select(x, dim=axis, index=indices, out=out)
[docs] def diag( self, x: DenseArray, k: int = 0, *, out: DenseArray | None = None, ) -> DenseArray: """ Extract or construct a diagonal tensor using PyTorch. Input: x: Dense backend tensor; k: Diagonal offset. Output: Dense backend tensor. See: https://docs.pytorch.org/docs/stable/generated/torch.diag.html """ return self.torch.diag(x, diagonal=k, out=out)
[docs] def diagonal(self, x: DenseArray, offset: int = 0, axis1: int = 0, axis2: int = 1) -> DenseArray: """ Return a tensor diagonal using PyTorch. Input: x: Dense backend tensor; offset, axis1, axis2: Diagonal controls. Output: Dense backend tensor view or value. See: https://docs.pytorch.org/docs/stable/generated/torch.diagonal.html """ return self.torch.diagonal(x, offset=offset, dim1=axis1, dim2=axis2)
[docs] def tril( self, x: DenseArray, k: int = 0, *, out: DenseArray | None = None, ) -> DenseArray: """ Return the lower triangular part using PyTorch. Input: x: Dense backend tensor; k: Diagonal offset. Output: Dense backend tensor. See: https://docs.pytorch.org/docs/stable/generated/torch.tril.html """ return self.torch.tril(x, diagonal=k, out=out)
[docs] def triu( self, x: DenseArray, k: int = 0, *, out: DenseArray | None = None, ) -> DenseArray: """ Return the upper triangular part using PyTorch. Input: x: Dense backend tensor; k: Diagonal offset. Output: Dense backend tensor. See: https://docs.pytorch.org/docs/stable/generated/torch.triu.html """ return self.torch.triu(x, diagonal=k, out=out)
[docs] def index_set(self, x: DenseArray, index: Index, values: DenseArray, *, copy: bool = True): """ Set indexed tensor values using PyTorch. Input: x: Dense backend tensor; index: Index expression; values: Replacement values. Output: Tensor with indexed values replaced. See: https://docs.pytorch.org/docs/stable/tensor_view.html Backend-specific notes: When ``copy`` is true, this clones ``x`` before assignment. Otherwise the assignment mutates ``x`` in place. """ y = x.clone() if copy else x y[index] = values return y
[docs] def index_add(self, x: DenseArray, index: Index, values: DenseArray, *, copy: bool = True): """ Add values at indexed tensor positions using PyTorch. Input: x: Dense backend tensor; index: Index expression; values: Values to add. Output: Tensor with indexed values incremented. See: https://docs.pytorch.org/docs/stable/tensor_view.html Backend-specific notes: When ``copy`` is true, this clones ``x`` before assignment. Otherwise the assignment mutates ``x`` in place. """ y = x.clone() if copy else x y[index] = y[index] + values return y
[docs] def ix_(self, *args: Any) -> Any: """ Construct open mesh indices using PyTorch. Input: args: One-dimensional index arrays or array-like objects. Output: Tuple of broadcastable index tensors. See: https://docs.pytorch.org/docs/stable/generated/torch.meshgrid.html """ tensors = tuple(arg if isinstance(arg, self.torch.Tensor) else self.asarray(arg) for arg in args) return self.torch.meshgrid(*tensors, indexing="ij")
[docs] def fori_loop(self, lower: int, upper: int, body_fun: Callable[[int, T], T], init_val: T) -> T: """ Run a counted loop eagerly in Python for PyTorch. Input: lower, upper: Integer loop bounds; body_fun: Loop body; init_val: Initial value. Output: Final loop value. See: https://docs.python.org/3/reference/compound_stmts.html#the-for-statement Backend-specific notes: This is an eager Python loop, not a compiled PyTorch control-flow primitive. Tensor operations inside ``body_fun`` follow PyTorch autograd semantics. """ val = init_val for i in range(int(lower), int(upper)): val = body_fun(i, val) return val
[docs] def while_loop(self, cond_fun: Callable[[T], bool], body_fun: Callable[[T], T], init_val: T) -> T: """ Run a while loop eagerly in Python for PyTorch. Input: cond_fun: Loop predicate; body_fun: Loop body; init_val: Initial value. Output: Final loop value. See: https://docs.python.org/3/reference/compound_stmts.html#the-while-statement Backend-specific notes: This is an eager Python loop. The predicate is converted to a Python bool each iteration. """ val = init_val while bool(cond_fun(val)): val = body_fun(val) return val
def _tree_map(self, f: Callable[[Any], Any], tree: Any) -> Any: if isinstance(tree, dict): return {k: self._tree_map(f, v) for k, v in tree.items()} if isinstance(tree, tuple): return tuple(self._tree_map(f, v) for v in tree) if isinstance(tree, list): return [self._tree_map(f, v) for v in tree] return f(tree) def _tree_multimap(self, f: Callable[..., Any], *trees: Any) -> Any: t0 = trees[0] if isinstance(t0, dict): return {k: self._tree_multimap(f, *(t[k] for t in trees)) for k in t0.keys()} if isinstance(t0, tuple): return tuple(self._tree_multimap(f, *(t[i] for t in trees)) for i in range(len(t0))) if isinstance(t0, list): return [self._tree_multimap(f, *(t[i] for t in trees)) for i in range(len(t0))] return f(*trees) def _tree_take0(self, xs: Any) -> Any: if isinstance(xs, dict): return self._tree_take0(next(iter(xs.values()))) if isinstance(xs, (tuple, list)): return self._tree_take0(xs[0]) return xs def _tree_index(self, xs: Any, i: int) -> Any: def _idx(a: Any) -> Any: try: return a[i] except Exception: return a return self._tree_map(_idx, xs) def _tree_stack(self, ys_list: Sequence[Any]) -> Any: if not ys_list: return () return self._tree_multimap(lambda *leaves: self.stack(leaves, axis=0), *ys_list)
[docs] def scan( self, f: Callable[[Carry, X], Tuple[Carry, Y]], init: Carry, xs: X, length: Optional[int] = None, reverse: bool = False, unroll: int = 1, ) -> Tuple[Carry, Y]: """ Run a scan-style loop eagerly in Python for PyTorch. Input: f: Scan body; init: Initial carry; xs: Per-step inputs plus scan options. Output: Tuple of final carry and stacked outputs. See: https://docs.jax.dev/en/latest/_autosummary/jax.lax.scan.html Backend-specific notes: PyTorch has no direct eager equivalent to ``jax.lax.scan`` in this backend. SpaceCore implements a Python loop and stacks tensor leaves at the end. """ carry = init if xs is None: if length is None: raise ValueError("scan(xs=None) requires an explicit `length`.") indices = range(int(length) - 1, -1, -1) if reverse else range(int(length)) ys_steps = [] for _ in indices: carry, y = f(carry, None) # type: ignore[arg-type] ys_steps.append(y) else: n = int(length) if length is not None else int(self._tree_take0(xs).shape[0]) indices = range(n - 1, -1, -1) if reverse else range(n) ys_steps = [] for i in indices: carry, y = f(carry, self._tree_index(xs, i)) ys_steps.append(y) if reverse: ys_steps.reverse() return carry, self._tree_stack(ys_steps)
[docs] def cond(self, pred: bool, true_fun: Callable[[T], R], false_fun: Callable[[T], R], *operands: Any) -> R: """ Run conditional branch selection eagerly in Python for PyTorch. Input: pred: Predicate; true_fun and false_fun: Branch functions; operands: Branch inputs. Output: Result returned by the selected branch. See: https://docs.python.org/3/reference/expressions.html#conditional-expressions Backend-specific notes: This uses Python eager branching, not a staged or compiled control flow primitive. """ return true_fun(*operands) if bool(pred) else false_fun(*operands)
[docs] def allclose(self, a: DenseArray, b: DenseArray, rtol: float = 1e-5, atol: float = 1e-8, equal_nan: bool = False) -> bool: """ Compare dense tensors elementwise within tolerances using PyTorch. Input: a, b: Dense backend tensors; rtol, atol, equal_nan: Comparison controls. Output: Boolean indicating whether tensors are close. See: https://docs.pytorch.org/docs/stable/generated/torch.allclose.html """ return bool(self.torch.allclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan))
[docs] def allclose_sparse(self, a: SparseArray, b: SparseArray, rtol: float = 1e-5, atol: float = 1e-8) -> bool: """ Compare sparse tensors elementwise within tolerances using PyTorch. Input: a, b: Sparse backend tensors; rtol and atol: Comparison controls. Output: Boolean indicating whether sparse tensors are close. See: https://docs.pytorch.org/docs/stable/sparse.html Backend-specific notes: Sparse tensors are compared by converting both operands to dense tensors before calling ``allclose``. """ if not self.is_sparse(a) or not self.is_sparse(b): raise TypeError("allclose_sparse expects two sparse tensors.") return self.allclose(a.to_dense(), b.to_dense(), rtol=rtol, atol=atol)