pytorch

Форк
0
/
_tensor.py 
1606 строк · 62.2 Кб
1
# mypy: allow-untyped-defs
2
import copyreg
3
import enum
4
import functools
5
import warnings
6
from collections import OrderedDict
7
from copy import deepcopy
8
from numbers import Number
9
from typing import Any, Dict, Optional, Tuple, Union
10

11
import torch
12
import torch._C as _C
13
from torch._namedtensor_internals import (
14
    check_serializing_named_tensor,
15
    is_ellipsis,
16
    resolve_ellipsis,
17
    single_ellipsis_index,
18
    unzip_namedshape,
19
    update_names,
20
)
21
from torch.overrides import (
22
    get_default_nowrap_functions,
23
    handle_torch_function,
24
    has_torch_function,
25
    has_torch_function_unary,
26
    has_torch_function_variadic,
27
)
28

29

30
def _handle_torch_function_and_wrap_type_error_to_not_implemented(f):
31
    assigned = functools.WRAPPER_ASSIGNMENTS
32

33
    @functools.wraps(f, assigned=assigned)
34
    def wrapped(*args, **kwargs):
35
        try:
36
            # See https://github.com/pytorch/pytorch/issues/75462
37
            if has_torch_function(args):
38
                return handle_torch_function(wrapped, args, *args, **kwargs)
39
            return f(*args, **kwargs)
40
        except TypeError:
41
            return NotImplemented
42

43
    return wrapped
44

45

46
# Should not be used, this is kept only for BC of loading old serialized Tensor subclasses
47
def _rebuild_from_type(func, type, args, dict):
48
    if type is Tensor:
49
        return func(*args)
50

51
    ret = func(*args).as_subclass(type)
52
    ret.__dict__ = dict
53
    return ret
54

55

56
def _rebuild_from_type_v2(func, new_type, args, state):
57
    ret = func(*args)
58
    if type(ret) is not new_type:
59
        ret = ret.as_subclass(new_type)
60
    # Tensor does define __setstate__ even though it doesn't define
61
    # __getstate__. So only use __setstate__ if it is NOT the one defined
62
    # on Tensor
63
    if (
64
        getattr(ret.__class__, "__setstate__", Tensor.__setstate__)
65
        is not Tensor.__setstate__
66
    ):
67
        ret.__setstate__(state)
68
    else:
69
        ret = torch._utils._set_obj_state(ret, state)
70
    return ret
71

72

73
# NB: If you subclass Tensor, and want to share the subclassed class
74
# across processes, you must also update torch/multiprocessing/reductions.py
75
# to define a ForkingPickler serialization mode for the class.
76
#
77
# NB: If you add a new method to Tensor, you must update
78
# torch/_C/__init__.pyi.in to add a type annotation for your method;
79
# otherwise, it will not show up in autocomplete.
80
class Tensor(torch._C.TensorBase):
81
    def __deepcopy__(self, memo):
82
        if has_torch_function_unary(self):
83
            return handle_torch_function(Tensor.__deepcopy__, (self,), self, memo)
84
        if not self.is_leaf:
85
            raise RuntimeError(
86
                "Only Tensors created explicitly by the user "
87
                "(graph leaves) support the deepcopy protocol at the moment.  "
88
                "If you were attempting to deepcopy a module, this may be because "
89
                "of a torch.nn.utils.weight_norm usage, "
90
                "see https://github.com/pytorch/pytorch/pull/103001"
91
            )
92
        if id(self) in memo:
93
            return memo[id(self)]
94
        with torch.no_grad():
95
            # TODO: skipping storage copy is wrong for meta, as meta
96
            # does accurate alias tracking; however, the code below
97
            # doesn't work because of
98
            # https://github.com/pytorch/pytorch/issues/47442
99
            # Update the test in test_serialization if you remove 'meta' from here
100
            if (
101
                self.is_sparse
102
                or self.device.type
103
                in ["lazy", "xla", "mtia", "mps", "maia", "meta", "ipu"]
104
                or (
105
                    not torch._C._has_storage(self)
106
                    and self.device.type == torch._C._get_privateuse1_backend_name()
107
                )
108
                or (type(self) is not Tensor and self.data_ptr() == 0)
109
            ):
110
                new_tensor = self.clone()
111
                if type(new_tensor) is not type(self):
112
                    raise RuntimeError(
113
                        "The default implementation of __deepcopy__() for wrapper subclasses "
114
                        "only works for subclass types that implement clone() and for which "
115
                        "cloning returns another instance of the same subclass. You should either "
116
                        "properly implement clone() for your subclass or override __deepcopy__() "
117
                        "if it is intended behavior for clone() to return an instance of a "
118
                        "different type."
119
                    )
120
            else:
121
                new_storage = self._typed_storage()._deepcopy(memo)
122
                if self.is_quantized:
123
                    # quantizer_params can be different type based on torch attribute
124
                    quantizer_params: Union[
125
                        Tuple[torch.qscheme, float, int],
126
                        Tuple[torch.qscheme, Tensor, Tensor, int],
127
                    ]
128
                    if self.qscheme() == torch.per_tensor_affine:
129
                        quantizer_params = (
130
                            self.qscheme(),
131
                            self.q_scale(),
132
                            self.q_zero_point(),
133
                        )
134
                    elif self.qscheme() in (
135
                        torch.per_channel_affine,
136
                        torch.per_channel_affine_float_qparams,
137
                    ):
138
                        quantizer_params = (
139
                            self.qscheme(),
140
                            self.q_per_channel_scales(),
141
                            self.q_per_channel_zero_points(),
142
                            self.q_per_channel_axis(),
143
                        )
144
                    else:
145
                        raise RuntimeError(
146
                            f"Unsupported qscheme {self.qscheme()} in deepcopy"
147
                        )
148
                    # TODO: Once we decide to break serialization FC, no longer
149
                    # need to wrap with TypedStorage
150
                    new_tensor = torch._utils._rebuild_qtensor(
151
                        torch.storage.TypedStorage(
152
                            wrap_storage=new_storage._untyped_storage,
153
                            dtype=self.dtype,
154
                            _internal=True,
155
                        ),
156
                        self.storage_offset(),
157
                        self.size(),
158
                        self.stride(),
159
                        quantizer_params,
160
                        self.requires_grad,
161
                        self._backward_hooks,
162
                    )
163
                    if type(new_tensor) is not type(self):
164
                        raise RuntimeError(
165
                            "The default implementation of __deepcopy__() for quantized tensors "
166
                            "expects the tensor returned by torch._utils._rebuild_qtensor() to "
167
                            "match the type of the instance being copied. If you encounter this, "
168
                            "please open an issue on PyTorch's GitHub."
169
                        )
170
                else:
171
                    new_tensor = self.new_empty([])
172
                    if type(new_tensor) is not type(self):
173
                        raise RuntimeError(
174
                            "The default implementation of __deepcopy__() for non-wrapper subclasses "
175
                            "only works for subclass types that implement new_empty() and for which "
176
                            "that function returns another instance of the same subclass. You should "
177
                            "either properly implement new_empty() for your subclass or override "
178
                            "__deepcopy__() if it is intended behavior for new_empty() to return "
179
                            "an instance of a different type."
180
                        )
181
                    new_tensor.set_(
182
                        new_storage, self.storage_offset(), self.size(), self.stride()
183
                    )
184
                    if self.is_conj():
185
                        new_tensor = new_tensor.conj_physical()
186
                    if self.is_neg():
187
                        new_tensor = new_tensor.neg()
188
            if self.requires_grad:
189
                new_tensor.requires_grad_()
190
            if self.grad is not None:
191
                new_tensor.grad = self.grad.__deepcopy__(memo)
192

193
            if type(self) is not Tensor:
194
                if type(new_tensor) is not type(self):
195
                    raise RuntimeError(
196
                        "Type of deepcopy result does not match the type of the source tensor. "
197
                        "If you encounter this, please open an issue on PyTorch's GitHub."
198
                    )
