pytorch

Форк
0
/
functional.py 
1983 строки · 83.6 Кб
1
from typing import (
2
    List, Tuple, Optional, Union, Any, Sequence, TYPE_CHECKING
3
)
4
import operator
5
import itertools
6

7
import torch
8
from torch._C import _add_docstr
9
import torch.nn.functional as F
10
from ._lowrank import svd_lowrank, pca_lowrank
11
from .overrides import (
12
    has_torch_function, has_torch_function_unary, has_torch_function_variadic,
13
    handle_torch_function)
14
from ._jit_internal import boolean_dispatch
15
from ._jit_internal import _overload as overload
16

17
Tensor = torch.Tensor
18
from torch import _VF
19

20
__all__ = [
21
    'atleast_1d',
22
    'atleast_2d',
23
    'atleast_3d',
24
    'align_tensors',
25
    'broadcast_shapes',
26
    'broadcast_tensors',
27
    'cartesian_prod',
28
    'block_diag',
29
    'cdist',
30
    'chain_matmul',
31
    'einsum',
32
    'istft',
33
    'lu',
34
    'norm',
35
    'meshgrid',
36
    'pca_lowrank',
37
    'split',
38
    'stft',
39
    'svd_lowrank',
40
    'tensordot',
41
    'unique',
42
    'unique_consecutive',
43
    'unravel_index',
44
]
45

46

47
def broadcast_tensors(*tensors):
48
    r"""broadcast_tensors(*tensors) -> List of Tensors
49

50
    Broadcasts the given tensors according to :ref:`broadcasting-semantics`.
51

52
    Args:
53
        *tensors: any number of tensors of the same type
54

55
    .. warning::
56

57
        More than one element of a broadcasted tensor may refer to a single
58
        memory location. As a result, in-place operations (especially ones that
59
        are vectorized) may result in incorrect behavior. If you need to write
60
        to the tensors, please clone them first.
61

62
    Example::
63

64
        >>> x = torch.arange(3).view(1, 3)
65
        >>> y = torch.arange(2).view(2, 1)
66
        >>> a, b = torch.broadcast_tensors(x, y)
67
        >>> a.size()
68
        torch.Size([2, 3])
69
        >>> a
70
        tensor([[0, 1, 2],
71
                [0, 1, 2]])
72
    """
73
    # This wrapper exists to support variadic args.
74
    if has_torch_function(tensors):
75
        return handle_torch_function(broadcast_tensors, tensors, *tensors)
76
    return _VF.broadcast_tensors(tensors)  # type: ignore[attr-defined]
77

78

79
def broadcast_shapes(*shapes):
80
    r"""broadcast_shapes(*shapes) -> Size
81

82
    Similar to :func:`broadcast_tensors` but for shapes.
83

84
    This is equivalent to
85
    ``torch.broadcast_tensors(*map(torch.empty, shapes))[0].shape``
86
    but avoids the need create to intermediate tensors. This is useful for
87
    broadcasting tensors of common batch shape but different rightmost shape,
88
    e.g. to broadcast mean vectors with covariance matrices.
89

90
    Example::
91

92
        >>> torch.broadcast_shapes((2,), (3, 1), (1, 1, 1))
93
        torch.Size([1, 3, 2])
94

95
    Args:
96
        \*shapes (torch.Size): Shapes of tensors.
97

98
    Returns:
99
        shape (torch.Size): A shape compatible with all input shapes.
100

101
    Raises:
102
        RuntimeError: If shapes are incompatible.
103
    """
104
    # This wrapper exists to support variadic args.
105
    # TODO Move this to C++ once the jit has better support for torch.Size.
106
    if not torch.jit.is_tracing():
107
        max_len = 0
108
        for shape in shapes:
109
            if isinstance(shape, (int, torch.SymInt)):
110
                if max_len < 1:
111
                    max_len = 1
112
            elif isinstance(shape, (tuple, list)):
113
                s = len(shape)
114
                if max_len < s:
115
                    max_len = s
116
        result = [1] * max_len
117

118
        from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
119

120
        for shape in shapes:
121
            if isinstance(shape, (int, torch.SymInt)):
122
                shape = (shape,)
123
            if isinstance(shape, (tuple, list)):
124
                for i in range(-1, -1 - len(shape), -1):
125
                    if shape[i] < 0:
126
                        raise RuntimeError(f"Trying to create tensor with negative dimension ({shape[i]}): ({shape[i]})")
127
                    # NB: result is initialized to 1 so this is effectively an
128
                    # equals one test
129
                    if guard_size_oblivious(shape[i] == 1) or guard_size_oblivious(shape[i] == result[i]):
130
                        continue
131
                    if result[i] != 1:
132
                        raise RuntimeError("Shape mismatch: objects cannot be broadcast to a single shape")
133
                    result[i] = shape[i]
134
            else:
135
                raise RuntimeError("Input shapes should be of type ints, a tuple of ints, or a list of ints, got ", shape)
136
        return torch.Size(result)
137
    else:
138
        # with implementation above, torch.jit.trace hardcodes the sizes which makes subsequent replays fail
139
        with torch.no_grad():
140
            scalar = torch.zeros((), device="cpu")
141
            tensors = [scalar.expand(shape) for shape in shapes]
142
            tensors = broadcast_tensors(*tensors)
143
            return tensors[0].shape
144

145

146
def split(
147
    tensor: Tensor, split_size_or_sections: Union[int, List[int]], dim: int = 0
148
) -> Tuple[Tensor, ...]:
149
    r"""Splits the tensor into chunks. Each chunk is a view of the original tensor.
150

151
    If :attr:`split_size_or_sections` is an integer type, then :attr:`tensor` will
152
    be split into equally sized chunks (if possible). Last chunk will be smaller if
153
    the tensor size along the given dimension :attr:`dim` is not divisible by
154
    :attr:`split_size`.
155

156
    If :attr:`split_size_or_sections` is a list, then :attr:`tensor` will be split
157
    into ``len(split_size_or_sections)`` chunks with sizes in :attr:`dim` according
158
    to :attr:`split_size_or_sections`.
159

160
    Args:
161
        tensor (Tensor): tensor to split.
162
        split_size_or_sections (int) or (list(int)): size of a single chunk or
163
            list of sizes for each chunk
164
        dim (int): dimension along which to split the tensor.
165

166
    Example::
167

168
        >>> a = torch.arange(10).reshape(5, 2)
169
        >>> a
170
        tensor([[0, 1],
171
                [2, 3],
172
                [4, 5],
173
                [6, 7],
174
                [8, 9]])
175
        >>> torch.split(a, 2)
176
        (tensor([[0, 1],
177
                 [2, 3]]),
178
         tensor([[4, 5],
179
                 [6, 7]]),
180
         tensor([[8, 9]]))
181
        >>> torch.split(a, [1, 4])
182
        (tensor([[0, 1]]),
183
         tensor([[2, 3],
184
                 [4, 5],
185
                 [6, 7],
186
                 [8, 9]]))
187
    """
188
    if has_torch_function_unary(tensor):
189
        return handle_torch_function(
190
            split, (tensor,), tensor, split_size_or_sections, dim=dim)
191
    # Overwriting reason:
192
    # This dispatches to two ATen functions depending on the type of
193
    # split_size_or_sections. The branching code is in _tensor.py, which we
194
    # call here.
195
    return tensor.split(split_size_or_sections, dim)
196

197

198
def einsum(*args: Any) -> Tensor:
199
    r"""einsum(equation, *operands) -> Tensor
200

201
    Sums the product of the elements of the input :attr:`operands` along dimensions specified using a notation
202
    based on the Einstein summation convention.
203

204
    Einsum allows computing many common multi-dimensional linear algebraic array operations by representing them
205
    in a short-hand format based on the Einstein summation convention, given by :attr:`equation`. The details of
206
    this format are described below, but the general idea is to label every dimension of the input :attr:`operands`
207
    with some subscript and define which subscripts are part of the output. The output is then computed by summing
208
    the product of the elements of the :attr:`operands` along the dimensions whose subscripts are not part of the
209
    output. For example, matrix multiplication can be computed using einsum as `torch.einsum("ij,jk->ik", A, B)`.
210
    Here, j is the summation subscript and i and k the output subscripts (see section below for more details on why).
211

212
    Equation:
213

214
        The :attr:`equation` string specifies the subscripts (letters in `[a-zA-Z]`) for each dimension of
215
        the input :attr:`operands` in the same order as the dimensions, separating subscripts for each operand by a
216
        comma (','), e.g. `'ij,jk'` specify subscripts for two 2D operands. The dimensions labeled with the same subscript
217
        must be broadcastable, that is, their size must either match or be `1`. The exception is if a subscript is
218
        repeated for the same input operand, in which case the dimensions labeled with this subscript for this operand
219
        must match in size and the operand will be replaced by its diagonal along these dimensions. The subscripts that
220
        appear exactly once in the :attr:`equation` will be part of the output, sorted in increasing alphabetical order.
221
        The output is computed by multiplying the input :attr:`operands` element-wise, with their dimensions aligned based
222
        on the subscripts, and then summing out the dimensions whose subscripts are not part of the output.
223

224
        Optionally, the output subscripts can be explicitly defined by adding an arrow ('->') at the end of the equation
225
        followed by the subscripts for the output. For instance, the following equation computes the transpose of a
226
        matrix multiplication: 'ij,jk->ki'. The output subscripts must appear at least once for some input operand and
227
        at most once for the output.
228

229
        Ellipsis ('...') can be used in place of subscripts to broadcast the dimensions covered by the ellipsis.
230
        Each input operand may contain at most one ellipsis which will cover the dimensions not covered by subscripts,
231
        e.g. for an input operand with 5 dimensions, the ellipsis in the equation `'ab...c'` cover the third and fourth
232
        dimensions. The ellipsis does not need to cover the same number of dimensions across the :attr:`operands` but the
233
        'shape' of the ellipsis (the size of the dimensions covered by them) must broadcast together. If the output is not
234
        explicitly defined with the arrow ('->') notation, the ellipsis will come first in the output (left-most dimensions),
235
        before the subscript labels that appear exactly once for the input operands. e.g. the following equation implements
236
        batch matrix multiplication `'...ij,...jk'`.
237

238
        A few final notes: the equation may contain whitespaces between the different elements (subscripts, ellipsis,
239
        arrow and comma) but something like `'. . .'` is not valid. An empty string `''` is valid for scalar operands.
240

241
    .. note::
242

243
        ``torch.einsum`` handles ellipsis ('...') differently from NumPy in that it allows dimensions
244
        covered by the ellipsis to be summed over, that is, ellipsis are not required to be part of the output.
245

246
    .. note::
247

248
        This function uses opt_einsum (https://optimized-einsum.readthedocs.io/en/stable/) to speed up computation or to
249
        consume less memory by optimizing contraction order. This optimization occurs when there are at least three
250
        inputs, since the order does not matter otherwise. Note that finding _the_ optimal path is an NP-hard problem,
251
        thus, opt_einsum relies on different heuristics to achieve near-optimal results. If opt_einsum is not available,
252
        the default order is to contract from left to right.
253

254
        To bypass this default behavior, add the following line to disable the usage of opt_einsum and skip path
255
        calculation: `torch.backends.opt_einsum.enabled = False`
256

257
        To specify which strategy you'd like for opt_einsum to compute the contraction path, add the following line:
258
        `torch.backends.opt_einsum.strategy = 'auto'`. The default strategy is 'auto', and we also support 'greedy' and
259
        'optimal'. Disclaimer that the runtime of 'optimal' is factorial in the number of inputs! See more details in
260
        the opt_einsum documentation (https://optimized-einsum.readthedocs.io/en/stable/path_finding.html).
261

262
    .. note::
263

264
        As of PyTorch 1.10 :func:`torch.einsum` also supports the sublist format (see examples below). In this format,
265
        subscripts for each operand are specified by sublists, list of integers in the range [0, 52). These sublists
266
        follow their operands, and an extra sublist can appear at the end of the input to specify the output's
267
        subscripts., e.g. `torch.einsum(op1, sublist1, op2, sublist2, ..., [subslist_out])`. Python's `Ellipsis` object
268
        may be provided in a sublist to enable broadcasting as described in the Equation section above.
269

270
    Args:
271
        equation (str): The subscripts for the Einstein summation.
272
        operands (List[Tensor]): The tensors to compute the Einstein summation of.
273

274
    Examples::
275

276
        >>> # xdoctest: +IGNORE_WANT("non-deterministic")
277
        >>> # trace
278
        >>> torch.einsum('ii', torch.randn(4, 4))
279
        tensor(-1.2104)
280

281
        >>> # xdoctest: +IGNORE_WANT("non-deterministic")
282
        >>> # diagonal
283
        >>> torch.einsum('ii->i', torch.randn(4, 4))
284
        tensor([-0.1034,  0.7952, -0.2433,  0.4545])
285

286
        >>> # xdoctest: +IGNORE_WANT("non-deterministic")
287
        >>> # outer product
288
        >>> x = torch.randn(5)
289
        >>> y = torch.randn(4)
290
        >>> torch.einsum('i,j->ij', x, y)
291
        tensor([[ 0.1156, -0.2897, -0.3918,  0.4963],
292
                [-0.3744,  0.9381,  1.2685, -1.6070],
293
                [ 0.7208, -1.8058, -2.4419,  3.0936],
294
                [ 0.1713, -0.4291, -0.5802,  0.7350],
295
                [ 0.5704, -1.4290, -1.9323,  2.4480]])
296

297
        >>> # xdoctest: +IGNORE_WANT("non-deterministic")
298
        >>> # batch matrix multiplication
299
        >>> As = torch.randn(3, 2, 5)
300
        >>> Bs = torch.randn(3, 5, 4)
301
        >>> torch.einsum('bij,bjk->bik', As, Bs)
302
        tensor([[[-1.0564, -1.5904,  3.2023,  3.1271],
303
                [-1.6706, -0.8097, -0.8025, -2.1183]],
304

305
                [[ 4.2239,  0.3107, -0.5756, -0.2354],
306
                [-1.4558, -0.3460,  1.5087, -0.8530]],
307

308
                [[ 2.8153,  1.8787, -4.3839, -1.2112],
309
                [ 0.3728, -2.1131,  0.0921,  0.8305]]])
310

311
        >>> # xdoctest: +IGNORE_WANT("non-deterministic")
312
        >>> # with sublist format and ellipsis
313
        >>> torch.einsum(As, [..., 0, 1], Bs, [..., 1, 2], [..., 0, 2])
314
        tensor([[[-1.0564, -1.5904,  3.2023,  3.1271],
315
                [-1.6706, -0.8097, -0.8025, -2.1183]],
316

317
                [[ 4.2239,  0.3107, -0.5756, -0.2354],
318
                [-1.4558, -0.3460,  1.5087, -0.8530]],
319

320
                [[ 2.8153,  1.8787, -4.3839, -1.2112],
321
                [ 0.3728, -2.1131,  0.0921,  0.8305]]])
322

323
        >>> # batch permute
324
        >>> A = torch.randn(2, 3, 4, 5)
325
        >>> torch.einsum('...ij->...ji', A).shape
326
        torch.Size([2, 3, 5, 4])
327

328
        >>> # equivalent to torch.nn.functional.bilinear
329
        >>> A = torch.randn(3, 5, 4)
330
        >>> l = torch.randn(2, 5)
331
        >>> r = torch.randn(2, 4)
332
        >>> torch.einsum('bn,anm,bm->ba', l, A, r)
333
        tensor([[-0.3430, -5.2405,  0.4494],
334
                [ 0.3311,  5.5201, -3.0356]])
335
    """
