3
from itertools import product
4
from functools import partial
8
from torch.testing._internal.common_device_type import (
9
instantiate_device_type_tests,
12
from torch.testing._internal.common_utils import (
21
reductions = ["max", "mean", "min", "sum", "prod"]
24
def get_default_value(initial_value, reduction):
25
if initial_value is not None:
27
if reduction == "max":
29
elif reduction == "mean":
31
elif reduction == "min":
33
elif reduction == "sum":
35
elif reduction == "prod":
39
class TestSegmentReductions(TestCase):
53
lengths_dtype=torch.int,
55
lengths = torch.tensor(lengths_arr, device=device, dtype=lengths_dtype)
57
zeros_shape = list(lengths.shape)
59
offsets = torch.cat((lengths.new_zeros(zeros_shape), lengths), -1).cumsum_(-1)
67
expected_result = torch.tensor(expected_arr, device=device, dtype=dtype)
68
expected_grad = torch.tensor(expected_grad_arr, device=device, dtype=dtype)
69
for mode in ['lengths', 'offsets']:
70
segment_reduce_kwargs = dict(
73
initial=initial_value)
74
if (mode == 'lengths'):
75
segment_reduce_kwargs['lengths'] = lengths
77
segment_reduce_kwargs['offsets'] = offsets
78
actual_result = torch._segment_reduce(
81
**segment_reduce_kwargs
84
expected_result, actual_result, rtol=1e-02, atol=1e-05, equal_nan=True
87
if not check_backward:
91
actual_result.sum().backward()
93
expected_grad, data.grad, rtol=1e-02, atol=1e-05, equal_nan=True
95
data = data.clone().detach().requires_grad_(True)
99
if dtype not in [torch.half, torch.bfloat16, torch.float]:
101
d_non_nan = np.nan_to_num(data_arr, nan=10)
102
new_data = torch.tensor(
111
lambda x: torch._segment_reduce(
114
**segment_reduce_kwargs
122
(torch.half, torch.bfloat16, torch.float, torch.double),
123
(torch.int, torch.int64),
126
def test_simple_1d(self, device, dtypes):
127
val_dtype, length_type = dtypes
128
lengths = [1, 2, 3, 0]
129
data = [1, float("nan"), 3, 4, 5, 5]
131
for reduction in reductions:
132
for initial in [0, None]:
133
check_backward = True if initial is not None else False
134
initial_value = initial
135
default_value = get_default_value(initial_value, reduction)
136
if reduction == "max":
137
expected_result = [1, float("nan"), 5, default_value]
138
expected_grad = [1, 1, 0, 0, 0.5, 0.5]
139
elif reduction == "mean":
140
expected_result = [1, float("nan"), 4.666, default_value]
141
expected_grad = [1.0, 0.5, 0.5, 0.333, 0.333, 0.333]
142
elif reduction == "min":
143
if initial is not None:
145
default_value = get_default_value(initial_value, reduction)
146
expected_result = [1, float("nan"), 4, default_value]
147
expected_grad = [1.0, 1.0, 0, 1, 0, 0]
148
elif reduction == "sum":
149
expected_result = [1, float("nan"), 14, default_value]
150
expected_grad = [1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
151
elif reduction == "prod":
152
if initial is not None:
154
default_value = get_default_value(initial_value, reduction)
155
expected_result = [2, float("nan"), 200, default_value]
156
expected_grad = [2.0, 6.0, float("nan"), 50.0, 40.0, 40.0]
158
expected_result = [1, float("nan"), 100, default_value]
159
expected_grad = [1.0, 3.0, float("nan"), 25.0, 20.0, 20.0]
161
for unsafe in [True, False]:
179
(torch.half, torch.bfloat16, torch.float, torch.double),
180
(torch.int, torch.int64),
183
def test_simple_zero_length(self, device, dtypes):
184
val_dtype, length_type = dtypes
188
for reduction in reductions:
189
for initial in [0, None]:
190
check_backward = True if initial is not None else False
191
initial_value = initial
192
default_value = get_default_value(initial_value, reduction)
193
if reduction == "max":
194
expected_result = [default_value, default_value]
196
elif reduction == "mean":
197
expected_result = [default_value, default_value]
199
elif reduction == "min":
200
if initial is not None:
202
default_value = get_default_value(initial_value, reduction)
203
expected_result = [default_value, default_value]
205
elif reduction == "sum":
206
expected_result = [default_value, default_value]
208
elif reduction == "prod":
209
if initial is not None:
211
default_value = get_default_value(initial_value, reduction)
212
expected_result = [default_value, default_value]
215
expected_result = [default_value, default_value]
218
for unsafe in [True, False]:
237
(torch.half, torch.bfloat16, torch.float, torch.double),
238
(torch.int, torch.int64),
241
def test_multi_d_simple(self, device, dtypes):
242
val_dtype, length_type = dtypes
244
lengths = [1, 2, 3, 0]
245
data = [[1, 1], [float("nan"), 1], [3, float("nan")], [4, 1], [3, 2], [2, 3]]
247
for reduction in reductions:
248
for initial in [0, None]:
249
check_backward = True if initial is not None else False
250
initial_value = initial
251
default_value = get_default_value(initial_value, reduction)
252
if reduction == "max":
255
[float("nan"), float("nan")],
257
[default_value, default_value],
267
elif reduction == "mean":
270
[float("nan"), float("nan")],
272
[default_value, default_value],
282
elif reduction == "min":
283
if initial is not None:
285
default_value = get_default_value(initial_value, reduction)
288
[float("nan"), float("nan")],
290
[default_value, default_value],
300
elif reduction == "sum":
303
[float("nan"), float("nan")],
305
[default_value, default_value],
315
elif reduction == "prod":
316
if initial is not None:
318
default_value = get_default_value(initial_value, reduction)
321
[float("nan"), float("nan")],
323
[default_value, default_value],
336
[float("nan"), float("nan")],
338
[default_value, default_value],
348
for unsafe in [True, False]:
365
(torch.half, torch.bfloat16, torch.float, torch.double),
366
(torch.int, torch.int64),
369
@parametrize("reduce", ['sum', 'prod', 'min', 'max', 'mean'])
370
def test_pytorch_scatter_test_cases(self, device, dtypes, reduce):
371
val_dtype, length_dtype = dtypes
375
'src': [1, 2, 3, 4, 5, 6],
376
'index': [0, 0, 1, 1, 1, 3],
377
'indptr': [0, 2, 5, 5, 6],
378
'sum': [3, 12, 0, 6],
379
'prod': [2, 60, 1, 6],
380
'mean': [1.5, 4, float('nan'), 6],
381
'min': [1, 3, float('inf'), 6],
382
'max': [2, 5, -float('inf'), 6],
385
'src': [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12]],
386
'index': [0, 0, 1, 1, 1, 3],
387
'indptr': [0, 2, 5, 5, 6],
388
'sum': [[4, 6], [21, 24], [0, 0], [11, 12]],
389
'prod': [[3, 8], [315, 480], [1, 1], [11, 12]],
390
'mean': [[2, 3], [7, 8], [float('nan'), float('nan')], [11, 12]],
391
'min': [[1, 2], [5, 6], [float('inf'), float('inf')], [11, 12]],
392
'max': [[3, 4], [9, 10], [-float('inf'), -float('inf')], [11, 12]],
395
'src': [[1, 3, 5, 7, 9, 11], [2, 4, 6, 8, 10, 12]],
396
'index': [[0, 0, 1, 1, 1, 3], [0, 0, 0, 1, 1, 2]],
397
'indptr': [[0, 2, 5, 5, 6], [0, 3, 5, 6, 6]],
398
'sum': [[4, 21, 0, 11], [12, 18, 12, 0]],
399
'prod': [[3, 315, 1, 11], [48, 80, 12, 1]],
400
'mean': [[2, 7, float('nan'), 11], [4, 9, 12, float('nan')]],
401
'min': [[1, 5, float('inf'), 11], [2, 8, 12, float('inf')]],
402
'max': [[3, 9, -float('inf'), 11], [6, 10, 12, -float('inf')]],
405
'src': [[[1, 2], [3, 4], [5, 6]], [[7, 9], [10, 11], [12, 13]]],
406
'index': [[0, 0, 1], [0, 2, 2]],
407
'indptr': [[0, 2, 3, 3], [0, 1, 1, 3]],
408
'sum': [[[4, 6], [5, 6], [0, 0]], [[7, 9], [0, 0], [22, 24]]],
409
'prod': [[[3, 8], [5, 6], [1, 1]], [[7, 9], [1, 1], [120, 143]]],
410
'mean': [[[2, 3], [5, 6], [float('nan'), float('nan')]],
411
[[7, 9], [float('nan'), float('nan')], [11, 12]]],
412
'min': [[[1, 2], [5, 6], [float('inf'), float('inf')]],
413
[[7, 9], [float('inf'), float('inf')], [10, 11]]],
414
'max': [[[3, 4], [5, 6], [-float('inf'), -float('inf')]],
415
[[7, 9], [-float('inf'), -float('inf')], [12, 13]]],
418
'src': [[1, 3], [2, 4]],
419
'index': [[0, 0], [0, 0]],
420
'indptr': [[0, 2], [0, 2]],
428
'src': [[[1, 1], [3, 3]], [[2, 2], [4, 4]]],
429
'index': [[0, 0], [0, 0]],
430
'indptr': [[0, 2], [0, 2]],
431
'sum': [[[4, 4]], [[6, 6]]],
432
'prod': [[[3, 3]], [[8, 8]]],
433
'mean': [[[2, 2]], [[3, 3]]],
434
'min': [[[1, 1]], [[2, 2]]],
435
'max': [[[3, 3]], [[4, 4]]],
439
data = torch.tensor(test['src'], dtype=val_dtype, device=device, requires_grad=True)
440
indptr = torch.tensor(test['indptr'], dtype=length_dtype, device=device)
441
dim = indptr.ndim - 1
443
lengths = torch.diff(indptr, dim=dim)
444
expected = torch.tensor(test[reduce], dtype=val_dtype, device=device)
446
actual_result = torch._segment_reduce(
453
self.assertEqual(actual_result, expected)
456
actual_result = torch._segment_reduce(
463
self.assertEqual(actual_result, expected)
465
if val_dtype == torch.float64:
466
def fn(x, mode='lengths'):
472
elif reduce == 'max':
474
segment_reduce_args = {x, reduce}
475
segment_reduce_kwargs = dict(axis=dim, unsafe=True, initial=initial)
476
if mode == 'lengths':
477
segment_reduce_kwargs[mode] = lengths
478
elif mode == 'offsets':
479
segment_reduce_kwargs[mode] = indptr
480
return torch._segment_reduce(*segment_reduce_args, **segment_reduce_kwargs)
481
self.assertTrue(gradcheck(partial(fn, mode='lengths'), (data.clone().detach().requires_grad_(True))))
482
self.assertTrue(gradcheck(partial(fn, mode='offsets'), (data.clone().detach().requires_grad_(True))))
487
(torch.half, torch.bfloat16, torch.float, torch.double),
488
(torch.int, torch.int64),
491
def test_multi_d(self, device, dtypes):
492
val_dtype, length_type = dtypes
494
lengths = [0, 2, 3, 0]
495
data = np.arange(50).reshape(5, 2, 5).tolist()
499
check_backward = False
501
for reduction in reductions:
503
if reduction == "max":
505
np.full((2, 5), initial_value).tolist(),
506
np.max(data[:2], axis=0).tolist(),
507
np.max(data[2:], axis=0).tolist(),
508
np.full((2, 5), initial_value).tolist(),
510
elif reduction == "mean":
512
np.full((2, 5), initial_value).tolist(),
513
np.mean(data[:2], axis=0).tolist(),
514
np.mean(data[2:], axis=0).tolist(),
515
np.full((2, 5), initial_value).tolist(),
517
elif reduction == "min":
520
np.full((2, 5), initial_value).tolist(),
521
np.min(data[:2], axis=0).tolist(),
522
np.min(data[2:], axis=0).tolist(),
523
np.full((2, 5), initial_value).tolist(),
525
elif reduction == "sum":
527
np.full((2, 5), initial_value).tolist(),
528
np.sum(data[:2], axis=0).tolist(),
529
np.sum(data[2:], axis=0).tolist(),
530
np.full((2, 5), initial_value).tolist(),
532
elif reduction == "prod":
535
np.full((2, 5), initial_value).tolist(),
536
np.prod(data[:2], axis=0).tolist(),
537
np.prod(data[2:], axis=0).tolist(),
538
np.full((2, 5), initial_value).tolist(),
540
for unsafe in [True, False]:
555
@dtypes(torch.int, torch.int64)
556
def test_unsafe_flag(self, device, dtype):
558
lengths = torch.tensor([0, 2, 3, 0], device=device, dtype=length_type)
559
data = torch.arange(6, dtype=torch.float, device=device)
562
with self.assertRaisesRegex(RuntimeError, "Expected all rows of lengths along axis"):
563
torch._segment_reduce(data, 'sum', lengths=lengths, axis=0, unsafe=False)
566
nd_lengths = torch.tensor([[0, 3, 3, 0], [2, 3, 0, 0]], dtype=length_type, device=device)
567
nd_data = torch.arange(12, dtype=torch.float, device=device).reshape(2, 6)
568
with self.assertRaisesRegex(RuntimeError, "Expected all rows of lengths along axis"):
569
torch._segment_reduce(nd_data, 'sum', lengths=nd_lengths, axis=1, unsafe=False)
574
instantiate_device_type_tests(TestSegmentReductions, globals())
576
if __name__ == "__main__":