199

200
                # Plain Tensors don't have slots
201
                slots_to_save = copyreg._slotnames(self.__class__)  # type: ignore[attr-defined]
202
                for slot in slots_to_save:
203
                    if hasattr(self, slot):
204
                        setattr(new_tensor, slot, deepcopy(getattr(self, slot), memo))
205

206
            new_tensor.__dict__ = deepcopy(self.__dict__, memo)
207

208
            memo[id(self)] = new_tensor
209
            return new_tensor
210

211
    def __reduce_ex__(self, proto):
212
        materialize_fake_tensors = (
213
            torch.serialization._serialization_tls.materialize_fake_tensors
214
        )
215
        state = torch._utils._get_obj_state(self)
216
        # Ignore all state when using FakeTensor with skip_data(materialize_fake_tensors) because FakeTensor has
217
        # some state that cannot be pickled
218
        if (
219
            type(self) is torch._subclasses.fake_tensor.FakeTensor
220
            and materialize_fake_tensors
221
        ) or (type(self) is Tensor and not state):
222
            # Fast path for regular tensor without Python state.
223
            return self._reduce_ex_internal(proto)
224
        if has_torch_function_unary(self):
225
            return handle_torch_function(Tensor.__reduce_ex__, (self,), self, proto)
226
        func, args = self._reduce_ex_internal(proto)
227
        return (_rebuild_from_type_v2, (func, type(self), args, state))
228

229
    def storage(self):
230
        r"""
231
        storage() -> torch.TypedStorage
232

233
        Returns the underlying :class:`TypedStorage`.
234

235
        .. warning::
236

237
            :class:`TypedStorage` is deprecated. It will be removed in the future, and
238
            :class:`UntypedStorage` will be the only storage class. To access the
239
            :class:`UntypedStorage` directly, use :attr:`Tensor.untyped_storage()`.
240
        """
241
        if has_torch_function_unary(self):
242
            return handle_torch_function(Tensor.storage, (self,), self)
243

244
        torch.storage._warn_typed_storage_removal(stacklevel=2)
245
        return self._typed_storage()
246

247
    # For internal use only, to avoid raising deprecation warning
248
    def _typed_storage(self):
249
        untyped_storage = self.untyped_storage()
250
        return torch.TypedStorage(
251
            wrap_storage=untyped_storage, dtype=self.dtype, _internal=True
252
        )
253

254
    def _reduce_ex_internal(self, proto):
255
        check_serializing_named_tensor(self)
256

257
        from torch.utils.hooks import warn_if_has_hooks
258

259
        # See Note [Don't serialize hooks]
260
        warn_if_has_hooks(self)
261
        backward_hooks: Dict[Any, Any] = OrderedDict()
262

263
        skip_data = torch.serialization._serialization_tls.skip_data
264
        materialize_fake_tensors = (
265
            torch.serialization._serialization_tls.materialize_fake_tensors
266
        )
267

268
        # Note: Numpy array is chosen to be the rebuild component for XLA, MTIA, MAIA Tensors.
269
        # We considered a few options:
270
        # 1. CPU tensor can't be used here.
271
        #    Otherwise in torch.load CPU storage is reconstructed with randomly
272
        #    initialized data, moved onto backend device, and then storage is updated
273
        #    to the serialized content. This works perfectly for CPU/CUDA but not these backends;
274
        #    their tensors are disconnected with storage so they don't get the update.
275
        # 2. Python list is not a good fit due to performance reason.
276
        #    `tolist()` converts every single element in the tensor into python objects
277
        #    and serialize them one by one.
278
        if self.device.type in ["xla", "mtia", "maia"] or (
279
            not torch._C._has_storage(self)
280
            and self.device.type == torch._C._get_privateuse1_backend_name()
281
        ):
282
            # Convert BFloat16 tesors to Float32 before conversion to numpy, as numpy doesn't
283
            # support BFloat16. The rebuild tensor from numpy takes in the original self.dtype,
284
            # this would reconstruct the BFloat16 tensor from numpy.
285
            if skip_data:
286
                raise RuntimeError(
287
                    "Cannot serialize tensors on backends with no storage under skip_data context manager"
288
                )
289
            numpy_tensor = (
290
                self.cpu().numpy()
291
                if self.dtype != torch.bfloat16
292
                else self.cpu().to(torch.float32).numpy()
293
            )
294
            return (
295
                torch._utils._rebuild_device_tensor_from_numpy,
296
                (numpy_tensor, self.dtype, str(self.device), self.requires_grad),
297
            )
298
        if self.device.type == "meta":
299
            # NB: This implementation BREAKS storage sharing.  Current
300
            # hypothesis is that no one cares for meta tensors.
301
            if skip_data:
302
                warnings.warn(
303
                    "Serializing tensors on the meta device under skip_data context manager is a no-op"
304
                )
305
            arg_meta = (
306
                self.dtype,
307
                tuple(self.size()),
308
                self.stride(),
309
                self.requires_grad,
310
            )
311
            return (torch._utils._rebuild_meta_tensor_no_storage, arg_meta)
312
        if self.is_quantized:
313
            if skip_data:
314
                raise RuntimeError(
315
                    "Cannot serialize qtensor under skip_data context manager, file an issue if you need this feature"
316
                )
317
            # quantizer_params can be different type based on torch attribute
318
            quantizer_params: Union[
319
                Tuple[torch.qscheme, float, int], Tuple[Any, Tensor, Tensor, int]
320
            ]
321
            if self.qscheme() == torch.per_tensor_affine:
322
                quantizer_params = (
323
                    torch.per_tensor_affine,
324
                    self.q_scale(),
325
                    self.q_zero_point(),
326
                )
327
            elif self.qscheme() in (
328
                torch.per_channel_affine,
329
                torch.per_channel_affine_float_qparams,
330
            ):
331
                # convert scales and zero points to tuple to avoid recursive calls
332
                # when/if we get multi-axis quantized tensors in the future, the shape
333
                # is recoverable from the main tensor shape
334
                quantizer_params = (
335
                    torch.per_channel_affine,
336
                    self.q_per_channel_scales(),
337
                    self.q_per_channel_zero_points(),
338
                    self.q_per_channel_axis(),
339
                )
340
            else:
341
                raise RuntimeError(
342
                    f"Serialization is not supported for tensors of type {self.qscheme()}"
343
                )
344
            # TODO: Once we decide to break serialization FC, no longer
345
            # need to wrap with TypedStorage
346
            args_qtensor = (
347
                torch.storage.TypedStorage(
348
                    wrap_storage=self._typed_storage()._untyped_storage,
349
                    dtype=self.dtype,
350
                    _internal=True,
351
                ),
352
                self.storage_offset(),
353
                tuple(self.size()),
354
                self.stride(),
355
                quantizer_params,
356
                self.requires_grad,
357
                backward_hooks,
358
            )
359
            return (torch._utils._rebuild_qtensor, args_qtensor)
360
        elif self.is_sparse:
361
            if self.layout == torch.sparse_coo:
362
                args_sparse = (
363
                    self.layout,
364
                    (self._indices(), self._values(), self.size(), self.is_coalesced()),
365
                )
366
            else:
367
                raise NotImplementedError(
368
                    f"sparse tensor __reduce_ex__ for layout `{self.layout}`"
369
                )
370
            return (torch._utils._rebuild_sparse_tensor, args_sparse)
371
        elif self.layout in {
372
            torch.sparse_csr,
373
            torch.sparse_csc,
374
            torch.sparse_bsr,
375
            torch.sparse_bsc,
376
        }:
377
            if self.layout in {torch.sparse_csr, torch.sparse_bsr}:
378
                compressed_indices, plain_indices = (
379
                    self.crow_indices(),
380
                    self.col_indices(),
381
                )
382
            else:
383
                compressed_indices, plain_indices = (
384
                    self.ccol_indices(),
385
                    self.row_indices(),
386
                )
