pytorch

Форк
0
/
test_packed_sequence.py 
499 строк · 19.6 Кб
1
# Owner(s): ["module: nn"]
2

3
import itertools
4
import random
5
from typing import List
6

7
import torch
8
import torch.nn.utils.rnn as rnn_utils
9
from torch.testing._internal.common_utils import run_tests, TestCase
10

11

12
class PackedSequenceTest(TestCase):
13
    _type_by_name = {
14
        "torch.DoubleTensor": (torch.DoubleTensor, "double"),
15
        "torch.FloatTensor": (torch.FloatTensor, "float"),
16
        # We leave out `'torch.HalfTensor': (torch.HalfTensor, 'half'),`
17
        # because of an error in `pad_packed_sequence`
18
        # > AttributeError: 'torch.HalfTensor' object has no attribute 'fill_'
19
        "torch.LongTensor": (torch.LongTensor, "long"),
20
        "torch.IntTensor": (torch.IntTensor, "int"),
21
        "torch.ShortTensor": (torch.ShortTensor, "short"),
22
        "torch.CharTensor": (torch.CharTensor, "char"),
23
        "torch.ByteTensor": (torch.ByteTensor, "byte"),
24
    }
25

26
    def __init__(self, *args, **kwargs):
27
        super().__init__(*args, **kwargs)
28
        self.batch_size = 5
29
        self.max_length = 6
30

31
    def _ordered_sequence(self, tensor_type):
32
        """Create ordered list of random sequences"""
33
        seqs = [
34
            tensor_type(random.randint(1, self.max_length))
35
            for _ in range(self.batch_size)
36
        ]
37
        if tensor_type == torch.ByteTensor:
38
            seqs = [s.random_(0, 256) for s in seqs]
39
        else:
40
            seqs = [s.random_(-128, 128) for s in seqs]
41
        ordered = sorted(seqs, key=len, reverse=True)
42
        return ordered
43

44
    def _padded_sequence(self, tensor_type):
45
        """Create Tensor of random padded sequences"""
46
        ordered = self._ordered_sequence(tensor_type)
47
        lengths = [len(i) for i in ordered]
48
        padded_tensor = rnn_utils.pad_sequence(ordered)
49
        return padded_tensor, lengths
50

51
    def test_type_casts(self):
52
        """Test type casting of `PackedSequence` against type casting of tensor"""
53
        for input_type, _ in self._type_by_name.values():
54
            for expected_type_str, (_, cast_str) in self._type_by_name.items():
55
                for enforce_sorted in [True, False]:
56
                    padded, lengths = self._padded_sequence(input_type)
57
                    packed = rnn_utils.pack_padded_sequence(
58
                        padded, lengths, enforce_sorted=enforce_sorted
59
                    )
60
                    # Apply cast to `PackedSequence` instance and unpack
61
                    masked = getattr(packed, cast_str)()
62
                    unpacked, lengths_out = rnn_utils.pad_packed_sequence(masked)
63
                    self.assertEqual(unpacked.type(), expected_type_str)
64

65
    def test_wrong_order(self):
66
        a = torch.ones(25, 300)
67
        b = torch.ones(22, 300)
68
        b_a = rnn_utils.pad_sequence([b, a])
69
        self.assertRaises(
70
            RuntimeError,
71
            lambda: rnn_utils.pack_padded_sequence(b_a, [22, 25], enforce_sorted=True),
72
        )
73

74
    def test_pad_sequence_with_tensor_sequences(self):
75
        seq_tuple_input = torch.nn.utils.rnn.pad_sequence(
76
            (torch.tensor([[7, 6]]), torch.tensor([[-7, -1]]))
77
        )
78
        seq_tensor_input = torch.nn.utils.rnn.pad_sequence(
79
            torch.tensor([[[7, 6]], [[-7, -1]]])
80
        )
81
        self.assertEqual(seq_tuple_input, seq_tensor_input)
82
        self.assertEqual(seq_tuple_input.shape, torch.Size([1, 2, 2]))
83

84
    def test_pad_sequence_with_non_iterable_sequences(self):
85
        msg = r"Expected iterable for input sequences, but got arg of type"
86
        with self.assertRaisesRegex(RuntimeError, msg):
87
            torch.nn.utils.rnn.pad_sequence(5)
88

89
    def test_total_length(self):
90
        padded, lengths = self._padded_sequence(torch.FloatTensor)
91
        max_length = max(lengths)