336
    import torch.backends.opt_einsum as opt_einsum
337
    # This wrapper exists to support variadic args.
338
    if len(args) < 2:
339
        raise ValueError('einsum(): must specify the equation string and at least one operand, '
340
                         'or at least one operand and its subscripts list')
341

342
    equation = None
343
    operands = None
344

345
    if isinstance(args[0], torch.Tensor):
346
        # Convert the subscript list format which is an interleaving of operand and its subscripts
347
        # list with an optional output subscripts list at the end (see documentation for more details on this)
348
        # to the equation string format by creating the equation string from the subscripts list and grouping the
349
        # input operands into a tensorlist (List[Tensor]).
350
        def parse_subscript(n: int) -> str:
351
            if n == Ellipsis:
352
                return '...'
353
            if n >= 0 and n < 26:
354
                return chr(ord('A') + n)
355
            if n >= 26 and n < 52:
356
                return chr(ord('a') + n - 26)
357
            raise ValueError('einsum(): subscript in subscript list is not within the valid range [0, 52)')
358

359
        # Parse subscripts for input operands
360
        equation = ','.join(''.join(parse_subscript(s) for s in l) for l in args[1::2])
361

362
        # Parse optional output subscripts (provided when the number of arguments is odd)
363
        if len(args) % 2 == 1:
364
            equation += '->' + ''.join(parse_subscript(s) for s in args[-1])
365
            operands = args[:-1:2]
366
        else:
367
            operands = args[::2]
368
    else:
369
        equation = args[0]
370
        operands = args[1:]
371

372
    if has_torch_function(operands):
373
        return handle_torch_function(einsum, operands, equation, *operands)
374

375
    if len(operands) == 1 and isinstance(operands[0], (list, tuple)):
376
        # the old interface of passing the operands as one list argument
377
        _operands = operands[0]
378
        # recurse incase operands contains value that has torch function
379
        # in the original implementation this line is omitted
380
        return einsum(equation, *_operands)
381

382
    if len(operands) <= 2 or not opt_einsum.enabled:
383
        # the path for contracting 0 or 1 time(s) is already optimized
384
        # or the user has disabled using opt_einsum
385
        return _VF.einsum(equation, operands)  # type: ignore[attr-defined]
386

387
    path = None
388
    if opt_einsum.is_available():
389
        _opt_einsum = opt_einsum.get_opt_einsum()
390
        tupled_path = _opt_einsum.contract_path(equation, *operands, optimize=opt_einsum.strategy)[0]
391
        # flatten path for dispatching to C++
392
        path = [item for pair in tupled_path for item in pair]
393
    return _VF.einsum(equation, operands, path=path)  # type: ignore[attr-defined]
394

395

396
# This wrapper exists to support variadic args.
397
if TYPE_CHECKING:
398
    # The JIT doesn't understand Union, so only add type annotation for mypy
399
    def meshgrid(*tensors: Union[Tensor, List[Tensor]],
400
                 indexing: Optional[str] = None) -> Tuple[Tensor, ...]:
401
        return _meshgrid(*tensors, indexing=indexing)
402
else:
403
    def meshgrid(*tensors, indexing: Optional[str] = None) -> Tuple[Tensor, ...]:
404
        r"""Creates grids of coordinates specified by the 1D inputs in `attr`:tensors.
405

406
        This is helpful when you want to visualize data over some
407
        range of inputs. See below for a plotting example.
408

409
        Given :math:`N` 1D tensors :math:`T_0 \ldots T_{N-1}` as
410
        inputs with corresponding sizes :math:`S_0 \ldots S_{N-1}`,
411
        this creates :math:`N` N-dimensional tensors :math:`G_0 \ldots
412
        G_{N-1}`, each with shape :math:`(S_0, ..., S_{N-1})` where
413
        the output :math:`G_i` is constructed by expanding :math:`T_i`
414
        to the result shape.
415

416
        .. note::
417
            0D inputs are treated equivalently to 1D inputs of a
418
            single element.
419

420
        .. warning::
421
            `torch.meshgrid(*tensors)` currently has the same behavior
422
            as calling `numpy.meshgrid(*arrays, indexing='ij')`.
423

424
            In the future `torch.meshgrid` will transition to
425
            `indexing='xy'` as the default.
426

427
            https://github.com/pytorch/pytorch/issues/50276 tracks
428
            this issue with the goal of migrating to NumPy's behavior.
429

430
        .. seealso::
431

432
            :func:`torch.cartesian_prod` has the same effect but it
433
            collects the data in a tensor of vectors.
434

435
        Args:
436
            tensors (list of Tensor): list of scalars or 1 dimensional tensors. Scalars will be
437
                treated as tensors of size :math:`(1,)` automatically
438

439
            indexing: (str, optional): the indexing mode, either "xy"
440
                or "ij", defaults to "ij". See warning for future changes.
441

442
                If "xy" is selected, the first dimension corresponds
443
                to the cardinality of the second input and the second
444
                dimension corresponds to the cardinality of the first
445
                input.
446

447
                If "ij" is selected, the dimensions are in the same
448
                order as the cardinality of the inputs.
449

450
        Returns:
451
            seq (sequence of Tensors): If the input has :math:`N`
452
            tensors of size :math:`S_0 \ldots S_{N-1}``, then the
453
            output will also have :math:`N` tensors, where each tensor
454
            is of shape :math:`(S_0, ..., S_{N-1})`.
455

456
        Example::
457

458
            >>> x = torch.tensor([1, 2, 3])
459
            >>> y = torch.tensor([4, 5, 6])
460

461
            Observe the element-wise pairings across the grid, (1, 4),
462
            (1, 5), ..., (3, 6). This is the same thing as the
463
            cartesian product.
464
            >>> grid_x, grid_y = torch.meshgrid(x, y, indexing='ij')
465
            >>> grid_x
466
            tensor([[1, 1, 1],
467
                    [2, 2, 2],
468
                    [3, 3, 3]])
469
            >>> grid_y
470
            tensor([[4, 5, 6],
471
                    [4, 5, 6],
472
                    [4, 5, 6]])
473

474
            This correspondence can be seen when these grids are
475
            stacked properly.
476
            >>> torch.equal(torch.cat(tuple(torch.dstack([grid_x, grid_y]))),
477
            ...             torch.cartesian_prod(x, y))
478
            True
479

480
            `torch.meshgrid` is commonly used to produce a grid for
481
            plotting.
482
            >>> # xdoctest: +REQUIRES(module:matplotlib)
483
            >>> # xdoctest: +REQUIRES(env:DOCTEST_SHOW)
484
            >>> import matplotlib.pyplot as plt
485
            >>> xs = torch.linspace(-5, 5, steps=100)
486
            >>> ys = torch.linspace(-5, 5, steps=100)
487
            >>> x, y = torch.meshgrid(xs, ys, indexing='xy')
488
            >>> z = torch.sin(torch.sqrt(x * x + y * y))
489
            >>> ax = plt.axes(projection='3d')
490
            >>> ax.plot_surface(x.numpy(), y.numpy(), z.numpy())
491
            >>> plt.show()
492

493
        .. image:: ../_static/img/meshgrid.png
494
            :width: 512
495

496
        """
497
        return _meshgrid(*tensors, indexing=indexing)
498

499

500
def _meshgrid(*tensors, indexing: Optional[str]):
501
    if has_torch_function(tensors):
502
        return handle_torch_function(meshgrid, tensors, *tensors, indexing=indexing)
503
    if len(tensors) == 1 and isinstance(tensors[0], (list, tuple)):
504
        # the old interface of passing the operands as one list argument
505
        tensors = tensors[0]  # type: ignore[assignment]
506

507
    # Continue allowing call of old method that takes no indexing
508
    # kwarg for forward compatibility reasons.
509
    #
510
    # Remove this two weeks after landing.
511
    kwargs = {} if indexing is None else {'indexing': indexing}
512
    return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
513

514

515
def stft(input: Tensor, n_fft: int, hop_length: Optional[int] = None,
516
         win_length: Optional[int] = None, window: Optional[Tensor] = None,
517
         center: bool = True, pad_mode: str = 'reflect', normalized: bool = False,
518
         onesided: Optional[bool] = None,
519
         return_complex: Optional[bool] = None) -> Tensor:
