2
List, Tuple, Optional, Union, Any, Sequence, TYPE_CHECKING
8
from torch._C import _add_docstr
9
import torch.nn.functional as F
10
from ._lowrank import svd_lowrank, pca_lowrank
11
from .overrides import (
12
has_torch_function, has_torch_function_unary, has_torch_function_variadic,
13
handle_torch_function)
14
from ._jit_internal import boolean_dispatch
15
from ._jit_internal import _overload as overload
47
def broadcast_tensors(*tensors):
48
r"""broadcast_tensors(*tensors) -> List of Tensors
50
Broadcasts the given tensors according to :ref:`broadcasting-semantics`.
53
*tensors: any number of tensors of the same type
57
More than one element of a broadcasted tensor may refer to a single
58
memory location. As a result, in-place operations (especially ones that
59
are vectorized) may result in incorrect behavior. If you need to write
60
to the tensors, please clone them first.
64
>>> x = torch.arange(3).view(1, 3)
65
>>> y = torch.arange(2).view(2, 1)
66
>>> a, b = torch.broadcast_tensors(x, y)
74
if has_torch_function(tensors):
75
return handle_torch_function(broadcast_tensors, tensors, *tensors)
76
return _VF.broadcast_tensors(tensors)
79
def broadcast_shapes(*shapes):
80
r"""broadcast_shapes(*shapes) -> Size
82
Similar to :func:`broadcast_tensors` but for shapes.
85
``torch.broadcast_tensors(*map(torch.empty, shapes))[0].shape``
86
but avoids the need create to intermediate tensors. This is useful for
87
broadcasting tensors of common batch shape but different rightmost shape,
88
e.g. to broadcast mean vectors with covariance matrices.
92
>>> torch.broadcast_shapes((2,), (3, 1), (1, 1, 1))
96
\*shapes (torch.Size): Shapes of tensors.
99
shape (torch.Size): A shape compatible with all input shapes.
102
RuntimeError: If shapes are incompatible.
106
if not torch.jit.is_tracing():
109
if isinstance(shape, (int, torch.SymInt)):
112
elif isinstance(shape, (tuple, list)):
116
result = [1] * max_len
118
from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
121
if isinstance(shape, (int, torch.SymInt)):
123
if isinstance(shape, (tuple, list)):
124
for i in range(-1, -1 - len(shape), -1):
126
raise RuntimeError(f"Trying to create tensor with negative dimension ({shape[i]}): ({shape[i]})")
129
if guard_size_oblivious(shape[i] == 1) or guard_size_oblivious(shape[i] == result[i]):
132
raise RuntimeError("Shape mismatch: objects cannot be broadcast to a single shape")
135
raise RuntimeError("Input shapes should be of type ints, a tuple of ints, or a list of ints, got ", shape)
136
return torch.Size(result)
139
with torch.no_grad():
140
scalar = torch.zeros((), device="cpu")
141
tensors = [scalar.expand(shape) for shape in shapes]
142
tensors = broadcast_tensors(*tensors)
143
return tensors[0].shape
147
tensor: Tensor, split_size_or_sections: Union[int, List[int]], dim: int = 0
148
) -> Tuple[Tensor, ...]:
149
r"""Splits the tensor into chunks. Each chunk is a view of the original tensor.
151
If :attr:`split_size_or_sections` is an integer type, then :attr:`tensor` will
152
be split into equally sized chunks (if possible). Last chunk will be smaller if
153
the tensor size along the given dimension :attr:`dim` is not divisible by
156
If :attr:`split_size_or_sections` is a list, then :attr:`tensor` will be split
157
into ``len(split_size_or_sections)`` chunks with sizes in :attr:`dim` according
158
to :attr:`split_size_or_sections`.
161
tensor (Tensor): tensor to split.
162
split_size_or_sections (int) or (list(int)): size of a single chunk or
163
list of sizes for each chunk
164
dim (int): dimension along which to split the tensor.
168
>>> a = torch.arange(10).reshape(5, 2)
175
>>> torch.split(a, 2)
181
>>> torch.split(a, [1, 4])
188
if has_torch_function_unary(tensor):
189
return handle_torch_function(
190
split, (tensor,), tensor, split_size_or_sections, dim=dim)
195
return tensor.split(split_size_or_sections, dim)
198
def einsum(*args: Any) -> Tensor:
199
r"""einsum(equation, *operands) -> Tensor
201
Sums the product of the elements of the input :attr:`operands` along dimensions specified using a notation
202
based on the Einstein summation convention.
204
Einsum allows computing many common multi-dimensional linear algebraic array operations by representing them
205
in a short-hand format based on the Einstein summation convention, given by :attr:`equation`. The details of
206
this format are described below, but the general idea is to label every dimension of the input :attr:`operands`
207
with some subscript and define which subscripts are part of the output. The output is then computed by summing
208
the product of the elements of the :attr:`operands` along the dimensions whose subscripts are not part of the
209
output. For example, matrix multiplication can be computed using einsum as `torch.einsum("ij,jk->ik", A, B)`.
210
Here, j is the summation subscript and i and k the output subscripts (see section below for more details on why).
214
The :attr:`equation` string specifies the subscripts (letters in `[a-zA-Z]`) for each dimension of
215
the input :attr:`operands` in the same order as the dimensions, separating subscripts for each operand by a
216
comma (','), e.g. `'ij,jk'` specify subscripts for two 2D operands. The dimensions labeled with the same subscript
217
must be broadcastable, that is, their size must either match or be `1`. The exception is if a subscript is
218
repeated for the same input operand, in which case the dimensions labeled with this subscript for this operand
219
must match in size and the operand will be replaced by its diagonal along these dimensions. The subscripts that
220
appear exactly once in the :attr:`equation` will be part of the output, sorted in increasing alphabetical order.
221
The output is computed by multiplying the input :attr:`operands` element-wise, with their dimensions aligned based
222
on the subscripts, and then summing out the dimensions whose subscripts are not part of the output.
224
Optionally, the output subscripts can be explicitly defined by adding an arrow ('->') at the end of the equation
225
followed by the subscripts for the output. For instance, the following equation computes the transpose of a
226
matrix multiplication: 'ij,jk->ki'. The output subscripts must appear at least once for some input operand and
227
at most once for the output.
229
Ellipsis ('...') can be used in place of subscripts to broadcast the dimensions covered by the ellipsis.
230
Each input operand may contain at most one ellipsis which will cover the dimensions not covered by subscripts,
231
e.g. for an input operand with 5 dimensions, the ellipsis in the equation `'ab...c'` cover the third and fourth
232
dimensions. The ellipsis does not need to cover the same number of dimensions across the :attr:`operands` but the
233
'shape' of the ellipsis (the size of the dimensions covered by them) must broadcast together. If the output is not
234
explicitly defined with the arrow ('->') notation, the ellipsis will come first in the output (left-most dimensions),
235
before the subscript labels that appear exactly once for the input operands. e.g. the following equation implements
236
batch matrix multiplication `'...ij,...jk'`.
238
A few final notes: the equation may contain whitespaces between the different elements (subscripts, ellipsis,
239
arrow and comma) but something like `'. . .'` is not valid. An empty string `''` is valid for scalar operands.
243
``torch.einsum`` handles ellipsis ('...') differently from NumPy in that it allows dimensions
244
covered by the ellipsis to be summed over, that is, ellipsis are not required to be part of the output.
248
This function uses opt_einsum (https://optimized-einsum.readthedocs.io/en/stable/) to speed up computation or to
249
consume less memory by optimizing contraction order. This optimization occurs when there are at least three
250
inputs, since the order does not matter otherwise. Note that finding _the_ optimal path is an NP-hard problem,
251
thus, opt_einsum relies on different heuristics to achieve near-optimal results. If opt_einsum is not available,
252
the default order is to contract from left to right.
254
To bypass this default behavior, add the following line to disable the usage of opt_einsum and skip path
255
calculation: `torch.backends.opt_einsum.enabled = False`
257
To specify which strategy you'd like for opt_einsum to compute the contraction path, add the following line:
258
`torch.backends.opt_einsum.strategy = 'auto'`. The default strategy is 'auto', and we also support 'greedy' and
259
'optimal'. Disclaimer that the runtime of 'optimal' is factorial in the number of inputs! See more details in
260
the opt_einsum documentation (https://optimized-einsum.readthedocs.io/en/stable/path_finding.html).
264
As of PyTorch 1.10 :func:`torch.einsum` also supports the sublist format (see examples below). In this format,
265
subscripts for each operand are specified by sublists, list of integers in the range [0, 52). These sublists
266
follow their operands, and an extra sublist can appear at the end of the input to specify the output's
267
subscripts., e.g. `torch.einsum(op1, sublist1, op2, sublist2, ..., [subslist_out])`. Python's `Ellipsis` object
268
may be provided in a sublist to enable broadcasting as described in the Equation section above.
271
equation (str): The subscripts for the Einstein summation.
272
operands (List[Tensor]): The tensors to compute the Einstein summation of.
276
>>> # xdoctest: +IGNORE_WANT("non-deterministic")
278
>>> torch.einsum('ii', torch.randn(4, 4))
281
>>> # xdoctest: +IGNORE_WANT("non-deterministic")
283
>>> torch.einsum('ii->i', torch.randn(4, 4))
284
tensor([-0.1034, 0.7952, -0.2433, 0.4545])
286
>>> # xdoctest: +IGNORE_WANT("non-deterministic")
288
>>> x = torch.randn(5)
289
>>> y = torch.randn(4)
290
>>> torch.einsum('i,j->ij', x, y)
291
tensor([[ 0.1156, -0.2897, -0.3918, 0.4963],
292
[-0.3744, 0.9381, 1.2685, -1.6070],
293
[ 0.7208, -1.8058, -2.4419, 3.0936],
294
[ 0.1713, -0.4291, -0.5802, 0.7350],
295
[ 0.5704, -1.4290, -1.9323, 2.4480]])
297
>>> # xdoctest: +IGNORE_WANT("non-deterministic")
298
>>> # batch matrix multiplication
299
>>> As = torch.randn(3, 2, 5)
300
>>> Bs = torch.randn(3, 5, 4)
301
>>> torch.einsum('bij,bjk->bik', As, Bs)
302
tensor([[[-1.0564, -1.5904, 3.2023, 3.1271],
303
[-1.6706, -0.8097, -0.8025, -2.1183]],
305
[[ 4.2239, 0.3107, -0.5756, -0.2354],
306
[-1.4558, -0.3460, 1.5087, -0.8530]],
308
[[ 2.8153, 1.8787, -4.3839, -1.2112],
309
[ 0.3728, -2.1131, 0.0921, 0.8305]]])
311
>>> # xdoctest: +IGNORE_WANT("non-deterministic")
312
>>> # with sublist format and ellipsis
313
>>> torch.einsum(As, [..., 0, 1], Bs, [..., 1, 2], [..., 0, 2])
314
tensor([[[-1.0564, -1.5904, 3.2023, 3.1271],
315
[-1.6706, -0.8097, -0.8025, -2.1183]],
317
[[ 4.2239, 0.3107, -0.5756, -0.2354],
318
[-1.4558, -0.3460, 1.5087, -0.8530]],
320
[[ 2.8153, 1.8787, -4.3839, -1.2112],
321
[ 0.3728, -2.1131, 0.0921, 0.8305]]])
324
>>> A = torch.randn(2, 3, 4, 5)
325
>>> torch.einsum('...ij->...ji', A).shape
326
torch.Size([2, 3, 5, 4])
328
>>> # equivalent to torch.nn.functional.bilinear
329
>>> A = torch.randn(3, 5, 4)
330
>>> l = torch.randn(2, 5)
331
>>> r = torch.randn(2, 4)
332
>>> torch.einsum('bn,anm,bm->ba', l, A, r)
333
tensor([[-0.3430, -5.2405, 0.4494],
334
[ 0.3311, 5.5201, -3.0356]])
336
import torch.backends.opt_einsum as opt_einsum
339
raise ValueError('einsum(): must specify the equation string and at least one operand, '
340
'or at least one operand and its subscripts list')
345
if isinstance(args[0], torch.Tensor):
350
def parse_subscript(n: int) -> str:
353
if n >= 0 and n < 26:
354
return chr(ord('A') + n)
355
if n >= 26 and n < 52:
356
return chr(ord('a') + n - 26)
357
raise ValueError('einsum(): subscript in subscript list is not within the valid range [0, 52)')
360
equation = ','.join(''.join(parse_subscript(s) for s in l) for l in args[1::2])
363
if len(args) % 2 == 1:
364
equation += '->' + ''.join(parse_subscript(s) for s in args[-1])
365
operands = args[:-1:2]
372
if has_torch_function(operands):
373
return handle_torch_function(einsum, operands, equation, *operands)
375
if len(operands) == 1 and isinstance(operands[0], (list, tuple)):
377
_operands = operands[0]
380
return einsum(equation, *_operands)
382
if len(operands) <= 2 or not opt_einsum.enabled:
385
return _VF.einsum(equation, operands)
388
if opt_einsum.is_available():
389
_opt_einsum = opt_einsum.get_opt_einsum()
390
tupled_path = _opt_einsum.contract_path(equation, *operands, optimize=opt_einsum.strategy)[0]
392
path = [item for pair in tupled_path for item in pair]
393
return _VF.einsum(equation, operands, path=path)
399
def meshgrid(*tensors: Union[Tensor, List[Tensor]],
400
indexing: Optional[str] = None) -> Tuple[Tensor, ...]:
401
return _meshgrid(*tensors, indexing=indexing)
403
def meshgrid(*tensors, indexing: Optional[str] = None) -> Tuple[Tensor, ...]:
404
r"""Creates grids of coordinates specified by the 1D inputs in `attr`:tensors.
406
This is helpful when you want to visualize data over some
407
range of inputs. See below for a plotting example.
409
Given :math:`N` 1D tensors :math:`T_0 \ldots T_{N-1}` as
410
inputs with corresponding sizes :math:`S_0 \ldots S_{N-1}`,
411
this creates :math:`N` N-dimensional tensors :math:`G_0 \ldots
412
G_{N-1}`, each with shape :math:`(S_0, ..., S_{N-1})` where
413
the output :math:`G_i` is constructed by expanding :math:`T_i`
417
0D inputs are treated equivalently to 1D inputs of a
421
`torch.meshgrid(*tensors)` currently has the same behavior
422
as calling `numpy.meshgrid(*arrays, indexing='ij')`.
424
In the future `torch.meshgrid` will transition to
425
`indexing='xy'` as the default.
427
https://github.com/pytorch/pytorch/issues/50276 tracks
428
this issue with the goal of migrating to NumPy's behavior.
432
:func:`torch.cartesian_prod` has the same effect but it
433
collects the data in a tensor of vectors.
436
tensors (list of Tensor): list of scalars or 1 dimensional tensors. Scalars will be
437
treated as tensors of size :math:`(1,)` automatically
439
indexing: (str, optional): the indexing mode, either "xy"
440
or "ij", defaults to "ij". See warning for future changes.
442
If "xy" is selected, the first dimension corresponds
443
to the cardinality of the second input and the second
444
dimension corresponds to the cardinality of the first
447
If "ij" is selected, the dimensions are in the same
448
order as the cardinality of the inputs.
451
seq (sequence of Tensors): If the input has :math:`N`
452
tensors of size :math:`S_0 \ldots S_{N-1}``, then the
453
output will also have :math:`N` tensors, where each tensor
454
is of shape :math:`(S_0, ..., S_{N-1})`.
458
>>> x = torch.tensor([1, 2, 3])
459
>>> y = torch.tensor([4, 5, 6])
461
Observe the element-wise pairings across the grid, (1, 4),
462
(1, 5), ..., (3, 6). This is the same thing as the
464
>>> grid_x, grid_y = torch.meshgrid(x, y, indexing='ij')
474
This correspondence can be seen when these grids are
476
>>> torch.equal(torch.cat(tuple(torch.dstack([grid_x, grid_y]))),
477
... torch.cartesian_prod(x, y))
480
`torch.meshgrid` is commonly used to produce a grid for
482
>>> # xdoctest: +REQUIRES(module:matplotlib)
483
>>> # xdoctest: +REQUIRES(env:DOCTEST_SHOW)
484
>>> import matplotlib.pyplot as plt
485
>>> xs = torch.linspace(-5, 5, steps=100)
486
>>> ys = torch.linspace(-5, 5, steps=100)
487
>>> x, y = torch.meshgrid(xs, ys, indexing='xy')
488
>>> z = torch.sin(torch.sqrt(x * x + y * y))
489
>>> ax = plt.axes(projection='3d')
490
>>> ax.plot_surface(x.numpy(), y.numpy(), z.numpy())
493
.. image:: ../_static/img/meshgrid.png
497
return _meshgrid(*tensors, indexing=indexing)
500
def _meshgrid(*tensors, indexing: Optional[str]):
501
if has_torch_function(tensors):
502
return handle_torch_function(meshgrid, tensors, *tensors, indexing=indexing)
503
if len(tensors) == 1 and isinstance(tensors[0], (list, tuple)):
511
kwargs = {} if indexing is None else {'indexing': indexing}
512
return _VF.meshgrid(tensors, **kwargs)
515
def stft(input: Tensor, n_fft: int, hop_length: Optional[int] = None,
516
win_length: Optional[int] = None, window: Optional[Tensor] = None,
517
center: bool = True, pad_mode: str = 'reflect', normalized: bool = False,
518
onesided: Optional[bool] = None,
519
return_complex: Optional[bool] = None) -> Tensor:
520
r"""Short-time Fourier transform (STFT).
523
From version 1.8.0, :attr:`return_complex` must always be given
524
explicitly for real inputs and `return_complex=False` has been
525
deprecated. Strongly prefer `return_complex=True` as in a future
526
pytorch release, this function will only return complex tensors.
528
Note that :func:`torch.view_as_real` can be used to recover a real
529
tensor with an extra last dimension for real and imaginary components.
532
From version 2.1, a warning will be provided if a :attr:`window` is
533
not specified. In a future release, this attribute will be required.
534
Not providing a window currently defaults to using a rectangular window,
535
which may result in undesirable artifacts. Consider using tapered windows,
536
such as :func:`torch.hann_window`.
538
The STFT computes the Fourier transform of short overlapping windows of the
539
input. This giving frequency components of the signal as they change over
540
time. The interface of this function is modeled after (but *not* a drop-in
541
replacement for) librosa_ stft function.
543
.. _librosa: https://librosa.org/doc/latest/generated/librosa.stft.html
545
Ignoring the optional batch dimension, this method computes the following
549
X[\omega, m] = \sum_{k = 0}^{\text{win\_length-1}}%
550
\text{window}[k]\ \text{input}[m \times \text{hop\_length} + k]\ %
551
\exp\left(- j \frac{2 \pi \cdot \omega k}{\text{n\_fft}}\right),
553
where :math:`m` is the index of the sliding window, and :math:`\omega` is
554
the frequency :math:`0 \leq \omega < \text{n\_fft}` for ``onesided=False``,
555
or :math:`0 \leq \omega < \lfloor \text{n\_fft} / 2 \rfloor + 1` for ``onesided=True``.
557
* :attr:`input` must be either a 1-D time sequence or a 2-D batch of time
560
* If :attr:`hop_length` is ``None`` (default), it is treated as equal to
561
``floor(n_fft / 4)``.
563
* If :attr:`win_length` is ``None`` (default), it is treated as equal to
566
* :attr:`window` can be a 1-D tensor of size :attr:`win_length`, e.g., from
567
:meth:`torch.hann_window`. If :attr:`window` is ``None`` (default), it is
568
treated as if having :math:`1` everywhere in the window. If
569
:math:`\text{win\_length} < \text{n\_fft}`, :attr:`window` will be padded on
570
both sides to length :attr:`n_fft` before being applied.
572
* If :attr:`center` is ``True`` (default), :attr:`input` will be padded on
573
both sides so that the :math:`t`-th frame is centered at time
574
:math:`t \times \text{hop\_length}`. Otherwise, the :math:`t`-th frame
575
begins at time :math:`t \times \text{hop\_length}`.
577
* :attr:`pad_mode` determines the padding method used on :attr:`input` when
578
:attr:`center` is ``True``. See :meth:`torch.nn.functional.pad` for
579
all available options. Default is ``"reflect"``.
581
* If :attr:`onesided` is ``True`` (default for real input), only values for
582
:math:`\omega` in :math:`\left[0, 1, 2, \dots, \left\lfloor
583
\frac{\text{n\_fft}}{2} \right\rfloor + 1\right]` are returned because
584
the real-to-complex Fourier transform satisfies the conjugate symmetry,
585
i.e., :math:`X[m, \omega] = X[m, \text{n\_fft} - \omega]^*`.
586
Note if the input or window tensors are complex, then :attr:`onesided`
587
output is not possible.
589
* If :attr:`normalized` is ``True`` (default is ``False``), the function
590
returns the normalized STFT results, i.e., multiplied by :math:`(\text{frame\_length})^{-0.5}`.
592
* If :attr:`return_complex` is ``True`` (default if input is complex), the
593
return is a ``input.dim() + 1`` dimensional complex tensor. If ``False``,
594
the output is a ``input.dim() + 2`` dimensional real tensor where the last
595
dimension represents the real and imaginary components.
597
Returns either a complex tensor of size :math:`(* \times N \times T)` if
598
:attr:`return_complex` is true, or a real tensor of size :math:`(* \times N
599
\times T \times 2)`. Where :math:`*` is the optional batch size of
600
:attr:`input`, :math:`N` is the number of frequencies where STFT is applied
601
and :math:`T` is the total number of frames used.
604
This function changed signature at version 0.4.1. Calling with the
605
previous signature may cause error or return incorrect result.
608
input (Tensor): the input tensor of shape `(B?, L)` where `B?` is an optional
610
n_fft (int): size of Fourier transform
611
hop_length (int, optional): the distance between neighboring sliding window
612
frames. Default: ``None`` (treated as equal to ``floor(n_fft / 4)``)
613
win_length (int, optional): the size of window frame and STFT filter.
614
Default: ``None`` (treated as equal to :attr:`n_fft`)
615
window (Tensor, optional): the optional window function.
616
Shape must be 1d and `<= n_fft`
617
Default: ``None`` (treated as window of all :math:`1` s)
618
center (bool, optional): whether to pad :attr:`input` on both sides so
619
that the :math:`t`-th frame is centered at time :math:`t \times \text{hop\_length}`.
621
pad_mode (str, optional): controls the padding method used when
622
:attr:`center` is ``True``. Default: ``"reflect"``
623
normalized (bool, optional): controls whether to return the normalized STFT results
625
onesided (bool, optional): controls whether to return half of results to
626
avoid redundancy for real inputs.
627
Default: ``True`` for real :attr:`input` and :attr:`window`, ``False`` otherwise.
628
return_complex (bool, optional): whether to return a complex tensor, or
629
a real tensor with an extra last dimension for the real and
630
imaginary components.
632
.. versionchanged:: 2.0
633
``return_complex`` is now a required argument for real inputs,
634
as the default is being transitioned to ``True``.
637
``return_complex=False`` is deprecated, instead use ``return_complex=True``
638
Note that calling :func:`torch.view_as_real` on the output will
639
recover the deprecated output format.
642
Tensor: A tensor containing the STFT result with shape `(B?, N, T, C?)` where
643
- `B?` is an optional batch dimension from the input.
644
- `N` is the number of frequency samples, `(n_fft // 2) + 1` for
645
`onesided=True`, or otherwise `n_fft`.
646
- `T` is the number of frames, `1 + L // hop_length`
647
for `center=True`, or `1 + (L - n_fft) // hop_length` otherwise.
648
- `C?` is an optional length-2 dimension of real and imaginary
649
components, present when `return_complex=False`.
652
if has_torch_function_unary(input):
653
return handle_torch_function(
654
stft, (input,), input, n_fft, hop_length=hop_length, win_length=win_length,
655
window=window, center=center, pad_mode=pad_mode, normalized=normalized,
656
onesided=onesided, return_complex=return_complex)
660
signal_dim = input.dim()
661
extended_shape = [1] * (3 - signal_dim) + list(input.size())
662
pad = int(n_fft // 2)
663
input = F.pad(input.view(extended_shape), [pad, pad], pad_mode)
664
input = input.view(input.shape[-signal_dim:])
665
return _VF.stft(input, n_fft, hop_length, win_length, window,
666
normalized, onesided, return_complex)
671
"istft(input, n_fft, hop_length=None, win_length=None, window=None, center=True, "
672
"normalized=False, onesided=None, length=None, return_complex=False) -> Tensor:\n"
674
Inverse short time Fourier Transform. This is expected to be the inverse of :func:`~torch.stft`.
677
From version 2.1, a warning will be provided if a :attr:`window` is
678
not specified. In a future release, this attribute will be required.
679
Please provide the same window used in the stft call.
681
It has the same parameters (+ additional optional parameter of :attr:`length`) and it should return the
682
least squares estimation of the original signal. The algorithm will check using the NOLA condition (
685
Important consideration in the parameters :attr:`window` and :attr:`center` so that the envelope
686
created by the summation of all the windows is never zero at certain point in time. Specifically,
687
:math:`\sum_{t=-\infty}^{\infty} |w|^2[n-t\times hop\_length] \cancel{=} 0`.
689
Since :func:`~torch.stft` discards elements at the end of the signal if they do not fit in a frame,
690
``istft`` may return a shorter signal than the original signal (can occur if :attr:`center` is False
691
since the signal isn't padded). If `length` is given in the arguments and is longer than expected,
692
``istft`` will pad zeros to the end of the returned signal.
694
If :attr:`center` is ``True``, then there will be padding e.g. ``'constant'``, ``'reflect'``, etc.
695
Left padding can be trimmed off exactly because they can be calculated but right padding cannot be
696
calculated without additional information.
698
Example: Suppose the last window is:
699
``[17, 18, 0, 0, 0]`` vs ``[18, 0, 0, 0, 0]``
701
The :attr:`n_fft`, :attr:`hop_length`, :attr:`win_length` are all the same which prevents the calculation
702
of right padding. These additional values could be zeros or a reflection of the signal so providing
703
:attr:`length` could be useful. If :attr:`length` is ``None`` then padding will be aggressively removed
704
(some loss of signal).
706
[1] D. W. Griffin and J. S. Lim, "Signal estimation from modified short-time Fourier transform,"
707
IEEE Trans. ASSP, vol.32, no.2, pp.236-243, Apr. 1984.
710
input (Tensor): The input tensor. Expected to be in the format of :func:`~torch.stft`,
711
output. That is a complex tensor of shape `(B?, N, T)` where
713
- `B?` is an optional batch dimension
714
- `N` is the number of frequency samples, `(n_fft // 2) + 1`
715
for onesided input, or otherwise `n_fft`.
716
- `T` is the number of frames, `1 + length // hop_length` for centered stft,
717
or `1 + (length - n_fft) // hop_length` otherwise.
719
.. versionchanged:: 2.0
720
Real datatype inputs are no longer supported. Input must now have a
721
complex datatype, as returned by ``stft(..., return_complex=True)``.
722
n_fft (int): Size of Fourier transform
723
hop_length (Optional[int]): The distance between neighboring sliding window frames.
724
(Default: ``n_fft // 4``)
725
win_length (Optional[int]): The size of window frame and STFT filter. (Default: ``n_fft``)
726
window (Optional[torch.Tensor]): The optional window function.
727
Shape must be 1d and `<= n_fft`
728
(Default: ``torch.ones(win_length)``)
729
center (bool): Whether :attr:`input` was padded on both sides so that the :math:`t`-th frame is
730
centered at time :math:`t \times \text{hop\_length}`.
732
normalized (bool): Whether the STFT was normalized. (Default: ``False``)
733
onesided (Optional[bool]): Whether the STFT was onesided.
734
(Default: ``True`` if `n_fft != fft_size` in the input size)
735
length (Optional[int]): The amount to trim the signal by (i.e. the
736
original signal length). Defaults to `(T - 1) * hop_length` for
737
centered stft, or `n_fft + (T - 1) * hop_length` otherwise, where `T`
738
is the number of input frames.
739
return_complex (Optional[bool]):
740
Whether the output should be complex, or if the input should be
741
assumed to derive from a real signal and window.
742
Note that this is incompatible with ``onesided=True``.
746
Tensor: Least squares estimation of the original signal of shape `(B?, length)` where
747
`B?` is an optional batch dimension from the input tensor.
755
_unique_impl_out = Any
757
_unique_impl_out = Tuple[Tensor, Tensor, Tensor]
760
def _unique_impl(input: Tensor, sorted: bool = True,
761
return_inverse: bool = False, return_counts: bool = False,
762
dim: Optional[int] = None) -> _unique_impl_out:
763
r"""unique(input, sorted=True, return_inverse=False, return_counts=False, dim=None) -> Tuple[Tensor, Tensor, Tensor]
765
Returns the unique elements of the input tensor.
767
.. note:: This function is different from :func:`torch.unique_consecutive` in the sense that
768
this function also eliminates non-consecutive duplicate values.
770
.. note:: Currently in the CUDA implementation and the CPU implementation,
771
`torch.unique` always sort the tensor at the beginning regardless of the `sort` argument.
772
Sorting could be slow, so if your input tensor is already sorted, it is recommended to use
773
:func:`torch.unique_consecutive` which avoids the sorting.
776
input (Tensor): the input tensor
777
sorted (bool): Whether to sort the unique elements in ascending order
778
before returning as output.
779
return_inverse (bool): Whether to also return the indices for where
780
elements in the original input ended up in the returned unique list.
781
return_counts (bool): Whether to also return the counts for each unique
783
dim (int, optional): the dimension to operate upon. If ``None``, the
784
unique of the flattened input is returned. Otherwise, each of the
785
tensors indexed by the given dimension is treated as one of the
786
elements to apply the unique operation upon. See examples for more
787
details. Default: ``None``
790
(Tensor, Tensor (optional), Tensor (optional)): A tensor or a tuple of tensors containing
792
- **output** (*Tensor*): the output list of unique scalar elements.
793
- **inverse_indices** (*Tensor*): (optional) if
794
:attr:`return_inverse` is True, there will be an additional
795
returned tensor (same shape as input) representing the indices
796
for where elements in the original input map to in the output;
797
otherwise, this function will only return a single tensor.
798
- **counts** (*Tensor*): (optional) if
799
:attr:`return_counts` is True, there will be an additional
800
returned tensor (same shape as output or output.size(dim),
801
if dim was specified) representing the number of occurrences
802
for each unique value or tensor.
806
>>> output = torch.unique(torch.tensor([1, 3, 2, 3], dtype=torch.long))
810
>>> output, inverse_indices = torch.unique(
811
... torch.tensor([1, 3, 2, 3], dtype=torch.long), sorted=True, return_inverse=True)
817
>>> output, inverse_indices = torch.unique(
818
... torch.tensor([[1, 3], [2, 3]], dtype=torch.long), sorted=True, return_inverse=True)
825
>>> a = torch.tensor([
843
>>> # If we call `torch.unique(a, dim=0)`, each of the tensors `a[idx, :, :]`
844
>>> # will be compared. We can see that `a[0, :, :]` and `a[2, :, :]` match
845
>>> # each other, so one of them will be removed.
846
>>> (a[0, :, :] == a[2, :, :]).all()
848
>>> a_unique_dim0 = torch.unique(a, dim=0)
850
tensor([[[0, 0, 1, 1],
857
>>> # Notice which sub-tensors from `a` match with the sub-tensors from
858
>>> # `a_unique_dim0`:
859
>>> (a_unique_dim0[0, :, :] == a[1, :, :]).all()
861
>>> (a_unique_dim0[1, :, :] == a[0, :, :]).all()
864
>>> # For `torch.unique(a, dim=1)`, each of the tensors `a[:, idx, :]` are
865
>>> # compared. `a[:, 0, :]` and `a[:, 1, :]` match each other, so one of
866
>>> # them will be removed.
867
>>> (a[:, 0, :] == a[:, 1, :]).all()
869
>>> torch.unique(a, dim=1)
870
tensor([[[0, 0, 1, 1],
877
>>> # For `torch.unique(a, dim=2)`, the tensors `a[:, :, idx]` are compared.
878
>>> # `a[:, :, 0]` and `a[:, :, 1]` match each other. Also, `a[:, :, 2]` and
879
>>> # `a[:, :, 3]` match each other as well. So in this case, two of the
880
>>> # sub-tensors will be removed.
881
>>> (a[:, :, 0] == a[:, :, 1]).all()
883
>>> (a[:, :, 2] == a[:, :, 3]).all()
885
>>> torch.unique(a, dim=2)
896
if has_torch_function_unary(input):
897
return handle_torch_function(
898
unique, (input,), input, sorted=sorted, return_inverse=return_inverse,
899
return_counts=return_counts, dim=dim)
902
output, inverse_indices, counts = _VF.unique_dim(
906
return_inverse=return_inverse,
907
return_counts=return_counts,
910
output, inverse_indices, counts = torch._unique2(
913
return_inverse=return_inverse,
914
return_counts=return_counts,
916
return output, inverse_indices, counts
919
def _unique_consecutive_impl(input: Tensor, return_inverse: bool = False,
920
return_counts: bool = False,
921
dim: Optional[int] = None) -> _unique_impl_out:
922
r"""Eliminates all but the first element from every consecutive group of equivalent elements.
924
.. note:: This function is different from :func:`torch.unique` in the sense that this function
925
only eliminates consecutive duplicate values. This semantics is similar to `std::unique`
929
input (Tensor): the input tensor
930
return_inverse (bool): Whether to also return the indices for where
931
elements in the original input ended up in the returned unique list.
932
return_counts (bool): Whether to also return the counts for each unique
934
dim (int): the dimension to apply unique. If ``None``, the unique of the
935
flattened input is returned. default: ``None``
938
(Tensor, Tensor (optional), Tensor (optional)): A tensor or a tuple of tensors containing
940
- **output** (*Tensor*): the output list of unique scalar elements.
941
- **inverse_indices** (*Tensor*): (optional) if
942
:attr:`return_inverse` is True, there will be an additional
943
returned tensor (same shape as input) representing the indices
944
for where elements in the original input map to in the output;
945
otherwise, this function will only return a single tensor.
946
- **counts** (*Tensor*): (optional) if
947
:attr:`return_counts` is True, there will be an additional
948
returned tensor (same shape as output or output.size(dim),
949
if dim was specified) representing the number of occurrences
950
for each unique value or tensor.
954
>>> x = torch.tensor([1, 1, 2, 2, 3, 1, 1, 2])
955
>>> output = torch.unique_consecutive(x)
957
tensor([1, 2, 3, 1, 2])
959
>>> output, inverse_indices = torch.unique_consecutive(x, return_inverse=True)
961
tensor([1, 2, 3, 1, 2])
963
tensor([0, 0, 1, 1, 2, 3, 3, 4])
965
>>> output, counts = torch.unique_consecutive(x, return_counts=True)
967
tensor([1, 2, 3, 1, 2])
969
tensor([2, 2, 1, 2, 1])
971
if has_torch_function_unary(input):
972
return handle_torch_function(
973
unique_consecutive, (input,), input, return_inverse=return_inverse,
974
return_counts=return_counts, dim=dim)
975
output, inverse_indices, counts = _VF.unique_consecutive(
976
input, return_inverse=return_inverse, return_counts=return_counts, dim=dim)
977
return output, inverse_indices, counts
980
def _return_counts(input, sorted=True, return_inverse=False, return_counts=False, dim=None):
983
if has_torch_function_unary(input):
984
return _unique_impl(input, sorted, return_inverse, return_counts, dim)
986
output, _, counts = _unique_impl(input, sorted, return_inverse, return_counts, dim)
987
return output, counts
990
def _return_output(input, sorted=True, return_inverse=False, return_counts=False, dim=None):
993
if has_torch_function_unary(input):
994
return _unique_impl(input, sorted, return_inverse, return_counts, dim)
996
output, _, _ = _unique_impl(input, sorted, return_inverse, return_counts, dim)
1000
def _return_inverse(input, sorted=True, return_inverse=False, return_counts=False, dim=None):
1003
if has_torch_function_unary(input):
1004
return _unique_impl(input, sorted, return_inverse, return_counts, dim)
1006
output, inverse_indices, _ = _unique_impl(input, sorted, return_inverse, return_counts, dim)
1007
return output, inverse_indices
1010
_return_inverse_false = boolean_dispatch(
1011
arg_name='return_counts',
1014
if_true=_return_counts,
1015
if_false=_return_output,
1016
module_name=__name__,
1019
_return_inverse_true = boolean_dispatch(
1020
arg_name='return_counts',
1023
if_true=_unique_impl,
1024
if_false=_return_inverse,
1025
module_name=__name__,
1031
unique = boolean_dispatch(
1032
arg_name='return_inverse',
1035
if_true=_return_inverse_true,
1036
if_false=_return_inverse_false,
1037
module_name=__name__,
1039
unique.__doc__ = _unique_impl.__doc__
1042
def _consecutive_return_counts(input, return_inverse=False, return_counts=False, dim=None):
1045
if has_torch_function_unary(input):
1046
return _unique_consecutive_impl(input, return_inverse, return_counts, dim)
1048
output, _, counts = _unique_consecutive_impl(input, return_inverse, return_counts, dim)
1049
return output, counts
1052
def _consecutive_return_output(input, return_inverse=False, return_counts=False, dim=None):
1055
if has_torch_function_unary(input):
1056
return _unique_consecutive_impl(input, return_inverse, return_counts, dim)
1058
output, _, _ = _unique_consecutive_impl(input, return_inverse, return_counts, dim)
1062
def _consecutive_return_inverse(input, return_inverse=False, return_counts=False, dim=None):
1065
if has_torch_function_unary(input):
1066
return _unique_consecutive_impl(input, return_inverse, return_counts, dim)
1068
output, inverse_indices, _ = _unique_consecutive_impl(input, return_inverse, return_counts, dim)
1069
return output, inverse_indices
1072
_consecutive_return_inverse_false = boolean_dispatch(
1073
arg_name='return_counts',
1076
if_true=_consecutive_return_counts,
1077
if_false=_consecutive_return_output,
1078
module_name=__name__,
1079
func_name='unique_consecutive')
1081
_consecutive_return_inverse_true = boolean_dispatch(
1082
arg_name='return_counts',
1085
if_true=_unique_consecutive_impl,
1086
if_false=_consecutive_return_inverse,
1087
module_name=__name__,
1088
func_name='unique_consecutive')
1093
unique_consecutive = boolean_dispatch(
1094
arg_name='return_inverse',
1097
if_true=_consecutive_return_inverse_true,
1098
if_false=_consecutive_return_inverse_false,
1099
module_name=__name__,
1100
func_name='unique_consecutive')
1101
unique_consecutive.__doc__ = _unique_consecutive_impl.__doc__
1109
def tensordot(a, b, dims: int = 2, out: Optional[torch.Tensor] = None):
1113
def tensordot(a, b, dims: Tuple[List[int], List[int]], out: Optional[torch.Tensor] = None):
1117
def tensordot(a, b, dims: List[List[int]], out: Optional[torch.Tensor] = None):
1121
def tensordot(a, b, dims: torch.Tensor, out: Optional[torch.Tensor] = None):
1125
def tensordot(a, b, dims=2, out: Optional[torch.Tensor] = None):
1126
r"""Returns a contraction of a and b over multiple dimensions.
1128
:attr:`tensordot` implements a generalized matrix product.
1131
a (Tensor): Left tensor to contract
1132
b (Tensor): Right tensor to contract
1133
dims (int or Tuple[List[int], List[int]] or List[List[int]] containing two lists or Tensor): number of dimensions to
1134
contract or explicit lists of dimensions for :attr:`a` and
1135
:attr:`b` respectively
1137
When called with a non-negative integer argument :attr:`dims` = :math:`d`, and
1138
the number of dimensions of :attr:`a` and :attr:`b` is :math:`m` and :math:`n`,
1139
respectively, :func:`~torch.tensordot` computes
1142
r_{i_0,...,i_{m-d}, i_d,...,i_n}
1143
= \sum_{k_0,...,k_{d-1}} a_{i_0,...,i_{m-d},k_0,...,k_{d-1}} \times b_{k_0,...,k_{d-1}, i_d,...,i_n}.
1145
When called with :attr:`dims` of the list form, the given dimensions will be contracted
1146
in place of the last :math:`d` of :attr:`a` and the first :math:`d` of :math:`b`. The sizes
1147
in these dimensions must match, but :func:`~torch.tensordot` will deal with broadcasted
1152
>>> a = torch.arange(60.).reshape(3, 4, 5)
1153
>>> b = torch.arange(24.).reshape(4, 3, 2)
1154
>>> torch.tensordot(a, b, dims=([1, 0], [0, 1]))
1155
tensor([[4400., 4730.],
1161
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
1162
>>> a = torch.randn(3, 4, 5, device='cuda')
1163
>>> b = torch.randn(4, 5, 6, device='cuda')
1164
>>> c = torch.tensordot(a, b, dims=2).cpu()
1165
tensor([[ 8.3504, -2.5436, 6.2922, 2.7556, -1.0732, 3.2741],
1166
[ 3.3161, 0.0704, 5.0187, -0.4079, -4.3126, 4.8744],
1167
[ 0.8223, 3.9445, 3.2168, -0.2400, 3.4117, 1.7780]])
1169
>>> a = torch.randn(3, 5, 4, 6)
1170
>>> b = torch.randn(6, 4, 5, 3)
1171
>>> torch.tensordot(a, b, dims=([2, 1, 3], [1, 2, 0]))
1172
tensor([[ 7.7193, -2.4867, -10.3204],
1173
[ 1.5513, -14.4737, -6.5113],
1174
[ -0.2850, 4.2573, -3.5997]])
1176
if has_torch_function_variadic(a, b):
1177
return handle_torch_function(tensordot, (a, b), a, b, dims=dims, out=out)
1179
if not isinstance(dims, (tuple, list, torch.Tensor, int, torch.SymInt)):
1180
raise RuntimeError("tensordot expects dims to be int or "
1181
+ "Tuple[List[int], List[int]] or "
1182
+ "List[List[int]] containing two lists, but got "
1185
dims_a: List[int] = []
1186
dims_b: List[int] = []
1188
if isinstance(dims, (tuple, list)):
1189
dims_a, dims_b = dims
1191
if isinstance(dims, torch.Tensor):
1192
num_elements = dims.numel()
1193
if num_elements > 1:
1194
assert dims.size()[0] == 2
1195
dims_a = torch.jit.annotate(List[int], dims[0].tolist())
1196
dims_b = torch.jit.annotate(List[int], dims[1].tolist())
1198
dims_val = int(dims.item())
1200
raise RuntimeError(f"tensordot expects dims >= 0, but got dims={dims}")
1201
dims_a = list(range(-dims_val, 0))
1202
dims_b = list(range(dims_val))
1204
if isinstance(dims, (int, torch.SymInt)):
1206
raise RuntimeError(f"tensordot expects dims >= 0, but got dims={dims}")
1207
if dims > min(a.dim(), b.dim()):
1208
raise RuntimeError(f"tensordot expects dims < ndim_a or ndim_b, but got dims={dims}")
1209
dims_a = list(range(-dims, 0))
1210
dims_b = list(range(dims))
1213
return _VF.tensordot(a, b, dims_a, dims_b)
1215
return _VF.tensordot(a, b, dims_a, dims_b, out=out)
1218
def cartesian_prod(*tensors: Tensor) -> Tensor:
1219
"""Do cartesian product of the given sequence of tensors. The behavior is similar to
1220
python's `itertools.product`.
1223
*tensors: any number of 1 dimensional tensors.
1226
Tensor: A tensor equivalent to converting all the input tensors into lists,
1227
do `itertools.product` on these lists, and finally convert the resulting list
1232
>>> import itertools
1235
>>> list(itertools.product(a, b))
1236
[(1, 4), (1, 5), (2, 4), (2, 5), (3, 4), (3, 5)]
1237
>>> tensor_a = torch.tensor(a)
1238
>>> tensor_b = torch.tensor(b)
1239
>>> torch.cartesian_prod(tensor_a, tensor_b)
1248
if has_torch_function(tensors):
1249
return handle_torch_function(cartesian_prod, tensors, *tensors)
1250
return _VF.cartesian_prod(tensors)
1253
def block_diag(*tensors):
1254
"""Create a block diagonal matrix from provided tensors.
1257
*tensors: One or more tensors with 0, 1, or 2 dimensions.
1260
Tensor: A 2 dimensional tensor with all the input tensors arranged in
1261
order such that their upper left and lower right corners are
1262
diagonally adjacent. All other elements are set to 0.
1267
>>> A = torch.tensor([[0, 1], [1, 0]])
1268
>>> B = torch.tensor([[3, 4, 5], [6, 7, 8]])
1269
>>> C = torch.tensor(7)
1270
>>> D = torch.tensor([1, 2, 3])
1271
>>> E = torch.tensor([[4], [5], [6]])
1272
>>> torch.block_diag(A, B, C, D, E)
1273
tensor([[0, 1, 0, 0, 0, 0, 0, 0, 0, 0],
1274
[1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
1275
[0, 0, 3, 4, 5, 0, 0, 0, 0, 0],
1276
[0, 0, 6, 7, 8, 0, 0, 0, 0, 0],
1277
[0, 0, 0, 0, 0, 7, 0, 0, 0, 0],
1278
[0, 0, 0, 0, 0, 0, 1, 2, 3, 0],
1279
[0, 0, 0, 0, 0, 0, 0, 0, 0, 4],
1280
[0, 0, 0, 0, 0, 0, 0, 0, 0, 5],
1281
[0, 0, 0, 0, 0, 0, 0, 0, 0, 6]])
1284
if has_torch_function(tensors):
1285
return handle_torch_function(block_diag, tensors, *tensors)
1286
return torch._C._VariableFunctions.block_diag(tensors)
1289
def cdist(x1, x2, p=2., compute_mode='use_mm_for_euclid_dist_if_necessary'):
1291
r"""Computes batched the p-norm distance between each pair of the two collections of row vectors.
1294
x1 (Tensor): input tensor of shape :math:`B \times P \times M`.
1295
x2 (Tensor): input tensor of shape :math:`B \times R \times M`.
1296
p: p value for the p-norm distance to calculate between each vector pair
1297
:math:`\in [0, \infty]`.
1299
'use_mm_for_euclid_dist_if_necessary' - will use matrix multiplication approach to calculate
1300
euclidean distance (p = 2) if P > 25 or R > 25
1301
'use_mm_for_euclid_dist' - will always use matrix multiplication approach to calculate
1302
euclidean distance (p = 2)
1303
'donot_use_mm_for_euclid_dist' - will never use matrix multiplication approach to calculate
1304
euclidean distance (p = 2)
1305
Default: use_mm_for_euclid_dist_if_necessary.
1307
If x1 has shape :math:`B \times P \times M` and x2 has shape :math:`B \times R \times M` then the
1308
output will have shape :math:`B \times P \times R`.
1310
This function is equivalent to `scipy.spatial.distance.cdist(input,'minkowski', p=p)`
1311
if :math:`p \in (0, \infty)`. When :math:`p = 0` it is equivalent to
1312
`scipy.spatial.distance.cdist(input, 'hamming') * M`. When :math:`p = \infty`, the closest
1313
scipy function is `scipy.spatial.distance.cdist(xn, lambda x, y: np.abs(x - y).max())`.
1317
>>> a = torch.tensor([[0.9041, 0.0196], [-0.3108, -2.4423], [-0.4821, 1.059]])
1319
tensor([[ 0.9041, 0.0196],
1322
>>> b = torch.tensor([[-2.1763, -0.4713], [-0.6986, 1.3702]])
1324
tensor([[-2.1763, -0.4713],
1326
>>> torch.cdist(a, b, p=2)
1327
tensor([[3.1193, 2.0959],
1331
if has_torch_function_variadic(x1, x2):
1332
return handle_torch_function(
1333
cdist, (x1, x2), x1, x2, p=p, compute_mode=compute_mode)
1334
if compute_mode == 'use_mm_for_euclid_dist_if_necessary':
1335
return _VF.cdist(x1, x2, p, None)
1336
elif compute_mode == 'use_mm_for_euclid_dist':
1337
return _VF.cdist(x1, x2, p, 1)
1338
elif compute_mode == 'donot_use_mm_for_euclid_dist':
1339
return _VF.cdist(x1, x2, p, 2)
1341
raise ValueError(f"{compute_mode} is not a valid value for compute_mode")
1344
def atleast_1d(*tensors):
1346
Returns a 1-dimensional view of each input tensor with zero dimensions.
1347
Input tensors with one or more dimensions are returned as-is.
1350
input (Tensor or list of Tensors)
1353
output (Tensor or tuple of Tensors)
1357
>>> x = torch.arange(2)
1360
>>> torch.atleast_1d(x)
1362
>>> x = torch.tensor(1.)
1365
>>> torch.atleast_1d(x)
1367
>>> x = torch.tensor(0.5)
1368
>>> y = torch.tensor(1.)
1369
>>> torch.atleast_1d((x, y))
1370
(tensor([0.5000]), tensor([1.]))
1373
if has_torch_function(tensors):
1374
return handle_torch_function(atleast_1d, tensors, *tensors)
1375
if len(tensors) == 1:
1376
tensors = tensors[0]
1377
return _VF.atleast_1d(tensors)
1380
def atleast_2d(*tensors):
1382
Returns a 2-dimensional view of each input tensor with zero dimensions.
1383
Input tensors with two or more dimensions are returned as-is.
1386
input (Tensor or list of Tensors)
1389
output (Tensor or tuple of Tensors)
1393
>>> x = torch.tensor(1.)
1396
>>> torch.atleast_2d(x)
1398
>>> x = torch.arange(4).view(2, 2)
1402
>>> torch.atleast_2d(x)
1405
>>> x = torch.tensor(0.5)
1406
>>> y = torch.tensor(1.)
1407
>>> torch.atleast_2d((x, y))
1408
(tensor([[0.5000]]), tensor([[1.]]))
1411
if has_torch_function(tensors):
1412
return handle_torch_function(atleast_2d, tensors, *tensors)
1413
if len(tensors) == 1:
1414
tensors = tensors[0]
1415
return _VF.atleast_2d(tensors)
1418
def atleast_3d(*tensors):
1420
Returns a 3-dimensional view of each input tensor with zero dimensions.
1421
Input tensors with three or more dimensions are returned as-is.
1424
input (Tensor or list of Tensors)
1427
output (Tensor or tuple of Tensors)
1431
>>> x = torch.tensor(0.5)
1434
>>> torch.atleast_3d(x)
1435
tensor([[[0.5000]]])
1436
>>> y = torch.arange(4).view(2, 2)
1440
>>> torch.atleast_3d(y)
1446
>>> x = torch.tensor(1).view(1, 1, 1)
1449
>>> torch.atleast_3d(x)
1451
>>> x = torch.tensor(0.5)
1452
>>> y = torch.tensor(1.)
1453
>>> torch.atleast_3d((x, y))
1454
(tensor([[[0.5000]]]), tensor([[[1.]]]))
1457
if has_torch_function(tensors):
1458
return handle_torch_function(atleast_3d, tensors, *tensors)
1459
if len(tensors) == 1:
1460
tensors = tensors[0]
1461
return _VF.atleast_3d(tensors)
1480
def norm(input, p="fro", dim=None, keepdim=False, out=None, dtype=None):
1485
def norm(input, p="fro", dim=None, keepdim=False, out=None, dtype=None):
1490
def norm(input, p="fro", dim=None, keepdim=False, out=None, dtype=None):
1495
def norm(input, p="fro", dim=None, keepdim=False, out=None, dtype=None):
1500
def norm(input, p: Optional[Union[float, str]] = "fro", dim=None, keepdim=False, out=None, dtype=None):
1501
r"""Returns the matrix norm or vector norm of a given tensor.
1505
torch.norm is deprecated and may be removed in a future PyTorch release.
1506
Its documentation and behavior may be incorrect, and it is no longer
1507
actively maintained.
1509
Use :func:`torch.linalg.vector_norm` when computing vector norms and
1510
:func:`torch.linalg.matrix_norm` when computing matrix norms.
1511
For a function with a similar behavior as this one see :func:`torch.linalg.norm`.
1512
Note, however, the signature for these functions is slightly different than the
1513
signature for ``torch.norm``.
1516
input (Tensor): The input tensor. Its data type must be either a floating
1517
point or complex type. For complex inputs, the norm is calculated using the
1518
absolute value of each element. If the input is complex and neither
1519
:attr:`dtype` nor :attr:`out` is specified, the result's data type will
1520
be the corresponding floating point type (e.g. float if :attr:`input` is
1523
p (int, float, inf, -inf, 'fro', 'nuc', optional): the order of norm. Default: ``'fro'``
1524
The following norms can be calculated:
1526
====== ============== ==========================
1527
ord matrix norm vector norm
1528
====== ============== ==========================
1529
'fro' Frobenius norm --
1530
'nuc' nuclear norm --
1531
Number -- sum(abs(x)**ord)**(1./ord)
1532
====== ============== ==========================
1534
The vector norm can be calculated across any number of dimensions.
1535
The corresponding dimensions of :attr:`input` are flattened into
1536
one dimension, and the norm is calculated on the flattened
1539
Frobenius norm produces the same result as ``p=2`` in all cases
1540
except when :attr:`dim` is a list of three or more dims, in which
1541
case Frobenius norm throws an error.
1543
Nuclear norm can only be calculated across exactly two dimensions.
1545
dim (int, tuple of ints, list of ints, optional):
1546
Specifies which dimension or dimensions of :attr:`input` to
1547
calculate the norm across. If :attr:`dim` is ``None``, the norm will
1548
be calculated across all dimensions of :attr:`input`. If the norm
1549
type indicated by :attr:`p` does not support the specified number of
1550
dimensions, an error will occur.
1551
keepdim (bool, optional): whether the output tensors have :attr:`dim`
1552
retained or not. Ignored if :attr:`dim` = ``None`` and
1553
:attr:`out` = ``None``. Default: ``False``
1554
out (Tensor, optional): the output tensor. Ignored if
1555
:attr:`dim` = ``None`` and :attr:`out` = ``None``.
1556
dtype (:class:`torch.dtype`, optional): the desired data type of
1557
returned tensor. If specified, the input tensor is casted to
1558
:attr:`dtype` while performing the operation. Default: None.
1561
Even though ``p='fro'`` supports any number of dimensions, the true
1562
mathematical definition of Frobenius norm only applies to tensors with
1563
exactly two dimensions. :func:`torch.linalg.matrix_norm` with ``ord='fro'``
1564
aligns with the mathematical definition, since it can only be applied across
1565
exactly two dimensions.
1570
>>> a = torch.arange(9, dtype= torch.float) - 4
1571
>>> b = a.reshape((3, 3))
1576
>>> torch.norm(a, float('inf'))
1578
>>> torch.norm(b, float('inf'))
1580
>>> c = torch.tensor([[ 1, 2, 3], [-1, 1, 4]] , dtype=torch.float)
1581
>>> torch.norm(c, dim=0)
1582
tensor([1.4142, 2.2361, 5.0000])
1583
>>> torch.norm(c, dim=1)
1584
tensor([3.7417, 4.2426])
1585
>>> torch.norm(c, p=1, dim=1)
1587
>>> d = torch.arange(8, dtype=torch.float).reshape(2, 2, 2)
1588
>>> torch.norm(d, dim=(1, 2))
1589
tensor([ 3.7417, 11.2250])
1590
>>> torch.norm(d[0, :, :]), torch.norm(d[1, :, :])
1591
(tensor(3.7417), tensor(11.2250))
1594
if has_torch_function_unary(input):
1595
return handle_torch_function(
1596
norm, (input,), input, p=p, dim=dim, keepdim=keepdim, out=out, dtype=dtype)
1602
if input.layout == torch.strided and input.device.type in \
1603
("cpu", "cuda", "meta", torch.utils.backend_registration._privateuse1_backend_name):
1605
if isinstance(dim, (int, torch.SymInt)):
1612
if isinstance(p, str):
1613
if p == "fro" and (dim is None or isinstance(dim, (int, torch.SymInt)) or len(dim) <= 2):
1615
return torch.linalg.vector_norm(input, 2, _dim, keepdim, dtype=dtype)
1617
return torch.linalg.vector_norm(input, 2, _dim, keepdim, dtype=dtype, out=out)
1622
_dim = list(range(input.ndim))
1624
return torch.linalg.matrix_norm(input, p, _dim, keepdim, dtype=dtype)
1626
return torch.linalg.matrix_norm(input, p, _dim, keepdim, dtype=dtype, out=out)
1629
_p = 2.0 if p is None else p
1631
return torch.linalg.vector_norm(input, _p, _dim, keepdim, dtype=dtype)
1633
return torch.linalg.vector_norm(input, _p, _dim, keepdim, dtype=dtype, out=out)
1638
if dim is None and out is None and dtype is None and p is not None:
1639
if isinstance(p, str):
1641
return _VF.frobenius_norm(input, dim=(), keepdim=keepdim)
1642
if not isinstance(p, str):
1643
_dim = [i for i in range(ndim)]
1644
return _VF.norm(input, p, dim=_dim, keepdim=keepdim)
1650
if isinstance(dim, (int, torch.SymInt)):
1657
if isinstance(p, str):
1659
if dtype is not None:
1660
raise ValueError("dtype argument is not supported in frobenius norm")
1663
_dim = list(range(ndim))
1665
return _VF.frobenius_norm(input, _dim, keepdim=keepdim)
1667
return _VF.frobenius_norm(input, _dim, keepdim=keepdim, out=out)
1669
if dtype is not None:
1670
raise ValueError("dtype argument is not supported in nuclear norm")
1673
return _VF.nuclear_norm(input, keepdim=keepdim)
1675
return _VF.nuclear_norm(input, keepdim=keepdim, out=out)
1678
return _VF.nuclear_norm(input, _dim, keepdim=keepdim)
1680
return _VF.nuclear_norm(input, _dim, keepdim=keepdim, out=out)
1681
raise RuntimeError(f"only valid string values are 'fro' and 'nuc', found {p}")
1684
_dim = list(range(ndim))
1688
return _VF.norm(input, p, _dim, keepdim=keepdim)
1690
return _VF.norm(input, p, _dim, keepdim=keepdim, dtype=dtype)
1693
return _VF.norm(input, p, _dim, keepdim=keepdim, out=out)
1695
return _VF.norm(input, p, _dim, keepdim=keepdim, dtype=dtype, out=out)
1697
def unravel_index(indices: Tensor, shape: Union[int, Sequence[int], torch.Size]) -> Tuple[Tensor, ...]:
1698
r"""Converts a tensor of flat indices into a tuple of coordinate tensors that
1699
index into an arbitrary tensor of the specified shape.
1702
indices (Tensor): An integer tensor containing indices into the
1703
flattened version of an arbitrary tensor of shape :attr:`shape`.
1704
All elements must be in the range ``[0, prod(shape) - 1]``.
1706
shape (int, sequence of ints, or torch.Size): The shape of the arbitrary
1707
tensor. All elements must be non-negative.
1710
tuple of Tensors: Each ``i``-th tensor in the output corresponds with
1711
dimension ``i`` of :attr:`shape`. Each tensor has the same shape as
1712
``indices`` and contains one index into dimension ``i`` for each of the
1713
flat indices given by ``indices``.
1718
>>> torch.unravel_index(torch.tensor(4), (3, 2))
1722
>>> torch.unravel_index(torch.tensor([4, 1]), (3, 2))
1726
>>> torch.unravel_index(torch.tensor([0, 1, 2, 3, 4, 5]), (3, 2))
1727
(tensor([0, 0, 1, 1, 2, 2]),
1728
tensor([0, 1, 0, 1, 0, 1]))
1730
>>> torch.unravel_index(torch.tensor([1234, 5678]), (10, 10, 10, 10))
1736
>>> torch.unravel_index(torch.tensor([[1234], [5678]]), (10, 10, 10, 10))
1737
(tensor([[1], [5]]),
1742
>>> torch.unravel_index(torch.tensor([[1234], [5678]]), (100, 100))
1743
(tensor([[12], [56]]),
1744
tensor([[34], [78]]))
1746
if has_torch_function_unary(indices):
1747
return handle_torch_function(
1748
unravel_index, (indices,), indices, shape=shape)
1749
res_tensor = _unravel_index(indices, shape)
1750
return res_tensor.unbind(-1)
1752
def _unravel_index(indices: Tensor, shape: Union[int, Sequence[int]]) -> Tensor:
1754
not indices.is_complex() and not indices.is_floating_point() and not indices.dtype == torch.bool,
1755
lambda: f"expected 'indices' to be integer dtype, but got {indices.dtype}")
1758
isinstance(shape, (int, torch.SymInt, Sequence)),
1759
lambda: f"expected 'shape' to be int or sequence of ints, but got {type(shape)}")
1761
if isinstance(shape, (int, torch.SymInt)):
1762
shape = torch.Size([shape])
1766
isinstance(dim, (int, torch.SymInt)),
1767
lambda: f"expected 'shape' sequence to only contain ints, but got {type(dim)}")
1768
shape = torch.Size(shape)
1771
all(dim >= 0 for dim in shape),
1772
lambda: f"'shape' cannot have negative values, but got {tuple(shape)}")
1774
coefs = list(reversed(list(itertools.accumulate(reversed(shape[1:] + torch.Size([1])), func=operator.mul))))
1775
return indices.unsqueeze(-1).floor_divide(
1776
torch.tensor(coefs, device=indices.device, dtype=torch.int64)
1777
) % torch.tensor(shape, device=indices.device, dtype=torch.int64)
1779
def chain_matmul(*matrices, out=None):
1780
r"""Returns the matrix product of the :math:`N` 2-D tensors. This product is efficiently computed
1781
using the matrix chain order algorithm which selects the order in which incurs the lowest cost in terms
1782
of arithmetic operations (`[CLRS]`_). Note that since this is a function to compute the product, :math:`N`
1783
needs to be greater than or equal to 2; if equal to 2 then a trivial matrix-matrix product is returned.
1784
If :math:`N` is 1, then this is a no-op - the original matrix is returned as is.
1788
:func:`torch.chain_matmul` is deprecated and will be removed in a future PyTorch release.
1789
Use :func:`torch.linalg.multi_dot` instead, which accepts a list of two or more tensors
1790
rather than multiple arguments.
1793
matrices (Tensors...): a sequence of 2 or more 2-D tensors whose product is to be determined.
1794
out (Tensor, optional): the output tensor. Ignored if :attr:`out` = ``None``.
1797
Tensor: if the :math:`i^{th}` tensor was of dimensions :math:`p_{i} \times p_{i + 1}`, then the product
1798
would be of dimensions :math:`p_{1} \times p_{N + 1}`.
1802
>>> # xdoctest: +SKIP
1803
>>> # xdoctest: +IGNORE_WANT("non-deterministic")
1804
>>> a = torch.randn(3, 4)
1805
>>> b = torch.randn(4, 5)
1806
>>> c = torch.randn(5, 6)
1807
>>> d = torch.randn(6, 7)
1808
>>> # will raise a deprecation warning
1809
>>> torch.chain_matmul(a, b, c, d)
1810
tensor([[ -2.3375, -3.9790, -4.1119, -6.6577, 9.5609, -11.5095, -3.2614],
1811
[ 21.4038, 3.3378, -8.4982, -5.2457, -10.2561, -2.4684, 2.7163],
1812
[ -0.9647, -5.8917, -2.3213, -5.2284, 12.8615, -12.2816, -2.5095]])
1814
.. _`[CLRS]`: https://mitpress.mit.edu/books/introduction-algorithms-third-edition
1817
if has_torch_function(matrices):
1818
return handle_torch_function(chain_matmul, matrices, *matrices)
1821
return _VF.chain_matmul(matrices)
1823
return _VF.chain_matmul(matrices, out=out)
1826
def _lu_impl(A, pivot=True, get_infos=False, out=None):
1828
r"""Computes the LU factorization of a matrix or batches of matrices
1829
:attr:`A`. Returns a tuple containing the LU factorization and
1830
pivots of :attr:`A`. Pivoting is done if :attr:`pivot` is set to
1835
:func:`torch.lu` is deprecated in favor of :func:`torch.linalg.lu_factor`
1836
and :func:`torch.linalg.lu_factor_ex`. :func:`torch.lu` will be removed in a
1837
future PyTorch release.
1838
``LU, pivots, info = torch.lu(A, compute_pivots)`` should be replaced with
1842
LU, pivots = torch.linalg.lu_factor(A, compute_pivots)
1844
``LU, pivots, info = torch.lu(A, compute_pivots, get_infos=True)`` should be replaced with
1848
LU, pivots, info = torch.linalg.lu_factor_ex(A, compute_pivots)
1851
* The returned permutation matrix for every matrix in the batch is
1852
represented by a 1-indexed vector of size ``min(A.shape[-2], A.shape[-1])``.
1853
``pivots[i] == j`` represents that in the ``i``-th step of the algorithm,
1854
the ``i``-th row was permuted with the ``j-1``-th row.
1855
* LU factorization with :attr:`pivot` = ``False`` is not available
1856
for CPU, and attempting to do so will throw an error. However,
1857
LU factorization with :attr:`pivot` = ``False`` is available for
1859
* This function does not check if the factorization was successful
1860
or not if :attr:`get_infos` is ``True`` since the status of the
1861
factorization is present in the third element of the return tuple.
1862
* In the case of batches of square matrices with size less or equal
1863
to 32 on a CUDA device, the LU factorization is repeated for
1864
singular matrices due to the bug in the MAGMA library
1865
(see magma issue 13).
1866
* ``L``, ``U``, and ``P`` can be derived using :func:`torch.lu_unpack`.
1869
The gradients of this function will only be finite when :attr:`A` is full rank.
1870
This is because the LU decomposition is just differentiable at full rank matrices.
1871
Furthermore, if :attr:`A` is close to not being full rank,
1872
the gradient will be numerically unstable as it depends on the computation of :math:`L^{-1}` and :math:`U^{-1}`.
1875
A (Tensor): the tensor to factor of size :math:`(*, m, n)`
1876
pivot (bool, optional): controls whether pivoting is done. Default: ``True``
1877
get_infos (bool, optional): if set to ``True``, returns an info IntTensor.
1879
out (tuple, optional): optional output tuple. If :attr:`get_infos` is ``True``,
1880
then the elements in the tuple are Tensor, IntTensor,
1881
and IntTensor. If :attr:`get_infos` is ``False``, then the
1882
elements in the tuple are Tensor, IntTensor. Default: ``None``
1885
(Tensor, IntTensor, IntTensor (optional)): A tuple of tensors containing
1887
- **factorization** (*Tensor*): the factorization of size :math:`(*, m, n)`
1889
- **pivots** (*IntTensor*): the pivots of size :math:`(*, \text{min}(m, n))`.
1890
``pivots`` stores all the intermediate transpositions of rows.
1891
The final permutation ``perm`` could be reconstructed by
1892
applying ``swap(perm[i], perm[pivots[i] - 1])`` for ``i = 0, ..., pivots.size(-1) - 1``,
1893
where ``perm`` is initially the identity permutation of :math:`m` elements
1894
(essentially this is what :func:`torch.lu_unpack` is doing).
1896
- **infos** (*IntTensor*, *optional*): if :attr:`get_infos` is ``True``, this is a tensor of
1897
size :math:`(*)` where non-zero values indicate whether factorization for the matrix or
1898
each minibatch has succeeded or failed
1902
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_LAPACK)
1903
>>> # xdoctest: +IGNORE_WANT("non-deterministic")
1904
>>> A = torch.randn(2, 3, 3)
1905
>>> A_LU, pivots = torch.lu(A)
1907
tensor([[[ 1.3506, 2.5558, -0.0816],
1908
[ 0.1684, 1.1551, 0.1940],
1909
[ 0.1193, 0.6189, -0.5497]],
1911
[[ 0.4526, 1.2526, -0.3285],
1912
[-0.7988, 0.7175, -0.9701],
1913
[ 0.2634, -0.9255, -0.3459]]])
1916
[ 3, 3, 3]], dtype=torch.int32)
1917
>>> A_LU, pivots, info = torch.lu(A, get_infos=True)
1918
>>> if info.nonzero().size(0) == 0:
1919
... print('LU factorization succeeded for all samples!')
1920
LU factorization succeeded for all samples!
1923
return torch._lu_with_info(A, pivot=pivot, check_errors=(not get_infos))
1926
_ListOrSeq = Sequence[Tensor]
1928
_ListOrSeq = List[Tensor]
1931
def _check_list_size(out_len: int, get_infos: bool, out: _ListOrSeq) -> None:
1932
get_infos_int = 1 if get_infos else 0
1933
if out_len - get_infos_int != 2:
1934
raise TypeError(f"expected tuple of {2 + int(get_infos)} elements but got {out_len}")
1935
if not isinstance(out, (tuple, list)):
1936
raise TypeError(f"argument 'out' must be tuple of Tensors, not {type(out).__name__}")
1939
def _lu_with_infos(A, pivot=True, get_infos=False, out=None):
1941
if has_torch_function_unary(A):
1942
return handle_torch_function(
1943
lu, (A,), A, pivot=pivot, get_infos=get_infos, out=out)
1944
result = _lu_impl(A, pivot, get_infos, out)
1946
_check_list_size(len(out), get_infos, out)
1947
for i in range(len(out)):
1948
out[i].resize_as_(result[i]).copy_(result[i])
1954
def _lu_no_infos(A, pivot=True, get_infos=False, out=None):
1957
if has_torch_function_unary(A):
1958
return handle_torch_function(
1959
lu, (A,), A, pivot=pivot, get_infos=get_infos, out=out)
1960
result = _lu_impl(A, pivot, get_infos, out)
1962
_check_list_size(len(out), get_infos, out)
1963
for i in range(len(out)):
1964
out[i].resize_as_(result[i]).copy_(result[i])
1967
return result[0], result[1]
1971
lu = boolean_dispatch(
1972
arg_name='get_infos',
1975
if_true=_lu_with_infos,
1976
if_false=_lu_no_infos,
1977
module_name=__name__,
1979
lu.__doc__ = _lu_impl.__doc__
1982
def align_tensors(*tensors):
1983
raise RuntimeError('`align_tensors` not yet implemented.')