387
            args_sparse_compressed = (
388
                self.layout,
389
                (
390
                    compressed_indices,
391
                    plain_indices,
392
                    self.values(),
393
                    self.size(),
394
                ),
395
            )
396
            return (torch._utils._rebuild_sparse_tensor, args_sparse_compressed)
397
        elif self.is_nested:
398
            if skip_data:
399
                raise RuntimeError(
400
                    "Cannot serialize nested tensor under skip_data context manager, file an issue if you need this feature"
401
                )
402
            args_nested = (
403
                # NB: values() currently returns the storage as a buffer in an unsafe way.
404
                # Ideally, we'd use a private API for this instead. TODO: Switch to this if
405
                # we ever get around to adding it.
406
                self.values(),
407
                self._nested_tensor_size(),
408
                self._nested_tensor_strides(),
409
                self._nested_tensor_storage_offsets(),
410
            )
411
            return (torch._utils._rebuild_nested_tensor, args_nested)
412
        elif (
413
            type(self) is not torch.Tensor
414
            and type(self).__torch_dispatch__ is not torch.Tensor.__torch_dispatch__
415
            and (
416
                isinstance(self, torch._subclasses.functional_tensor.FunctionalTensor)
417
                or (
418
                    not isinstance(self, torch._subclasses.fake_tensor.FakeTensor)
419
                    and self.data_ptr() == 0
420
                )
421
            )
422
        ):
423
            arg_wrapper_subclass = (
424
                type(self),
425
                self.dtype,
426
                tuple(self.size()),
427
                self.stride(),
428
                self.storage_offset(),
429
                self.layout,
430
                self.device,
431
                self.requires_grad,
432
            )
433
            return (torch._utils._rebuild_wrapper_subclass, arg_wrapper_subclass)
434
        elif (
435
            type(self) is not torch.Tensor
436
            and type(self).__torch_dispatch__ is not torch.Tensor.__torch_dispatch__
437
            and (
438
                isinstance(self, torch._subclasses.fake_tensor.FakeTensor)
439
                and not (skip_data and materialize_fake_tensors)
440
            )
441
        ):
442
            arg_wrapper_subclass = (
443
                type(self),
444
                self.dtype,
445
                tuple(self.size()),
446
                self.stride(),
447
                self.storage_offset(),
448
                self.layout,
449
                self.device,
450
                self.requires_grad,
451
            )
452
            return (torch._utils._rebuild_wrapper_subclass, arg_wrapper_subclass)
453
        else:
454
            v3_dtypes = torch.storage._new_dtypes()
455
            if self.dtype in v3_dtypes:
456
                rebuild_func = torch._utils._rebuild_tensor_v3
457
                storage = self.untyped_storage()
458
            else:
459
                # TODO: Once we decide to break serialization FC, no longer
460
                # need to wrap with TypedStorage
461
                rebuild_func = torch._utils._rebuild_tensor_v2  # type: ignore[assignment]
462
                storage = torch.storage.TypedStorage(
463
                    wrap_storage=self._typed_storage()._untyped_storage,
464
                    dtype=self.dtype,
465
                    _internal=True,
466
                )  # type: ignore[assignment]
467

468
            if isinstance(self, torch._subclasses.fake_tensor.FakeTensor) and skip_data:
469
                storage._fake_device = self.device
470

471
            args = (
472
                storage,
473
                self.storage_offset(),
474
                tuple(self.size()),
475
                self.stride(),
476
                self.requires_grad,
477
                backward_hooks,
478
            )  # previously was self._backward_hooks
479

480
            if isinstance(storage, torch.storage.UntypedStorage):
481
                args = args + (self.dtype,)  # type: ignore[assignment]
482

483
            metadata = torch._utils.get_tensor_metadata(self)
484
            if metadata:
485
                args = args + (metadata,)  # type: ignore[assignment]
486

487
            return (rebuild_func, args)
488

489
    def __setstate__(self, state):
490
        if has_torch_function_unary(self):
491
            return handle_torch_function(Tensor.__setstate__, (self,), self, state)
492
        # Warning: this method is NOT called when you torch.load() a tensor;
493
        # that is managed by _rebuild_tensor_v2
494
        if not self.is_leaf:
495
            raise RuntimeError("__setstate__ can be only called on leaf Tensors")
496
        if len(state) == 4:
497
            # legacy serialization of Tensor
498
            self.set_(*state)
499
            return
500
        elif len(state) == 5:
501
            # legacy serialization of Variable
502
            self.data = state[0]
503
            state = (state[3], state[4], state[2])
504
        # The setting of _backward_hooks is expected to be a no-op.
505
        # See Note [Don't serialize hooks]
506
        self.requires_grad, _, self._backward_hooks = state
507

508
    def __repr__(self, *, tensor_contents=None):
509
        if has_torch_function_unary(self):
510
            return handle_torch_function(
511
                Tensor.__repr__, (self,), self, tensor_contents=tensor_contents
512
            )
513
        # All strings are unicode in Python 3.
514
        return torch._tensor_str._str(self, tensor_contents=tensor_contents)
515

516
    def backward(
517
        self, gradient=None, retain_graph=None, create_graph=False, inputs=None
518
    ):
519
        r"""Computes the gradient of current tensor wrt graph leaves.
520

521
        The graph is differentiated using the chain rule. If the tensor is
522
        non-scalar (i.e. its data has more than one element) and requires
523
        gradient, the function additionally requires specifying a ``gradient``.
524
        It should be a tensor of matching type and shape, that represents
525
        the gradient of the differentiated function w.r.t. ``self``.
526

527
        This function accumulates gradients in the leaves - you might need to zero
528
        ``.grad`` attributes or set them to ``None`` before calling it.
529
        See :ref:`Default gradient layouts<default-grad-layouts>`
530
        for details on the memory layout of accumulated gradients.
531

532
        .. note::
533

534
            If you run any forward ops, create ``gradient``, and/or call ``backward``
535
            in a user-specified CUDA stream context, see
536
            :ref:`Stream semantics of backward passes<bwd-cuda-stream-semantics>`.
537

538
        .. note::
539

540
            When ``inputs`` are provided and a given input is not a leaf,
541
            the current implementation will call its grad_fn (though it is not strictly needed to get this gradients).
542
            It is an implementation detail on which the user should not rely.
543
            See https://github.com/pytorch/pytorch/pull/60521#issuecomment-867061780 for more details.
544

545
        Args:
546
            gradient (Tensor, optional): The gradient of the function
547
                being differentiated w.r.t. ``self``.
548
                This argument can be omitted if ``self`` is a scalar.
549
            retain_graph (bool, optional): If ``False``, the graph used to compute
550
                the grads will be freed. Note that in nearly all cases setting
551
                this option to True is not needed and often can be worked around
552
                in a much more efficient way. Defaults to the value of
553
                ``create_graph``.
554
            create_graph (bool, optional): If ``True``, graph of the derivative will
555
                be constructed, allowing to compute higher order derivative
556
                products. Defaults to ``False``.
557
            inputs (sequence of Tensor, optional): Inputs w.r.t. which the gradient will be
558
                accumulated into ``.grad``. All other tensors will be ignored. If not
559
                provided, the gradient is accumulated into all the leaf Tensors that were
560
                used to compute the :attr:`tensors`.
561
        """
562
        if has_torch_function_unary(self):
563
            return handle_torch_function(
564
                Tensor.backward,
565
                (self,),
566
                self,
567
                gradient=gradient,
568
                retain_graph=retain_graph,
569
                create_graph=create_graph,
570
                inputs=inputs,
571
            )
572
        torch.autograd.backward(
573
            self, gradient, retain_graph, create_graph, inputs=inputs
574
        )
575

576
    def register_hook(self, hook):
