pytorch

Форк
0
517 строк · 20.8 Кб
1
import warnings
2
from typing import Iterable, List, NamedTuple, Tuple, Union
3

4
import torch
5
from torch import Tensor
6
from ... import _VF
7
from ..._jit_internal import Optional
8

9

10
__all__ = ['PackedSequence', 'invert_permutation', 'pack_padded_sequence', 'pad_packed_sequence', 'pad_sequence',
11
           'unpad_sequence', 'pack_sequence', 'unpack_sequence']
12

13

14
class PackedSequence_(NamedTuple):
15
    data: torch.Tensor
16
    batch_sizes: torch.Tensor
17
    sorted_indices: Optional[torch.Tensor]
18
    unsorted_indices: Optional[torch.Tensor]
19

20

21
def bind(optional, fn):
22
    if optional is None:
23
        return None
24
    return fn(optional)
25

26

27
class PackedSequence(PackedSequence_):
28
    r"""Holds the data and list of :attr:`batch_sizes` of a packed sequence.
29

30
    All RNN modules accept packed sequences as inputs.
31

32
    Note:
33
        Instances of this class should never be created manually. They are meant
34
        to be instantiated by functions like :func:`pack_padded_sequence`.
35

36
        Batch sizes represent the number elements at each sequence step in
37
        the batch, not the varying sequence lengths passed to
38
        :func:`pack_padded_sequence`.  For instance, given data ``abc`` and ``x``
39
        the :class:`PackedSequence` would contain data ``axbc`` with
40
        ``batch_sizes=[2,1,1]``.
41

42
    Attributes:
43
        data (Tensor): Tensor containing packed sequence
44
        batch_sizes (Tensor): Tensor of integers holding
45
            information about the batch size at each sequence step
46
        sorted_indices (Tensor, optional): Tensor of integers holding how this
47
            :class:`PackedSequence` is constructed from sequences.
48
        unsorted_indices (Tensor, optional): Tensor of integers holding how this
49
            to recover the original sequences with correct order.
50

51
    .. note::
52
        :attr:`data` can be on arbitrary device and of arbitrary dtype.
53
        :attr:`sorted_indices` and :attr:`unsorted_indices` must be ``torch.int64``
54
        tensors on the same device as :attr:`data`.
55

56
        However, :attr:`batch_sizes` should always be a CPU ``torch.int64`` tensor.
57

58
        This invariant is maintained throughout :class:`PackedSequence` class,
59
        and all functions that construct a :class:`PackedSequence` in PyTorch
60
        (i.e., they only pass in tensors conforming to this constraint).
61

62
    """
63

64
    def __new__(cls, data, batch_sizes=None, sorted_indices=None, unsorted_indices=None):
65
        return super().__new__(
66
            cls,
67
            *_packed_sequence_init_args(data, batch_sizes, sorted_indices,
68
                                        unsorted_indices))
69

70
    # NOTE [ device and dtype of a PackedSequence ]
71
    #
72
    # See the note above in doc string (starting with ":attr:`data` can be on
73
    # arbitrary device...").
74
    def pin_memory(self):
75
        # Why not convert `batch_sizes`?
76
        # See NOTE [ device and dtype of a PackedSequence ]
77
        return type(self)(self.data.pin_memory(), self.batch_sizes,
78
                          bind(self.sorted_indices, lambda t: t.pin_memory()),
79
                          bind(self.unsorted_indices, lambda t: t.pin_memory()))
80

81
    def cuda(self, *args, **kwargs):
82
        # Tests to see if 'cuda' should be added to kwargs
83
        ex = torch.tensor((), dtype=self.data.dtype, device=self.data.device).to(*args, **kwargs)
84
        if ex.is_cuda:
85
            return self.to(*args, **kwargs)
86
        return self.to(*args, device='cuda', **kwargs)
87

88
    def cpu(self, *args, **kwargs):
89