92
        packed = rnn_utils.pack_padded_sequence(padded, lengths)
93
        # test ValueError if total_length < max_length
94
        for total_length in (-1, 0, max_length - 1):
95
            for batch_first in (True, False):
96

97
                def err_fn():
98
                    rnn_utils.pad_packed_sequence(
99
                        packed, batch_first=batch_first, total_length=total_length
100
                    )
101

102
            self.assertRaisesRegex(
103
                ValueError,
104
                r"Expected total_length to be at least the "
105
                r"length of the longest sequence in input",
106
                err_fn,
107
            )
108
        # test that pad_packed_sequence returns results of correct length
109
        for batch_first in (True, False):
110
            no_extra_pad, _ = rnn_utils.pad_packed_sequence(
111
                packed, batch_first=batch_first
112
            )
113
            for total_length_delta in (0, 1, 8):
114
                total_length = max_length + total_length_delta
115
                unpacked, lengths_out = rnn_utils.pad_packed_sequence(
116
                    packed, batch_first=batch_first, total_length=total_length
117
                )
118
                self.assertEqual(lengths, lengths_out)
119
                self.assertEqual(unpacked.size(1 if batch_first else 0), total_length)
120
                if total_length_delta == 0:
121
                    ref_output = no_extra_pad
122
                elif batch_first:
123
                    extra_pad = no_extra_pad.new_zeros(
124
                        self.batch_size, total_length_delta
125
                    )
126
                    ref_output = torch.cat([no_extra_pad, extra_pad], 1)
127
                else:
128
                    extra_pad = no_extra_pad.new_zeros(
129
                        total_length_delta, self.batch_size
130
                    )
131
                    ref_output = torch.cat([no_extra_pad, extra_pad], 0)
132
                self.assertEqual(unpacked, ref_output)
133

134
    def test_to(self):
135
        for enforce_sorted in (True, False):
136
            padded, lengths = self._padded_sequence(torch.IntTensor)
137
            a = rnn_utils.pack_padded_sequence(
138
                padded, lengths, enforce_sorted=enforce_sorted
139
            ).cpu()
140

141
            self.assertIs(a, a.to("cpu"))
142
            self.assertIs(a, a.cpu())
143
            self.assertIs(a, a.to("cpu", dtype=torch.int32))
144
            self.assertEqual(a.long(), a.to(torch.int64))
145

146
            if torch.cuda.is_available():
147
                for cuda in [
148
                    "cuda",
149
                    "cuda:0" if torch.cuda.device_count() == 1 else "cuda:1",
150
                ]:
151
                    b = a.cuda(device=cuda)
152
                    self.assertIs(b, b.to(cuda))
153
                    self.assertIs(b, b.cuda())
154
                    self.assertEqual(a, b.to("cpu"))
155
                    self.assertEqual(b, a.to(cuda))
156
                    self.assertEqual(a, b.to("cpu", dtype=torch.int32))
157
                    self.assertIs(b, b.to(dtype=torch.int32))
158
                    self.assertEqual(b.long(), b.to(dtype=torch.int64))
159

160
    def test_to_memory_format(self):
161
        m = torch.nn.Conv2d(in_channels=16, out_channels=32, kernel_size=2, bias=True)
162
        m = m.to(memory_format=torch.channels_last)
163
        for param in m.parameters():
164
            if param.dim() == 4:
165
                self.assertTrue(param.is_contiguous(memory_format=torch.channels_last))
166

167
    def test_pad_sequence(self):
168
        def pad(tensor, length):
169
            return torch.cat(
170
                [
171
                    tensor.data,
172
                    tensor.data.new(
173
                        length - tensor.size(0), *tensor.size()[1:]
174
                    ).zero_(),
175
                ]
176
            )
177

178
        # single dimensional
179
        a = torch.tensor([1, 2, 3])
180
        b = torch.tensor([4, 5])
181
        c = torch.tensor([6])
182

183
        # batch_first = true
184
        expected = torch.tensor([[4, 5, 0], [1, 2, 3], [6, 0, 0]])
185
        padded = rnn_utils.pad_sequence([b, a, c], True)
186
        self.assertEqual(padded, expected)
187

188
        # batch_first = false
189
        padded = rnn_utils.pad_sequence([b, a, c])
190
        self.assertEqual(padded, expected.transpose(0, 1))
191

192
        # padding_side = "left", batch_first=True
193
        expected = torch.tensor([[0, 4, 5], [1, 2, 3], [0, 0, 6]])
