8
import torch.nn.utils.rnn as rnn_utils
9
from torch.testing._internal.common_utils import run_tests, TestCase
12
class PackedSequenceTest(TestCase):
14
"torch.DoubleTensor": (torch.DoubleTensor, "double"),
15
"torch.FloatTensor": (torch.FloatTensor, "float"),
19
"torch.LongTensor": (torch.LongTensor, "long"),
20
"torch.IntTensor": (torch.IntTensor, "int"),
21
"torch.ShortTensor": (torch.ShortTensor, "short"),
22
"torch.CharTensor": (torch.CharTensor, "char"),
23
"torch.ByteTensor": (torch.ByteTensor, "byte"),
26
def __init__(self, *args, **kwargs):
27
super().__init__(*args, **kwargs)
31
def _ordered_sequence(self, tensor_type):
32
"""Create ordered list of random sequences"""
34
tensor_type(random.randint(1, self.max_length))
35
for _ in range(self.batch_size)
37
if tensor_type == torch.ByteTensor:
38
seqs = [s.random_(0, 256) for s in seqs]
40
seqs = [s.random_(-128, 128) for s in seqs]
41
ordered = sorted(seqs, key=len, reverse=True)
44
def _padded_sequence(self, tensor_type):
45
"""Create Tensor of random padded sequences"""
46
ordered = self._ordered_sequence(tensor_type)
47
lengths = [len(i) for i in ordered]
48
padded_tensor = rnn_utils.pad_sequence(ordered)
49
return padded_tensor, lengths
51
def test_type_casts(self):
52
"""Test type casting of `PackedSequence` against type casting of tensor"""
53
for input_type, _ in self._type_by_name.values():
54
for expected_type_str, (_, cast_str) in self._type_by_name.items():
55
for enforce_sorted in [True, False]:
56
padded, lengths = self._padded_sequence(input_type)
57
packed = rnn_utils.pack_padded_sequence(
58
padded, lengths, enforce_sorted=enforce_sorted
61
masked = getattr(packed, cast_str)()
62
unpacked, lengths_out = rnn_utils.pad_packed_sequence(masked)
63
self.assertEqual(unpacked.type(), expected_type_str)
65
def test_wrong_order(self):
66
a = torch.ones(25, 300)
67
b = torch.ones(22, 300)
68
b_a = rnn_utils.pad_sequence([b, a])
71
lambda: rnn_utils.pack_padded_sequence(b_a, [22, 25], enforce_sorted=True),
74
def test_pad_sequence_with_tensor_sequences(self):
75
seq_tuple_input = torch.nn.utils.rnn.pad_sequence(
76
(torch.tensor([[7, 6]]), torch.tensor([[-7, -1]]))
78
seq_tensor_input = torch.nn.utils.rnn.pad_sequence(
79
torch.tensor([[[7, 6]], [[-7, -1]]])
81
self.assertEqual(seq_tuple_input, seq_tensor_input)
82
self.assertEqual(seq_tuple_input.shape, torch.Size([1, 2, 2]))
84
def test_pad_sequence_with_non_iterable_sequences(self):
85
msg = r"Expected iterable for input sequences, but got arg of type"
86
with self.assertRaisesRegex(RuntimeError, msg):
87
torch.nn.utils.rnn.pad_sequence(5)
89
def test_total_length(self):
90
padded, lengths = self._padded_sequence(torch.FloatTensor)
91
max_length = max(lengths)
92
packed = rnn_utils.pack_padded_sequence(padded, lengths)
94
for total_length in (-1, 0, max_length - 1):
95
for batch_first in (True, False):
98
rnn_utils.pad_packed_sequence(
99
packed, batch_first=batch_first, total_length=total_length
102
self.assertRaisesRegex(
104
r"Expected total_length to be at least the "
105
r"length of the longest sequence in input",
109
for batch_first in (True, False):
110
no_extra_pad, _ = rnn_utils.pad_packed_sequence(
111
packed, batch_first=batch_first
113
for total_length_delta in (0, 1, 8):
114
total_length = max_length + total_length_delta
115
unpacked, lengths_out = rnn_utils.pad_packed_sequence(
116
packed, batch_first=batch_first, total_length=total_length
118
self.assertEqual(lengths, lengths_out)
119
self.assertEqual(unpacked.size(1 if batch_first else 0), total_length)
120
if total_length_delta == 0:
121
ref_output = no_extra_pad
123
extra_pad = no_extra_pad.new_zeros(
124
self.batch_size, total_length_delta
126
ref_output = torch.cat([no_extra_pad, extra_pad], 1)
128
extra_pad = no_extra_pad.new_zeros(
129
total_length_delta, self.batch_size
131
ref_output = torch.cat([no_extra_pad, extra_pad], 0)
132
self.assertEqual(unpacked, ref_output)
135
for enforce_sorted in (True, False):
136
padded, lengths = self._padded_sequence(torch.IntTensor)
137
a = rnn_utils.pack_padded_sequence(
138
padded, lengths, enforce_sorted=enforce_sorted
141
self.assertIs(a, a.to("cpu"))
142
self.assertIs(a, a.cpu())
143
self.assertIs(a, a.to("cpu", dtype=torch.int32))
144
self.assertEqual(a.long(), a.to(torch.int64))
146
if torch.cuda.is_available():
149
"cuda:0" if torch.cuda.device_count() == 1 else "cuda:1",
151
b = a.cuda(device=cuda)
152
self.assertIs(b, b.to(cuda))
153
self.assertIs(b, b.cuda())
154
self.assertEqual(a, b.to("cpu"))
155
self.assertEqual(b, a.to(cuda))
156
self.assertEqual(a, b.to("cpu", dtype=torch.int32))
157
self.assertIs(b, b.to(dtype=torch.int32))
158
self.assertEqual(b.long(), b.to(dtype=torch.int64))
160
def test_to_memory_format(self):
161
m = torch.nn.Conv2d(in_channels=16, out_channels=32, kernel_size=2, bias=True)
162
m = m.to(memory_format=torch.channels_last)
163
for param in m.parameters():
165
self.assertTrue(param.is_contiguous(memory_format=torch.channels_last))
167
def test_pad_sequence(self):
168
def pad(tensor, length):
173
length - tensor.size(0), *tensor.size()[1:]
179
a = torch.tensor([1, 2, 3])
180
b = torch.tensor([4, 5])
181
c = torch.tensor([6])
184
expected = torch.tensor([[4, 5, 0], [1, 2, 3], [6, 0, 0]])
185
padded = rnn_utils.pad_sequence([b, a, c], True)
186
self.assertEqual(padded, expected)
189
padded = rnn_utils.pad_sequence([b, a, c])
190
self.assertEqual(padded, expected.transpose(0, 1))
193
expected = torch.tensor([[0, 4, 5], [1, 2, 3], [0, 0, 6]])
194
padded = rnn_utils.pad_sequence(
199
self.assertEqual(padded, expected)
202
padded = rnn_utils.pad_sequence(
207
self.assertEqual(padded, expected.transpose(0, 1))
210
expected = torch.tensor([[4, 5, 1], [1, 2, 3], [6, 1, 1]])
211
padded = rnn_utils.pad_sequence([b, a, c], True, 1)
212
self.assertEqual(padded, expected)
215
expected = torch.tensor([[1, 2, 3], [4, 5, 0], [6, 0, 0]])
216
padded = rnn_utils.pad_sequence([a, b, c], True)
217
self.assertEqual(padded, expected)
221
for num_dim in (0, 1, 2, 3):
222
sequences: List[torch.Tensor] = []
223
trailing_dims = [4] * num_dim
224
for i in range(1, maxlen + 1):
226
sequences.append(torch.rand(seq_len, 5, *trailing_dims))
227
random.shuffle(sequences)
229
expected = torch.stack([pad(seq, maxlen * maxlen) for seq in sequences])
230
padded = rnn_utils.pad_sequence(sequences, True)
231
self.assertEqual(padded, expected)
234
padded = rnn_utils.pad_sequence(sequences)
235
self.assertEqual(padded, expected.transpose(0, 1))
238
expected = torch.stack(
239
[pad(seq.flip(0), maxlen * maxlen).flip(0) for seq in sequences]
241
padded = rnn_utils.pad_sequence(
246
self.assertEqual(padded, expected)
249
padded = rnn_utils.pad_sequence(
254
self.assertEqual(padded, expected.transpose(0, 1))
256
def test_unpad_sequence(self):
258
a = torch.tensor([1, 2, 3])
259
b = torch.tensor([4, 5])
260
c = torch.tensor([6])
261
sequences = [a, b, c]
263
lengths = torch.as_tensor([v.size(0) for v in sequences])
264
for batch_first in [True, False]:
265
padded_sequences = rnn_utils.pad_sequence(
266
sequences, batch_first=batch_first
268
unpadded_sequences = rnn_utils.unpad_sequence(
269
padded_sequences, lengths, batch_first=batch_first
271
self.assertEqual(sequences, unpadded_sequences)
275
for num_dim in (0, 1, 2, 3):
277
trailing_dims = [4] * num_dim
278
for i in range(1, maxlen + 1):
280
sequences.append(torch.rand(seq_len, 5, *trailing_dims))
281
random.shuffle(sequences)
283
lengths = torch.as_tensor([v.size(0) for v in sequences])
284
padded_sequences = rnn_utils.pad_sequence(
285
sequences, batch_first=batch_first
287
unpadded_sequences = rnn_utils.unpad_sequence(
288
padded_sequences, lengths, batch_first=batch_first
290
self.assertEqual(sequences, unpadded_sequences)
292
def test_pack_sequence(self):
293
def _compatibility_test(sequences, lengths, batch_first, enforce_sorted=False):
294
padded = rnn_utils.pad_sequence(sequences, batch_first)
295
packed = rnn_utils.pack_sequence(sequences, enforce_sorted)
296
unpacked = rnn_utils.pad_packed_sequence(packed, batch_first)
297
self.assertEqual(padded, unpacked[0])
298
pack_padded = rnn_utils.pack_padded_sequence(
299
padded, lengths, batch_first, enforce_sorted
301
self.assertEqual(packed, pack_padded)
304
a = torch.tensor([1, 2, 3])
305
b = torch.tensor([4, 5])
306
c = torch.tensor([6])
307
packed = rnn_utils.pack_sequence([a, b, c], enforce_sorted=False)
308
expected = torch.tensor([1, 4, 6, 2, 5, 3])
309
self.assertEqual(packed.batch_sizes, [3, 2, 1])
310
self.assertEqual(packed.data.data, expected)
311
self.assertEqual(packed.sorted_indices, [0, 1, 2])
312
self.assertEqual(packed.unsorted_indices, [0, 1, 2])
314
packed_unsorted = rnn_utils.pack_sequence([b, c, a], enforce_sorted=False)
315
self.assertEqual(packed_unsorted.batch_sizes, [3, 2, 1])
316
self.assertEqual(packed_unsorted.data.data, expected)
317
self.assertEqual(packed_unsorted.sorted_indices, [2, 0, 1])
318
self.assertEqual(packed_unsorted.unsorted_indices, [1, 2, 0])
321
packed_enforce_sorted = rnn_utils.pack_sequence([a, b, c], enforce_sorted=True)
322
self.assertEqual(packed_enforce_sorted.batch_sizes, [3, 2, 1])
323
self.assertEqual(packed_enforce_sorted.data.data, expected)
324
self.assertTrue(packed_enforce_sorted.sorted_indices is None)
325
self.assertTrue(packed_enforce_sorted.unsorted_indices is None)
327
with self.assertRaisesRegex(RuntimeError, "must be sorted in decreasing order"):
328
rnn_utils.pack_sequence([b, c, a], enforce_sorted=True)
330
with self.assertRaisesRegex(
331
RuntimeError, "You can pass `enforce_sorted=False`"
333
rnn_utils.pack_sequence([b, c, a], enforce_sorted=True)
337
for num_dim in (0, 1, 2, 3):
340
trailing_dims = [4] * num_dim
341
for i in range(maxlen, 0, -1):
343
lengths.append(seq_len)
344
sequences.append(torch.rand(seq_len, 5, *trailing_dims))
345
unsorted_sequences = [s.clone() for s in sequences]
346
random.shuffle(unsorted_sequences)
347
unsorted_sequences_lengths = [t.size(0) for t in unsorted_sequences]
350
for batch_first in (True, False):
351
for enforce_sorted in (True, False):
352
_compatibility_test(sequences, lengths, batch_first, enforce_sorted)
354
unsorted_sequences, unsorted_sequences_lengths, batch_first
357
def test_unpack_sequence(self):
359
a = torch.tensor([1, 2, 3])
360
b = torch.tensor([4, 5])
361
c = torch.tensor([6])
362
sequences = [a, b, c]
364
packed_sequences = rnn_utils.pack_sequence(sequences, enforce_sorted=False)
365
unpacked_sequences = rnn_utils.unpack_sequence(packed_sequences)
366
self.assertEqual(sequences, unpacked_sequences)
370
for num_dim in (0, 1, 2, 3):
372
trailing_dims = [4] * num_dim
373
for i in range(1, maxlen + 1):
375
sequences.append(torch.rand(seq_len, 5, *trailing_dims))
376
random.shuffle(sequences)
378
packed_sequences = rnn_utils.pack_sequence(sequences, enforce_sorted=False)
379
unpacked_sequences = rnn_utils.unpack_sequence(packed_sequences)
380
self.assertEqual(sequences, unpacked_sequences)
382
def test_pack_padded_sequence(self):
383
def generate_test_case(sorted_lengths, should_shuffle):
384
def pad(tensor, length):
388
tensor.new(length - tensor.size(0), *tensor.size()[1:]).zero_(),
392
max_length = sorted_lengths[0]
394
sum(map(bool, filter(lambda x: x >= i, sorted_lengths)))
395
for i in range(1, max_length + 1)
401
i * 100 + torch.arange(1.0, 5 * l + 1).view(l, 1, 5), max_length
403
for i, l in enumerate(sorted_lengths, 1)
409
torch.arange(1.0, 6) + (i + 1) * 100 + 5 * n
410
for i in range(batch_size)
412
for n, batch_size in enumerate(batch_sizes)
414
expected_data = list(itertools.chain.from_iterable(expected_data))
415
expected_data = torch.stack(expected_data, dim=0)
419
permutation = list(range(len(sorted_lengths)))
420
random.shuffle(permutation)
422
unsorted_indices = torch.tensor(permutation)
423
padded = padded.index_select(1, unsorted_indices)
424
lengths = torch.tensor(sorted_lengths).index_select(0, unsorted_indices)
426
unsorted_indices = None
427
lengths = sorted_lengths
430
padded.requires_grad_(),
439
[[10, 8, 4, 2, 2, 2, 1], False],
440
[[11, 10, 8, 6, 4, 3, 1], False],
441
[[11, 10, 8, 6, 4, 3, 1], True],
444
for test_case, batch_first in itertools.product(test_cases, (True, False)):
445
sorted_lengths, should_shuffle = test_case
452
) = generate_test_case(sorted_lengths, should_shuffle)
456
src = src.transpose(0, 1)
459
packed = rnn_utils.pack_padded_sequence(
460
src, lengths, batch_first=batch_first, enforce_sorted=not should_shuffle
462
self.assertEqual(packed.data.data, expected_data)
463
self.assertEqual(packed.batch_sizes, batch_sizes)
464
self.assertEqual(packed.unsorted_indices, unsorted_indices)
467
unpacked, unpacked_len = rnn_utils.pad_packed_sequence(
468
packed, batch_first=batch_first
470
self.assertEqual(unpacked, src)
471
self.assertEqual(unpacked_len, lengths)
474
if padded.grad is not None:
475
padded.grad.data.zero_()
476
grad_output = unpacked.data.clone().normal_()
477
unpacked.backward(grad_output)
479
grad_output.transpose_(0, 1)
480
for i, l in enumerate(lengths):
481
self.assertEqual(padded.grad.data[:l, i], grad_output[:l, i])
483
self.assertEqual(padded.grad.data[l:, i].abs().sum(), 0)
486
with self.assertRaisesRegex(
487
RuntimeError, "You can pass `enforce_sorted=False`"
489
packed = rnn_utils.pack_padded_sequence(torch.randn(3, 3), [1, 3, 2])
490
with self.assertRaisesRegex(RuntimeError, "empty tensor"):
491
packed = rnn_utils.pack_padded_sequence(torch.randn(0, 0), [])
492
with self.assertRaisesRegex(RuntimeError, "empty tensor"):
493
packed = rnn_utils.pack_padded_sequence(
494
torch.randn([0, 1, 10]), torch.randn([11, 14, 14, 2]), True
498
if __name__ == "__main__":