90
        ex = torch.tensor((), dtype=self.data.dtype, device=self.data.device).to(*args, **kwargs)
91
        if ex.device.type == 'cpu':
92
            return self.to(*args, **kwargs)
93
        return self.to(*args, device='cpu', **kwargs)
94

95
    def double(self):
96
        return self.to(dtype=torch.double)
97

98
    def float(self):
99
        return self.to(dtype=torch.float)
100

101
    def half(self):
102
        return self.to(dtype=torch.half)
103

104
    def long(self):
105
        return self.to(dtype=torch.long)
106

107
    def int(self):
108
        return self.to(dtype=torch.int)
109

110
    def short(self):
111
        return self.to(dtype=torch.short)
112

113
    def char(self):
114
        return self.to(dtype=torch.int8)
115

116
    def byte(self):
117
        return self.to(dtype=torch.uint8)
118

119
    def to(self, *args, **kwargs):
120
        r"""Perform dtype and/or device conversion on `self.data`.
121

122
        It has similar signature as :meth:`torch.Tensor.to`, except optional
123
        arguments like `non_blocking` and `copy` should be passed as kwargs,
124
        not args, or they will not apply to the index tensors.
125

126
        .. note::
127

128
            If the ``self.data`` Tensor already has the correct :class:`torch.dtype`
129
            and :class:`torch.device`, then ``self`` is returned.
130
            Otherwise, returns a copy with the desired configuration.
131
        """
132
        # Why not convert `batch_sizes`?
133
        # See NOTE [ device and dtype of a PackedSequence ]
134
        data = self.data.to(*args, **kwargs)
135
        if data is self.data:
136
            return self
137
        else:
138
            # Does not forward device or dtype arg/kwargs, device is set from data.device
139
            kwargs = dict(filter(lambda t: t[0] != 'device' and t[0] != 'dtype', kwargs.items()))
140
            sorted_indices = bind(self.sorted_indices, lambda t: t.to(data.device, **kwargs))
141
            unsorted_indices = bind(self.unsorted_indices, lambda t: t.to(data.device, **kwargs))
142
            return type(self)(data, self.batch_sizes, sorted_indices, unsorted_indices)
143

144
    @property
145
    def is_cuda(self):
146
        r"""Return true if `self.data` stored on a gpu."""
147
        return self.data.is_cuda
148

149
    def is_pinned(self):
150
        r"""Return true if `self.data` stored on in pinned memory."""
151
        return self.data.is_pinned()
152

153

154
# TorchScript doesn't support constructors on named tuples, so we use this helper
155
# method to construct PackedSequence
156
def _packed_sequence_init_args(
157
    data: Tensor,
158
    batch_sizes: Optional[Tensor] = None,
159
    sorted_indices: Optional[Tensor] = None,
160
    unsorted_indices: Optional[Tensor] = None,
161
) -> Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor]]:
162
    # NB: if unsorted_indices is provided, it should be the inverse permutation
163
    # to sorted_indices. Don't assert it here because the PackedSequence ctor
164
    # should only be used internally.
165

166
    if unsorted_indices is None:
167
        unsorted_indices = invert_permutation(sorted_indices)
168

169
    # support being called as `PackedSequence(data, batch_sizes, sorted_indices)`
170
    if batch_sizes is not None:
171
        # TODO: Re-enable this check (.type isn't supported in TorchScript)
172
        if batch_sizes.device.type != 'cpu':
173
            raise ValueError(
174
                "batch_sizes should always be on CPU. "
175
                "Instances of PackedSequence should never be created manually. "
176
                "They should be instantiated by functions like pack_sequence "
177
                "and pack_padded_sequences in nn.utils.rnn. "
178
                "https://pytorch.org/docs/stable/nn.html#torch.nn.utils.rnn.pack_sequence")
179
        return data, batch_sizes, sorted_indices, unsorted_indices
180

181
    # support being called as `PackedSequence((data, batch_sizes), *, sorted_indices)`
182
    else:
