2
from typing import Iterable, List, NamedTuple, Tuple, Union
5
from torch import Tensor
7
from ..._jit_internal import Optional
10
__all__ = ['PackedSequence', 'invert_permutation', 'pack_padded_sequence', 'pad_packed_sequence', 'pad_sequence',
11
'unpad_sequence', 'pack_sequence', 'unpack_sequence']
14
class PackedSequence_(NamedTuple):
16
batch_sizes: torch.Tensor
17
sorted_indices: Optional[torch.Tensor]
18
unsorted_indices: Optional[torch.Tensor]
21
def bind(optional, fn):
27
class PackedSequence(PackedSequence_):
28
r"""Holds the data and list of :attr:`batch_sizes` of a packed sequence.
30
All RNN modules accept packed sequences as inputs.
33
Instances of this class should never be created manually. They are meant
34
to be instantiated by functions like :func:`pack_padded_sequence`.
36
Batch sizes represent the number elements at each sequence step in
37
the batch, not the varying sequence lengths passed to
38
:func:`pack_padded_sequence`. For instance, given data ``abc`` and ``x``
39
the :class:`PackedSequence` would contain data ``axbc`` with
40
``batch_sizes=[2,1,1]``.
43
data (Tensor): Tensor containing packed sequence
44
batch_sizes (Tensor): Tensor of integers holding
45
information about the batch size at each sequence step
46
sorted_indices (Tensor, optional): Tensor of integers holding how this
47
:class:`PackedSequence` is constructed from sequences.
48
unsorted_indices (Tensor, optional): Tensor of integers holding how this
49
to recover the original sequences with correct order.
52
:attr:`data` can be on arbitrary device and of arbitrary dtype.
53
:attr:`sorted_indices` and :attr:`unsorted_indices` must be ``torch.int64``
54
tensors on the same device as :attr:`data`.
56
However, :attr:`batch_sizes` should always be a CPU ``torch.int64`` tensor.
58
This invariant is maintained throughout :class:`PackedSequence` class,
59
and all functions that construct a :class:`PackedSequence` in PyTorch
60
(i.e., they only pass in tensors conforming to this constraint).
64
def __new__(cls, data, batch_sizes=None, sorted_indices=None, unsorted_indices=None):
65
return super().__new__(
67
*_packed_sequence_init_args(data, batch_sizes, sorted_indices,
70
# NOTE [ device and dtype of a PackedSequence ]
72
# See the note above in doc string (starting with ":attr:`data` can be on
73
# arbitrary device...").
75
# Why not convert `batch_sizes`?
76
# See NOTE [ device and dtype of a PackedSequence ]
77
return type(self)(self.data.pin_memory(), self.batch_sizes,
78
bind(self.sorted_indices, lambda t: t.pin_memory()),
79
bind(self.unsorted_indices, lambda t: t.pin_memory()))
81
def cuda(self, *args, **kwargs):
82
# Tests to see if 'cuda' should be added to kwargs
83
ex = torch.tensor((), dtype=self.data.dtype, device=self.data.device).to(*args, **kwargs)
85
return self.to(*args, **kwargs)
86
return self.to(*args, device='cuda', **kwargs)
88
def cpu(self, *args, **kwargs):
90
ex = torch.tensor((), dtype=self.data.dtype, device=self.data.device).to(*args, **kwargs)
91
if ex.device.type == 'cpu':
92
return self.to(*args, **kwargs)
93
return self.to(*args, device='cpu', **kwargs)
96
return self.to(dtype=torch.double)
99
return self.to(dtype=torch.float)
102
return self.to(dtype=torch.half)
105
return self.to(dtype=torch.long)
108
return self.to(dtype=torch.int)
111
return self.to(dtype=torch.short)
114
return self.to(dtype=torch.int8)
117
return self.to(dtype=torch.uint8)
119
def to(self, *args, **kwargs):
120
r"""Perform dtype and/or device conversion on `self.data`.
122
It has similar signature as :meth:`torch.Tensor.to`, except optional
123
arguments like `non_blocking` and `copy` should be passed as kwargs,
124
not args, or they will not apply to the index tensors.
128
If the ``self.data`` Tensor already has the correct :class:`torch.dtype`
129
and :class:`torch.device`, then ``self`` is returned.
130
Otherwise, returns a copy with the desired configuration.
132
# Why not convert `batch_sizes`?
133
# See NOTE [ device and dtype of a PackedSequence ]
134
data = self.data.to(*args, **kwargs)
135
if data is self.data:
138
# Does not forward device or dtype arg/kwargs, device is set from data.device
139
kwargs = dict(filter(lambda t: t[0] != 'device' and t[0] != 'dtype', kwargs.items()))
140
sorted_indices = bind(self.sorted_indices, lambda t: t.to(data.device, **kwargs))
141
unsorted_indices = bind(self.unsorted_indices, lambda t: t.to(data.device, **kwargs))
142
return type(self)(data, self.batch_sizes, sorted_indices, unsorted_indices)
146
r"""Return true if `self.data` stored on a gpu."""
147
return self.data.is_cuda
150
r"""Return true if `self.data` stored on in pinned memory."""
151
return self.data.is_pinned()
154
# TorchScript doesn't support constructors on named tuples, so we use this helper
155
# method to construct PackedSequence
156
def _packed_sequence_init_args(
158
batch_sizes: Optional[Tensor] = None,
159
sorted_indices: Optional[Tensor] = None,
160
unsorted_indices: Optional[Tensor] = None,
161
) -> Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor]]:
162
# NB: if unsorted_indices is provided, it should be the inverse permutation
163
# to sorted_indices. Don't assert it here because the PackedSequence ctor
164
# should only be used internally.
166
if unsorted_indices is None:
167
unsorted_indices = invert_permutation(sorted_indices)
169
# support being called as `PackedSequence(data, batch_sizes, sorted_indices)`
170
if batch_sizes is not None:
171
# TODO: Re-enable this check (.type isn't supported in TorchScript)
172
if batch_sizes.device.type != 'cpu':
174
"batch_sizes should always be on CPU. "
175
"Instances of PackedSequence should never be created manually. "
176
"They should be instantiated by functions like pack_sequence "
177
"and pack_padded_sequences in nn.utils.rnn. "
178
"https://pytorch.org/docs/stable/nn.html#torch.nn.utils.rnn.pack_sequence")
179
return data, batch_sizes, sorted_indices, unsorted_indices
181
# support being called as `PackedSequence((data, batch_sizes), *, sorted_indices)`
183
assert isinstance(data, (list, tuple)) and len(data) == 2
184
return data[0], data[1], sorted_indices, unsorted_indices
187
def _packed_sequence_init(
189
batch_sizes: Optional[Tensor] = None,
190
sorted_indices: Optional[Tensor] = None,
191
unsorted_indices: Optional[Tensor] = None,
193
data, batch_sizes, sorted_indices, unsorted_indices = _packed_sequence_init_args(
194
data, batch_sizes, sorted_indices, unsorted_indices)
195
return PackedSequence(data, batch_sizes, sorted_indices, unsorted_indices)
198
def invert_permutation(permutation: Optional[Tensor]) -> Optional[Tensor]:
199
if permutation is None:
201
output = torch.empty_like(permutation, memory_format=torch.legacy_contiguous_format)
202
output.scatter_(0, permutation,
203
torch.arange(0, permutation.numel(), device=permutation.device))
207
def pack_padded_sequence(
210
batch_first: bool = False,
211
enforce_sorted: bool = True,
213
r"""Packs a Tensor containing padded sequences of variable length.
215
:attr:`input` can be of size ``T x B x *`` where `T` is the length of the
216
longest sequence (equal to ``lengths[0]``), ``B`` is the batch size, and
217
``*`` is any number of dimensions (including 0). If ``batch_first`` is
218
``True``, ``B x T x *`` :attr:`input` is expected.
220
For unsorted sequences, use `enforce_sorted = False`. If :attr:`enforce_sorted` is
221
``True``, the sequences should be sorted by length in a decreasing order, i.e.
222
``input[:,0]`` should be the longest sequence, and ``input[:,B-1]`` the shortest
223
one. `enforce_sorted = True` is only necessary for ONNX export.
226
This function accepts any input that has at least two dimensions. You
227
can apply it to pack the labels, and use the output of the RNN with
228
them to compute the loss directly. A Tensor can be retrieved from
229
a :class:`PackedSequence` object by accessing its ``.data`` attribute.
232
input (Tensor): padded batch of variable length sequences.
233
lengths (Tensor or list(int)): list of sequence lengths of each batch
234
element (must be on the CPU if provided as a tensor).
235
batch_first (bool, optional): if ``True``, the input is expected in ``B x T x *``
237
enforce_sorted (bool, optional): if ``True``, the input is expected to
238
contain sequences sorted by length in a decreasing order. If
239
``False``, the input will get sorted unconditionally. Default: ``True``.
242
a :class:`PackedSequence` object
244
if not isinstance(lengths, torch.Tensor):
245
if torch._C._get_tracing_state():
246
warnings.warn('pack_padded_sequence has been called with a Python list of '
247
'sequence lengths. The tracer cannot track the data flow of Python '
248
'values, and it will treat them as constants, likely rendering '
249
'the trace incorrect for any other combination of lengths.',
251
lengths = torch.as_tensor(lengths, dtype=torch.int64, device='cpu')
253
lengths = lengths.to(dtype=torch.int64)
256
sorted_indices = None
258
lengths, sorted_indices = torch.sort(lengths, descending=True)
259
sorted_indices = sorted_indices.to(input.device)
260
batch_dim = 0 if batch_first else 1
261
input = input.index_select(batch_dim, sorted_indices)
263
data, batch_sizes = \
264
_VF._pack_padded_sequence(input, lengths, batch_first)
265
return _packed_sequence_init(data, batch_sizes, sorted_indices, None)
268
def pad_packed_sequence(
269
sequence: PackedSequence,
270
batch_first: bool = False,
271
padding_value: float = 0.0,
272
total_length: Optional[int] = None,
273
) -> Tuple[Tensor, Tensor]:
274
r"""Pad a packed batch of variable length sequences.
276
It is an inverse operation to :func:`pack_padded_sequence`.
278
The returned Tensor's data will be of size ``T x B x *``, where `T` is the length
279
of the longest sequence and `B` is the batch size. If ``batch_first`` is True,
280
the data will be transposed into ``B x T x *`` format.
283
>>> from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
284
>>> seq = torch.tensor([[1, 2, 0], [3, 0, 0], [4, 5, 6]])
286
>>> packed = pack_padded_sequence(seq, lens, batch_first=True, enforce_sorted=False)
288
PackedSequence(data=tensor([4, 1, 3, 5, 2, 6]), batch_sizes=tensor([3, 2, 1]),
289
sorted_indices=tensor([2, 0, 1]), unsorted_indices=tensor([1, 2, 0]))
290
>>> seq_unpacked, lens_unpacked = pad_packed_sequence(packed, batch_first=True)
299
:attr:`total_length` is useful to implement the
300
``pack sequence -> recurrent network -> unpack sequence`` pattern in a
301
:class:`~torch.nn.Module` wrapped in :class:`~torch.nn.DataParallel`.
302
See :ref:`this FAQ section <pack-rnn-unpack-with-data-parallelism>` for
306
sequence (PackedSequence): batch to pad
307
batch_first (bool, optional): if ``True``, the output will be in ``B x T x *``
309
padding_value (float, optional): values for padded elements.
310
total_length (int, optional): if not ``None``, the output will be padded to
311
have length :attr:`total_length`. This method will throw :class:`ValueError`
312
if :attr:`total_length` is less than the max sequence length in
316
Tuple of Tensor containing the padded sequence, and a Tensor
317
containing the list of lengths of each sequence in the batch.
318
Batch elements will be re-ordered as they were ordered originally when
319
the batch was passed to ``pack_padded_sequence`` or ``pack_sequence``.
325
max_seq_length = sequence.batch_sizes.size(0)
326
if total_length is not None:
327
if total_length < max_seq_length:
328
raise ValueError("Expected total_length to be at least the length "
329
"of the longest sequence in input, but got "
330
f"total_length={total_length} and max sequence length being {max_seq_length}"
332
max_seq_length = total_length
333
padded_output, lengths = _VF._pad_packed_sequence(
334
sequence.data, sequence.batch_sizes, batch_first, padding_value, max_seq_length)
335
unsorted_indices = sequence.unsorted_indices
336
if unsorted_indices is not None:
337
batch_dim = 0 if batch_first else 1
338
return padded_output.index_select(batch_dim, unsorted_indices), lengths[unsorted_indices.cpu()]
339
return padded_output, lengths
341
# NOTE: .pyi stub allows Iterable[Tensor], but for JIT-compatibility we need to be more restrictive here.
343
sequences: Union[Tensor, List[Tensor]],
344
batch_first: bool = False,
345
padding_value: float = 0.0,
347
r"""Pad a list of variable length Tensors with ``padding_value``.
349
``pad_sequence`` stacks a list of Tensors along a new dimension,
350
and pads them to equal length. For example, if the input is a list of
351
sequences with size ``L x *`` and ``batch_first`` is False, the output is
352
of size ``T x B x *``.
354
`B` is batch size. It is equal to the number of elements in ``sequences``.
355
`T` is length of the longest sequence.
356
`L` is length of the sequence.
357
`*` is any number of trailing dimensions, including none.
360
>>> from torch.nn.utils.rnn import pad_sequence
361
>>> a = torch.ones(25, 300)
362
>>> b = torch.ones(22, 300)
363
>>> c = torch.ones(15, 300)
364
>>> pad_sequence([a, b, c]).size()
365
torch.Size([25, 3, 300])
368
This function returns a Tensor of size ``T x B x *`` or ``B x T x *``
369
where `T` is the length of the longest sequence. This function assumes
370
trailing dimensions and type of all the Tensors in sequences are same.
373
sequences (list[Tensor]): list of variable length sequences.
374
batch_first (bool, optional): output will be in ``B x T x *`` if True, or in
375
``T x B x *`` otherwise. Default: False.
376
padding_value (float, optional): value for padded elements. Default: 0.
379
Tensor of size ``T x B x *`` if :attr:`batch_first` is ``False``.
380
Tensor of size ``B x T x *`` otherwise
382
if not (torch.jit.is_tracing() or torch.jit.is_scripting()):
383
# JIT doesn't support `Iterable`
384
if not isinstance(sequences, Iterable):
385
msg = ('pad_sequence: Expected iterable for input sequences, but got arg of type: '
386
f'{type(sequences)}')
387
raise RuntimeError(msg)
389
# In JIT context this leads to,
390
# RuntimeError: cannot statically infer the expected size of a list in this context
391
sequences = tuple(sequences)
393
# For JIT, we only support Union[Tensor, Tuple[Tensor]]
394
if isinstance(sequences, torch.Tensor):
395
sequences = sequences.unbind(0)
397
# assuming trailing dimensions and type of all the Tensors
398
# in sequences are same and fetching those from sequences[0]
399
return torch._C._nn.pad_sequence(sequences, batch_first, padding_value)
403
padded_sequences: Tensor,
405
batch_first: bool = False,
407
r"""Unpad padded Tensor into a list of variable length Tensors.
409
``unpad_sequence`` unstacks padded Tensor into a list of variable length Tensors.
412
>>> from torch.nn.utils.rnn import pad_sequence, unpad_sequence
413
>>> a = torch.ones(25, 300)
414
>>> b = torch.ones(22, 300)
415
>>> c = torch.ones(15, 300)
416
>>> sequences = [a, b, c]
417
>>> padded_sequences = pad_sequence(sequences)
418
>>> lengths = torch.as_tensor([v.size(0) for v in sequences])
419
>>> unpadded_sequences = unpad_sequence(padded_sequences, lengths)
420
>>> torch.allclose(sequences[0], unpadded_sequences[0])
422
>>> torch.allclose(sequences[1], unpadded_sequences[1])
424
>>> torch.allclose(sequences[2], unpadded_sequences[2])
428
padded_sequences (Tensor): padded sequences.
429
lengths (Tensor): length of original (unpadded) sequences.
430
batch_first (bool, optional): whether batch dimension first or not. Default: False.
433
a list of :class:`Tensor` objects
435
unpadded_sequences = []
438
padded_sequences.transpose_(0, 1)
440
max_length = padded_sequences.shape[1]
441
idx = torch.arange(max_length, device=lengths.device)
443
for seq, length in zip(padded_sequences, lengths):
445
unpacked_seq = seq[mask]
446
unpadded_sequences.append(unpacked_seq)
448
return unpadded_sequences
451
def pack_sequence(sequences: List[Tensor], enforce_sorted: bool = True) -> PackedSequence:
452
r"""Packs a list of variable length Tensors.
454
Consecutive call of the next functions: ``pad_sequence``, ``pack_padded_sequence``.
456
``sequences`` should be a list of Tensors of size ``L x *``, where `L` is
457
the length of a sequence and `*` is any number of trailing dimensions,
460
For unsorted sequences, use `enforce_sorted = False`. If ``enforce_sorted``
461
is ``True``, the sequences should be sorted in the order of decreasing length.
462
``enforce_sorted = True`` is only necessary for ONNX export.
466
>>> from torch.nn.utils.rnn import pack_sequence
467
>>> a = torch.tensor([1, 2, 3])
468
>>> b = torch.tensor([4, 5])
469
>>> c = torch.tensor([6])
470
>>> pack_sequence([a, b, c])
471
PackedSequence(data=tensor([1, 4, 6, 2, 5, 3]), batch_sizes=tensor([3, 2, 1]), sorted_indices=None, unsorted_indices=None)
475
sequences (list[Tensor]): A list of sequences of decreasing length.
476
enforce_sorted (bool, optional): if ``True``, checks that the input
477
contains sequences sorted by length in a decreasing order. If
478
``False``, this condition is not checked. Default: ``True``.
481
a :class:`PackedSequence` object
483
lengths = torch.as_tensor([v.size(0) for v in sequences])
484
return pack_padded_sequence(pad_sequence(sequences), lengths, enforce_sorted=enforce_sorted)
487
def unpack_sequence(packed_sequences: PackedSequence) -> List[Tensor]:
488
r"""Unpack PackedSequence into a list of variable length Tensors.
490
``packed_sequences`` should be a PackedSequence object.
494
>>> from torch.nn.utils.rnn import pack_sequence, unpack_sequence
495
>>> a = torch.tensor([1, 2, 3])
496
>>> b = torch.tensor([4, 5])
497
>>> c = torch.tensor([6])
498
>>> sequences = [a, b, c]
500
[tensor([1, 2, 3]), tensor([4, 5]), tensor([6])]
501
>>> packed_sequences = pack_sequence(sequences)
502
>>> print(packed_sequences)
503
PackedSequence(data=tensor([1, 4, 6, 2, 5, 3]), batch_sizes=tensor([3, 2, 1]), sorted_indices=None, unsorted_indices=None)
504
>>> unpacked_sequences = unpack_sequence(packed_sequences)
505
>>> print(unpacked_sequences)
506
[tensor([1, 2, 3]), tensor([4, 5]), tensor([6])]
510
packed_sequences (PackedSequence): A PackedSequence object.
513
a list of :class:`Tensor` objects
515
padded_sequences, lengths = pad_packed_sequence(packed_sequences, batch_first=True)
516
unpacked_sequences = unpad_sequence(padded_sequences, lengths, batch_first=True)
517
return unpacked_sequences