577
        r"""Registers a backward hook.
578

579
        The hook will be called every time a gradient with respect to the
580
        Tensor is computed. The hook should have the following signature::
581

582
            hook(grad) -> Tensor or None
583

584

585
        The hook should not modify its argument, but it can optionally return
586
        a new gradient which will be used in place of :attr:`grad`.
587

588
        This function returns a handle with a method ``handle.remove()``
589
        that removes the hook from the module.
590

591
        .. note::
592
            See :ref:`backward-hooks-execution` for more information on how when this hook
593
            is executed, and how its execution is ordered relative to other hooks.
594

595
        Example::
596

597
            >>> v = torch.tensor([0., 0., 0.], requires_grad=True)
598
            >>> h = v.register_hook(lambda grad: grad * 2)  # double the gradient
599
            >>> v.backward(torch.tensor([1., 2., 3.]))
600
            >>> v.grad
601

602
             2
603
             4
604
             6
605
            [torch.FloatTensor of size (3,)]
606

607
            >>> h.remove()  # removes the hook
608
        """
609
        if has_torch_function_unary(self):
610
            return handle_torch_function(Tensor.register_hook, (self,), self, hook)
611
        if not self.requires_grad:
612
            raise RuntimeError(
613
                "cannot register a hook on a tensor that doesn't require gradient"
614
            )
615
        if self._backward_hooks is None:
616
            self._backward_hooks = OrderedDict()
617
            if self.grad_fn is not None:
618
                self.grad_fn._register_hook_dict(self)
619

620
        from torch.utils.hooks import RemovableHandle
621

622
        handle = RemovableHandle(self._backward_hooks)
623
        self._backward_hooks[handle.id] = hook
624
        return handle
625

626
    def register_post_accumulate_grad_hook(self, hook):
627
        r"""Registers a backward hook that runs after grad accumulation.
628

629
        The hook will be called after all gradients for a tensor have been accumulated,
630
        meaning that the .grad field has been updated on that tensor. The post
631
        accumulate grad hook is ONLY applicable for leaf tensors (tensors without a
632
        .grad_fn field). Registering this hook on a non-leaf tensor will error!
633

634
        The hook should have the following signature::
635

636
            hook(param: Tensor) -> None
637

638
        Note that, unlike other autograd hooks, this hook operates on the tensor
639
        that requires grad and not the grad itself. The hook can in-place modify
640
        and access its Tensor argument, including its .grad field.
641

642
        This function returns a handle with a method ``handle.remove()``
643
        that removes the hook from the module.
644

645
        .. note::
646
            See :ref:`backward-hooks-execution` for more information on how when this hook
647
            is executed, and how its execution is ordered relative to other hooks. Since
648
            this hook runs during the backward pass, it will run in no_grad mode (unless
649
            create_graph is True). You can use torch.enable_grad() to re-enable autograd
650
            within the hook if you need it.
651

652
        Example::
653

654
            >>> v = torch.tensor([0., 0., 0.], requires_grad=True)
655
            >>> lr = 0.01
656
            >>> # simulate a simple SGD update
657
            >>> h = v.register_post_accumulate_grad_hook(lambda p: p.add_(p.grad, alpha=-lr))
658
            >>> v.backward(torch.tensor([1., 2., 3.]))
659
            >>> v
660
            tensor([-0.0100, -0.0200, -0.0300], requires_grad=True)
661

662
            >>> h.remove()  # removes the hook
663
        """
664
        if has_torch_function_unary(self):
665
            return handle_torch_function(
666
                Tensor.register_post_accumulate_grad_hook, (self,), self, hook
667
            )
668
        if not self.requires_grad:
669
            raise RuntimeError(
670
                "cannot register a hook on a tensor that doesn't require gradient"
671
            )
672
        if self.grad_fn is not None:
673
            raise RuntimeError(
674
                "post accumulate grad hooks cannot be registered on non-leaf tensors"
675
            )
676
        if self._post_accumulate_grad_hooks is None:
677
            self._post_accumulate_grad_hooks: Dict[Any, Any] = OrderedDict()
678

679
        from torch.utils.hooks import RemovableHandle
680

681
        handle = RemovableHandle(self._post_accumulate_grad_hooks)
682
        self._post_accumulate_grad_hooks[handle.id] = hook
683
        return handle
684

685
    def reinforce(self, reward):
686
        def trim(str):
687
            return "\n".join([line.strip() for line in str.split("\n")])
688

689
        raise RuntimeError(
690
            trim(
691
                r"""reinforce() was removed.
692
            Use torch.distributions instead.
693
            See https://pytorch.org/docs/main/distributions.html
694

695
            Instead of:
696

697
            probs = policy_network(state)
698
            action = probs.multinomial()
699
            next_state, reward = env.step(action)
700
            action.reinforce(reward)
701
            action.backward()
702

703
            Use:
704

705
            probs = policy_network(state)
706
            # NOTE: categorical is equivalent to what used to be called multinomial
707
            m = torch.distributions.Categorical(probs)
708
            action = m.sample()
709
            next_state, reward = env.step(action)
710
            loss = -m.log_prob(action) * reward
711
            loss.backward()
712
        """
713
            )
714
        )
715

716
    detach = _C._add_docstr(
717
        _C.TensorBase.detach,
718
        r"""
719
    Returns a new Tensor, detached from the current graph.
720

721
    The result will never require gradient.
722

723
    This method also affects forward mode AD gradients and the result will never
724
    have forward mode AD gradients.
725

726
    .. note::
727

728
      Returned Tensor shares the same storage with the original one.
729
      In-place modifications on either of them will be seen, and may trigger
730
      errors in correctness checks.
731
    """,
732
    )
733

734
    detach_ = _C._add_docstr(
735
        _C.TensorBase.detach_,
736
        r"""
737
    Detaches the Tensor from the graph that created it, making it a leaf.
738
    Views cannot be detached in-place.
739

740
    This method also affects forward mode AD gradients and the result will never
741
    have forward mode AD gradients.
742
    """,
743
    )
744

745
    def is_shared(self):
746
        r"""Checks if tensor is in shared memory.
747

748
        This is always ``True`` for CUDA tensors.
749
        """
750
        if has_torch_function_unary(self):
751
            return handle_torch_function(Tensor.is_shared, (self,), self)
752
        return self._typed_storage()._is_shared()
753

754
    def share_memory_(self):
755
        r"""Moves the underlying storage to shared memory.
756

757
        This is a no-op if the underlying storage is already in shared memory
758
        and for CUDA tensors. Tensors in shared memory cannot be resized.
759

760
        See :meth:`torch.UntypedStorage.share_memory_` for more details.
761
        """
762
        if has_torch_function_unary(self):
763
            return handle_torch_function(Tensor.share_memory_, (self,), self)
764
        self._typed_storage()._share_memory_()
765
        return self
766

767
    def module_load(self, other, assign=False):
768
        r"""Defines how to transform ``other`` when loading it into ``self`` in :meth:`~nn.Module.load_state_dict`.
769

770
        Used when :func:`~torch.__future__.get_swap_module_params_on_conversion` is ``True``.
771

772
        It is expected that ``self`` is a parameter or buffer in an ``nn.Module`` and ``other`` is the
773
        value in the state dictionary with the corresponding key, this method defines
774
        how ``other`` is remapped before being swapped with ``self`` via
775
        :func:`~torch.utils.swap_tensors` in :meth:`~nn.Module.load_state_dict`.
776

777
        .. note::
778
            This method should always return a new object that is not ``self`` or ``other``.
779
            For example, the default implementation returns ``self.copy_(other).detach()``
780
            if ``assign`` is ``False`` or ``other.detach()`` if ``assign`` is ``True``.
781

782
        Args:
783
            other (Tensor): value in state dict with key corresponding to ``self``
784
            assign (bool): the assign argument passed to :meth:`nn.Module.load_state_dict`
785

786
        """
787
        if has_torch_function_variadic(self, other):
788
            return handle_torch_function(
789
                Tensor.module_load, (self, other), self, other, assign=assign
790
            )
791

792
        if assign:
793
            return other.detach()
794
        else:
795
            return self.copy_(other).detach()
796

797
    def __reversed__(self):
798
        r"""Reverses the tensor along dimension 0."""
799
        if has_torch_function_unary(self):
