from __future__ import annotations from typing import TYPE_CHECKING if TYPE_CHECKING: import torch array = torch.Tensor from torch import dtype as Dtype from typing import Optional from torch.linalg import * # torch.linalg doesn't define __all__ # from torch.linalg import __all__ as linalg_all from torch import linalg as torch_linalg linalg_all = [i for i in dir(torch_linalg) if not i.startswith('_')] # outer is implemented in torch but aren't in the linalg namespace from torch import outer from ._aliases import _fix_promotion, matrix_transpose, tensordot, sum # Note: torch.linalg.cross does not default to axis=-1 (it defaults to the # first axis with size 3), see https://github.com/pytorch/pytorch/issues/58743 def cross(x1: array, x2: array, /, *, axis: int = -1) -> array: x1, x2 = _fix_promotion(x1, x2, only_scalar=False) return torch_linalg.cross(x1, x2, dim=axis) def vecdot(x1: array, x2: array, /, *, axis: int = -1, **kwargs) -> array: from ._aliases import isdtype x1, x2 = _fix_promotion(x1, x2, only_scalar=False) # torch.linalg.vecdot doesn't support integer dtypes if isdtype(x1.dtype, 'integral') or isdtype(x2.dtype, 'integral'): if kwargs: raise RuntimeError("vecdot kwargs not supported for integral dtypes") ndim = max(x1.ndim, x2.ndim) x1_shape = (1,)*(ndim - x1.ndim) + tuple(x1.shape) x2_shape = (1,)*(ndim - x2.ndim) + tuple(x2.shape) if x1_shape[axis] != x2_shape[axis]: raise ValueError("x1 and x2 must have the same size along the given axis") x1_, x2_ = torch.broadcast_tensors(x1, x2) x1_ = torch.moveaxis(x1_, axis, -1) x2_ = torch.moveaxis(x2_, axis, -1) res = x1_[..., None, :] @ x2_[..., None] return res[..., 0, 0] return torch.linalg.vecdot(x1, x2, dim=axis, **kwargs) def solve(x1: array, x2: array, /, **kwargs) -> array: x1, x2 = _fix_promotion(x1, x2, only_scalar=False) return torch.linalg.solve(x1, x2, **kwargs) # torch.trace doesn't support the offset argument and doesn't support stacking def trace(x: array, /, *, offset: int = 0, dtype: Optional[Dtype] = None) -> array: # Use our wrapped sum to make sure it does upcasting correctly return sum(torch.diagonal(x, offset=offset, dim1=-2, dim2=-1), axis=-1, dtype=dtype) __all__ = linalg_all + ['outer', 'trace', 'matrix_transpose', 'tensordot', 'vecdot', 'solve'] del linalg_all