183
        assert isinstance(data, (list, tuple)) and len(data) == 2
184
        return data[0], data[1], sorted_indices, unsorted_indices
185

186

187
def _packed_sequence_init(
188
    data: Tensor,
189
    batch_sizes: Optional[Tensor] = None,
190
    sorted_indices: Optional[Tensor] = None,
191
    unsorted_indices: Optional[Tensor] = None,
192
) -> PackedSequence:
193
    data, batch_sizes, sorted_indices, unsorted_indices = _packed_sequence_init_args(
194
        data, batch_sizes, sorted_indices, unsorted_indices)
195
    return PackedSequence(data, batch_sizes, sorted_indices, unsorted_indices)
196

197

198
def invert_permutation(permutation: Optional[Tensor]) -> Optional[Tensor]:
199
    if permutation is None:
200
        return None
201
    output = torch.empty_like(permutation, memory_format=torch.legacy_contiguous_format)
202
    output.scatter_(0, permutation,
203
                    torch.arange(0, permutation.numel(), device=permutation.device))
204
    return output
205

206

207
def pack_padded_sequence(
208
    input: Tensor,
209
    lengths: Tensor,
210
    batch_first: bool = False,
211
    enforce_sorted: bool = True,
212
) -> PackedSequence:
213
    r"""Packs a Tensor containing padded sequences of variable length.
214

215
    :attr:`input` can be of size ``T x B x *`` where `T` is the length of the
216
    longest sequence (equal to ``lengths[0]``), ``B`` is the batch size, and
217
    ``*`` is any number of dimensions (including 0). If ``batch_first`` is
218
    ``True``, ``B x T x *`` :attr:`input` is expected.
219

220
    For unsorted sequences, use `enforce_sorted = False`. If :attr:`enforce_sorted` is
221
    ``True``, the sequences should be sorted by length in a decreasing order, i.e.
222
    ``input[:,0]`` should be the longest sequence, and ``input[:,B-1]`` the shortest
223
    one. `enforce_sorted = True` is only necessary for ONNX export.
224

225
    Note:
226
        This function accepts any input that has at least two dimensions. You
227
        can apply it to pack the labels, and use the output of the RNN with
228
        them to compute the loss directly. A Tensor can be retrieved from
229
        a :class:`PackedSequence` object by accessing its ``.data`` attribute.
230

231
    Args:
232
        input (Tensor): padded batch of variable length sequences.
233
        lengths (Tensor or list(int)): list of sequence lengths of each batch
234
            element (must be on the CPU if provided as a tensor).
235
        batch_first (bool, optional): if ``True``, the input is expected in ``B x T x *``
236
            format.
237
        enforce_sorted (bool, optional): if ``True``, the input is expected to
238
            contain sequences sorted by length in a decreasing order. If
239
            ``False``, the input will get sorted unconditionally. Default: ``True``.
240

241
    Returns:
242
        a :class:`PackedSequence` object
243
    """
244
    if not isinstance(lengths, torch.Tensor):
245
        if torch._C._get_tracing_state():
246
            warnings.warn('pack_padded_sequence has been called with a Python list of '
247
                          'sequence lengths. The tracer cannot track the data flow of Python '
248
                          'values, and it will treat them as constants, likely rendering '
249
                          'the trace incorrect for any other combination of lengths.',
250
                          stacklevel=2)
251
        lengths = torch.as_tensor(lengths, dtype=torch.int64, device='cpu')
252
    else:
253
        lengths = lengths.to(dtype=torch.int64)
254

255
    if enforce_sorted:
256
        sorted_indices = None
257
    else:
258
        lengths, sorted_indices = torch.sort(lengths, descending=True)
259
        sorted_indices = sorted_indices.to(input.device)
260
        batch_dim = 0 if batch_first else 1
261
        input = input.index_select(batch_dim, sorted_indices)
262

263
    data, batch_sizes = \
264
        _VF._pack_padded_sequence(input, lengths, batch_first)