800
            return handle_torch_function(Tensor.__reversed__, (self,), self)
801
        if self.dim() == 0:
802
            return self
803
        else:
804
            return self.flip(0)
805

806
    def norm(
807
        self,
808
        p: Optional[Union[float, str]] = "fro",
809
        dim=None,
810
        keepdim=False,
811
        dtype=None,
812
    ):
813
        r"""See :func:`torch.norm`"""
814
        if has_torch_function_unary(self):
815
            return handle_torch_function(
816
                Tensor.norm, (self,), self, p=p, dim=dim, keepdim=keepdim, dtype=dtype
817
            )
818
        return torch.norm(self, p, dim, keepdim, dtype=dtype)
819

820
    def solve(self, other):
821
        from torch._linalg_utils import solve
822

823
        return solve(self, other)
824

825
    def lstsq(self, other):
826
        from torch._linalg_utils import lstsq
827

828
        return lstsq(self, other)
829

830
    def eig(self, eigenvectors=False):
831
        from torch._linalg_utils import eig
832

833
        return eig(self, eigenvectors=eigenvectors)
834

835
    def symeig(self, eigenvectors=False):
836
        from torch._linalg_utils import _symeig
837

838
        return _symeig(self, eigenvectors=eigenvectors)
839

840
    def lu(self, pivot=True, get_infos=False):
841
        r"""See :func:`torch.lu`"""
842
        # If get_infos is True, then we don't need to check for errors and vice versa
843
        if has_torch_function_unary(self):
844
            return handle_torch_function(
845
                Tensor.lu, (self,), self, pivot=pivot, get_infos=get_infos
846
            )
847

848
        LU, pivots, infos = torch._lu_with_info(
849
            self, pivot=pivot, check_errors=(not get_infos)
850
        )
851
        if get_infos:
852
            return LU, pivots, infos
853
        else:
854
            return LU, pivots
855

856
    def stft(
857
        self,
858
        n_fft: int,
859
        hop_length: Optional[int] = None,
860
        win_length: Optional[int] = None,
861
        window: "Optional[Tensor]" = None,
862
        center: bool = True,
863
        pad_mode: str = "reflect",
864
        normalized: bool = False,
865
        onesided: Optional[bool] = None,
866
        return_complex: Optional[bool] = None,
867
    ):
868
        r"""See :func:`torch.stft`
869

870
        .. warning::
871
          This function changed signature at version 0.4.1. Calling with
872
          the previous signature may cause error or return incorrect result.
873
        """
874
        if has_torch_function_unary(self):
875
            return handle_torch_function(
876
                Tensor.stft,
877
                (self,),
878
                self,
879
                n_fft,
880
                hop_length=hop_length,
881
                win_length=win_length,
882
                window=window,
883
                center=center,
884
                pad_mode=pad_mode,
885
                normalized=normalized,
886
                onesided=onesided,
887
                return_complex=return_complex,
888
            )
889
        return torch.stft(
890
            self,
891
            n_fft,
892
            hop_length,
893
            win_length,
894
            window,
895
            center,
896
            pad_mode,
897
            normalized,
898
            onesided,
899
            return_complex=return_complex,
900
        )
901

902
    def istft(
903
        self,
904
        n_fft: int,
905
        hop_length: Optional[int] = None,
906
        win_length: Optional[int] = None,
907
        window: "Optional[Tensor]" = None,
908
        center: bool = True,
909
        normalized: bool = False,
910
        onesided: Optional[bool] = None,
911
        length: Optional[int] = None,
912
        return_complex: bool = False,
913
    ):
914
        r"""See :func:`torch.istft`"""
915
        if has_torch_function_unary(self):
916
            return handle_torch_function(
917
                Tensor.istft,
918
                (self,),
919
                self,
920
                n_fft,
921
                hop_length=hop_length,
922
                win_length=win_length,
923
                window=window,
924
                center=center,
925
                normalized=normalized,
926
                onesided=onesided,
927
                length=length,
928
                return_complex=return_complex,
929
            )
930
        return torch.istft(
931
            self,
932
            n_fft,
933
            hop_length,
934
            win_length,
935
            window,
936
            center,
937
            normalized,
938
            onesided,
939
            length,
940
            return_complex=return_complex,
941
        )
942

943
    def resize(self, *sizes):
944
        if has_torch_function_unary(self):
945
            return handle_torch_function(Tensor.resize, (self,), self, *sizes)
946
        warnings.warn("non-inplace resize is deprecated")
947
        from torch.autograd._functions import Resize
948

949
        return Resize.apply(self, sizes)
950

951
    def resize_as(self, tensor):
952
        if has_torch_function_variadic(self, tensor):
953
            return handle_torch_function(Tensor.resize_as, (self, tensor), self, tensor)
954
        warnings.warn("non-inplace resize_as is deprecated")
955
        from torch.autograd._functions import Resize
956

957
        return Resize.apply(self, tensor.size())
958

959
    def split(self, split_size, dim=0):
960
        r"""See :func:`torch.split`"""
961
        if has_torch_function_unary(self):
962
            return handle_torch_function(
963
                Tensor.split, (self,), self, split_size, dim=dim
964
            )
965
        if isinstance(split_size, Tensor):
966
            try:
967
                split_size = int(split_size)
968
            except ValueError:
969
                pass
970

971
        if isinstance(split_size, (int, torch.SymInt)):
972
            return torch._VF.split(self, split_size, dim)  # type: ignore[attr-defined]
973
        else:
974
            return torch._VF.split_with_sizes(self, split_size, dim)
975

976
    def unique(self, sorted=True, return_inverse=False, return_counts=False, dim=None):
977
        r"""Returns the unique elements of the input tensor.
978

979
        See :func:`torch.unique`
980
        """
981
        if has_torch_function_unary(self):
982
            return handle_torch_function(
983
                Tensor.unique,
984
                (self,),
985
                self,
986
                sorted=sorted,
987
                return_inverse=return_inverse,
988
                return_counts=return_counts,
989
                dim=dim,
990
            )
991
        return torch.unique(
992
            self,
993
            sorted=sorted,
994
            return_inverse=return_inverse,
995
            return_counts=return_counts,
996
            dim=dim,
997
        )
998

999
    def unique_consecutive(self, return_inverse=False, return_counts=False, dim=None):
1000
        r"""Eliminates all but the first element from every consecutive group of equivalent elements.
1001

1002
        See :func:`torch.unique_consecutive`
1003
        """
1004
        if has_torch_function_unary(self):
1005
            return handle_torch_function(
1006
                Tensor.unique_consecutive,
1007
                (self,),
1008
                self,
1009
                return_inverse=return_inverse,
1010
                return_counts=return_counts,
1011
                dim=dim,
1012
            )
1013
        return torch.unique_consecutive(
1014
            self, return_inverse=return_inverse, return_counts=return_counts, dim=dim
1015
        )
1016

1017
    @_handle_torch_function_and_wrap_type_error_to_not_implemented
1018
    def __rsub__(self, other):
1019
        return _C._VariableFunctions.rsub(self, other)
1020

1021
    @_handle_torch_function_and_wrap_type_error_to_not_implemented
1022
    def __rdiv__(self, other):
1023
        return self.reciprocal() * other
1024

1025
    __rtruediv__ = __rdiv__
1026
    __itruediv__ = _C.TensorBase.__idiv__
1027

1028
    __pow__ = _handle_torch_function_and_wrap_type_error_to_not_implemented(
1029
        _C.TensorBase.pow
1030
    )
1031
    __ipow__ = _handle_torch_function_and_wrap_type_error_to_not_implemented(
1032
        _C.TensorBase.pow_
1033
    )
1034

1035
    @_handle_torch_function_and_wrap_type_error_to_not_implemented
1036
    def __rmod__(self, other):
1037
        return torch.remainder(other, self)
1038

1039
    def __format__(self, format_spec):
1040
        if has_torch_function_unary(self):
1041
            return handle_torch_function(Tensor.__format__, (self,), self, format_spec)