194
        padded = rnn_utils.pad_sequence(
195
            [b, a, c],
196
            batch_first=True,
197
            padding_side="left",
198
        )
199
        self.assertEqual(padded, expected)
200

201
        # padding_side = "left", batch_first=False
202
        padded = rnn_utils.pad_sequence(
203
            [b, a, c],
204
            batch_first=False,
205
            padding_side="left",
206
        )
207
        self.assertEqual(padded, expected.transpose(0, 1))
208

209
        # pad with non-zero value
210
        expected = torch.tensor([[4, 5, 1], [1, 2, 3], [6, 1, 1]])
211
        padded = rnn_utils.pad_sequence([b, a, c], True, 1)
212
        self.assertEqual(padded, expected)
213

214
        # Test pad sorted sequence
215
        expected = torch.tensor([[1, 2, 3], [4, 5, 0], [6, 0, 0]])
216
        padded = rnn_utils.pad_sequence([a, b, c], True)
217
        self.assertEqual(padded, expected)
218

219
        # more dimensions
220
        maxlen = 9
221
        for num_dim in (0, 1, 2, 3):
222
            sequences: List[torch.Tensor] = []
223
            trailing_dims = [4] * num_dim
224
            for i in range(1, maxlen + 1):
225
                seq_len = i * i
226
                sequences.append(torch.rand(seq_len, 5, *trailing_dims))
227
            random.shuffle(sequences)
228
            # batch first = true
229
            expected = torch.stack([pad(seq, maxlen * maxlen) for seq in sequences])
230
            padded = rnn_utils.pad_sequence(sequences, True)
231
            self.assertEqual(padded, expected)
232

233
            # batch first = false
234
            padded = rnn_utils.pad_sequence(sequences)
235
            self.assertEqual(padded, expected.transpose(0, 1))
236

237
            # padding_side = "left", batch_first=True
238
            expected = torch.stack(
239
                [pad(seq.flip(0), maxlen * maxlen).flip(0) for seq in sequences]
240
            )
241
            padded = rnn_utils.pad_sequence(
242
                sequences,
243
                batch_first=True,
244
                padding_side="left",
245
            )
246
            self.assertEqual(padded, expected)
247

248
            # padding_side = "left", batch_first=False
249
            padded = rnn_utils.pad_sequence(
250
                sequences,
251
                batch_first=False,
252
                padding_side="left",
253
            )
254
            self.assertEqual(padded, expected.transpose(0, 1))
255

256
    def test_unpad_sequence(self):
257
        # single dimensional
258
        a = torch.tensor([1, 2, 3])
259
        b = torch.tensor([4, 5])
260
        c = torch.tensor([6])
261
        sequences = [a, b, c]
262

263
        lengths = torch.as_tensor([v.size(0) for v in sequences])
264
        for batch_first in [True, False]:
265
            padded_sequences = rnn_utils.pad_sequence(
266
                sequences, batch_first=batch_first
267
            )
268
            unpadded_sequences = rnn_utils.unpad_sequence(
269
                padded_sequences, lengths, batch_first=batch_first
270
            )
271
            self.assertEqual(sequences, unpadded_sequences)
272

273
        # more dimensions
274
        maxlen = 9
275
        for num_dim in (0, 1, 2, 3):
276
            sequences = []
277
            trailing_dims = [4] * num_dim
278
            for i in range(1, maxlen + 1):
279
                seq_len = i * i
280
                sequences.append(torch.rand(seq_len, 5, *trailing_dims))
281
            random.shuffle(sequences)
282

283
            lengths = torch.as_tensor([v.size(0) for v in sequences])
284
            padded_sequences = rnn_utils.pad_sequence(
285
                sequences, batch_first=batch_first
286
            )
287
            unpadded_sequences = rnn_utils.unpad_sequence(
288
                padded_sequences, lengths, batch_first=batch_first
289
            )
290
            self.assertEqual(sequences, unpadded_sequences)
291

292
    def test_pack_sequence(self):
293
        def _compatibility_test(sequences, lengths, batch_first, enforce_sorted=False):
294
            padded = rnn_utils.pad_sequence(sequences, batch_first)
295
            packed = rnn_utils.pack_sequence(sequences, enforce_sorted)
296
            unpacked = rnn_utils.pad_packed_sequence(packed, batch_first)
297
            self.assertEqual(padded, unpacked[0])