520
    r"""Short-time Fourier transform (STFT).
521

522
    .. warning::
523
        From version 1.8.0, :attr:`return_complex` must always be given
524
        explicitly for real inputs and `return_complex=False` has been
525
        deprecated. Strongly prefer `return_complex=True` as in a future
526
        pytorch release, this function will only return complex tensors.
527

528
        Note that :func:`torch.view_as_real` can be used to recover a real
529
        tensor with an extra last dimension for real and imaginary components.
530

531
    .. warning::
532
        From version 2.1, a warning will be provided if a :attr:`window` is
533
        not specified. In a future release, this attribute will be required.
534
        Not providing a window currently defaults to using a rectangular window,
535
        which may result in undesirable artifacts. Consider using tapered windows,
536
        such as :func:`torch.hann_window`.
537

538
    The STFT computes the Fourier transform of short overlapping windows of the
539
    input. This giving frequency components of the signal as they change over
540
    time. The interface of this function is modeled after (but *not* a drop-in
541
    replacement for) librosa_ stft function.
542

543
    .. _librosa: https://librosa.org/doc/latest/generated/librosa.stft.html
544

545
    Ignoring the optional batch dimension, this method computes the following
546
    expression:
547

548
    .. math::
549
        X[\omega, m] = \sum_{k = 0}^{\text{win\_length-1}}%
550
                            \text{window}[k]\ \text{input}[m \times \text{hop\_length} + k]\ %
551
                            \exp\left(- j \frac{2 \pi \cdot \omega k}{\text{n\_fft}}\right),
552

553
    where :math:`m` is the index of the sliding window, and :math:`\omega` is
554
    the frequency :math:`0 \leq \omega < \text{n\_fft}` for ``onesided=False``,
555
    or :math:`0 \leq \omega < \lfloor \text{n\_fft} / 2 \rfloor + 1` for ``onesided=True``.
556

557
    * :attr:`input` must be either a 1-D time sequence or a 2-D batch of time
558
      sequences.
559

560
    * If :attr:`hop_length` is ``None`` (default), it is treated as equal to
561
      ``floor(n_fft / 4)``.
562

563
    * If :attr:`win_length` is ``None`` (default), it is treated as equal to
564
      :attr:`n_fft`.
565

566
    * :attr:`window` can be a 1-D tensor of size :attr:`win_length`, e.g., from
567
      :meth:`torch.hann_window`. If :attr:`window` is ``None`` (default), it is
568
      treated as if having :math:`1` everywhere in the window. If
569
      :math:`\text{win\_length} < \text{n\_fft}`, :attr:`window` will be padded on
570
      both sides to length :attr:`n_fft` before being applied.
571

572
    * If :attr:`center` is ``True`` (default), :attr:`input` will be padded on
573
      both sides so that the :math:`t`-th frame is centered at time
574
      :math:`t \times \text{hop\_length}`. Otherwise, the :math:`t`-th frame
575
      begins at time  :math:`t \times \text{hop\_length}`.
576

577
    * :attr:`pad_mode` determines the padding method used on :attr:`input` when
578
      :attr:`center` is ``True``. See :meth:`torch.nn.functional.pad` for
579
      all available options. Default is ``"reflect"``.
580

581
    * If :attr:`onesided` is ``True`` (default for real input), only values for
582
      :math:`\omega` in :math:`\left[0, 1, 2, \dots, \left\lfloor
583
      \frac{\text{n\_fft}}{2} \right\rfloor + 1\right]` are returned because
584
      the real-to-complex Fourier transform satisfies the conjugate symmetry,
585
      i.e., :math:`X[m, \omega] = X[m, \text{n\_fft} - \omega]^*`.
586
      Note if the input or window tensors are complex, then :attr:`onesided`
587
      output is not possible.
588

589
    * If :attr:`normalized` is ``True`` (default is ``False``), the function
590
      returns the normalized STFT results, i.e., multiplied by :math:`(\text{frame\_length})^{-0.5}`.
591

592
    * If :attr:`return_complex` is ``True`` (default if input is complex), the
593
      return is a ``input.dim() + 1`` dimensional complex tensor. If ``False``,
594
      the output is a ``input.dim() + 2`` dimensional real tensor where the last
595
      dimension represents the real and imaginary components.
596

597
    Returns either a complex tensor of size :math:`(* \times N \times T)` if
598
    :attr:`return_complex` is true, or a real tensor of size :math:`(* \times N
599
    \times T \times 2)`. Where :math:`*` is the optional batch size of
600
    :attr:`input`, :math:`N` is the number of frequencies where STFT is applied
601
    and :math:`T` is the total number of frames used.
602

603
    .. warning::
604
      This function changed signature at version 0.4.1. Calling with the
605
      previous signature may cause error or return incorrect result.
606

607
    Args:
608
        input (Tensor): the input tensor of shape `(B?, L)` where `B?` is an optional
609
            batch dimension
610
        n_fft (int): size of Fourier transform
611
        hop_length (int, optional): the distance between neighboring sliding window
612
            frames. Default: ``None`` (treated as equal to ``floor(n_fft / 4)``)
613
        win_length (int, optional): the size of window frame and STFT filter.
614
            Default: ``None``  (treated as equal to :attr:`n_fft`)
615
        window (Tensor, optional): the optional window function.
616
            Shape must be 1d and `<= n_fft`
617
            Default: ``None`` (treated as window of all :math:`1` s)
618
        center (bool, optional): whether to pad :attr:`input` on both sides so
619
            that the :math:`t`-th frame is centered at time :math:`t \times \text{hop\_length}`.
620
            Default: ``True``
621
        pad_mode (str, optional): controls the padding method used when
622
            :attr:`center` is ``True``. Default: ``"reflect"``
623
        normalized (bool, optional): controls whether to return the normalized STFT results
624
             Default: ``False``
625
        onesided (bool, optional): controls whether to return half of results to
626
            avoid redundancy for real inputs.
627
            Default: ``True`` for real :attr:`input` and :attr:`window`, ``False`` otherwise.
628
        return_complex (bool, optional): whether to return a complex tensor, or
629
            a real tensor with an extra last dimension for the real and
630
            imaginary components.
631

632
            .. versionchanged:: 2.0
633
               ``return_complex`` is now a required argument for real inputs,
634
               as the default is being transitioned to ``True``.
635

636
            .. deprecated:: 2.0
637
               ``return_complex=False`` is deprecated, instead use ``return_complex=True``
638
               Note that calling :func:`torch.view_as_real` on the output will
639
               recover the deprecated output format.
640

641
    Returns:
642
        Tensor: A tensor containing the STFT result with shape `(B?, N, T, C?)` where
643
           - `B?` is an optional batch dimension from the input.
644
           - `N` is the number of frequency samples, `(n_fft // 2) + 1` for
645
             `onesided=True`, or otherwise `n_fft`.
646
           - `T` is the number of frames, `1 + L // hop_length`
647
             for `center=True`, or `1 + (L - n_fft) // hop_length` otherwise.
648
           - `C?` is an optional length-2 dimension of real and imaginary
649
             components, present when `return_complex=False`.
650

651
    """
652
    if has_torch_function_unary(input):
653
        return handle_torch_function(
654
            stft, (input,), input, n_fft, hop_length=hop_length, win_length=win_length,
655
            window=window, center=center, pad_mode=pad_mode, normalized=normalized,
656
            onesided=onesided, return_complex=return_complex)
657
    # NOTE: Do not edit. This code will be removed once the forward-compatibility
658
    #       period is over for PR #73432
659
    if center:
660
        signal_dim = input.dim()
661
        extended_shape = [1] * (3 - signal_dim) + list(input.size())