265
    return _packed_sequence_init(data, batch_sizes, sorted_indices, None)
266

267

268
def pad_packed_sequence(
269
    sequence: PackedSequence,
270
    batch_first: bool = False,
271
    padding_value: float = 0.0,
272
    total_length: Optional[int] = None,
273
) -> Tuple[Tensor, Tensor]:
274
    r"""Pad a packed batch of variable length sequences.
275

276
    It is an inverse operation to :func:`pack_padded_sequence`.
277

278
    The returned Tensor's data will be of size ``T x B x *``, where `T` is the length
279
    of the longest sequence and `B` is the batch size. If ``batch_first`` is True,
280
    the data will be transposed into ``B x T x *`` format.
281

282
    Example:
283
        >>> from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
284
        >>> seq = torch.tensor([[1, 2, 0], [3, 0, 0], [4, 5, 6]])
285
        >>> lens = [2, 1, 3]
286
        >>> packed = pack_padded_sequence(seq, lens, batch_first=True, enforce_sorted=False)
287
        >>> packed
288
        PackedSequence(data=tensor([4, 1, 3, 5, 2, 6]), batch_sizes=tensor([3, 2, 1]),
289
                       sorted_indices=tensor([2, 0, 1]), unsorted_indices=tensor([1, 2, 0]))
290
        >>> seq_unpacked, lens_unpacked = pad_packed_sequence(packed, batch_first=True)
291
        >>> seq_unpacked
292
        tensor([[1, 2, 0],
293
                [3, 0, 0],
294
                [4, 5, 6]])
295
        >>> lens_unpacked
296
        tensor([2, 1, 3])
297

298
    .. note::
299
        :attr:`total_length` is useful to implement the
300
        ``pack sequence -> recurrent network -> unpack sequence`` pattern in a
301
        :class:`~torch.nn.Module` wrapped in :class:`~torch.nn.DataParallel`.
302
        See :ref:`this FAQ section <pack-rnn-unpack-with-data-parallelism>` for
303
        details.
304

305
    Args:
306
        sequence (PackedSequence): batch to pad
307
        batch_first (bool, optional): if ``True``, the output will be in ``B x T x *``
308
            format.
309
        padding_value (float, optional): values for padded elements.
310
        total_length (int, optional): if not ``None``, the output will be padded to
311
            have length :attr:`total_length`. This method will throw :class:`ValueError`
312
            if :attr:`total_length` is less than the max sequence length in
313
            :attr:`sequence`.
314

315
    Returns:
316
        Tuple of Tensor containing the padded sequence, and a Tensor
317
        containing the list of lengths of each sequence in the batch.
318
        Batch elements will be re-ordered as they were ordered originally when
319
        the batch was passed to ``pack_padded_sequence`` or ``pack_sequence``.
320

321

322

323

324
    """
325
    max_seq_length = sequence.batch_sizes.size(0)
326
    if total_length is not None:
327
        if total_length < max_seq_length:
328
            raise ValueError("Expected total_length to be at least the length "
329
                             "of the longest sequence in input, but got "
330
                             f"total_length={total_length} and max sequence length being {max_seq_length}"
331
                             )
332
        max_seq_length = total_length
333
    padded_output, lengths = _VF._pad_packed_sequence(
334
        sequence.data, sequence.batch_sizes, batch_first, padding_value, max_seq_length)
335
    unsorted_indices = sequence.unsorted_indices
336
    if unsorted_indices is not None:
337
        batch_dim = 0 if batch_first else 1
338
        return padded_output.index_select(batch_dim, unsorted_indices), lengths[unsorted_indices.cpu()]
339
    return padded_output, lengths
340