298
            pack_padded = rnn_utils.pack_padded_sequence(
299
                padded, lengths, batch_first, enforce_sorted
300
            )
301
            self.assertEqual(packed, pack_padded)
302

303
        # single dimensional
304
        a = torch.tensor([1, 2, 3])
305
        b = torch.tensor([4, 5])
306
        c = torch.tensor([6])
307
        packed = rnn_utils.pack_sequence([a, b, c], enforce_sorted=False)
308
        expected = torch.tensor([1, 4, 6, 2, 5, 3])
309
        self.assertEqual(packed.batch_sizes, [3, 2, 1])
310
        self.assertEqual(packed.data.data, expected)
311
        self.assertEqual(packed.sorted_indices, [0, 1, 2])
312
        self.assertEqual(packed.unsorted_indices, [0, 1, 2])
313

314
        packed_unsorted = rnn_utils.pack_sequence([b, c, a], enforce_sorted=False)
315
        self.assertEqual(packed_unsorted.batch_sizes, [3, 2, 1])
316
        self.assertEqual(packed_unsorted.data.data, expected)
317
        self.assertEqual(packed_unsorted.sorted_indices, [2, 0, 1])
318
        self.assertEqual(packed_unsorted.unsorted_indices, [1, 2, 0])
319

320
        # single dimensional, enforce_sorted = True
321
        packed_enforce_sorted = rnn_utils.pack_sequence([a, b, c], enforce_sorted=True)
322
        self.assertEqual(packed_enforce_sorted.batch_sizes, [3, 2, 1])
323
        self.assertEqual(packed_enforce_sorted.data.data, expected)
324
        self.assertTrue(packed_enforce_sorted.sorted_indices is None)
325
        self.assertTrue(packed_enforce_sorted.unsorted_indices is None)
326

327
        with self.assertRaisesRegex(RuntimeError, "must be sorted in decreasing order"):
328
            rnn_utils.pack_sequence([b, c, a], enforce_sorted=True)
329

330
        with self.assertRaisesRegex(
331
            RuntimeError, "You can pass `enforce_sorted=False`"
332
        ):
333
            rnn_utils.pack_sequence([b, c, a], enforce_sorted=True)
334

335
        # more dimensions
336
        maxlen = 9
337
        for num_dim in (0, 1, 2, 3):
338
            sequences = []
339
            lengths = []
340
            trailing_dims = [4] * num_dim
341
            for i in range(maxlen, 0, -1):
342
                seq_len = i * i
343
                lengths.append(seq_len)
344
                sequences.append(torch.rand(seq_len, 5, *trailing_dims))
345
            unsorted_sequences = [s.clone() for s in sequences]
346
            random.shuffle(unsorted_sequences)
347
            unsorted_sequences_lengths = [t.size(0) for t in unsorted_sequences]
348

349
            # compatibility with other utilities
350
            for batch_first in (True, False):
351
                for enforce_sorted in (True, False):
352
                    _compatibility_test(sequences, lengths, batch_first, enforce_sorted)
353
                _compatibility_test(
354
                    unsorted_sequences, unsorted_sequences_lengths, batch_first
355
                )
356

357
    def test_unpack_sequence(self):
358
        # single dimensional
359
        a = torch.tensor([1, 2, 3])
360
        b = torch.tensor([4, 5])
361
        c = torch.tensor([6])
362
        sequences = [a, b, c]
363

364
        packed_sequences = rnn_utils.pack_sequence(sequences, enforce_sorted=False)
365
        unpacked_sequences = rnn_utils.unpack_sequence(packed_sequences)
366
        self.assertEqual(sequences, unpacked_sequences)
367

368
        # more dimensions
369
        maxlen = 9
370
        for num_dim in (0, 1, 2, 3):
371
            sequences = []
372
            trailing_dims = [4] * num_dim
373
            for i in range(1, maxlen + 1):
374
                seq_len = i * i
375
                sequences.append(torch.rand(seq_len, 5, *trailing_dims))
376
            random.shuffle(sequences)
377

378
            packed_sequences = rnn_utils.pack_sequence(sequences, enforce_sorted=False)
379
            unpacked_sequences = rnn_utils.unpack_sequence(packed_sequences)
380
            self.assertEqual(sequences, unpacked_sequences)
381

382
    def test_pack_padded_sequence(self):
383
        def generate_test_case(sorted_lengths, should_shuffle):
384
            def pad(tensor, length):