1042
        if self.dim() == 0 and not self.is_meta and type(self) is Tensor:
1043
            return self.item().__format__(format_spec)
1044
        return object.__format__(self, format_spec)
1045

1046
    @_handle_torch_function_and_wrap_type_error_to_not_implemented
1047
    def __rpow__(self, other):
1048
        return torch.pow(other, self)
1049

1050
    @_handle_torch_function_and_wrap_type_error_to_not_implemented
1051
    def __floordiv__(self, other):
1052
        return torch.floor_divide(self, other)
1053

1054
    @_handle_torch_function_and_wrap_type_error_to_not_implemented
1055
    def __rfloordiv__(self, other):
1056
        return torch.floor_divide(other, self)
1057

1058
    @_handle_torch_function_and_wrap_type_error_to_not_implemented
1059
    def __rlshift__(self, other):
1060
        return torch.bitwise_left_shift(other, self)
1061

1062
    @_handle_torch_function_and_wrap_type_error_to_not_implemented
1063
    def __rrshift__(self, other):
1064
        return torch.bitwise_right_shift(other, self)
1065

1066
    @_handle_torch_function_and_wrap_type_error_to_not_implemented
1067
    def __rmatmul__(self, other):
1068
        return torch.matmul(other, self)
1069

1070
    __pos__ = _C.TensorBase.positive
1071
    __neg__ = _C.TensorBase.neg
1072
    __abs__ = _C.TensorBase.abs
1073

1074
    def __len__(self):
1075
        if has_torch_function_unary(self):
1076
            return handle_torch_function(Tensor.__len__, (self,), self)
1077
        if self.dim() == 0:
1078
            raise TypeError("len() of a 0-d tensor")
1079
        if torch._C._get_tracing_state():
1080
            warnings.warn(
1081
                "Using len to get tensor shape might cause the trace to be incorrect. "
1082
                "Recommended usage would be tensor.shape[0]. "
1083
                "Passing a tensor of different shape might lead to errors or silently give "
1084
                "incorrect results.",
1085
                category=torch.jit.TracerWarning,
1086
                stacklevel=2,
1087
            )
1088
        return self.shape[0]
1089

1090
    def __iter__(self):
1091
        # NB: we use 'imap' and not 'map' here, so that in Python 2 we get a
1092
        # generator and don't eagerly perform all the indexes.  This could
1093
        # save us work, and also helps keep trace ordering deterministic
1094
        # (e.g., if you zip(*hiddens), the eager map will force all the
1095
        # indexes of hiddens[0] before hiddens[1], while the generator
1096
        # map will interleave them.)
1097
        # NB: We have intentionally skipped __torch_function__ dispatch here.
1098
        # See gh-54457
1099
        if self.dim() == 0:
1100
            raise TypeError("iteration over a 0-d tensor")
1101
        if torch._C._get_tracing_state():
1102
            warnings.warn(
1103
                "Iterating over a tensor might cause the trace to be incorrect. "
1104
                "Passing a tensor of different shape won't change the number of "
1105
                "iterations executed (and might lead to errors or silently give "
1106
                "incorrect results).",
1107
                category=torch.jit.TracerWarning,
1108
                stacklevel=2,
1109
            )
1110
        return iter(self.unbind(0))
1111

1112
    def __hash__(self):
1113
        # Do NOT handle __torch_function__ here as user's default
1114
        # implementation that handle most functions will most likely do it wrong.
1115
        # It can be easily overridden by defining this method on the user
1116
        # subclass if needed.
1117
        return id(self)
1118

1119
    def __dir__(self):
1120
        if has_torch_function_unary(self):
1121
            return handle_torch_function(Tensor.__dir__, (self,), self)
1122
        tensor_methods = dir(self.__class__)
1123
        tensor_methods.remove("volatile")  # deprecated
1124
        attrs = list(self.__dict__.keys())
1125
        keys = tensor_methods + attrs
1126

1127
        # property only available dense, cuda tensors
1128
        if (not self.is_cuda) or self.is_sparse:
1129
            keys.remove("__cuda_array_interface__")
1130

1131
        return sorted(keys)
1132

1133
    # Numpy array interface, to support `numpy.asarray(tensor) -> ndarray`
1134
    __array_priority__ = 1000  # prefer Tensor ops over numpy ones
1135

1136
    def __array__(self, dtype=None):
1137
        if has_torch_function_unary(self):
1138
            return handle_torch_function(Tensor.__array__, (self,), self, dtype=dtype)
1139
        if dtype is None:
1140
            return self.numpy()
1141
        else:
1142
            return self.numpy().astype(dtype, copy=False)
1143

1144
    # Wrap Numpy array again in a suitable tensor when done, to support e.g.
1145
    # `numpy.sin(tensor) -> tensor` or `numpy.greater(tensor, 0) -> ByteTensor`
1146
    def __array_wrap__(self, array):
1147
        if has_torch_function_unary(self):
1148
            return handle_torch_function(
1149
                Tensor.__array_wrap__, (self,), self, array=array
1150
            )
1151
        if array.dtype == bool:
1152
            # Workaround, torch has no built-in bool tensor
1153
            array = array.astype("uint8")
1154
        return torch.from_numpy(array)
1155

1156
    def __contains__(self, element: Any, /) -> bool:
1157
        r"""Check if `element` is present in tensor
1158

1159
        Args:
1160
            element (Tensor or scalar): element to be checked
1161
                for presence in current tensor"
1162
        """
1163
        if has_torch_function_unary(self):
1164
            return handle_torch_function(Tensor.__contains__, (self,), self, element)
1165
        if isinstance(
1166
            element, (torch.Tensor, Number, torch.SymInt, torch.SymFloat, torch.SymBool)
1167
        ):
1168
            # type hint doesn't understand the __contains__ result array
1169
            return bool((element == self).any().item())  # type: ignore[union-attr]
1170

1171
        raise RuntimeError(
1172
            f"Tensor.__contains__ only supports Tensor or scalar, but you passed in a {type(element)}."
1173
        )
1174

1175
    @property
1176
    def __cuda_array_interface__(self):
1177
        """Array view description for cuda tensors.
1178

1179
        See:
1180
        https://numba.pydata.org/numba-doc/latest/cuda/cuda_array_interface.html
1181
        """
1182
        if has_torch_function_unary(self):
1183
            # TODO mypy doesn't support @property, see: https://github.com/python/mypy/issues/6185
1184
            return handle_torch_function(
1185
                Tensor.__cuda_array_interface__.__get__,  # type: ignore[attr-defined]
1186
                (self,),
1187
                self,
1188
            )
1189

1190
        # raise AttributeError for unsupported tensors, so that
1191
        # hasattr(cpu_tensor, "__cuda_array_interface__") is False.
1192
        if not self.is_cuda:
1193
            raise AttributeError(
1194
                f"Can't get __cuda_array_interface__ on non-CUDA tensor type: {self.type()} "
1195
                "If CUDA data is required use tensor.cuda() to copy tensor to device memory."
1196
            )
1197

1198
        if self.is_sparse:
1199
            raise AttributeError(
1200
                f"Can't get __cuda_array_interface__ on sparse type: {self.type()} "
1201
                "Use Tensor.to_dense() to convert to a dense tensor first."
1202
            )
1203

1204
        # RuntimeError, matching tensor.__array__() behavior.
1205
        if self.requires_grad:
1206
            raise RuntimeError(
1207
                "Can't get __cuda_array_interface__ on Variable that requires grad. "
1208
                "If gradients aren't required, use var.detach() to get Variable that doesn't require grad."
1209
            )
1210

1211
        # CUDA devices are little-endian and tensors are stored in native byte
1212
        # order. 1-byte entries are endian-agnostic.
