6
from collections import OrderedDict
7
from copy import deepcopy
8
from numbers import Number
9
from typing import Any, Dict, Optional, Tuple, Union
13
from torch._namedtensor_internals import (
14
check_serializing_named_tensor,
17
single_ellipsis_index,
21
from torch.overrides import (
22
get_default_nowrap_functions,
23
handle_torch_function,
25
has_torch_function_unary,
26
has_torch_function_variadic,
30
def _handle_torch_function_and_wrap_type_error_to_not_implemented(f):
31
assigned = functools.WRAPPER_ASSIGNMENTS
33
@functools.wraps(f, assigned=assigned)
34
def wrapped(*args, **kwargs):
37
if has_torch_function(args):
38
return handle_torch_function(wrapped, args, *args, **kwargs)
39
return f(*args, **kwargs)
47
def _rebuild_from_type(func, type, args, dict):
51
ret = func(*args).as_subclass(type)
56
def _rebuild_from_type_v2(func, new_type, args, state):
58
if type(ret) is not new_type:
59
ret = ret.as_subclass(new_type)
64
getattr(ret.__class__, "__setstate__", Tensor.__setstate__)
65
is not Tensor.__setstate__
67
ret.__setstate__(state)
69
ret = torch._utils._set_obj_state(ret, state)
80
class Tensor(torch._C.TensorBase):
81
def __deepcopy__(self, memo):
82
if has_torch_function_unary(self):
83
return handle_torch_function(Tensor.__deepcopy__, (self,), self, memo)
86
"Only Tensors created explicitly by the user "
87
"(graph leaves) support the deepcopy protocol at the moment. "
88
"If you were attempting to deepcopy a module, this may be because "
89
"of a torch.nn.utils.weight_norm usage, "
90
"see https://github.com/pytorch/pytorch/pull/103001"
103
in ["lazy", "xla", "mtia", "mps", "maia", "meta", "ipu"]
105
not torch._C._has_storage(self)
106
and self.device.type == torch._C._get_privateuse1_backend_name()
108
or (type(self) is not Tensor and self.data_ptr() == 0)
110
new_tensor = self.clone()
111
if type(new_tensor) is not type(self):
113
"The default implementation of __deepcopy__() for wrapper subclasses "
114
"only works for subclass types that implement clone() and for which "
115
"cloning returns another instance of the same subclass. You should either "
116
"properly implement clone() for your subclass or override __deepcopy__() "
117
"if it is intended behavior for clone() to return an instance of a "
121
new_storage = self._typed_storage()._deepcopy(memo)
122
if self.is_quantized:
124
quantizer_params: Union[
125
Tuple[torch.qscheme, float, int],
126
Tuple[torch.qscheme, Tensor, Tensor, int],
128
if self.qscheme() == torch.per_tensor_affine:
134
elif self.qscheme() in (
135
torch.per_channel_affine,
136
torch.per_channel_affine_float_qparams,
140
self.q_per_channel_scales(),
141
self.q_per_channel_zero_points(),
142
self.q_per_channel_axis(),
146
f"Unsupported qscheme {self.qscheme()} in deepcopy"
150
new_tensor = torch._utils._rebuild_qtensor(
151
torch.storage.TypedStorage(
152
wrap_storage=new_storage._untyped_storage,
156
self.storage_offset(),
161
self._backward_hooks,
163
if type(new_tensor) is not type(self):
165
"The default implementation of __deepcopy__() for quantized tensors "
166
"expects the tensor returned by torch._utils._rebuild_qtensor() to "
167
"match the type of the instance being copied. If you encounter this, "
168
"please open an issue on PyTorch's GitHub."
171
new_tensor = self.new_empty([])
172
if type(new_tensor) is not type(self):
174
"The default implementation of __deepcopy__() for non-wrapper subclasses "
175
"only works for subclass types that implement new_empty() and for which "
176
"that function returns another instance of the same subclass. You should "
177
"either properly implement new_empty() for your subclass or override "
178
"__deepcopy__() if it is intended behavior for new_empty() to return "
179
"an instance of a different type."
182
new_storage, self.storage_offset(), self.size(), self.stride()
185
new_tensor = new_tensor.conj_physical()
187
new_tensor = new_tensor.neg()
188
if self.requires_grad:
189
new_tensor.requires_grad_()
190
if self.grad is not None:
191
new_tensor.grad = self.grad.__deepcopy__(memo)
193
if type(self) is not Tensor:
194
if type(new_tensor) is not type(self):
196
"Type of deepcopy result does not match the type of the source tensor. "
197
"If you encounter this, please open an issue on PyTorch's GitHub."
201
slots_to_save = copyreg._slotnames(self.__class__)
202
for slot in slots_to_save:
203
if hasattr(self, slot):
204
setattr(new_tensor, slot, deepcopy(getattr(self, slot), memo))
206
new_tensor.__dict__ = deepcopy(self.__dict__, memo)
208
memo[id(self)] = new_tensor
211
def __reduce_ex__(self, proto):
212
materialize_fake_tensors = (
213
torch.serialization._serialization_tls.materialize_fake_tensors
215
state = torch._utils._get_obj_state(self)
219
type(self) is torch._subclasses.fake_tensor.FakeTensor
220
and materialize_fake_tensors
221
) or (type(self) is Tensor and not state):
223
return self._reduce_ex_internal(proto)
224
if has_torch_function_unary(self):
225
return handle_torch_function(Tensor.__reduce_ex__, (self,), self, proto)
226
func, args = self._reduce_ex_internal(proto)
227
return (_rebuild_from_type_v2, (func, type(self), args, state))
231
storage() -> torch.TypedStorage
233
Returns the underlying :class:`TypedStorage`.
237
:class:`TypedStorage` is deprecated. It will be removed in the future, and
238
:class:`UntypedStorage` will be the only storage class. To access the
239
:class:`UntypedStorage` directly, use :attr:`Tensor.untyped_storage()`.
241
if has_torch_function_unary(self):
242
return handle_torch_function(Tensor.storage, (self,), self)
244
torch.storage._warn_typed_storage_removal(stacklevel=2)
245
return self._typed_storage()
248
def _typed_storage(self):
249
untyped_storage = self.untyped_storage()
250
return torch.TypedStorage(
251
wrap_storage=untyped_storage, dtype=self.dtype, _internal=True
254
def _reduce_ex_internal(self, proto):
255
check_serializing_named_tensor(self)
257
from torch.utils.hooks import warn_if_has_hooks
260
warn_if_has_hooks(self)
261
backward_hooks: Dict[Any, Any] = OrderedDict()
263
skip_data = torch.serialization._serialization_tls.skip_data
264
materialize_fake_tensors = (
265
torch.serialization._serialization_tls.materialize_fake_tensors
278
if self.device.type in ["xla", "mtia", "maia"] or (
279
not torch._C._has_storage(self)
280
and self.device.type == torch._C._get_privateuse1_backend_name()
287
"Cannot serialize tensors on backends with no storage under skip_data context manager"
291
if self.dtype != torch.bfloat16
292
else self.cpu().to(torch.float32).numpy()
295
torch._utils._rebuild_device_tensor_from_numpy,
296
(numpy_tensor, self.dtype, str(self.device), self.requires_grad),
298
if self.device.type == "meta":
303
"Serializing tensors on the meta device under skip_data context manager is a no-op"
311
return (torch._utils._rebuild_meta_tensor_no_storage, arg_meta)
312
if self.is_quantized:
315
"Cannot serialize qtensor under skip_data context manager, file an issue if you need this feature"
318
quantizer_params: Union[
319
Tuple[torch.qscheme, float, int], Tuple[Any, Tensor, Tensor, int]
321
if self.qscheme() == torch.per_tensor_affine:
323
torch.per_tensor_affine,
327
elif self.qscheme() in (
328
torch.per_channel_affine,
329
torch.per_channel_affine_float_qparams,
335
torch.per_channel_affine,
336
self.q_per_channel_scales(),
337
self.q_per_channel_zero_points(),
338
self.q_per_channel_axis(),
342
f"Serialization is not supported for tensors of type {self.qscheme()}"
347
torch.storage.TypedStorage(
348
wrap_storage=self._typed_storage()._untyped_storage,
352
self.storage_offset(),
359
return (torch._utils._rebuild_qtensor, args_qtensor)
361
if self.layout == torch.sparse_coo:
364
(self._indices(), self._values(), self.size(), self.is_coalesced()),
367
raise NotImplementedError(
368
f"sparse tensor __reduce_ex__ for layout `{self.layout}`"
370
return (torch._utils._rebuild_sparse_tensor, args_sparse)
371
elif self.layout in {
377
if self.layout in {torch.sparse_csr, torch.sparse_bsr}:
378
compressed_indices, plain_indices = (
383
compressed_indices, plain_indices = (
387
args_sparse_compressed = (
396
return (torch._utils._rebuild_sparse_tensor, args_sparse_compressed)
400
"Cannot serialize nested tensor under skip_data context manager, file an issue if you need this feature"
407
self._nested_tensor_size(),
408
self._nested_tensor_strides(),
409
self._nested_tensor_storage_offsets(),
411
return (torch._utils._rebuild_nested_tensor, args_nested)
413
type(self) is not torch.Tensor
414
and type(self).__torch_dispatch__ is not torch.Tensor.__torch_dispatch__
416
isinstance(self, torch._subclasses.functional_tensor.FunctionalTensor)
418
not isinstance(self, torch._subclasses.fake_tensor.FakeTensor)
419
and self.data_ptr() == 0
423
arg_wrapper_subclass = (
428
self.storage_offset(),
433
return (torch._utils._rebuild_wrapper_subclass, arg_wrapper_subclass)
435
type(self) is not torch.Tensor
436
and type(self).__torch_dispatch__ is not torch.Tensor.__torch_dispatch__
438
isinstance(self, torch._subclasses.fake_tensor.FakeTensor)
439
and not (skip_data and materialize_fake_tensors)
442
arg_wrapper_subclass = (
447
self.storage_offset(),
452
return (torch._utils._rebuild_wrapper_subclass, arg_wrapper_subclass)
454
v3_dtypes = torch.storage._new_dtypes()
455
if self.dtype in v3_dtypes:
456
rebuild_func = torch._utils._rebuild_tensor_v3
457
storage = self.untyped_storage()
461
rebuild_func = torch._utils._rebuild_tensor_v2
462
storage = torch.storage.TypedStorage(
463
wrap_storage=self._typed_storage()._untyped_storage,
468
if isinstance(self, torch._subclasses.fake_tensor.FakeTensor) and skip_data:
469
storage._fake_device = self.device
473
self.storage_offset(),
480
if isinstance(storage, torch.storage.UntypedStorage):
481
args = args + (self.dtype,)
483
metadata = torch._utils.get_tensor_metadata(self)
485
args = args + (metadata,)
487
return (rebuild_func, args)
489
def __setstate__(self, state):
490
if has_torch_function_unary(self):
491
return handle_torch_function(Tensor.__setstate__, (self,), self, state)
495
raise RuntimeError("__setstate__ can be only called on leaf Tensors")
500
elif len(state) == 5:
503
state = (state[3], state[4], state[2])
506
self.requires_grad, _, self._backward_hooks = state
508
def __repr__(self, *, tensor_contents=None):
509
if has_torch_function_unary(self):
510
return handle_torch_function(
511
Tensor.__repr__, (self,), self, tensor_contents=tensor_contents
514
return torch._tensor_str._str(self, tensor_contents=tensor_contents)
517
self, gradient=None, retain_graph=None, create_graph=False, inputs=None
519
r"""Computes the gradient of current tensor wrt graph leaves.
521
The graph is differentiated using the chain rule. If the tensor is
522
non-scalar (i.e. its data has more than one element) and requires
523
gradient, the function additionally requires specifying a ``gradient``.
524
It should be a tensor of matching type and shape, that represents
525
the gradient of the differentiated function w.r.t. ``self``.
527
This function accumulates gradients in the leaves - you might need to zero
528
``.grad`` attributes or set them to ``None`` before calling it.
529
See :ref:`Default gradient layouts<default-grad-layouts>`
530
for details on the memory layout of accumulated gradients.
534
If you run any forward ops, create ``gradient``, and/or call ``backward``
535
in a user-specified CUDA stream context, see
536
:ref:`Stream semantics of backward passes<bwd-cuda-stream-semantics>`.
540
When ``inputs`` are provided and a given input is not a leaf,
541
the current implementation will call its grad_fn (though it is not strictly needed to get this gradients).
542
It is an implementation detail on which the user should not rely.
543
See https://github.com/pytorch/pytorch/pull/60521#issuecomment-867061780 for more details.
546
gradient (Tensor, optional): The gradient of the function
547
being differentiated w.r.t. ``self``.
548
This argument can be omitted if ``self`` is a scalar.
549
retain_graph (bool, optional): If ``False``, the graph used to compute
550
the grads will be freed. Note that in nearly all cases setting
551
this option to True is not needed and often can be worked around
552
in a much more efficient way. Defaults to the value of
554
create_graph (bool, optional): If ``True``, graph of the derivative will
555
be constructed, allowing to compute higher order derivative
556
products. Defaults to ``False``.
557
inputs (sequence of Tensor, optional): Inputs w.r.t. which the gradient will be
558
accumulated into ``.grad``. All other tensors will be ignored. If not
559
provided, the gradient is accumulated into all the leaf Tensors that were
560
used to compute the :attr:`tensors`.
562
if has_torch_function_unary(self):
563
return handle_torch_function(
568
retain_graph=retain_graph,
569
create_graph=create_graph,
572
torch.autograd.backward(
573
self, gradient, retain_graph, create_graph, inputs=inputs
576
def register_hook(self, hook):
577
r"""Registers a backward hook.
579
The hook will be called every time a gradient with respect to the
580
Tensor is computed. The hook should have the following signature::
582
hook(grad) -> Tensor or None
585
The hook should not modify its argument, but it can optionally return
586
a new gradient which will be used in place of :attr:`grad`.
588
This function returns a handle with a method ``handle.remove()``
589
that removes the hook from the module.
592
See :ref:`backward-hooks-execution` for more information on how when this hook
593
is executed, and how its execution is ordered relative to other hooks.
597
>>> v = torch.tensor([0., 0., 0.], requires_grad=True)
598
>>> h = v.register_hook(lambda grad: grad * 2) # double the gradient
599
>>> v.backward(torch.tensor([1., 2., 3.]))
605
[torch.FloatTensor of size (3,)]
607
>>> h.remove() # removes the hook
609
if has_torch_function_unary(self):
610
return handle_torch_function(Tensor.register_hook, (self,), self, hook)
611
if not self.requires_grad:
613
"cannot register a hook on a tensor that doesn't require gradient"
615
if self._backward_hooks is None:
616
self._backward_hooks = OrderedDict()
617
if self.grad_fn is not None:
618
self.grad_fn._register_hook_dict(self)
620
from torch.utils.hooks import RemovableHandle
622
handle = RemovableHandle(self._backward_hooks)
623
self._backward_hooks[handle.id] = hook
626
def register_post_accumulate_grad_hook(self, hook):
627
r"""Registers a backward hook that runs after grad accumulation.
629
The hook will be called after all gradients for a tensor have been accumulated,
630
meaning that the .grad field has been updated on that tensor. The post
631
accumulate grad hook is ONLY applicable for leaf tensors (tensors without a
632
.grad_fn field). Registering this hook on a non-leaf tensor will error!
634
The hook should have the following signature::
636
hook(param: Tensor) -> None
638
Note that, unlike other autograd hooks, this hook operates on the tensor
639
that requires grad and not the grad itself. The hook can in-place modify
640
and access its Tensor argument, including its .grad field.
642
This function returns a handle with a method ``handle.remove()``
643
that removes the hook from the module.
646
See :ref:`backward-hooks-execution` for more information on how when this hook
647
is executed, and how its execution is ordered relative to other hooks. Since
648
this hook runs during the backward pass, it will run in no_grad mode (unless
649
create_graph is True). You can use torch.enable_grad() to re-enable autograd
650
within the hook if you need it.
654
>>> v = torch.tensor([0., 0., 0.], requires_grad=True)
656
>>> # simulate a simple SGD update
657
>>> h = v.register_post_accumulate_grad_hook(lambda p: p.add_(p.grad, alpha=-lr))
658
>>> v.backward(torch.tensor([1., 2., 3.]))
660
tensor([-0.0100, -0.0200, -0.0300], requires_grad=True)
662
>>> h.remove() # removes the hook
664
if has_torch_function_unary(self):
665
return handle_torch_function(
666
Tensor.register_post_accumulate_grad_hook, (self,), self, hook
668
if not self.requires_grad:
670
"cannot register a hook on a tensor that doesn't require gradient"
672
if self.grad_fn is not None:
674
"post accumulate grad hooks cannot be registered on non-leaf tensors"
676
if self._post_accumulate_grad_hooks is None:
677
self._post_accumulate_grad_hooks: Dict[Any, Any] = OrderedDict()
679
from torch.utils.hooks import RemovableHandle
681
handle = RemovableHandle(self._post_accumulate_grad_hooks)
682
self._post_accumulate_grad_hooks[handle.id] = hook
685
def reinforce(self, reward):
687
return "\n".join([line.strip() for line in str.split("\n")])
691
r"""reinforce() was removed.
692
Use torch.distributions instead.
693
See https://pytorch.org/docs/main/distributions.html
697
probs = policy_network(state)
698
action = probs.multinomial()
699
next_state, reward = env.step(action)
700
action.reinforce(reward)
705
probs = policy_network(state)
706
# NOTE: categorical is equivalent to what used to be called multinomial
707
m = torch.distributions.Categorical(probs)
709
next_state, reward = env.step(action)
710
loss = -m.log_prob(action) * reward
716
detach = _C._add_docstr(
717
_C.TensorBase.detach,
719
Returns a new Tensor, detached from the current graph.
721
The result will never require gradient.
723
This method also affects forward mode AD gradients and the result will never
724
have forward mode AD gradients.
728
Returned Tensor shares the same storage with the original one.
729
In-place modifications on either of them will be seen, and may trigger
730
errors in correctness checks.
734
detach_ = _C._add_docstr(
735
_C.TensorBase.detach_,
737
Detaches the Tensor from the graph that created it, making it a leaf.
738
Views cannot be detached in-place.
740
This method also affects forward mode AD gradients and the result will never
741
have forward mode AD gradients.
746
r"""Checks if tensor is in shared memory.
748
This is always ``True`` for CUDA tensors.
750
if has_torch_function_unary(self):
751
return handle_torch_function(Tensor.is_shared, (self,), self)
752
return self._typed_storage()._is_shared()
754
def share_memory_(self):
755
r"""Moves the underlying storage to shared memory.
757
This is a no-op if the underlying storage is already in shared memory
758
and for CUDA tensors. Tensors in shared memory cannot be resized.
760
See :meth:`torch.UntypedStorage.share_memory_` for more details.
762
if has_torch_function_unary(self):
763
return handle_torch_function(Tensor.share_memory_, (self,), self)
764
self._typed_storage()._share_memory_()
767
def module_load(self, other, assign=False):
768
r"""Defines how to transform ``other`` when loading it into ``self`` in :meth:`~nn.Module.load_state_dict`.
770
Used when :func:`~torch.__future__.get_swap_module_params_on_conversion` is ``True``.
772
It is expected that ``self`` is a parameter or buffer in an ``nn.Module`` and ``other`` is the
773
value in the state dictionary with the corresponding key, this method defines
774
how ``other`` is remapped before being swapped with ``self`` via
775
:func:`~torch.utils.swap_tensors` in :meth:`~nn.Module.load_state_dict`.
778
This method should always return a new object that is not ``self`` or ``other``.
779
For example, the default implementation returns ``self.copy_(other).detach()``
780
if ``assign`` is ``False`` or ``other.detach()`` if ``assign`` is ``True``.
783
other (Tensor): value in state dict with key corresponding to ``self``
784
assign (bool): the assign argument passed to :meth:`nn.Module.load_state_dict`
787
if has_torch_function_variadic(self, other):
788
return handle_torch_function(
789
Tensor.module_load, (self, other), self, other, assign=assign
793
return other.detach()
795
return self.copy_(other).detach()
797
def __reversed__(self):
798
r"""Reverses the tensor along dimension 0."""
799
if has_torch_function_unary(self):
800
return handle_torch_function(Tensor.__reversed__, (self,), self)
808
p: Optional[Union[float, str]] = "fro",
813
r"""See :func:`torch.norm`"""
814
if has_torch_function_unary(self):
815
return handle_torch_function(
816
Tensor.norm, (self,), self, p=p, dim=dim, keepdim=keepdim, dtype=dtype
818
return torch.norm(self, p, dim, keepdim, dtype=dtype)
820
def solve(self, other):
821
from torch._linalg_utils import solve
823
return solve(self, other)
825
def lstsq(self, other):
826
from torch._linalg_utils import lstsq
828
return lstsq(self, other)
830
def eig(self, eigenvectors=False):
831
from torch._linalg_utils import eig
833
return eig(self, eigenvectors=eigenvectors)
835
def symeig(self, eigenvectors=False):
836
from torch._linalg_utils import _symeig
838
return _symeig(self, eigenvectors=eigenvectors)
840
def lu(self, pivot=True, get_infos=False):
841
r"""See :func:`torch.lu`"""
843
if has_torch_function_unary(self):
844
return handle_torch_function(
845
Tensor.lu, (self,), self, pivot=pivot, get_infos=get_infos
848
LU, pivots, infos = torch._lu_with_info(
849
self, pivot=pivot, check_errors=(not get_infos)
852
return LU, pivots, infos
859
hop_length: Optional[int] = None,
860
win_length: Optional[int] = None,
861
window: "Optional[Tensor]" = None,
863
pad_mode: str = "reflect",
864
normalized: bool = False,
865
onesided: Optional[bool] = None,
866
return_complex: Optional[bool] = None,
868
r"""See :func:`torch.stft`
871
This function changed signature at version 0.4.1. Calling with
872
the previous signature may cause error or return incorrect result.
874
if has_torch_function_unary(self):
875
return handle_torch_function(
880
hop_length=hop_length,
881
win_length=win_length,
885
normalized=normalized,
887
return_complex=return_complex,
899
return_complex=return_complex,
905
hop_length: Optional[int] = None,
906
win_length: Optional[int] = None,
907
window: "Optional[Tensor]" = None,
909
normalized: bool = False,
910
onesided: Optional[bool] = None,
911
length: Optional[int] = None,
912
return_complex: bool = False,
914
r"""See :func:`torch.istft`"""
915
if has_torch_function_unary(self):
916
return handle_torch_function(
921
hop_length=hop_length,
922
win_length=win_length,
925
normalized=normalized,
928
return_complex=return_complex,
940
return_complex=return_complex,
943
def resize(self, *sizes):
944
if has_torch_function_unary(self):
945
return handle_torch_function(Tensor.resize, (self,), self, *sizes)
946
warnings.warn("non-inplace resize is deprecated")
947
from torch.autograd._functions import Resize
949
return Resize.apply(self, sizes)
951
def resize_as(self, tensor):
952
if has_torch_function_variadic(self, tensor):
953
return handle_torch_function(Tensor.resize_as, (self, tensor), self, tensor)
954
warnings.warn("non-inplace resize_as is deprecated")
955
from torch.autograd._functions import Resize
957
return Resize.apply(self, tensor.size())
959
def split(self, split_size, dim=0):
960
r"""See :func:`torch.split`"""
961
if has_torch_function_unary(self):
962
return handle_torch_function(
963
Tensor.split, (self,), self, split_size, dim=dim
965
if isinstance(split_size, Tensor):
967
split_size = int(split_size)
971
if isinstance(split_size, (int, torch.SymInt)):
972
return torch._VF.split(self, split_size, dim)
974
return torch._VF.split_with_sizes(self, split_size, dim)
976
def unique(self, sorted=True, return_inverse=False, return_counts=False, dim=None):
977
r"""Returns the unique elements of the input tensor.
979
See :func:`torch.unique`
981
if has_torch_function_unary(self):
982
return handle_torch_function(
987
return_inverse=return_inverse,
988
return_counts=return_counts,
994
return_inverse=return_inverse,
995
return_counts=return_counts,
999
def unique_consecutive(self, return_inverse=False, return_counts=False, dim=None):
1000
r"""Eliminates all but the first element from every consecutive group of equivalent elements.
1002
See :func:`torch.unique_consecutive`
1004
if has_torch_function_unary(self):
1005
return handle_torch_function(
1006
Tensor.unique_consecutive,
1009
return_inverse=return_inverse,
1010
return_counts=return_counts,
1013
return torch.unique_consecutive(
1014
self, return_inverse=return_inverse, return_counts=return_counts, dim=dim
1017
@_handle_torch_function_and_wrap_type_error_to_not_implemented
1018
def __rsub__(self, other):
1019
return _C._VariableFunctions.rsub(self, other)
1021
@_handle_torch_function_and_wrap_type_error_to_not_implemented
1022
def __rdiv__(self, other):
1023
return self.reciprocal() * other
1025
__rtruediv__ = __rdiv__
1026
__itruediv__ = _C.TensorBase.__idiv__
1028
__pow__ = _handle_torch_function_and_wrap_type_error_to_not_implemented(
1031
__ipow__ = _handle_torch_function_and_wrap_type_error_to_not_implemented(
1035
@_handle_torch_function_and_wrap_type_error_to_not_implemented
1036
def __rmod__(self, other):
1037
return torch.remainder(other, self)
1039
def __format__(self, format_spec):
1040
if has_torch_function_unary(self):
1041
return handle_torch_function(Tensor.__format__, (self,), self, format_spec)
1042
if self.dim() == 0 and not self.is_meta and type(self) is Tensor:
1043
return self.item().__format__(format_spec)
1044
return object.__format__(self, format_spec)
1046
@_handle_torch_function_and_wrap_type_error_to_not_implemented
1047
def __rpow__(self, other):
1048
return torch.pow(other, self)
1050
@_handle_torch_function_and_wrap_type_error_to_not_implemented
1051
def __floordiv__(self, other):
1052
return torch.floor_divide(self, other)
1054
@_handle_torch_function_and_wrap_type_error_to_not_implemented
1055
def __rfloordiv__(self, other):
1056
return torch.floor_divide(other, self)
1058
@_handle_torch_function_and_wrap_type_error_to_not_implemented
1059
def __rlshift__(self, other):
1060
return torch.bitwise_left_shift(other, self)
1062
@_handle_torch_function_and_wrap_type_error_to_not_implemented
1063
def __rrshift__(self, other):
1064
return torch.bitwise_right_shift(other, self)
1066
@_handle_torch_function_and_wrap_type_error_to_not_implemented
1067
def __rmatmul__(self, other):
1068
return torch.matmul(other, self)
1070
__pos__ = _C.TensorBase.positive
1071
__neg__ = _C.TensorBase.neg
1072
__abs__ = _C.TensorBase.abs
1075
if has_torch_function_unary(self):
1076
return handle_torch_function(Tensor.__len__, (self,), self)
1078
raise TypeError("len() of a 0-d tensor")
1079
if torch._C._get_tracing_state():
1081
"Using len to get tensor shape might cause the trace to be incorrect. "
1082
"Recommended usage would be tensor.shape[0]. "
1083
"Passing a tensor of different shape might lead to errors or silently give "
1084
"incorrect results.",
1085
category=torch.jit.TracerWarning,
1088
return self.shape[0]
1100
raise TypeError("iteration over a 0-d tensor")
1101
if torch._C._get_tracing_state():
1103
"Iterating over a tensor might cause the trace to be incorrect. "
1104
"Passing a tensor of different shape won't change the number of "
1105
"iterations executed (and might lead to errors or silently give "
1106
"incorrect results).",
1107
category=torch.jit.TracerWarning,
1110
return iter(self.unbind(0))
1120
if has_torch_function_unary(self):
1121
return handle_torch_function(Tensor.__dir__, (self,), self)
1122
tensor_methods = dir(self.__class__)
1123
tensor_methods.remove("volatile")
1124
attrs = list(self.__dict__.keys())
1125
keys = tensor_methods + attrs
1128
if (not self.is_cuda) or self.is_sparse:
1129
keys.remove("__cuda_array_interface__")
1134
__array_priority__ = 1000
1136
def __array__(self, dtype=None):
1137
if has_torch_function_unary(self):
1138
return handle_torch_function(Tensor.__array__, (self,), self, dtype=dtype)
1142
return self.numpy().astype(dtype, copy=False)
1146
def __array_wrap__(self, array):
1147
if has_torch_function_unary(self):
1148
return handle_torch_function(
1149
Tensor.__array_wrap__, (self,), self, array=array
1151
if array.dtype == bool:
1153
array = array.astype("uint8")
1154
return torch.from_numpy(array)
1156
def __contains__(self, element: Any, /) -> bool:
1157
r"""Check if `element` is present in tensor
1160
element (Tensor or scalar): element to be checked
1161
for presence in current tensor"
1163
if has_torch_function_unary(self):
1164
return handle_torch_function(Tensor.__contains__, (self,), self, element)
1166
element, (torch.Tensor, Number, torch.SymInt, torch.SymFloat, torch.SymBool)
1169
return bool((element == self).any().item())
1172
f"Tensor.__contains__ only supports Tensor or scalar, but you passed in a {type(element)}."
1176
def __cuda_array_interface__(self):
1177
"""Array view description for cuda tensors.
1180
https://numba.pydata.org/numba-doc/latest/cuda/cuda_array_interface.html
1182
if has_torch_function_unary(self):
1184
return handle_torch_function(
1185
Tensor.__cuda_array_interface__.__get__,
1192
if not self.is_cuda:
1193
raise AttributeError(
1194
f"Can't get __cuda_array_interface__ on non-CUDA tensor type: {self.type()} "
1195
"If CUDA data is required use tensor.cuda() to copy tensor to device memory."
1199
raise AttributeError(
1200
f"Can't get __cuda_array_interface__ on sparse type: {self.type()} "
1201
"Use Tensor.to_dense() to convert to a dense tensor first."
1205
if self.requires_grad:
1207
"Can't get __cuda_array_interface__ on Variable that requires grad. "
1208
"If gradients aren't required, use var.detach() to get Variable that doesn't require grad."
1214
torch.complex64: "<c8",
1215
torch.complex128: "<c16",
1216
torch.bfloat16: "<f2",
1217
torch.float16: "<f2",
1218
torch.float32: "<f4",
1219
torch.float64: "<f8",
1222
torch.uint16: "<u2",
1224
torch.uint32: "<u4",
1226
torch.uint64: "<u8",
1231
itemsize = self.element_size()
1233
shape = tuple(self.shape)
1234
if self.is_contiguous():
1239
strides = tuple(s * itemsize for s in self.stride())
1240
data_ptr = self.data_ptr() if self.numel() > 0 else 0
1241
data = (data_ptr, False)
1243
return dict(typestr=typestr, shape=shape, strides=strides, data=data, version=2)
1245
def storage_type(self):
1246
r"""storage_type() -> type
1248
Returns the type of the underlying storage.
1251
if has_torch_function_unary(self):
1252
return handle_torch_function(Tensor.storage_type, (self,), self)
1254
torch.storage._warn_typed_storage_removal()
1256
return self._typed_storage()._get_legacy_storage_class()
1258
def refine_names(self, *names):
1259
r"""Refines the dimension names of :attr:`self` according to :attr:`names`.
1261
Refining is a special case of renaming that "lifts" unnamed dimensions.
1262
A ``None`` dim can be refined to have any name; a named dim can only be
1263
refined to have the same name.
1265
Because named tensors can coexist with unnamed tensors, refining names
1266
gives a nice way to write named-tensor-aware code that works with both
1267
named and unnamed tensors.
1269
:attr:`names` may contain up to one Ellipsis (``...``).
1270
The Ellipsis is expanded greedily; it is expanded in-place to fill
1271
:attr:`names` to the same length as ``self.dim()`` using names from the
1272
corresponding indices of ``self.names``.
1274
Python 2 does not support Ellipsis but one may use a string literal
1275
instead (``'...'``).
1278
names (iterable of str): The desired names of the output tensor. May
1279
contain up to one Ellipsis.
1283
>>> imgs = torch.randn(32, 3, 128, 128)
1284
>>> named_imgs = imgs.refine_names('N', 'C', 'H', 'W')
1285
>>> named_imgs.names
1286
('N', 'C', 'H', 'W')
1288
>>> tensor = torch.randn(2, 3, 5, 7, 11)
1289
>>> tensor = tensor.refine_names('A', ..., 'B', 'C')
1291
('A', None, None, 'B', 'C')
1294
The named tensor API is experimental and subject to change.
1297
if has_torch_function_unary(self):
1298
return handle_torch_function(Tensor.refine_names, (self,), self, *names)
1299
names = resolve_ellipsis(names, self.names, "refine_names")
1300
return super().refine_names(names)
1302
def align_to(self, *names):
1303
r"""Permutes the dimensions of the :attr:`self` tensor to match the order
1304
specified in :attr:`names`, adding size-one dims for any new names.
1306
All of the dims of :attr:`self` must be named in order to use this method.
1307
The resulting tensor is a view on the original tensor.
1309
All dimension names of :attr:`self` must be present in :attr:`names`.
1310
:attr:`names` may contain additional names that are not in ``self.names``;
1311
the output tensor has a size-one dimension for each of those new names.
1313
:attr:`names` may contain up to one Ellipsis (``...``).
1314
The Ellipsis is expanded to be equal to all dimension names of :attr:`self`
1315
that are not mentioned in :attr:`names`, in the order that they appear
1318
Python 2 does not support Ellipsis but one may use a string literal
1319
instead (``'...'``).
1322
names (iterable of str): The desired dimension ordering of the
1323
output tensor. May contain up to one Ellipsis that is expanded
1324
to all unmentioned dim names of :attr:`self`.
1328
>>> tensor = torch.randn(2, 2, 2, 2, 2, 2)
1329
>>> named_tensor = tensor.refine_names('A', 'B', 'C', 'D', 'E', 'F')
1331
# Move the F and E dims to the front while keeping the rest in order
1332
>>> named_tensor.align_to('F', 'E', ...)
1335
The named tensor API is experimental and subject to change.
1338
if has_torch_function_unary(self):
1339
return handle_torch_function(Tensor.align_to, (self,), self, *names)
1340
ellipsis_idx = single_ellipsis_index(names, "align_to")
1341
if ellipsis_idx is None:
1342
return super().align_to(names)
1343
return super().align_to(
1344
[name for name in names if not is_ellipsis(name)], ellipsis_idx
1347
def unflatten(self, dim, sizes):
1349
unflatten(dim, sizes) -> Tensor
1351
See :func:`torch.unflatten`.
1354
if has_torch_function_unary(self):
1355
return handle_torch_function(Tensor.unflatten, (self,), self, dim, sizes)
1358
raise RuntimeError("unflatten: sizes must be non-empty")
1361
if isinstance(sizes, OrderedDict) or (
1362
isinstance(sizes, (tuple, list)) and isinstance(sizes[0], (tuple, list))
1364
names, sizes = unzip_namedshape(sizes)
1365
return super().unflatten(dim, sizes, names)
1367
return super().unflatten(dim, sizes)
1369
def rename_(self, *names, **rename_map):
1370
"""In-place version of :meth:`~Tensor.rename`."""
1372
if has_torch_function_unary(self):
1373
return handle_torch_function(
1374
Tensor.rename_, (self,), self, *names, **rename_map
1382
return update_names(self, names, rename_map, inplace=True)
1384
def rename(self, *names, **rename_map):
1385
"""Renames dimension names of :attr:`self`.
1387
There are two main usages:
1389
``self.rename(**rename_map)`` returns a view on tensor that has dims
1390
renamed as specified in the mapping :attr:`rename_map`.
1392
``self.rename(*names)`` returns a view on tensor, renaming all
1393
dimensions positionally using :attr:`names`.
1394
Use ``self.rename(None)`` to drop names on a tensor.
1396
One cannot specify both positional args :attr:`names` and keyword args
1401
>>> imgs = torch.rand(2, 3, 5, 7, names=('N', 'C', 'H', 'W'))
1402
>>> renamed_imgs = imgs.rename(N='batch', C='channels')
1403
>>> renamed_imgs.names
1404
('batch', 'channels', 'H', 'W')
1406
>>> renamed_imgs = imgs.rename(None)
1407
>>> renamed_imgs.names
1408
(None, None, None, None)
1410
>>> renamed_imgs = imgs.rename('batch', 'channel', 'height', 'width')
1411
>>> renamed_imgs.names
1412
('batch', 'channel', 'height', 'width')
1415
The named tensor API is experimental and subject to change.
1418
if has_torch_function_unary(self):
1419
return handle_torch_function(
1420
Tensor.rename, (self,), self, *names, **rename_map
1424
return update_names(self, names, rename_map, inplace=False)
1426
def to_sparse_coo(self):
1427
"""Convert a tensor to :ref:`coordinate format <sparse-coo-docs>`.
1431
>>> dense = torch.randn(5, 5)
1432
>>> sparse = dense.to_sparse_coo()
1437
return self.to_sparse()
1439
def dim_order(self):
1442
dim_order() -> tuple
1444
Returns a tuple of int describing the dim order or physical layout of :attr:`self`.
1449
Dim order represents how dimensions are laid out in memory,
1450
starting from the outermost to the innermost dimension.
1453
>>> torch.empty((2, 3, 5, 7)).dim_order()
1455
>>> torch.empty((2, 3, 5, 7), memory_format=torch.channels_last).dim_order()
1459
The dim_order tensor API is experimental and subject to change.
1462
if has_torch_function_unary(self):
1463
return handle_torch_function(Tensor.dim_order, (self,), self)
1465
import torch._prims_common as utils
1467
return tuple(utils.compute_elementwise_output_logical_to_physical_perm(self))
1469
def _update_names(self, names, inplace):
1470
if has_torch_function_unary(self):
1471
return handle_torch_function(
1472
Tensor._update_names, (self,), self, names, inplace
1477
return super().rename_(names)
1479
return super().rename(names)
1482
def __torch_function__(cls, func, types, args=(), kwargs=None):
1484
This __torch_function__ implementation wraps subclasses such that
1485
methods called on subclasses return a subclass instance instead of
1486
a ``torch.Tensor`` instance.
1488
One corollary to this is that you need coverage for torch.Tensor
1489
methods if implementing __torch_function__ for subclasses.
1491
We recommend always calling ``super().__torch_function__`` as the base
1492
case when doing the above.
1494
While not mandatory, we recommend making `__torch_function__` a classmethod.
1499
if not all(issubclass(cls, t) for t in types):
1500
return NotImplemented
1502
with _C.DisableTorchFunctionSubclass():
1503
ret = func(*args, **kwargs)
1504
if func in get_default_nowrap_functions():
1507
return _convert(ret, cls)
1509
__torch_dispatch__ = _C._disabled_torch_dispatch_impl
1511
def __dlpack__(self, stream=None):
1513
Creates a DLpack `capsule https://data-apis.org/array-api/latest/design_topics/data_interchange.html#data-interchange`_
1514
of the current tensor to be exported to other libraries.
1516
This function will be called from the `from_dlpack` method
1517
of the library that will consume the capsule. `from_dlpack` passes the current
1518
stream to this method as part of the specification.
1521
stream (integer or None): An optional Python integer representing a
1522
pointer to a CUDA stream. The current stream is synchronized with
1523
this stream before the capsule is created, and since the capsule
1524
shares its storage with the tensor this make it safe to access from
1525
both streams. If None or -1 is passed then no synchronization is performed.
1526
If 1 (on CUDA) or 0 (on ROCM) then the default stream is used for
1529
if has_torch_function_unary(self):
1530
return handle_torch_function(Tensor.__dlpack__, (self,), self, stream)
1535
if self.requires_grad:
1537
"Can't export tensors that require gradient, use tensor.detach()"
1540
raise RuntimeError("Can't export tensors with the conjugate bit set")
1541
if self.layout != torch.strided:
1543
"Can't export tensors with layout other than torch.strided"
1546
if stream is not None and type(stream) is not int:
1549
raise TypeError("stream must be ``int`` or ``none``")
1550
elif stream is not None and stream != -1:
1551
if self.device.type == "cuda":
1555
if stream == 1 and torch.version.hip is None:
1556
stream = torch.cuda.default_stream()
1557
elif stream == 0 and torch.version.hip is not None:
1558
stream = torch.cuda.default_stream()
1560
stream = torch.cuda.ExternalStream(stream)
1562
sync_stream = torch.cuda.current_stream()
1563
if stream != sync_stream:
1564
event = torch.cuda.Event()
1565
event.record(sync_stream)
1566
stream.wait_event(event)
1567
return torch.to_dlpack(self)
1569
def __dlpack_device__(self) -> Tuple[enum.IntEnum, int]:
1570
if has_torch_function_unary(self):
1571
return handle_torch_function(Tensor.__dlpack_device__, (self,), self)
1573
from torch.utils.dlpack import DLDeviceType
1575
device = self.device
1576
idx = device.index if device.index is not None else 0
1577
torch_device_type = device.type
1578
if torch_device_type == "cuda" and torch.version.hip is not None:
1579
device_type = DLDeviceType.kDLROCM
1580
elif torch_device_type == "cpu" and self.is_pinned():
1581
device_type = DLDeviceType.kDLCPUPinned
1582
elif torch_device_type == "cuda":
1583
device_type = DLDeviceType.kDLGPU
1584
elif torch_device_type == "cpu":
1585
device_type = DLDeviceType.kDLCPU
1586
elif self.device.type == "xpu":
1587
device_type = DLDeviceType.kDLOneAPI
1589
raise ValueError(f"Unknown device type {torch_device_type} for Dlpack")
1590
return (device_type, idx)
1592
__module__ = "torch"
1595
def _convert(ret, cls):
1599
if isinstance(ret, Tensor) and not isinstance(ret, cls):
1600
ret = ret.as_subclass(cls)
1602
if isinstance(ret, (tuple, list)):
1604
ret = type(ret)(_convert(r, cls) for r in ret)