662
        pad = int(n_fft // 2)
663
        input = F.pad(input.view(extended_shape), [pad, pad], pad_mode)
664
        input = input.view(input.shape[-signal_dim:])
665
    return _VF.stft(input, n_fft, hop_length, win_length, window,  # type: ignore[attr-defined]
666
                    normalized, onesided, return_complex)
667

668

669
istft = _add_docstr(
670
    torch.istft,
671
    "istft(input, n_fft, hop_length=None, win_length=None, window=None, center=True, "
672
    "normalized=False, onesided=None, length=None, return_complex=False) -> Tensor:\n"
673
    r"""
674
Inverse short time Fourier Transform. This is expected to be the inverse of :func:`~torch.stft`.
675

676
.. warning::
677
    From version 2.1, a warning will be provided if a :attr:`window` is
678
    not specified. In a future release, this attribute will be required.
679
    Please provide the same window used in the stft call.
680

681
It has the same parameters (+ additional optional parameter of :attr:`length`) and it should return the
682
least squares estimation of the original signal. The algorithm will check using the NOLA condition (
683
nonzero overlap).
684

685
Important consideration in the parameters :attr:`window` and :attr:`center` so that the envelope
686
created by the summation of all the windows is never zero at certain point in time. Specifically,
687
:math:`\sum_{t=-\infty}^{\infty} |w|^2[n-t\times hop\_length] \cancel{=} 0`.
688

689
Since :func:`~torch.stft` discards elements at the end of the signal if they do not fit in a frame,
690
``istft`` may return a shorter signal than the original signal (can occur if :attr:`center` is False
691
since the signal isn't padded). If `length` is given in the arguments and is longer than expected,
692
``istft`` will pad zeros to the end of the returned signal.
693

694
If :attr:`center` is ``True``, then there will be padding e.g. ``'constant'``, ``'reflect'``, etc.
695
Left padding can be trimmed off exactly because they can be calculated but right padding cannot be
696
calculated without additional information.
697

698
Example: Suppose the last window is:
699
``[17, 18, 0, 0, 0]`` vs ``[18, 0, 0, 0, 0]``
700

701
The :attr:`n_fft`, :attr:`hop_length`, :attr:`win_length` are all the same which prevents the calculation
702
of right padding. These additional values could be zeros or a reflection of the signal so providing
703
:attr:`length` could be useful. If :attr:`length` is ``None`` then padding will be aggressively removed
704
(some loss of signal).
705

706
[1] D. W. Griffin and J. S. Lim, "Signal estimation from modified short-time Fourier transform,"
707
IEEE Trans. ASSP, vol.32, no.2, pp.236-243, Apr. 1984.
708

709
Args:
710
    input (Tensor): The input tensor. Expected to be in the format of :func:`~torch.stft`,
711
        output. That is a complex tensor of shape `(B?, N, T)` where
712

713
        - `B?` is an optional batch dimension
714
        - `N` is the number of frequency samples, `(n_fft // 2) + 1`
715
          for onesided input, or otherwise `n_fft`.
716
        - `T` is the number of frames, `1 + length // hop_length` for centered stft,
717
          or `1 + (length - n_fft) // hop_length` otherwise.
718

719
        .. versionchanged:: 2.0
720
            Real datatype inputs are no longer supported. Input must now have a
721
            complex datatype, as returned by ``stft(..., return_complex=True)``.
722
    n_fft (int): Size of Fourier transform
723
    hop_length (Optional[int]): The distance between neighboring sliding window frames.
724
        (Default: ``n_fft // 4``)
725
    win_length (Optional[int]): The size of window frame and STFT filter. (Default: ``n_fft``)
726
    window (Optional[torch.Tensor]): The optional window function.
727
        Shape must be 1d and `<= n_fft`
728
        (Default: ``torch.ones(win_length)``)
729
    center (bool): Whether :attr:`input` was padded on both sides so that the :math:`t`-th frame is
730
        centered at time :math:`t \times \text{hop\_length}`.
731
        (Default: ``True``)
732
    normalized (bool): Whether the STFT was normalized. (Default: ``False``)
733
    onesided (Optional[bool]): Whether the STFT was onesided.
734
        (Default: ``True`` if `n_fft != fft_size` in the input size)
735
    length (Optional[int]): The amount to trim the signal by (i.e. the
736
        original signal length). Defaults to `(T - 1) * hop_length` for
737
        centered stft, or `n_fft + (T - 1) * hop_length` otherwise, where `T`
738
        is the number of input frames.
739
    return_complex (Optional[bool]):
740
        Whether the output should be complex, or if the input should be
741
        assumed to derive from a real signal and window.
742
        Note that this is incompatible with ``onesided=True``.
743
        (Default: ``False``)
744

745
Returns:
746
    Tensor: Least squares estimation of the original signal of shape `(B?, length)` where
747
        `B?` is an optional batch dimension from the input tensor.
748
""")
749

750

751
if TYPE_CHECKING:
752
    # These _impl functions return a variable number of tensors as output with
753
    # __torch_function__; tuple unpacking is done already rather than being
754
    # done by the caller of the _impl function
755
    _unique_impl_out = Any
756
else:
757
    _unique_impl_out = Tuple[Tensor, Tensor, Tensor]
758

759

760
def _unique_impl(input: Tensor, sorted: bool = True,
761
                 return_inverse: bool = False, return_counts: bool = False,
762
                 dim: Optional[int] = None) -> _unique_impl_out:
763
    r"""unique(input, sorted=True, return_inverse=False, return_counts=False, dim=None) -> Tuple[Tensor, Tensor, Tensor]
764

765
    Returns the unique elements of the input tensor.
766

767
    .. note:: This function is different from :func:`torch.unique_consecutive` in the sense that
768
        this function also eliminates non-consecutive duplicate values.
769

770
    .. note:: Currently in the CUDA implementation and the CPU implementation,
771
        `torch.unique` always sort the tensor at the beginning regardless of the `sort` argument.
772
        Sorting could be slow, so if your input tensor is already sorted, it is recommended to use
773
        :func:`torch.unique_consecutive` which avoids the sorting.
774

775
    Args:
776
        input (Tensor): the input tensor
777
        sorted (bool): Whether to sort the unique elements in ascending order
778
            before returning as output.
779
        return_inverse (bool): Whether to also return the indices for where
780
            elements in the original input ended up in the returned unique list.
781
        return_counts (bool): Whether to also return the counts for each unique
782
            element.
783
        dim (int, optional): the dimension to operate upon. If ``None``, the
784
            unique of the flattened input is returned. Otherwise, each of the
785
            tensors indexed by the given dimension is treated as one of the
786
            elements to apply the unique operation upon. See examples for more
787
            details. Default: ``None``
788

789
    Returns:
790
        (Tensor, Tensor (optional), Tensor (optional)): A tensor or a tuple of tensors containing
791

792
            - **output** (*Tensor*): the output list of unique scalar elements.
793
            - **inverse_indices** (*Tensor*): (optional) if
794
              :attr:`return_inverse` is True, there will be an additional
795
              returned tensor (same shape as input) representing the indices
796
              for where elements in the original input map to in the output;
797
              otherwise, this function will only return a single tensor.
798
            - **counts** (*Tensor*): (optional) if
799
              :attr:`return_counts` is True, there will be an additional
800
              returned tensor (same shape as output or output.size(dim),
801
              if dim was specified) representing the number of occurrences
802
              for each unique value or tensor.
803

804
    Example::
805

806
        >>> output = torch.unique(torch.tensor([1, 3, 2, 3], dtype=torch.long))
807
        >>> output
808
        tensor([1, 2, 3])
809

810
        >>> output, inverse_indices = torch.unique(
811
        ...     torch.tensor([1, 3, 2, 3], dtype=torch.long), sorted=True, return_inverse=True)
812
        >>> output
813
        tensor([1, 2, 3])
814
        >>> inverse_indices
815
        tensor([0, 2, 1, 2])
816

817
        >>> output, inverse_indices = torch.unique(
818
        ...     torch.tensor([[1, 3], [2, 3]], dtype=torch.long), sorted=True, return_inverse=True)
819
        >>> output
820
        tensor([1, 2, 3])
821
        >>> inverse_indices
822
        tensor([[0, 2],
823
                [1, 2]])
824

825
        >>> a = torch.tensor([
826
        ...     [
827
        ...         [1, 1, 0, 0],
828
        ...         [1, 1, 0, 0],
829
        ...         [0, 0, 1, 1],
830
        ...     ],
831
        ...     [
832
        ...         [0, 0, 1, 1],
833
        ...         [0, 0, 1, 1],
834
        ...         [1, 1, 1, 1],
835
        ...     ],
836
        ...     [
837
        ...         [1, 1, 0, 0],
838
        ...         [1, 1, 0, 0],
839
        ...         [0, 0, 1, 1],
840
        ...     ],
841
        ... ])
842

843
        >>> # If we call `torch.unique(a, dim=0)`, each of the tensors `a[idx, :, :]`
844
        >>> # will be compared. We can see that `a[0, :, :]` and `a[2, :, :]` match
845
        >>> # each other, so one of them will be removed.
846
        >>> (a[0, :, :] == a[2, :, :]).all()
847
        tensor(True)
848
        >>> a_unique_dim0 = torch.unique(a, dim=0)
849
        >>> a_unique_dim0
850
        tensor([[[0, 0, 1, 1],
851
                 [0, 0, 1, 1],
852
                 [1, 1, 1, 1]],
853
                [[1, 1, 0, 0],
854
                 [1, 1, 0, 0],
855
                 [0, 0, 1, 1]]])
856

857
        >>> # Notice which sub-tensors from `a` match with the sub-tensors from
858
        >>> # `a_unique_dim0`:
859
        >>> (a_unique_dim0[0, :, :] == a[1, :, :]).all()
860
        tensor(True)
861
        >>> (a_unique_dim0[1, :, :] == a[0, :, :]).all()
862
        tensor(True)
863

864
        >>> # For `torch.unique(a, dim=1)`, each of the tensors `a[:, idx, :]` are
865
        >>> # compared. `a[:, 0, :]` and `a[:, 1, :]` match each other, so one of
866
        >>> # them will be removed.
867
        >>> (a[:, 0, :] == a[:, 1, :]).all()
868
        tensor(True)
869
        >>> torch.unique(a, dim=1)
870
        tensor([[[0, 0, 1, 1],
871
                 [1, 1, 0, 0]],
872
                [[1, 1, 1, 1],
873
                 [0, 0, 1, 1]],
874
                [[0, 0, 1, 1],
875
                 [1, 1, 0, 0]]])
876

877
        >>> # For `torch.unique(a, dim=2)`, the tensors `a[:, :, idx]` are compared.
878
        >>> # `a[:, :, 0]` and `a[:, :, 1]` match each other. Also, `a[:, :, 2]` and
879
        >>> # `a[:, :, 3]` match each other as well. So in this case, two of the
880
        >>> # sub-tensors will be removed.
881
        >>> (a[:, :, 0] == a[:, :, 1]).all()
882
        tensor(True)
883
        >>> (a[:, :, 2] == a[:, :, 3]).all()
884
        tensor(True)
885
        >>> torch.unique(a, dim=2)
886
        tensor([[[0, 1],
887
                 [0, 1],
888
                 [1, 0]],
889
                [[1, 0],
890
                 [1, 0],
891
                 [1, 1]],
892
                [[0, 1],
893
                 [0, 1],
894
                 [1, 0]]])
895
    """
896
    if has_torch_function_unary(input):
897
        return handle_torch_function(
898
            unique, (input,), input, sorted=sorted, return_inverse=return_inverse,
899
            return_counts=return_counts, dim=dim)
900

901
    if dim is not None:
902
        output, inverse_indices, counts = _VF.unique_dim(
903
            input,
904
            dim,
905
            sorted=sorted,
906
            return_inverse=return_inverse,
907
            return_counts=return_counts,
908
        )
909
    else:
910
        output, inverse_indices, counts = torch._unique2(
911
            input,
912
            sorted=sorted,
913
            return_inverse=return_inverse,
914
            return_counts=return_counts,
915
        )
916
    return output, inverse_indices, counts
917

918

919
def _unique_consecutive_impl(input: Tensor, return_inverse: bool = False,
920
                             return_counts: bool = False,
921
                             dim: Optional[int] = None) -> _unique_impl_out:
922
    r"""Eliminates all but the first element from every consecutive group of equivalent elements.
923

924
    .. note:: This function is different from :func:`torch.unique` in the sense that this function
925
        only eliminates consecutive duplicate values. This semantics is similar to `std::unique`
926
        in C++.
927

928
    Args:
929
        input (Tensor): the input tensor
930
        return_inverse (bool): Whether to also return the indices for where
931
            elements in the original input ended up in the returned unique list.
932
        return_counts (bool): Whether to also return the counts for each unique
933
            element.
934
        dim (int): the dimension to apply unique. If ``None``, the unique of the
935
            flattened input is returned. default: ``None``
936

937
    Returns:
938
        (Tensor, Tensor (optional), Tensor (optional)): A tensor or a tuple of tensors containing
939

940
            - **output** (*Tensor*): the output list of unique scalar elements.
941
            - **inverse_indices** (*Tensor*): (optional) if
942
              :attr:`return_inverse` is True, there will be an additional
943
              returned tensor (same shape as input) representing the indices
944
              for where elements in the original input map to in the output;
945
              otherwise, this function will only return a single tensor.
946
            - **counts** (*Tensor*): (optional) if
947
              :attr:`return_counts` is True, there will be an additional
948
              returned tensor (same shape as output or output.size(dim),
949
              if dim was specified) representing the number of occurrences
950
              for each unique value or tensor.
951

952
    Example::
953

954
        >>> x = torch.tensor([1, 1, 2, 2, 3, 1, 1, 2])
955
        >>> output = torch.unique_consecutive(x)
956
        >>> output
957
        tensor([1, 2, 3, 1, 2])
958

959
        >>> output, inverse_indices = torch.unique_consecutive(x, return_inverse=True)
960
        >>> output
961
        tensor([1, 2, 3, 1, 2])
962
        >>> inverse_indices
963
        tensor([0, 0, 1, 1, 2, 3, 3, 4])
964

965
        >>> output, counts = torch.unique_consecutive(x, return_counts=True)
966
        >>> output
967
        tensor([1, 2, 3, 1, 2])
968
        >>> counts
969
        tensor([2, 2, 1, 2, 1])
970
    """
971
    if has_torch_function_unary(input):
972
        return handle_torch_function(
973
            unique_consecutive, (input,), input, return_inverse=return_inverse,
974
            return_counts=return_counts, dim=dim)
975
    output, inverse_indices, counts = _VF.unique_consecutive(  # type: ignore[attr-defined]
976
        input, return_inverse=return_inverse, return_counts=return_counts, dim=dim)
977
    return output, inverse_indices, counts
978

979

980
def _return_counts(input, sorted=True, return_inverse=False, return_counts=False, dim=None):
981
    # type: (Tensor, bool, bool, bool, Optional[int]) -> Tuple[Tensor, Tensor]
982

983
    if has_torch_function_unary(input):
984
        return _unique_impl(input, sorted, return_inverse, return_counts, dim)
985

986
    output, _, counts = _unique_impl(input, sorted, return_inverse, return_counts, dim)
987
    return output, counts
988

989

990
def _return_output(input, sorted=True, return_inverse=False, return_counts=False, dim=None):
991
    # type: (Tensor, bool, bool, bool, Optional[int]) -> Tensor
992

993
    if has_torch_function_unary(input):
994
        return _unique_impl(input, sorted, return_inverse, return_counts, dim)
995

996
    output, _, _ = _unique_impl(input, sorted, return_inverse, return_counts, dim)
997
    return output
998

999

1000
def _return_inverse(input, sorted=True, return_inverse=False, return_counts=False, dim=None):
1001
    # type: (Tensor, bool, bool, bool, Optional[int]) -> Tuple[Tensor, Tensor]
1002

1003
    if has_torch_function_unary(input):
1004
        return _unique_impl(input, sorted, return_inverse, return_counts, dim)
1005

1006
    output, inverse_indices, _ = _unique_impl(input, sorted, return_inverse, return_counts, dim)
1007
    return output, inverse_indices
1008

1009

1010
_return_inverse_false = boolean_dispatch(
1011
    arg_name='return_counts',
1012
    arg_index=3,
1013
    default=False,
1014
    if_true=_return_counts,
1015
    if_false=_return_output,
1016
    module_name=__name__,
1017
    func_name='unique')
1018

1019
_return_inverse_true = boolean_dispatch(
1020
    arg_name='return_counts',
1021
    arg_index=3,
1022
    default=False,
1023
    if_true=_unique_impl,
1024
    if_false=_return_inverse,
1025
    module_name=__name__,
1026
    func_name='unique')
1027

1028
# The return type of unique depends on `return_inverse`, and `return_counts` so in order to
1029
# resolve the output type in TorchScript we need to statically know the value of both parameters
1030

1031
unique = boolean_dispatch(
1032
    arg_name='return_inverse',
1033
    arg_index=2,
1034
    default=False,
1035
    if_true=_return_inverse_true,
1036
    if_false=_return_inverse_false,
1037
    module_name=__name__,
1038
    func_name='unique')
1039
unique.__doc__ = _unique_impl.__doc__
1040

1041

1042
def _consecutive_return_counts(input, return_inverse=False, return_counts=False, dim=None):
1043
    # type: (Tensor, bool, bool, Optional[int]) -> Tuple[Tensor, Tensor]
1044

1045
    if has_torch_function_unary(input):
1046
        return _unique_consecutive_impl(input, return_inverse, return_counts, dim)
1047

1048
    output, _, counts = _unique_consecutive_impl(input, return_inverse, return_counts, dim)
1049
    return output, counts
1050

1051

1052
def _consecutive_return_output(input, return_inverse=False, return_counts=False, dim=None):
1053
    # type: (Tensor, bool, bool, Optional[int]) -> Tensor
1054

1055
    if has_torch_function_unary(input):
1056
        return _unique_consecutive_impl(input, return_inverse, return_counts, dim)
1057

1058
    output, _, _ = _unique_consecutive_impl(input, return_inverse, return_counts, dim)
1059
    return output
1060

1061

1062
def _consecutive_return_inverse(input, return_inverse=False, return_counts=False, dim=None):
1063
    # type: (Tensor, bool, bool, Optional[int]) -> Tuple[Tensor, Tensor]
1064

1065
    if has_torch_function_unary(input):
1066
        return _unique_consecutive_impl(input, return_inverse, return_counts, dim)
1067

1068
    output, inverse_indices, _ = _unique_consecutive_impl(input, return_inverse, return_counts, dim)
1069
    return output, inverse_indices
1070

1071

1072
_consecutive_return_inverse_false = boolean_dispatch(
1073
    arg_name='return_counts',
1074
    arg_index=1,
1075
    default=False,
1076
    if_true=_consecutive_return_counts,
1077
    if_false=_consecutive_return_output,
1078
    module_name=__name__,
1079
    func_name='unique_consecutive')
1080

1081
_consecutive_return_inverse_true = boolean_dispatch(
1082
    arg_name='return_counts',
1083
    arg_index=1,
1084
    default=False,
1085
    if_true=_unique_consecutive_impl,
1086
    if_false=_consecutive_return_inverse,
1087
    module_name=__name__,
1088
    func_name='unique_consecutive')
1089

1090
# The return type of unique depends on `return_inverse`, and `return_counts` so in order to
1091
# resolve the output type in TorchScript we need to statically know the value of both parameters
1092

1093
unique_consecutive = boolean_dispatch(
1094
    arg_name='return_inverse',
1095
    arg_index=2,
1096
    default=False,
1097
    if_true=_consecutive_return_inverse_true,
1098
    if_false=_consecutive_return_inverse_false,
1099
    module_name=__name__,
1100
    func_name='unique_consecutive')
1101
unique_consecutive.__doc__ = _unique_consecutive_impl.__doc__
1102

1103
if TYPE_CHECKING:
1104
    pass
1105
    # There's no good way to use this type annotation without breaking JIT
1106
    # overloads. So leave untyped for mypy for now.
1107
else:
1108
    @overload
1109
    def tensordot(a, b, dims: int = 2, out: Optional[torch.Tensor] = None):
1110
        pass
1111

1112
    @overload  # noqa: F811
1113
    def tensordot(a, b, dims: Tuple[List[int], List[int]], out: Optional[torch.Tensor] = None):  # noqa: F811
1114
        pass
1115

1116
    @overload  # noqa: F811
1117
    def tensordot(a, b, dims: List[List[int]], out: Optional[torch.Tensor] = None):  # noqa: F811
1118
        pass
1119

1120
    @overload  # noqa: F811
1121
    def tensordot(a, b, dims: torch.Tensor, out: Optional[torch.Tensor] = None):  # noqa: F811
1122
        pass
1123

1124

1125
def tensordot(a, b, dims=2, out: Optional[torch.Tensor] = None):  # noqa: F811
1126
    r"""Returns a contraction of a and b over multiple dimensions.
1127

1128
    :attr:`tensordot` implements a generalized matrix product.
1129

1130
    Args:
1131
      a (Tensor): Left tensor to contract
1132
      b (Tensor): Right tensor to contract
1133
      dims (int or Tuple[List[int], List[int]] or List[List[int]] containing two lists or Tensor): number of dimensions to
1134
         contract or explicit lists of dimensions for :attr:`a` and
1135
         :attr:`b` respectively
1136

1137
    When called with a non-negative integer argument :attr:`dims` = :math:`d`, and
1138
    the number of dimensions of :attr:`a` and :attr:`b` is :math:`m` and :math:`n`,
1139
    respectively, :func:`~torch.tensordot` computes
1140

1141
    .. math::
1142
        r_{i_0,...,i_{m-d}, i_d,...,i_n}
1143
          = \sum_{k_0,...,k_{d-1}} a_{i_0,...,i_{m-d},k_0,...,k_{d-1}} \times b_{k_0,...,k_{d-1}, i_d,...,i_n}.
1144

1145
    When called with :attr:`dims` of the list form, the given dimensions will be contracted
1146
    in place of the last :math:`d` of :attr:`a` and the first :math:`d` of :math:`b`. The sizes
1147
    in these dimensions must match, but :func:`~torch.tensordot` will deal with broadcasted
1148
    dimensions.
1149

1150
    Examples::
1151

1152
        >>> a = torch.arange(60.).reshape(3, 4, 5)
1153
        >>> b = torch.arange(24.).reshape(4, 3, 2)
1154
        >>> torch.tensordot(a, b, dims=([1, 0], [0, 1]))
1155
        tensor([[4400., 4730.],
1156
                [4532., 4874.],
1157
                [4664., 5018.],
1158
                [4796., 5162.],
1159
                [4928., 5306.]])
1160

1161
        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
1162
        >>> a = torch.randn(3, 4, 5, device='cuda')
1163
        >>> b = torch.randn(4, 5, 6, device='cuda')
1164
        >>> c = torch.tensordot(a, b, dims=2).cpu()
1165
        tensor([[ 8.3504, -2.5436,  6.2922,  2.7556, -1.0732,  3.2741],
1166
                [ 3.3161,  0.0704,  5.0187, -0.4079, -4.3126,  4.8744],
1167
                [ 0.8223,  3.9445,  3.2168, -0.2400,  3.4117,  1.7780]])
1168

1169
        >>> a = torch.randn(3, 5, 4, 6)
1170
        >>> b = torch.randn(6, 4, 5, 3)
1171
        >>> torch.tensordot(a, b, dims=([2, 1, 3], [1, 2, 0]))
1172
        tensor([[  7.7193,  -2.4867, -10.3204],
1173
                [  1.5513, -14.4737,  -6.5113],
1174
                [ -0.2850,   4.2573,  -3.5997]])
1175
    """
1176
    if has_torch_function_variadic(a, b):
1177
        return handle_torch_function(tensordot, (a, b), a, b, dims=dims, out=out)
1178

1179
    if not isinstance(dims, (tuple, list, torch.Tensor, int, torch.SymInt)):
1180
        raise RuntimeError("tensordot expects dims to be int or "
1181
                           + "Tuple[List[int], List[int]] or "
1182
                           + "List[List[int]] containing two lists, but got "
1183
                           + f"dims={dims}")
1184

1185
    dims_a: List[int] = []
1186
    dims_b: List[int] = []
1187

1188
    if isinstance(dims, (tuple, list)):
1189
        dims_a, dims_b = dims
1190

1191
    if isinstance(dims, torch.Tensor):
1192
        num_elements = dims.numel()
1193
        if num_elements > 1:
1194
            assert dims.size()[0] == 2
1195
            dims_a = torch.jit.annotate(List[int], dims[0].tolist())
1196
            dims_b = torch.jit.annotate(List[int], dims[1].tolist())
1197
        else:
1198
            dims_val = int(dims.item())
1199
            if dims_val < 0:
1200
                raise RuntimeError(f"tensordot expects dims >= 0, but got dims={dims}")
1201
            dims_a = list(range(-dims_val, 0))
1202
            dims_b = list(range(dims_val))
1203

1204
    if isinstance(dims, (int, torch.SymInt)):
1205
        if dims < 0:
1206
            raise RuntimeError(f"tensordot expects dims >= 0, but got dims={dims}")
1207
        if dims > min(a.dim(), b.dim()):
1208
            raise RuntimeError(f"tensordot expects dims < ndim_a or ndim_b, but got dims={dims}")
1209
        dims_a = list(range(-dims, 0))
1210
        dims_b = list(range(dims))
1211

1212
    if out is None:
1213
        return _VF.tensordot(a, b, dims_a, dims_b)  # type: ignore[attr-defined]
1214
    else:
1215
        return _VF.tensordot(a, b, dims_a, dims_b, out=out)  # type: ignore[attr-defined]
1216

1217

1218
def cartesian_prod(*tensors: Tensor) -> Tensor:
1219
    """Do cartesian product of the given sequence of tensors. The behavior is similar to
1220
    python's `itertools.product`.
1221

1222
    Args:
1223
        *tensors: any number of 1 dimensional tensors.
1224

1225
    Returns:
1226
        Tensor: A tensor equivalent to converting all the input tensors into lists,
1227
        do `itertools.product` on these lists, and finally convert the resulting list
1228
        into tensor.
1229

1230
    Example::
1231

1232
        >>> import itertools
1233
        >>> a = [1, 2, 3]
1234
        >>> b = [4, 5]
1235
        >>> list(itertools.product(a, b))
1236
        [(1, 4), (1, 5), (2, 4), (2, 5), (3, 4), (3, 5)]
1237
        >>> tensor_a = torch.tensor(a)
1238
        >>> tensor_b = torch.tensor(b)
1239
        >>> torch.cartesian_prod(tensor_a, tensor_b)
1240
        tensor([[1, 4],
1241
                [1, 5],
1242
                [2, 4],
1243
                [2, 5],
1244
                [3, 4],
1245
                [3, 5]])
1246
    """
1247
    # This wrapper exists to support variadic args.
1248
    if has_torch_function(tensors):
1249
        return handle_torch_function(cartesian_prod, tensors, *tensors)
1250
    return _VF.cartesian_prod(tensors)  # type: ignore[attr-defined]
1251

1252

1253
def block_diag(*tensors):
1254
    """Create a block diagonal matrix from provided tensors.
1255

1256
    Args:
1257
        *tensors: One or more tensors with 0, 1, or 2 dimensions.
1258

1259
    Returns:
1260
        Tensor: A 2 dimensional tensor with all the input tensors arranged in
1261
        order such that their upper left and lower right corners are
1262
        diagonally adjacent. All other elements are set to 0.
1263

1264
    Example::
1265

1266
        >>> import torch
1267
        >>> A = torch.tensor([[0, 1], [1, 0]])
1268
        >>> B = torch.tensor([[3, 4, 5], [6, 7, 8]])
1269
        >>> C = torch.tensor(7)
1270
        >>> D = torch.tensor([1, 2, 3])
1271
        >>> E = torch.tensor([[4], [5], [6]])
1272
        >>> torch.block_diag(A, B, C, D, E)
1273
        tensor([[0, 1, 0, 0, 0, 0, 0, 0, 0, 0],
1274
                [1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
1275
                [0, 0, 3, 4, 5, 0, 0, 0, 0, 0],
1276
                [0, 0, 6, 7, 8, 0, 0, 0, 0, 0],
1277
                [0, 0, 0, 0, 0, 7, 0, 0, 0, 0],
1278
                [0, 0, 0, 0, 0, 0, 1, 2, 3, 0],
1279
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 4],
1280
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 5],
1281
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 6]])
1282
    """
1283
    # This wrapper exists to support variadic args.
1284
    if has_torch_function(tensors):
1285
        return handle_torch_function(block_diag, tensors, *tensors)
1286
    return torch._C._VariableFunctions.block_diag(tensors)  # type: ignore[attr-defined]
1287

1288

1289
def cdist(x1, x2, p=2., compute_mode='use_mm_for_euclid_dist_if_necessary'):
1290
    # type: (Tensor, Tensor, float, str) -> (Tensor)
1291
    r"""Computes batched the p-norm distance between each pair of the two collections of row vectors.
1292

1293
    Args:
1294
        x1 (Tensor): input tensor of shape :math:`B \times P \times M`.
1295
        x2 (Tensor): input tensor of shape :math:`B \times R \times M`.
1296
        p: p value for the p-norm distance to calculate between each vector pair
1297
            :math:`\in [0, \infty]`.
1298
        compute_mode:
1299
            'use_mm_for_euclid_dist_if_necessary' - will use matrix multiplication approach to calculate
1300
            euclidean distance (p = 2) if P > 25 or R > 25
1301
            'use_mm_for_euclid_dist' - will always use matrix multiplication approach to calculate
1302
            euclidean distance (p = 2)
1303
            'donot_use_mm_for_euclid_dist' - will never use matrix multiplication approach to calculate
1304
            euclidean distance (p = 2)
1305
            Default: use_mm_for_euclid_dist_if_necessary.
1306

1307
    If x1 has shape :math:`B \times P \times M` and x2 has shape :math:`B \times R \times M` then the
1308
    output will have shape :math:`B \times P \times R`.
1309

1310
    This function is equivalent to `scipy.spatial.distance.cdist(input,'minkowski', p=p)`
1311
    if :math:`p \in (0, \infty)`. When :math:`p = 0` it is equivalent to
1312
    `scipy.spatial.distance.cdist(input, 'hamming') * M`. When :math:`p = \infty`, the closest
1313
    scipy function is `scipy.spatial.distance.cdist(xn, lambda x, y: np.abs(x - y).max())`.
1314

1315
    Example:
1316

1317
        >>> a = torch.tensor([[0.9041,  0.0196], [-0.3108, -2.4423], [-0.4821,  1.059]])
1318
        >>> a
1319
        tensor([[ 0.9041,  0.0196],
1320
                [-0.3108, -2.4423],
1321
                [-0.4821,  1.0590]])
1322
        >>> b = torch.tensor([[-2.1763, -0.4713], [-0.6986,  1.3702]])
1323
        >>> b
1324
        tensor([[-2.1763, -0.4713],
1325
                [-0.6986,  1.3702]])
1326
        >>> torch.cdist(a, b, p=2)
1327
        tensor([[3.1193, 2.0959],
1328
                [2.7138, 3.8322],
1329
                [2.2830, 0.3791]])
1330
    """
1331
    if has_torch_function_variadic(x1, x2):
1332
        return handle_torch_function(
1333
            cdist, (x1, x2), x1, x2, p=p, compute_mode=compute_mode)
1334
    if compute_mode == 'use_mm_for_euclid_dist_if_necessary':
1335
        return _VF.cdist(x1, x2, p, None)  # type: ignore[attr-defined]
1336
    elif compute_mode == 'use_mm_for_euclid_dist':
1337
        return _VF.cdist(x1, x2, p, 1)  # type: ignore[attr-defined]
1338
    elif compute_mode == 'donot_use_mm_for_euclid_dist':
1339
        return _VF.cdist(x1, x2, p, 2)  # type: ignore[attr-defined]
1340
    else:
1341
        raise ValueError(f"{compute_mode} is not a valid value for compute_mode")
1342

1343

1344
def atleast_1d(*tensors):
1345
    r"""
1346
    Returns a 1-dimensional view of each input tensor with zero dimensions.
1347
    Input tensors with one or more dimensions are returned as-is.
1348

1349
    Args:
1350
        input (Tensor or list of Tensors)
1351

1352
    Returns:
1353
        output (Tensor or tuple of Tensors)
1354

1355
    Example::
1356

1357
        >>> x = torch.arange(2)
1358
        >>> x
1359
        tensor([0, 1])
1360
        >>> torch.atleast_1d(x)
1361
        tensor([0, 1])
1362
        >>> x = torch.tensor(1.)
1363
        >>> x
1364
        tensor(1.)
1365
        >>> torch.atleast_1d(x)
1366
        tensor([1.])
1367
        >>> x = torch.tensor(0.5)
1368
        >>> y = torch.tensor(1.)
1369
        >>> torch.atleast_1d((x, y))
1370
        (tensor([0.5000]), tensor([1.]))
1371
    """
1372
    # This wrapper exists to support variadic args.
1373
    if has_torch_function(tensors):
1374
        return handle_torch_function(atleast_1d, tensors, *tensors)
1375
    if len(tensors) == 1:
1376
        tensors = tensors[0]
1377
    return _VF.atleast_1d(tensors)  # type: ignore[attr-defined]
1378

1379

1380
def atleast_2d(*tensors):
1381
    r"""
1382
    Returns a 2-dimensional view of each input tensor with zero dimensions.
1383
    Input tensors with two or more dimensions are returned as-is.
1384

1385
    Args:
1386
        input (Tensor or list of Tensors)
1387

1388
    Returns:
1389
        output (Tensor or tuple of Tensors)
1390

1391
    Example::
1392

1393
        >>> x = torch.tensor(1.)
1394
        >>> x
1395
        tensor(1.)
1396
        >>> torch.atleast_2d(x)
1397
        tensor([[1.]])
1398
        >>> x = torch.arange(4).view(2, 2)
1399
        >>> x
1400
        tensor([[0, 1],
1401
                [2, 3]])
1402
        >>> torch.atleast_2d(x)
1403
        tensor([[0, 1],
1404
                [2, 3]])
1405
        >>> x = torch.tensor(0.5)
1406
        >>> y = torch.tensor(1.)
1407
        >>> torch.atleast_2d((x, y))
1408
        (tensor([[0.5000]]), tensor([[1.]]))
1409
    """
1410
    # This wrapper exists to support variadic args.
1411
    if has_torch_function(tensors):
1412
        return handle_torch_function(atleast_2d, tensors, *tensors)
1413
    if len(tensors) == 1:
1414
        tensors = tensors[0]
1415
    return _VF.atleast_2d(tensors)  # type: ignore[attr-defined]
1416

1417

1418
def atleast_3d(*tensors):
1419
    r"""
1420
    Returns a 3-dimensional view of each input tensor with zero dimensions.
1421
    Input tensors with three or more dimensions are returned as-is.
1422

1423
    Args:
1424
        input (Tensor or list of Tensors)
1425

1426
    Returns:
1427
        output (Tensor or tuple of Tensors)
1428

1429
    Example:
1430

1431
        >>> x = torch.tensor(0.5)
1432
        >>> x
1433
        tensor(0.5000)
1434
        >>> torch.atleast_3d(x)
1435
        tensor([[[0.5000]]])
1436
        >>> y = torch.arange(4).view(2, 2)
1437
        >>> y
1438
        tensor([[0, 1],
1439
                [2, 3]])
1440
        >>> torch.atleast_3d(y)
1441
        tensor([[[0],
1442
                 [1]],
1443
                <BLANKLINE>
1444
                [[2],
1445
                 [3]]])
1446
        >>> x = torch.tensor(1).view(1, 1, 1)
1447
        >>> x
1448
        tensor([[[1]]])
1449
        >>> torch.atleast_3d(x)
1450
        tensor([[[1]]])
1451
        >>> x = torch.tensor(0.5)
1452
        >>> y = torch.tensor(1.)
1453
        >>> torch.atleast_3d((x, y))
1454
        (tensor([[[0.5000]]]), tensor([[[1.]]]))
1455
    """
1456
    # This wrapper exists to support variadic args.
1457
    if has_torch_function(tensors):
1458
        return handle_torch_function(atleast_3d, tensors, *tensors)
1459
    if len(tensors) == 1:
1460
        tensors = tensors[0]
1461
    return _VF.atleast_3d(tensors)  # type: ignore[attr-defined]
1462

1463

1464
if TYPE_CHECKING:
1465
    pass
1466
    # There's no good way to use this type annotation; cannot rename norm() to
1467
    # _norm_impl() in a way that doesn't break JIT overloads. So leave untyped
1468
    # for mypy for now.
1469
    #    def norm(input: Tensor,
1470
    #             p: Optional[Union[str, Number]] = "fro",
1471
    #             dim: Optional[Union[int, List[int]]] = None,
1472
    #             keepdim: bool = False,
1473
    #             out: Optional[Tensor] = None,
1474
    #             dtype: _dtype = None) -> Tensor:
1475
    #        return _norm_impl(input, p, dim, keepdim, out, dtype)
1476
else:
1477
    # TODO: type dim as BroadcastingList when
1478
    # https://github.com/pytorch/pytorch/issues/33782 is fixed
1479
    @overload
1480
    def norm(input, p="fro", dim=None, keepdim=False, out=None, dtype=None):
1481
        # type: (Tensor, str, Optional[List[int]], bool, Optional[Tensor], Optional[int]) -> Tensor
1482
        pass
1483

1484
    @overload  # noqa: F811
1485
    def norm(input, p="fro", dim=None, keepdim=False, out=None, dtype=None):  # noqa: F811
1486
        # type: (Tensor, Optional[number], Optional[List[int]], bool, Optional[Tensor], Optional[int]) -> Tensor
1487
        pass
1488

1489
    @overload  # noqa: F811
1490
    def norm(input, p="fro", dim=None, keepdim=False, out=None, dtype=None):  # noqa: F811
1491
        # type: (Tensor, Optional[number], Optional[int], bool, Optional[Tensor], Optional[int]) -> Tensor
1492
        pass
1493

1494
    @overload  # noqa: F811
1495
    def norm(input, p="fro", dim=None, keepdim=False, out=None, dtype=None):  # noqa: F811
1496
        # type: (Tensor, str, Optional[int], bool, Optional[Tensor], Optional[int]) -> Tensor
1497
        pass
1498

1499

1500
def norm(input, p: Optional[Union[float, str]] = "fro", dim=None, keepdim=False, out=None, dtype=None):  # noqa: F811
1501
    r"""Returns the matrix norm or vector norm of a given tensor.
1502

1503
    .. warning::
1504

1505
        torch.norm is deprecated and may be removed in a future PyTorch release.
1506
        Its documentation and behavior may be incorrect, and it is no longer
1507
        actively maintained.
1508

1509
        Use :func:`torch.linalg.vector_norm` when computing vector norms and
1510
        :func:`torch.linalg.matrix_norm` when computing matrix norms.
1511
        For a function with a similar behavior as this one see :func:`torch.linalg.norm`.
1512
        Note, however, the signature for these functions is slightly different than the
1513
        signature for ``torch.norm``.
1514

1515
    Args:
1516
        input (Tensor): The input tensor. Its data type must be either a floating
1517
            point or complex type. For complex inputs, the norm is calculated using the
1518
            absolute value of each element. If the input is complex and neither
1519
            :attr:`dtype` nor :attr:`out` is specified, the result's data type will
1520
            be the corresponding floating point type (e.g. float if :attr:`input` is
1521
            complexfloat).
1522

1523
        p (int, float, inf, -inf, 'fro', 'nuc', optional): the order of norm. Default: ``'fro'``
1524
            The following norms can be calculated:
1525

1526
            ======  ==============  ==========================
1527
            ord     matrix norm     vector norm
1528
            ======  ==============  ==========================
1529
            'fro'   Frobenius norm  --
1530
            'nuc'   nuclear norm    --
1531
            Number  --              sum(abs(x)**ord)**(1./ord)
1532
            ======  ==============  ==========================
1533

1534
            The vector norm can be calculated across any number of dimensions.
1535
            The corresponding dimensions of :attr:`input` are flattened into
1536
            one dimension, and the norm is calculated on the flattened
1537
            dimension.
1538

1539
            Frobenius norm produces the same result as ``p=2`` in all cases
1540
            except when :attr:`dim` is a list of three or more dims, in which
1541
            case Frobenius norm throws an error.
1542

1543
            Nuclear norm can only be calculated across exactly two dimensions.
1544

1545
        dim (int, tuple of ints, list of ints, optional):
1546
            Specifies which dimension or dimensions of :attr:`input` to
1547
            calculate the norm across. If :attr:`dim` is ``None``, the norm will
1548
            be calculated across all dimensions of :attr:`input`. If the norm
1549
            type indicated by :attr:`p` does not support the specified number of
1550
            dimensions, an error will occur.
1551
        keepdim (bool, optional): whether the output tensors have :attr:`dim`
1552
            retained or not. Ignored if :attr:`dim` = ``None`` and
1553
            :attr:`out` = ``None``. Default: ``False``
1554
        out (Tensor, optional): the output tensor. Ignored if
1555
            :attr:`dim` = ``None`` and :attr:`out` = ``None``.
1556
        dtype (:class:`torch.dtype`, optional): the desired data type of
1557
            returned tensor. If specified, the input tensor is casted to
1558
            :attr:`dtype` while performing the operation. Default: None.
1559

1560
    .. note::
1561
        Even though ``p='fro'`` supports any number of dimensions, the true
1562
        mathematical definition of Frobenius norm only applies to tensors with
1563
        exactly two dimensions. :func:`torch.linalg.matrix_norm` with ``ord='fro'``
1564
        aligns with the mathematical definition, since it can only be applied across
1565
        exactly two dimensions.
1566

1567
    Example::
1568

1569
        >>> import torch
1570
        >>> a = torch.arange(9, dtype= torch.float) - 4
1571
        >>> b = a.reshape((3, 3))
1572
        >>> torch.norm(a)
1573
        tensor(7.7460)
1574
        >>> torch.norm(b)
1575
        tensor(7.7460)
1576
        >>> torch.norm(a, float('inf'))
1577
        tensor(4.)
1578
        >>> torch.norm(b, float('inf'))
1579
        tensor(4.)
1580
        >>> c = torch.tensor([[ 1, 2, 3], [-1, 1, 4]] , dtype=torch.float)
1581
        >>> torch.norm(c, dim=0)
1582
        tensor([1.4142, 2.2361, 5.0000])
1583
        >>> torch.norm(c, dim=1)
1584
        tensor([3.7417, 4.2426])
1585
        >>> torch.norm(c, p=1, dim=1)
1586
        tensor([6., 6.])
1587
        >>> d = torch.arange(8, dtype=torch.float).reshape(2, 2, 2)
1588
        >>> torch.norm(d, dim=(1, 2))
1589
        tensor([ 3.7417, 11.2250])
1590
        >>> torch.norm(d[0, :, :]), torch.norm(d[1, :, :])
1591
        (tensor(3.7417), tensor(11.2250))
1592
    """
1593

1594
    if has_torch_function_unary(input):
1595
        return handle_torch_function(
1596
            norm, (input,), input, p=p, dim=dim, keepdim=keepdim, out=out, dtype=dtype)
1597

1598
    # NB. All the repeated code and weird python is to please TorchScript.
1599
    #     For a more compact implementation see the relevant function in `_refs/__init__.py`
1600

1601
    # We don't do this for MPS or sparse tensors
1602
    if input.layout == torch.strided and input.device.type in \
1603
            ("cpu", "cuda", "meta", torch.utils.backend_registration._privateuse1_backend_name):
1604
        if dim is not None:
1605
            if isinstance(dim, (int, torch.SymInt)):
1606
                _dim = [dim]
1607
            else:
1608
                _dim = dim
1609
        else:
1610
            _dim = None  # type: ignore[assignment]
1611

1612
        if isinstance(p, str):
1613
            if p == "fro" and (dim is None or isinstance(dim, (int, torch.SymInt)) or len(dim) <= 2):
1614
                if out is None:
1615
                    return torch.linalg.vector_norm(input, 2, _dim, keepdim, dtype=dtype)
1616
                else:
1617
                    return torch.linalg.vector_norm(input, 2, _dim, keepdim, dtype=dtype, out=out)
1618

1619
            # Here we either call the nuclear norm, or we call matrix_norm with some arguments
1620
            # that will throw an error
1621
            if _dim is None:
1622
                _dim = list(range(input.ndim))
1623
            if out is None:
1624
                return torch.linalg.matrix_norm(input, p, _dim, keepdim, dtype=dtype)
1625
            else:
1626
                return torch.linalg.matrix_norm(input, p, _dim, keepdim, dtype=dtype, out=out)
1627
        else:
1628
            # NB. p should be Union[str, number], not Optional!
1629
            _p = 2.0 if p is None else p
1630
            if out is None:
1631
                return torch.linalg.vector_norm(input, _p, _dim, keepdim, dtype=dtype)
1632
            else:
1633
                return torch.linalg.vector_norm(input, _p, _dim, keepdim, dtype=dtype, out=out)
1634

1635
    ndim = input.dim()
1636

1637
    # catch default case
1638
    if dim is None and out is None and dtype is None and p is not None:
1639
        if isinstance(p, str):
1640
            if p == "fro":
1641
                return _VF.frobenius_norm(input, dim=(), keepdim=keepdim)
1642
        if not isinstance(p, str):
1643
            _dim = [i for i in range(ndim)]  # noqa: C416 TODO: rewrite as list(range(m))
1644
            return _VF.norm(input, p, dim=_dim, keepdim=keepdim)  # type: ignore[attr-defined]
1645

1646
    # TODO: when https://github.com/pytorch/pytorch/issues/33782 is fixed
1647
    # remove the overloads where dim is an int and replace with BraodcastingList1
1648
    # and remove next four lines, replace _dim with dim
1649
    if dim is not None:
1650
        if isinstance(dim, (int, torch.SymInt)):
1651
            _dim = [dim]
1652
        else:
1653
            _dim = dim
1654
    else:
1655
        _dim = None  # type: ignore[assignment]
1656

1657
    if isinstance(p, str):
1658
        if p == "fro":
1659
            if dtype is not None:
1660
                raise ValueError("dtype argument is not supported in frobenius norm")
1661

1662
            if _dim is None:
1663
                _dim = list(range(ndim))
1664
            if out is None:
1665
                return _VF.frobenius_norm(input, _dim, keepdim=keepdim)  # type: ignore[arg-type]
1666
            else:
1667
                return _VF.frobenius_norm(input, _dim, keepdim=keepdim, out=out)  # type: ignore[arg-type]
1668
        elif p == "nuc":
1669
            if dtype is not None:
1670
                raise ValueError("dtype argument is not supported in nuclear norm")
1671
            if _dim is None:
1672
                if out is None:
1673
                    return _VF.nuclear_norm(input, keepdim=keepdim)  # type: ignore[arg-type]
1674
                else:
1675
                    return _VF.nuclear_norm(input, keepdim=keepdim, out=out)  # type: ignore[arg-type]
1676
            else:
1677
                if out is None:
1678
                    return _VF.nuclear_norm(input, _dim, keepdim=keepdim)  # type: ignore[arg-type]
1679
                else:
1680
                    return _VF.nuclear_norm(input, _dim, keepdim=keepdim, out=out)  # type: ignore[arg-type]
1681
        raise RuntimeError(f"only valid string values are 'fro' and 'nuc', found {p}")
1682
    else:
1683
        if _dim is None:
1684
            _dim = list(range(ndim))
1685

1686
        if out is None:
1687
            if dtype is None:
1688
                return _VF.norm(input, p, _dim, keepdim=keepdim)  # type: ignore[attr-defined]
1689
            else:
1690
                return _VF.norm(input, p, _dim, keepdim=keepdim, dtype=dtype)  # type: ignore[attr-defined]
1691
        else:
1692
            if dtype is None:
1693
                return _VF.norm(input, p, _dim, keepdim=keepdim, out=out)  # type: ignore[attr-defined]
1694
            else:
1695
                return _VF.norm(input, p, _dim, keepdim=keepdim, dtype=dtype, out=out)  # type: ignore[attr-defined]
1696

1697
def unravel_index(indices: Tensor, shape: Union[int, Sequence[int], torch.Size]) -> Tuple[Tensor, ...]:
1698
    r"""Converts a tensor of flat indices into a tuple of coordinate tensors that
1699
    index into an arbitrary tensor of the specified shape.
1700

1701
    Args:
1702
        indices (Tensor): An integer tensor containing indices into the
1703
            flattened version of an arbitrary tensor of shape :attr:`shape`.
1704
            All elements must be in the range ``[0, prod(shape) - 1]``.
1705

1706
        shape (int, sequence of ints, or torch.Size): The shape of the arbitrary
1707
            tensor. All elements must be non-negative.
1708

1709
    Returns:
1710
        tuple of Tensors: Each ``i``-th tensor in the output corresponds with
1711
        dimension ``i`` of :attr:`shape`. Each tensor has the same shape as
1712
        ``indices`` and contains one index into dimension ``i`` for each of the
1713
        flat indices given by ``indices``.
1714

1715
    Example::
1716

1717
        >>> import torch
1718
        >>> torch.unravel_index(torch.tensor(4), (3, 2))
1719
        (tensor(2),
1720
         tensor(0))
1721

1722
        >>> torch.unravel_index(torch.tensor([4, 1]), (3, 2))
1723
        (tensor([2, 0]),
1724
         tensor([0, 1]))
1725

1726
        >>> torch.unravel_index(torch.tensor([0, 1, 2, 3, 4, 5]), (3, 2))
1727
        (tensor([0, 0, 1, 1, 2, 2]),
1728
         tensor([0, 1, 0, 1, 0, 1]))
1729

1730
        >>> torch.unravel_index(torch.tensor([1234, 5678]), (10, 10, 10, 10))
1731
        (tensor([1, 5]),
1732
         tensor([2, 6]),
1733
         tensor([3, 7]),
1734
         tensor([4, 8]))
1735

1736
        >>> torch.unravel_index(torch.tensor([[1234], [5678]]), (10, 10, 10, 10))
1737
        (tensor([[1], [5]]),
1738
         tensor([[2], [6]]),
1739
         tensor([[3], [7]]),
1740
         tensor([[4], [8]]))
1741

1742
        >>> torch.unravel_index(torch.tensor([[1234], [5678]]), (100, 100))
1743
        (tensor([[12], [56]]),
1744
         tensor([[34], [78]]))
1745
    """
1746
    if has_torch_function_unary(indices):
1747
        return handle_torch_function(
1748
            unravel_index, (indices,), indices, shape=shape)
1749
    res_tensor = _unravel_index(indices, shape)
1750
    return res_tensor.unbind(-1)
1751

1752
def _unravel_index(indices: Tensor, shape: Union[int, Sequence[int]]) -> Tensor:
1753
    torch._check_type(
1754
        not indices.is_complex() and not indices.is_floating_point() and not indices.dtype == torch.bool,
1755
        lambda: f"expected 'indices' to be integer dtype, but got {indices.dtype}")
1756

1757
    torch._check_type(
1758
        isinstance(shape, (int, torch.SymInt, Sequence)),
1759
        lambda: f"expected 'shape' to be int or sequence of ints, but got {type(shape)}")
1760

1761
    if isinstance(shape, (int, torch.SymInt)):
1762
        shape = torch.Size([shape])
1763
    else:
1764
        for dim in shape:
1765
            torch._check_type(
1766
                isinstance(dim, (int, torch.SymInt)),
1767
                lambda: f"expected 'shape' sequence to only contain ints, but got {type(dim)}")
1768
        shape = torch.Size(shape)
1769

1770
    torch._check_value(
1771
        all(dim >= 0 for dim in shape),
1772
        lambda: f"'shape' cannot have negative values, but got {tuple(shape)}")
1773

1774
    coefs = list(reversed(list(itertools.accumulate(reversed(shape[1:] + torch.Size([1])), func=operator.mul))))
1775
    return indices.unsqueeze(-1).floor_divide(
1776
        torch.tensor(coefs, device=indices.device, dtype=torch.int64)
1777
    ) % torch.tensor(shape, device=indices.device, dtype=torch.int64)
1778

1779
def chain_matmul(*matrices, out=None):
1780
    r"""Returns the matrix product of the :math:`N` 2-D tensors. This product is efficiently computed
1781
    using the matrix chain order algorithm which selects the order in which incurs the lowest cost in terms
1782
    of arithmetic operations (`[CLRS]`_). Note that since this is a function to compute the product, :math:`N`
1783
    needs to be greater than or equal to 2; if equal to 2 then a trivial matrix-matrix product is returned.
1784
    If :math:`N` is 1, then this is a no-op - the original matrix is returned as is.
1785

1786
    .. warning::
1787

1788
        :func:`torch.chain_matmul` is deprecated and will be removed in a future PyTorch release.
1789
        Use :func:`torch.linalg.multi_dot` instead, which accepts a list of two or more tensors
1790
        rather than multiple arguments.
1791

1792
    Args:
1793
        matrices (Tensors...): a sequence of 2 or more 2-D tensors whose product is to be determined.
1794
        out (Tensor, optional): the output tensor. Ignored if :attr:`out` = ``None``.
1795

1796
    Returns:
1797
        Tensor: if the :math:`i^{th}` tensor was of dimensions :math:`p_{i} \times p_{i + 1}`, then the product
1798
        would be of dimensions :math:`p_{1} \times p_{N + 1}`.
1799

1800
    Example::
1801

1802
        >>> # xdoctest: +SKIP
1803
        >>> # xdoctest: +IGNORE_WANT("non-deterministic")
1804
        >>> a = torch.randn(3, 4)
1805
        >>> b = torch.randn(4, 5)
1806
        >>> c = torch.randn(5, 6)
1807
        >>> d = torch.randn(6, 7)
1808
        >>> # will raise a deprecation warning
1809
        >>> torch.chain_matmul(a, b, c, d)
1810
        tensor([[ -2.3375,  -3.9790,  -4.1119,  -6.6577,   9.5609, -11.5095,  -3.2614],
1811
                [ 21.4038,   3.3378,  -8.4982,  -5.2457, -10.2561,  -2.4684,   2.7163],
1812
                [ -0.9647,  -5.8917,  -2.3213,  -5.2284,  12.8615, -12.2816,  -2.5095]])
1813

1814
    .. _`[CLRS]`: https://mitpress.mit.edu/books/introduction-algorithms-third-edition
1815
    """
1816
    # This wrapper exists to support variadic args.
1817
    if has_torch_function(matrices):
1818
        return handle_torch_function(chain_matmul, matrices, *matrices)
1819

1820
    if out is None:
1821
        return _VF.chain_matmul(matrices)  # type: ignore[attr-defined]
1822
    else:
1823
        return _VF.chain_matmul(matrices, out=out)  # type: ignore[attr-defined]
1824

1825

1826
def _lu_impl(A, pivot=True, get_infos=False, out=None):
1827
    # type: (Tensor, bool, bool, Any) -> Tuple[Tensor, Tensor, Tensor]
1828
    r"""Computes the LU factorization of a matrix or batches of matrices
1829
    :attr:`A`. Returns a tuple containing the LU factorization and
1830
    pivots of :attr:`A`.  Pivoting is done if :attr:`pivot` is set to
1831
    ``True``.
1832

1833
    .. warning::
1834

1835
        :func:`torch.lu` is deprecated in favor of :func:`torch.linalg.lu_factor`
1836
        and :func:`torch.linalg.lu_factor_ex`. :func:`torch.lu` will be removed in a
1837
        future PyTorch release.
1838
        ``LU, pivots, info = torch.lu(A, compute_pivots)`` should be replaced with
1839

1840
        .. code:: python
1841

1842
            LU, pivots = torch.linalg.lu_factor(A, compute_pivots)
1843

1844
        ``LU, pivots, info = torch.lu(A, compute_pivots, get_infos=True)`` should be replaced with
1845

1846
        .. code:: python
1847

1848
            LU, pivots, info = torch.linalg.lu_factor_ex(A, compute_pivots)
1849

1850
    .. note::
1851
        * The returned permutation matrix for every matrix in the batch is
1852
          represented by a 1-indexed vector of size ``min(A.shape[-2], A.shape[-1])``.
1853
          ``pivots[i] == j`` represents that in the ``i``-th step of the algorithm,
1854
          the ``i``-th row was permuted with the ``j-1``-th row.
1855
        * LU factorization with :attr:`pivot` = ``False`` is not available
1856
          for CPU, and attempting to do so will throw an error. However,
1857
          LU factorization with :attr:`pivot` = ``False`` is available for
1858
          CUDA.
1859
        * This function does not check if the factorization was successful
1860
          or not if :attr:`get_infos` is ``True`` since the status of the
1861
          factorization is present in the third element of the return tuple.
1862
        * In the case of batches of square matrices with size less or equal
1863
          to 32 on a CUDA device, the LU factorization is repeated for
1864
          singular matrices due to the bug in the MAGMA library
1865
          (see magma issue 13).
1866
        * ``L``, ``U``, and ``P`` can be derived using :func:`torch.lu_unpack`.
1867

1868
    .. warning::
1869
        The gradients of this function will only be finite when :attr:`A` is full rank.
1870
        This is because the LU decomposition is just differentiable at full rank matrices.
1871
        Furthermore, if :attr:`A` is close to not being full rank,
1872
        the gradient will be numerically unstable as it depends on the computation of :math:`L^{-1}` and :math:`U^{-1}`.
1873

1874
    Args:
1875
        A (Tensor): the tensor to factor of size :math:`(*, m, n)`
1876
        pivot (bool, optional): controls whether pivoting is done. Default: ``True``
1877
        get_infos (bool, optional): if set to ``True``, returns an info IntTensor.
1878
                                    Default: ``False``
1879
        out (tuple, optional): optional output tuple. If :attr:`get_infos` is ``True``,
1880
                               then the elements in the tuple are Tensor, IntTensor,
1881
                               and IntTensor. If :attr:`get_infos` is ``False``, then the
1882
                               elements in the tuple are Tensor, IntTensor. Default: ``None``
1883

1884
    Returns:
1885
        (Tensor, IntTensor, IntTensor (optional)): A tuple of tensors containing
1886

1887
            - **factorization** (*Tensor*): the factorization of size :math:`(*, m, n)`
1888

1889
            - **pivots** (*IntTensor*): the pivots of size :math:`(*, \text{min}(m, n))`.
1890
              ``pivots`` stores all the intermediate transpositions of rows.
1891
              The final permutation ``perm`` could be reconstructed by
1892
              applying ``swap(perm[i], perm[pivots[i] - 1])`` for ``i = 0, ..., pivots.size(-1) - 1``,
1893
              where ``perm`` is initially the identity permutation of :math:`m` elements
1894
              (essentially this is what :func:`torch.lu_unpack` is doing).
1895

1896
            - **infos** (*IntTensor*, *optional*): if :attr:`get_infos` is ``True``, this is a tensor of
1897
              size :math:`(*)` where non-zero values indicate whether factorization for the matrix or
1898
              each minibatch has succeeded or failed
1899

1900
    Example::
1901

1902
        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_LAPACK)
1903
        >>> # xdoctest: +IGNORE_WANT("non-deterministic")
1904
        >>> A = torch.randn(2, 3, 3)
1905
        >>> A_LU, pivots = torch.lu(A)
1906
        >>> A_LU
1907
        tensor([[[ 1.3506,  2.5558, -0.0816],
1908
                 [ 0.1684,  1.1551,  0.1940],
1909
                 [ 0.1193,  0.6189, -0.5497]],
1910

1911
                [[ 0.4526,  1.2526, -0.3285],
1912
                 [-0.7988,  0.7175, -0.9701],
1913
                 [ 0.2634, -0.9255, -0.3459]]])
1914
        >>> pivots
1915
        tensor([[ 3,  3,  3],
1916
                [ 3,  3,  3]], dtype=torch.int32)
1917
        >>> A_LU, pivots, info = torch.lu(A, get_infos=True)
1918
        >>> if info.nonzero().size(0) == 0:
1919
        ...     print('LU factorization succeeded for all samples!')
1920
        LU factorization succeeded for all samples!
1921
    """
1922
    # If get_infos is True, then we don't need to check for errors and vice versa
1923
    return torch._lu_with_info(A, pivot=pivot, check_errors=(not get_infos))
1924

1925
if TYPE_CHECKING:
1926
    _ListOrSeq = Sequence[Tensor]
1927
else:
1928
    _ListOrSeq = List[Tensor]
1929

1930

1931
def _check_list_size(out_len: int, get_infos: bool, out: _ListOrSeq) -> None:
1932
    get_infos_int = 1 if get_infos else 0
1933
    if out_len - get_infos_int != 2:
1934
        raise TypeError(f"expected tuple of {2 + int(get_infos)} elements but got {out_len}")
1935
    if not isinstance(out, (tuple, list)):
1936
        raise TypeError(f"argument 'out' must be tuple of Tensors, not {type(out).__name__}")
1937

1938

1939
def _lu_with_infos(A, pivot=True, get_infos=False, out=None):
1940
    # type: (Tensor, bool, bool, Optional[Tuple[Tensor, Tensor, Tensor]]) -> Tuple[Tensor, Tensor, Tensor]
1941
    if has_torch_function_unary(A):
1942
        return handle_torch_function(
1943
            lu, (A,), A, pivot=pivot, get_infos=get_infos, out=out)
1944
    result = _lu_impl(A, pivot, get_infos, out)
1945
    if out is not None:
1946
        _check_list_size(len(out), get_infos, out)
1947
        for i in range(len(out)):
1948
            out[i].resize_as_(result[i]).copy_(result[i])
1949
        return out
1950
    else:
1951
        return result  # A_LU, pivots, infos
1952

1953

1954
def _lu_no_infos(A, pivot=True, get_infos=False, out=None):
1955
    # type: (Tensor, bool, bool, Optional[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, Tensor]
1956
    # need to check for torch_function here so that we exit if
1957
    if has_torch_function_unary(A):
1958
        return handle_torch_function(
1959
            lu, (A,), A, pivot=pivot, get_infos=get_infos, out=out)
1960
    result = _lu_impl(A, pivot, get_infos, out)
1961
    if out is not None:
1962
        _check_list_size(len(out), get_infos, out)
1963
        for i in range(len(out)):
1964
            out[i].resize_as_(result[i]).copy_(result[i])
1965
        return out
1966
    else:
1967
        return result[0], result[1]  # A_LU, pivots
1968

1969
# The return type of lu depends on `get_infos`, so in order to resolve the output type
1970
# of lu in TorchScript we need to statically know the value of `get_infos`
1971
lu = boolean_dispatch(
1972
    arg_name='get_infos',
1973
    arg_index=2,
1974
    default=False,
1975
    if_true=_lu_with_infos,
1976
    if_false=_lu_no_infos,
1977
    module_name=__name__,
1978
    func_name='lu')
1979
lu.__doc__ = _lu_impl.__doc__
1980

1981

1982
def align_tensors(*tensors):
1983
    raise RuntimeError('`align_tensors` not yet implemented.')
1984

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.