1213
        typestr = {
1214
            torch.complex64: "<c8",
1215
            torch.complex128: "<c16",
1216
            torch.bfloat16: "<f2",
1217
            torch.float16: "<f2",
1218
            torch.float32: "<f4",
1219
            torch.float64: "<f8",
1220
            torch.uint8: "|u1",
1221
            torch.int8: "|i1",
1222
            torch.uint16: "<u2",
1223
            torch.int16: "<i2",
1224
            torch.uint32: "<u4",
1225
            torch.int32: "<i4",
1226
            torch.uint64: "<u8",
1227
            torch.int64: "<i8",
1228
            torch.bool: "|b1",
1229
        }[self.dtype]
1230

1231
        itemsize = self.element_size()
1232

1233
        shape = tuple(self.shape)
1234
        if self.is_contiguous():
1235
            # __cuda_array_interface__ v2 requires the strides to be omitted
1236
            # (either not set or set to None) for C-contiguous arrays.
1237
            strides = None
1238
        else:
1239
            strides = tuple(s * itemsize for s in self.stride())
1240
        data_ptr = self.data_ptr() if self.numel() > 0 else 0
1241
        data = (data_ptr, False)  # read-only is false
1242

1243
        return dict(typestr=typestr, shape=shape, strides=strides, data=data, version=2)
1244

1245
    def storage_type(self):
1246
        r"""storage_type() -> type
1247

1248
        Returns the type of the underlying storage.
1249

1250
        """
1251
        if has_torch_function_unary(self):
1252
            return handle_torch_function(Tensor.storage_type, (self,), self)
1253

1254
        torch.storage._warn_typed_storage_removal()
1255

1256
        return self._typed_storage()._get_legacy_storage_class()
1257

1258
    def refine_names(self, *names):
1259
        r"""Refines the dimension names of :attr:`self` according to :attr:`names`.
1260

1261
        Refining is a special case of renaming that "lifts" unnamed dimensions.
1262
        A ``None`` dim can be refined to have any name; a named dim can only be
1263
        refined to have the same name.
1264

1265
        Because named tensors can coexist with unnamed tensors, refining names
1266
        gives a nice way to write named-tensor-aware code that works with both
1267
        named and unnamed tensors.
1268

1269
        :attr:`names` may contain up to one Ellipsis (``...``).
1270
        The Ellipsis is expanded greedily; it is expanded in-place to fill
1271
        :attr:`names` to the same length as ``self.dim()`` using names from the
1272
        corresponding indices of ``self.names``.
1273

1274
        Python 2 does not support Ellipsis but one may use a string literal
1275
        instead (``'...'``).
1276

1277
        Args:
1278
            names (iterable of str): The desired names of the output tensor. May
1279
                contain up to one Ellipsis.
1280

1281
        Examples::
1282

1283
            >>> imgs = torch.randn(32, 3, 128, 128)
1284
            >>> named_imgs = imgs.refine_names('N', 'C', 'H', 'W')
1285
            >>> named_imgs.names
1286
            ('N', 'C', 'H', 'W')
1287

1288
            >>> tensor = torch.randn(2, 3, 5, 7, 11)
1289
            >>> tensor = tensor.refine_names('A', ..., 'B', 'C')
1290
            >>> tensor.names
1291
            ('A', None, None, 'B', 'C')
1292

1293
        .. warning::
1294
            The named tensor API is experimental and subject to change.
1295

1296
        """
1297
        if has_torch_function_unary(self):
1298
            return handle_torch_function(Tensor.refine_names, (self,), self, *names)
1299
        names = resolve_ellipsis(names, self.names, "refine_names")
1300
        return super().refine_names(names)
1301

1302
    def align_to(self, *names):
1303
        r"""Permutes the dimensions of the :attr:`self` tensor to match the order
1304
        specified in :attr:`names`, adding size-one dims for any new names.
1305

1306
        All of the dims of :attr:`self` must be named in order to use this method.
1307
        The resulting tensor is a view on the original tensor.
1308

1309
        All dimension names of :attr:`self` must be present in :attr:`names`.
1310
        :attr:`names` may contain additional names that are not in ``self.names``;
1311
        the output tensor has a size-one dimension for each of those new names.
1312

1313
        :attr:`names` may contain up to one Ellipsis (``...``).
1314
        The Ellipsis is expanded to be equal to all dimension names of :attr:`self`
1315
        that are not mentioned in :attr:`names`, in the order that they appear
1316
        in :attr:`self`.
1317

1318
        Python 2 does not support Ellipsis but one may use a string literal
1319
        instead (``'...'``).
1320

1321
        Args:
1322
            names (iterable of str): The desired dimension ordering of the
1323
                output tensor. May contain up to one Ellipsis that is expanded
1324
                to all unmentioned dim names of :attr:`self`.
1325

1326
        Examples::
1327

1328
            >>> tensor = torch.randn(2, 2, 2, 2, 2, 2)
1329
            >>> named_tensor = tensor.refine_names('A', 'B', 'C', 'D', 'E', 'F')
1330

1331
            # Move the F and E dims to the front while keeping the rest in order
1332
            >>> named_tensor.align_to('F', 'E', ...)
1333

1334
        .. warning::
1335
            The named tensor API is experimental and subject to change.
1336

1337
        """
1338
        if has_torch_function_unary(self):
1339
            return handle_torch_function(Tensor.align_to, (self,), self, *names)
1340
        ellipsis_idx = single_ellipsis_index(names, "align_to")
1341
        if ellipsis_idx is None:
1342
            return super().align_to(names)
1343
        return super().align_to(
1344
            [name for name in names if not is_ellipsis(name)], ellipsis_idx
1345
        )
1346

1347
    def unflatten(self, dim, sizes):
1348
        r"""
1349
        unflatten(dim, sizes) -> Tensor
1350

1351
        See :func:`torch.unflatten`.
1352

1353
        """
1354
        if has_torch_function_unary(self):
1355
            return handle_torch_function(Tensor.unflatten, (self,), self, dim, sizes)
1356

1357
        if not sizes:
1358
            raise RuntimeError("unflatten: sizes must be non-empty")
1359

1360
        names = None
1361
        if isinstance(sizes, OrderedDict) or (
1362
            isinstance(sizes, (tuple, list)) and isinstance(sizes[0], (tuple, list))
1363
        ):
1364
            names, sizes = unzip_namedshape(sizes)
1365
            return super().unflatten(dim, sizes, names)
1366
        else:
1367
            return super().unflatten(dim, sizes)
1368

1369
    def rename_(self, *names, **rename_map):
1370
        """In-place version of :meth:`~Tensor.rename`."""
1371

1372
        if has_torch_function_unary(self):
1373
            return handle_torch_function(
1374
                Tensor.rename_, (self,), self, *names, **rename_map
1375
            )
1376

1377
        # Note [rename_ / rename API]
1378
        # The Python API for these is different from the C++ API. In Python:
1379
        # 1) tensor.rename(*names) takes a vararglist of names
1380
        # 2) tensor.rename(**rename_map) takes a map of names to rename.
1381
        # C++ is static, making it difficult to implement similar behavior.
1382
        return update_names(self, names, rename_map, inplace=True)
1383

1384
    def rename(self, *names, **rename_map):
1385
        """Renames dimension names of :attr:`self`.
1386

1387
        There are two main usages:
1388

1389
        ``self.rename(**rename_map)`` returns a view on tensor that has dims
1390
        renamed as specified in the mapping :attr:`rename_map`.
1391

1392
        ``self.rename(*names)`` returns a view on tensor, renaming all
1393
        dimensions positionally using :attr:`names`.
1394
        Use ``self.rename(None)`` to drop names on a tensor.
1395

1396
        One cannot specify both positional args :attr:`names` and keyword args
1397
        :attr:`rename_map`.
1398

1399
        Examples::
1400

1401
            >>> imgs = torch.rand(2, 3, 5, 7, names=('N', 'C', 'H', 'W'))
1402
            >>> renamed_imgs = imgs.rename(N='batch', C='channels')
1403
            >>> renamed_imgs.names
1404
            ('batch', 'channels', 'H', 'W')
1405

1406
            >>> renamed_imgs = imgs.rename(None)
1407
            >>> renamed_imgs.names
1408
            (None, None, None, None)
1409

1410
            >>> renamed_imgs = imgs.rename('batch', 'channel', 'height', 'width')
1411
            >>> renamed_imgs.names
1412
            ('batch', 'channel', 'height', 'width')
1413

1414
        .. warning::
1415
            The named tensor API is experimental and subject to change.
1416

1417
        """