385
                return torch.cat(
386
                    [
387
                        tensor,
388
                        tensor.new(length - tensor.size(0), *tensor.size()[1:]).zero_(),
389
                    ]
390
                )
391

392
            max_length = sorted_lengths[0]
393
            batch_sizes = [
394
                sum(map(bool, filter(lambda x: x >= i, sorted_lengths)))
395
                for i in range(1, max_length + 1)
396
            ]
397
            offset = 0
398
            padded = torch.cat(
399
                [
400
                    pad(
401
                        i * 100 + torch.arange(1.0, 5 * l + 1).view(l, 1, 5), max_length
402
                    )
403
                    for i, l in enumerate(sorted_lengths, 1)
404
                ],
405
                1,
406
            )
407
            expected_data = [
408
                [
409
                    torch.arange(1.0, 6) + (i + 1) * 100 + 5 * n
410
                    for i in range(batch_size)
411
                ]
412
                for n, batch_size in enumerate(batch_sizes)
413
            ]
414
            expected_data = list(itertools.chain.from_iterable(expected_data))
415
            expected_data = torch.stack(expected_data, dim=0)
416

417
            if should_shuffle:
418
                # Shuffle the padded sequence to create an unsorted sequence
419
                permutation = list(range(len(sorted_lengths)))
420
                random.shuffle(permutation)
421

422
                unsorted_indices = torch.tensor(permutation)
423
                padded = padded.index_select(1, unsorted_indices)
424
                lengths = torch.tensor(sorted_lengths).index_select(0, unsorted_indices)
425
            else:
426
                unsorted_indices = None
427
                lengths = sorted_lengths
428

429
            return (
430
                padded.requires_grad_(),
431
                lengths,
432
                expected_data,
433
                batch_sizes,
434
                unsorted_indices,
435
            )
436

437
        test_cases = [
438
            # sorted_lengths, should_shuffle
439
            [[10, 8, 4, 2, 2, 2, 1], False],
440
            [[11, 10, 8, 6, 4, 3, 1], False],
441
            [[11, 10, 8, 6, 4, 3, 1], True],
442
        ]
443

444
        for test_case, batch_first in itertools.product(test_cases, (True, False)):
445
            sorted_lengths, should_shuffle = test_case
446
            (
447
                padded,
448
                lengths,
449
                expected_data,
450
                batch_sizes,
451
                unsorted_indices,
452
            ) = generate_test_case(sorted_lengths, should_shuffle)
453

454
            src = padded
455
            if batch_first:
456
                src = src.transpose(0, 1)
457

458
            # check output
459
            packed = rnn_utils.pack_padded_sequence(
460
                src, lengths, batch_first=batch_first, enforce_sorted=not should_shuffle
461
            )
462
            self.assertEqual(packed.data.data, expected_data)
463
            self.assertEqual(packed.batch_sizes, batch_sizes)
464
            self.assertEqual(packed.unsorted_indices, unsorted_indices)
465

466
            # test inverse
467
            unpacked, unpacked_len = rnn_utils.pad_packed_sequence(
468
                packed, batch_first=batch_first
469
            )
470
            self.assertEqual(unpacked, src)
471
            self.assertEqual(unpacked_len, lengths)
472

473
            # check grad
474
            if padded.grad is not None:
475
                padded.grad.data.zero_()
476
            grad_output = unpacked.data.clone().normal_()
477
            unpacked.backward(grad_output)
478
            if batch_first:
479
                grad_output.transpose_(0, 1)
480
            for i, l in enumerate(lengths):
481
                self.assertEqual(padded.grad.data[:l, i], grad_output[:l, i])
482
                if l < 10:
483
                    self.assertEqual(padded.grad.data[l:, i].abs().sum(), 0)
484

485
        # test error messages
486
        with self.assertRaisesRegex(
487
            RuntimeError, "You can pass `enforce_sorted=False`"
488
        ):
489
            packed = rnn_utils.pack_padded_sequence(torch.randn(3, 3), [1, 3, 2])
490
        with self.assertRaisesRegex(RuntimeError, "empty tensor"):
491
            packed = rnn_utils.pack_padded_sequence(torch.randn(0, 0), [])
492
        with self.assertRaisesRegex(RuntimeError, "empty tensor"):
493
            packed = rnn_utils.pack_padded_sequence(
494
                torch.randn([0, 1, 10]), torch.randn([11, 14, 14, 2]), True
495
            )
496

497

498
if __name__ == "__main__":
499
    run_tests()
500

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.