341
# NOTE: .pyi stub allows Iterable[Tensor], but for JIT-compatibility we need to be more restrictive here.
342
def pad_sequence(
343
    sequences: Union[Tensor, List[Tensor]],
344
    batch_first: bool = False,
345
    padding_value: float = 0.0,
346
) -> Tensor:
347
    r"""Pad a list of variable length Tensors with ``padding_value``.
348

349
    ``pad_sequence`` stacks a list of Tensors along a new dimension,
350
    and pads them to equal length. For example, if the input is a list of
351
    sequences with size ``L x *`` and ``batch_first`` is False, the output is
352
    of size ``T x B x *``.
353

354
    `B` is batch size. It is equal to the number of elements in ``sequences``.
355
    `T` is length of the longest sequence.
356
    `L` is length of the sequence.
357
    `*` is any number of trailing dimensions, including none.
358

359
    Example:
360
        >>> from torch.nn.utils.rnn import pad_sequence
361
        >>> a = torch.ones(25, 300)
362
        >>> b = torch.ones(22, 300)
363
        >>> c = torch.ones(15, 300)
364
        >>> pad_sequence([a, b, c]).size()
365
        torch.Size([25, 3, 300])
366

367
    Note:
368
        This function returns a Tensor of size ``T x B x *`` or ``B x T x *``
369
        where `T` is the length of the longest sequence. This function assumes
370
        trailing dimensions and type of all the Tensors in sequences are same.
371

372
    Args:
373
        sequences (list[Tensor]): list of variable length sequences.
374
        batch_first (bool, optional): output will be in ``B x T x *`` if True, or in
375
            ``T x B x *`` otherwise. Default: False.
376
        padding_value (float, optional): value for padded elements. Default: 0.
377

378
    Returns:
379
        Tensor of size ``T x B x *`` if :attr:`batch_first` is ``False``.
380
        Tensor of size ``B x T x *`` otherwise
381
    """
382
    if not (torch.jit.is_tracing() or torch.jit.is_scripting()):
383
        # JIT doesn't support `Iterable`
384
        if not isinstance(sequences, Iterable):
385
            msg = ('pad_sequence: Expected iterable for input sequences, but got arg of type: '
386
                   f'{type(sequences)}')
387
            raise RuntimeError(msg)
388

389
        # In JIT context this leads to,
390
        # RuntimeError: cannot statically infer the expected size of a list in this context
391
        sequences = tuple(sequences)
392
    else:
393
        # For JIT, we only support Union[Tensor, Tuple[Tensor]]
394
        if isinstance(sequences, torch.Tensor):
395
            sequences = sequences.unbind(0)
396

397
    # assuming trailing dimensions and type of all the Tensors
398
    # in sequences are same and fetching those from sequences[0]
399
    return torch._C._nn.pad_sequence(sequences, batch_first, padding_value)
400

401

402
def unpad_sequence(
403
    padded_sequences: Tensor,
404
    lengths: Tensor,
405
    batch_first: bool = False,
406
) -> List[Tensor]:
407
    r"""Unpad padded Tensor into a list of variable length Tensors.
408

409
    ``unpad_sequence`` unstacks padded Tensor into a list of variable length Tensors.
410

411
    Example:
412
        >>> from torch.nn.utils.rnn import pad_sequence, unpad_sequence
413
        >>> a = torch.ones(25, 300)
414
        >>> b = torch.ones(22, 300)
415
        >>> c = torch.ones(15, 300)
416
        >>> sequences = [a, b, c]
417
        >>> padded_sequences = pad_sequence(sequences)
418
        >>> lengths = torch.as_tensor([v.size(0) for v in sequences])
419
        >>> unpadded_sequences = unpad_sequence(padded_sequences, lengths)
420
        >>> torch.allclose(sequences[0], unpadded_sequences[0])
421
        True
422
        >>> torch.allclose(sequences[1], unpadded_sequences[1])
423
        True
424
        >>> torch.allclose(sequences[2], unpadded_sequences[2])
425
        True
426

427
    Args:
428
        padded_sequences (Tensor): padded sequences.
429
        lengths (Tensor): length of original (unpadded) sequences.
430
        batch_first (bool, optional): whether batch dimension first or not. Default: False.
431

432
    Returns:
433
        a list of :class:`Tensor` objects
434
    """