1418
        if has_torch_function_unary(self):
1419
            return handle_torch_function(
1420
                Tensor.rename, (self,), self, *names, **rename_map
1421
            )
1422

1423
        # See Note [rename_ / rename API]
1424
        return update_names(self, names, rename_map, inplace=False)
1425

1426
    def to_sparse_coo(self):
1427
        """Convert a tensor to :ref:`coordinate format <sparse-coo-docs>`.
1428

1429
        Examples::
1430

1431
             >>> dense = torch.randn(5, 5)
1432
             >>> sparse = dense.to_sparse_coo()
1433
             >>> sparse._nnz()
1434
             25
1435

1436
        """
1437
        return self.to_sparse()
1438

1439
    def dim_order(self):
1440
        """
1441

1442
        dim_order() -> tuple
1443

1444
        Returns a tuple of int describing the dim order or physical layout of :attr:`self`.
1445

1446
        Args:
1447
            None
1448

1449
        Dim order represents how dimensions are laid out in memory,
1450
        starting from the outermost to the innermost dimension.
1451

1452
        Example::
1453
            >>> torch.empty((2, 3, 5, 7)).dim_order()
1454
            (0, 1, 2, 3)
1455
            >>> torch.empty((2, 3, 5, 7), memory_format=torch.channels_last).dim_order()
1456
            (0, 2, 3, 1)
1457

1458
        .. warning::
1459
            The dim_order tensor API is experimental and subject to change.
1460

1461
        """
1462
        if has_torch_function_unary(self):
1463
            return handle_torch_function(Tensor.dim_order, (self,), self)
1464

1465
        import torch._prims_common as utils
1466

1467
        return tuple(utils.compute_elementwise_output_logical_to_physical_perm(self))
1468

1469
    def _update_names(self, names, inplace):
1470
        if has_torch_function_unary(self):
1471
            return handle_torch_function(
1472
                Tensor._update_names, (self,), self, names, inplace
1473
            )
1474

1475
        # See Note [rename_ / rename API]
1476
        if inplace:
1477
            return super().rename_(names)
1478
        else:
1479
            return super().rename(names)
1480

1481
    @classmethod
1482
    def __torch_function__(cls, func, types, args=(), kwargs=None):
1483
        """
1484
        This __torch_function__ implementation wraps subclasses such that
1485
        methods called on subclasses return a subclass instance instead of
1486
        a ``torch.Tensor`` instance.
1487

1488
        One corollary to this is that you need coverage for torch.Tensor
1489
        methods if implementing __torch_function__ for subclasses.
1490

1491
        We recommend always calling ``super().__torch_function__`` as the base
1492
        case when doing the above.
1493

1494
        While not mandatory, we recommend making `__torch_function__` a classmethod.
1495
        """
1496
        if kwargs is None:
1497
            kwargs = {}
1498

1499
        if not all(issubclass(cls, t) for t in types):
1500
            return NotImplemented
1501

1502
        with _C.DisableTorchFunctionSubclass():
1503
            ret = func(*args, **kwargs)
1504
            if func in get_default_nowrap_functions():
1505
                return ret
1506
            else:
1507
                return _convert(ret, cls)
1508

1509
    __torch_dispatch__ = _C._disabled_torch_dispatch_impl
1510

1511
    def __dlpack__(self, stream=None):
1512
        """
1513
        Creates a DLpack `capsule https://data-apis.org/array-api/latest/design_topics/data_interchange.html#data-interchange`_
1514
        of the current tensor to be exported to other libraries.
1515

1516
        This function will be called from the `from_dlpack` method
1517
        of the library that will consume the capsule. `from_dlpack` passes the current
1518
        stream to this method as part of the specification.
1519

1520
        Args:
1521
            stream (integer or None): An optional Python integer representing a
1522
            pointer to a CUDA stream. The current stream is synchronized with
1523
            this stream before the capsule is created, and since the capsule
1524
            shares its storage with the tensor this make it safe to access from
1525
            both streams.  If None or -1 is passed then no synchronization is performed.
1526
            If 1 (on CUDA) or 0 (on ROCM) then the default stream is used for
1527
            synchronization.
1528
        """
1529
        if has_torch_function_unary(self):
1530
            return handle_torch_function(Tensor.__dlpack__, (self,), self, stream)
1531

1532
        # DLPack capsules can't capture all of PyTorch's semantics,
1533
        # so we prohibit exporting tensors that would lose their properties like
1534
        # requires_grad and having the conjugate bit set.
1535
        if self.requires_grad:
1536
            raise RuntimeError(
1537
                "Can't export tensors that require gradient, use tensor.detach()"
1538
            )
1539
        if self.is_conj():
1540
            raise RuntimeError("Can't export tensors with the conjugate bit set")
1541
        if self.layout != torch.strided:
1542
            raise RuntimeError(
1543
                "Can't export tensors with layout other than torch.strided"
1544
            )
1545

1546
        if stream is not None and type(stream) is not int:
1547
            # Stream pointers in CUDA/ROCm are uniquely numbered and can
1548
            # be retrieved from their integer value.
1549
            raise TypeError("stream must be ``int`` or ``none``")
1550
        elif stream is not None and stream != -1:
1551
            if self.device.type == "cuda":
1552
                # NB: This logic handles the special case values for default
1553
                # streams and must be kept in sync with from_dlpack in
1554
                # torch/utils/dlpack.py
1555
                if stream == 1 and torch.version.hip is None:
1556
                    stream = torch.cuda.default_stream()
1557
                elif stream == 0 and torch.version.hip is not None:
1558
                    stream = torch.cuda.default_stream()
1559
                else:
1560
                    stream = torch.cuda.ExternalStream(stream)
1561
                # Only synchronize on different streams
1562
                sync_stream = torch.cuda.current_stream()
1563
                if stream != sync_stream:
1564
                    event = torch.cuda.Event()
1565
                    event.record(sync_stream)
1566
                    stream.wait_event(event)
1567
        return torch.to_dlpack(self)
1568

1569
    def __dlpack_device__(self) -> Tuple[enum.IntEnum, int]:
1570
        if has_torch_function_unary(self):
1571
            return handle_torch_function(Tensor.__dlpack_device__, (self,), self)
1572

1573
        from torch.utils.dlpack import DLDeviceType
1574

1575
        device = self.device
1576
        idx = device.index if device.index is not None else 0
1577
        torch_device_type = device.type
1578
        if torch_device_type == "cuda" and torch.version.hip is not None:
1579
            device_type = DLDeviceType.kDLROCM
1580
        elif torch_device_type == "cpu" and self.is_pinned():
1581
            device_type = DLDeviceType.kDLCPUPinned
1582
        elif torch_device_type == "cuda":
1583
            device_type = DLDeviceType.kDLGPU
1584
        elif torch_device_type == "cpu":
1585
            device_type = DLDeviceType.kDLCPU
1586
        elif self.device.type == "xpu":
1587
            device_type = DLDeviceType.kDLOneAPI
1588
        else:
1589
            raise ValueError(f"Unknown device type {torch_device_type} for Dlpack")
1590
        return (device_type, idx)
1591

1592
    __module__ = "torch"
1593

1594

1595
def _convert(ret, cls):
1596
    if cls is Tensor:
1597
        return ret
1598

1599
    if isinstance(ret, Tensor) and not isinstance(ret, cls):
1600
        ret = ret.as_subclass(cls)
1601

1602
    if isinstance(ret, (tuple, list)):
1603
        # Also handles things like namedtuples
1604
        ret = type(ret)(_convert(r, cls) for r in ret)
1605

1606
    return ret
1607

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.