435
    unpadded_sequences = []
436

437
    if not batch_first:
438
        padded_sequences.transpose_(0, 1)
439

440
    max_length = padded_sequences.shape[1]
441
    idx = torch.arange(max_length, device=lengths.device)
442

443
    for seq, length in zip(padded_sequences, lengths):
444
        mask = idx < length
445
        unpacked_seq = seq[mask]
446
        unpadded_sequences.append(unpacked_seq)
447

448
    return unpadded_sequences
449

450

451
def pack_sequence(sequences: List[Tensor], enforce_sorted: bool = True) -> PackedSequence:
452
    r"""Packs a list of variable length Tensors.
453

454
    Consecutive call of the next functions: ``pad_sequence``, ``pack_padded_sequence``.
455

456
    ``sequences`` should be a list of Tensors of size ``L x *``, where `L` is
457
    the length of a sequence and `*` is any number of trailing dimensions,
458
    including zero.
459

460
    For unsorted sequences, use `enforce_sorted = False`. If ``enforce_sorted``
461
    is ``True``, the sequences should be sorted in the order of decreasing length.
462
    ``enforce_sorted = True`` is only necessary for ONNX export.
463

464

465
    Example:
466
        >>> from torch.nn.utils.rnn import pack_sequence
467
        >>> a = torch.tensor([1, 2, 3])
468
        >>> b = torch.tensor([4, 5])
469
        >>> c = torch.tensor([6])
470
        >>> pack_sequence([a, b, c])
471
        PackedSequence(data=tensor([1, 4, 6, 2, 5, 3]), batch_sizes=tensor([3, 2, 1]), sorted_indices=None, unsorted_indices=None)
472

473

474
    Args:
475
        sequences (list[Tensor]): A list of sequences of decreasing length.
476
        enforce_sorted (bool, optional): if ``True``, checks that the input
477
            contains sequences sorted by length in a decreasing order. If
478
            ``False``, this condition is not checked. Default: ``True``.
479

480
    Returns:
481
        a :class:`PackedSequence` object
482
    """
483
    lengths = torch.as_tensor([v.size(0) for v in sequences])
484
    return pack_padded_sequence(pad_sequence(sequences), lengths, enforce_sorted=enforce_sorted)
485

486

487
def unpack_sequence(packed_sequences: PackedSequence) -> List[Tensor]:
488
    r"""Unpack PackedSequence into a list of variable length Tensors.
489

490
    ``packed_sequences`` should be a PackedSequence object.
491

492

493
    Example:
494
        >>> from torch.nn.utils.rnn import pack_sequence, unpack_sequence
495
        >>> a = torch.tensor([1, 2, 3])
496
        >>> b = torch.tensor([4, 5])
497
        >>> c = torch.tensor([6])
498
        >>> sequences = [a, b, c]
499
        >>> print(sequences)
500
        [tensor([1, 2, 3]), tensor([4, 5]), tensor([6])]
501
        >>> packed_sequences = pack_sequence(sequences)
502
        >>> print(packed_sequences)
503
        PackedSequence(data=tensor([1, 4, 6, 2, 5, 3]), batch_sizes=tensor([3, 2, 1]), sorted_indices=None, unsorted_indices=None)
504
        >>> unpacked_sequences = unpack_sequence(packed_sequences)
505
        >>> print(unpacked_sequences)
506
        [tensor([1, 2, 3]), tensor([4, 5]), tensor([6])]
507

508

509
    Args:
510
        packed_sequences (PackedSequence): A PackedSequence object.
511

512
    Returns:
513
        a list of :class:`Tensor` objects
514
    """
515
    padded_sequences, lengths = pad_packed_sequence(packed_sequences, batch_first=True)
516
    unpacked_sequences = unpad_sequence(padded_sequences, lengths, batch_first=True)
517
    return unpacked_sequences
518

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.