pytorch

Форк
0
/
test_segment_reductions.py 
577 строк · 22.7 Кб
1
# Owner(s): ["module: scatter & gather ops"]
2

3
from itertools import product
4
from functools import partial
5

6
import numpy as np
7
import torch
8
from torch.testing._internal.common_device_type import (
9
    instantiate_device_type_tests,
10
    dtypes,
11
)
12
from torch.testing._internal.common_utils import (
13
    TestCase,
14
    run_tests,
15
    gradcheck,
16
    parametrize,
17
    skipIfRocm,
18
)
19

20

21
reductions = ["max", "mean", "min", "sum", "prod"]
22

23

24
def get_default_value(initial_value, reduction):
25
    if initial_value is not None:
26
        return initial_value
27
    if reduction == "max":
28
        return -float("Inf")
29
    elif reduction == "mean":
30
        return float("nan")
31
    elif reduction == "min":
32
        return float("Inf")
33
    elif reduction == "sum":
34
        return 0.0
35
    elif reduction == "prod":
36
        return 1.0
37

38

39
class TestSegmentReductions(TestCase):
40
    def _test_common(
41
        self,
42
        reduction,
43
        device,
44
        dtype,
45
        unsafe,
46
        axis,
47
        initial_value,
48
        data_arr,
49
        lengths_arr,
50
        expected_arr,
51
        expected_grad_arr,
52
        check_backward,
53
        lengths_dtype=torch.int,
54
    ):
55
        lengths = torch.tensor(lengths_arr, device=device, dtype=lengths_dtype)
56
        # generate offsets from lengths
57
        zeros_shape = list(lengths.shape)
58
        zeros_shape[-1] = 1
59
        offsets = torch.cat((lengths.new_zeros(zeros_shape), lengths), -1).cumsum_(-1)
60

61
        data = torch.tensor(
62
            data_arr,
63
            device=device,
64
            dtype=dtype,
65
            requires_grad=True,
66
        )
67
        expected_result = torch.tensor(expected_arr, device=device, dtype=dtype)
68
        expected_grad = torch.tensor(expected_grad_arr, device=device, dtype=dtype)
69
        for mode in ['lengths', 'offsets']:
70
            segment_reduce_kwargs = dict(
71
                axis=axis,
72
                unsafe=unsafe,
73
                initial=initial_value)
74
            if (mode == 'lengths'):
75
                segment_reduce_kwargs['lengths'] = lengths
76
            else:
77
                segment_reduce_kwargs['offsets'] = offsets
78
            actual_result = torch._segment_reduce(
79
                data=data,
80
                reduce=reduction,
81
                **segment_reduce_kwargs
82
            )
83
            self.assertEqual(
84
                expected_result, actual_result, rtol=1e-02, atol=1e-05, equal_nan=True
85
            )
86

87
            if not check_backward:
88
                return
89

90
            # Test backward
91
            actual_result.sum().backward()
92
            self.assertEqual(
93
                expected_grad, data.grad, rtol=1e-02, atol=1e-05, equal_nan=True
94
            )
95
            data = data.clone().detach().requires_grad_(True)
96

97
            # gradcheck does not work well with bfloat16 or fp16 cpu types
98
            # also there is small numerical difference with fp32
99
            if dtype not in [torch.half, torch.bfloat16, torch.float]:
100
                # gradcheck does not like "nan" input, setting to random 10
101
                d_non_nan = np.nan_to_num(data_arr, nan=10)
102
                new_data = torch.tensor(
103
                    # [10 if v == float("nan") else v for v in data],
104
                    d_non_nan,
105
                    device=device,
106
                    dtype=dtype,
107
                    requires_grad=True,
108
                )
109
                self.assertTrue(
110
                    gradcheck(
111
                        lambda x: torch._segment_reduce(
112
                            data=x,
113
                            reduce=reduction,
114
                            **segment_reduce_kwargs
115
                        ),
116
                        (new_data,),
117
                    )
118
                )
119

120
    @dtypes(
121
        *product(
122
            (torch.half, torch.bfloat16, torch.float, torch.double),
123
            (torch.int, torch.int64),
124
        )
125
    )
126
    def test_simple_1d(self, device, dtypes):
127
        val_dtype, length_type = dtypes
128
        lengths = [1, 2, 3, 0]
129
        data = [1, float("nan"), 3, 4, 5, 5]
130

131
        for reduction in reductions:
132
            for initial in [0, None]:
133
                check_backward = True if initial is not None else False
134
                initial_value = initial
135
                default_value = get_default_value(initial_value, reduction)
136
                if reduction == "max":
137
                    expected_result = [1, float("nan"), 5, default_value]
138
                    expected_grad = [1, 1, 0, 0, 0.5, 0.5]
139
                elif reduction == "mean":
140
                    expected_result = [1, float("nan"), 4.666, default_value]
141
                    expected_grad = [1.0, 0.5, 0.5, 0.333, 0.333, 0.333]
142
                elif reduction == "min":
143
                    if initial is not None:
144
                        initial_value = 1000  # some high number
145
                        default_value = get_default_value(initial_value, reduction)
146
                    expected_result = [1, float("nan"), 4, default_value]
147
                    expected_grad = [1.0, 1.0, 0, 1, 0, 0]
148
                elif reduction == "sum":
149
                    expected_result = [1, float("nan"), 14, default_value]
150
                    expected_grad = [1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
151
                elif reduction == "prod":
152
                    if initial is not None:
153
                        initial_value = 2  # 0 initial_value will zero out everything for prod
154
                        default_value = get_default_value(initial_value, reduction)
155
                        expected_result = [2, float("nan"), 200, default_value]
156
                        expected_grad = [2.0, 6.0, float("nan"), 50.0, 40.0, 40.0]
157
                    else:
158
                        expected_result = [1, float("nan"), 100, default_value]
159
                        expected_grad = [1.0, 3.0, float("nan"), 25.0, 20.0, 20.0]
160
                for axis in [0, -1]:
161
                    for unsafe in [True, False]:
162
                        self._test_common(
163
                            reduction,
164
                            device,
165
                            val_dtype,
166
                            unsafe,
167
                            axis,
168
                            initial_value,
169
                            data,
170
                            lengths,
171
                            expected_result,
172
                            expected_grad,
173
                            check_backward,
174
                            length_type,
175
                        )
176

177
    @dtypes(
178
        *product(
179
            (torch.half, torch.bfloat16, torch.float, torch.double),
180
            (torch.int, torch.int64),
181
        )
182
    )
183
    def test_simple_zero_length(self, device, dtypes):
184
        val_dtype, length_type = dtypes
185
        lengths = [0, 0]
186
        data = torch.ones(0)
187

188
        for reduction in reductions:
189
            for initial in [0, None]:
190
                check_backward = True if initial is not None else False
191
                initial_value = initial
192
                default_value = get_default_value(initial_value, reduction)
193
                if reduction == "max":
194
                    expected_result = [default_value, default_value]
195
                    expected_grad = []
196
                elif reduction == "mean":
197
                    expected_result = [default_value, default_value]
198
                    expected_grad = []
199
                elif reduction == "min":
200
                    if initial is not None:
201
                        initial_value = 1000  # some high number
202
                        default_value = get_default_value(initial_value, reduction)
203
                    expected_result = [default_value, default_value]
204
                    expected_grad = []
205
                elif reduction == "sum":
206
                    expected_result = [default_value, default_value]
207
                    expected_grad = []
208
                elif reduction == "prod":
209
                    if initial is not None:
210
                        initial_value = 2  # 0 initial_value will zero out everything for prod
211
                        default_value = get_default_value(initial_value, reduction)
212
                        expected_result = [default_value, default_value]
213
                        expected_grad = []
214
                    else:
215
                        expected_result = [default_value, default_value]
216
                        expected_grad = []
217
                for axis in [0]:
218
                    for unsafe in [True, False]:
219
                        self._test_common(
220
                            reduction,
221
                            device,
222
                            val_dtype,
223
                            unsafe,
224
                            axis,
225
                            initial_value,
226
                            data,
227
                            lengths,
228
                            expected_result,
229
                            expected_grad,
230
                            check_backward,
231
                            length_type,
232
                        )
233

234
    @skipIfRocm
235
    @dtypes(
236
        *product(
237
            (torch.half, torch.bfloat16, torch.float, torch.double),
238
            (torch.int, torch.int64),
239
        )
240
    )
241
    def test_multi_d_simple(self, device, dtypes):
242
        val_dtype, length_type = dtypes
243
        axis = 0
244
        lengths = [1, 2, 3, 0]
245
        data = [[1, 1], [float("nan"), 1], [3, float("nan")], [4, 1], [3, 2], [2, 3]]
246

247
        for reduction in reductions:
248
            for initial in [0, None]:
249
                check_backward = True if initial is not None else False
250
                initial_value = initial
251
                default_value = get_default_value(initial_value, reduction)
252
                if reduction == "max":
253
                    expected_result = [
254
                        [1, 1],
255
                        [float("nan"), float("nan")],
256
                        [4, 3],
257
                        [default_value, default_value],
258
                    ]
259
                    expected_grad = [
260
                        [1, 1],
261
                        [1, 0],
262
                        [0, 1],
263
                        [1, 0],
264
                        [0, 0],
265
                        [0, 1],
266
                    ]
267
                elif reduction == "mean":
268
                    expected_result = [
269
                        [1, 1],
270
                        [float("nan"), float("nan")],
271
                        [3, 2],
272
                        [default_value, default_value],
273
                    ]
274
                    expected_grad = [
275
                        [1.0, 1.0],
276
                        [0.5, 0.5],
277
                        [0.5, 0.5],
278
                        [0.333, 0.333],
279
                        [0.333, 0.333],
280
                        [0.333, 0.333],
281
                    ]
282
                elif reduction == "min":
283
                    if initial is not None:
284
                        initial_value = 1000  # some high number
285
                        default_value = get_default_value(initial_value, reduction)
286
                    expected_result = [
287
                        [1, 1],
288
                        [float("nan"), float("nan")],
289
                        [2, 1],
290
                        [default_value, default_value],
291
                    ]
292
                    expected_grad = [
293
                        [1.0, 1.0],
294
                        [1, 0],
295
                        [0, 1],
296
                        [0, 1],
297
                        [0, 0],
298
                        [1, 0],
299
                    ]
300
                elif reduction == "sum":
301
                    expected_result = [
302
                        [1, 1],
303
                        [float("nan"), float("nan")],
304
                        [9, 6],
305
                        [default_value, default_value],
306
                    ]
307
                    expected_grad = [
308
                        [1.0, 1.0],
309
                        [1.0, 1.0],
310
                        [1.0, 1.0],
311
                        [1.0, 1.0],
312
                        [1.0, 1.0],
313
                        [1.0, 1.0],
314
                    ]
315
                elif reduction == "prod":
316
                    if initial is not None:
317
                        initial_value = 2  # 0 initial_value will zero out everything for prod
318
                        default_value = get_default_value(initial_value, reduction)
319
                        expected_result = [
320
                            [2, 2],
321
                            [float("nan"), float("nan")],
322
                            [48, 12],
323
                            [default_value, default_value],
324
                        ]
325
                        expected_grad = [
326
                            [2.0, 2.0],
327
                            [6.0, float("nan")],
328
                            [float("nan"), 2.0],
329
                            [12.0, 12.0],
330
                            [16.0, 6.0],
331
                            [24.0, 4.0],
332
                        ]
333
                    else:
334
                        expected_result = [
335
                            [1, 1],
336
                            [float("nan"), float("nan")],
337
                            [24, 6],
338
                            [default_value, default_value],
339
                        ]
340
                        expected_grad = [
341
                            [1.0, 1.0],
342
                            [3.0, float("nan")],
343
                            [float("nan"), 1.0],
344
                            [6.0, 6.0],
345
                            [8.0, 3.0],
346
                            [12.0, 2.0],
347
                        ]
348
                for unsafe in [True, False]:
349
                    self._test_common(
350
                        reduction,
351
                        device,
352
                        val_dtype,
353
                        unsafe,
354
                        axis,
355
                        initial_value,
356
                        data,
357
                        lengths,
358
                        expected_result,
359
                        expected_grad,
360
                        check_backward,
361
                    )
362

363
    @dtypes(
364
        *product(
365
            (torch.half, torch.bfloat16, torch.float, torch.double),
366
            (torch.int, torch.int64),
367
        )
368
    )
369
    @parametrize("reduce", ['sum', 'prod', 'min', 'max', 'mean'])
370
    def test_pytorch_scatter_test_cases(self, device, dtypes, reduce):
371
        val_dtype, length_dtype = dtypes
372
        # zero-length segments are filled with reduction inits contrary to pytorch_scatter.
373
        tests = [
374
            {
375
                'src': [1, 2, 3, 4, 5, 6],
376
                'index': [0, 0, 1, 1, 1, 3],
377
                'indptr': [0, 2, 5, 5, 6],
378
                'sum': [3, 12, 0, 6],
379
                'prod': [2, 60, 1, 6],
380
                'mean': [1.5, 4, float('nan'), 6],
381
                'min': [1, 3, float('inf'), 6],
382
                'max': [2, 5, -float('inf'), 6],
383
            },
384
            {
385
                'src': [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12]],
386
                'index': [0, 0, 1, 1, 1, 3],
387
                'indptr': [0, 2, 5, 5, 6],
388
                'sum': [[4, 6], [21, 24], [0, 0], [11, 12]],
389
                'prod': [[3, 8], [315, 480], [1, 1], [11, 12]],
390
                'mean': [[2, 3], [7, 8], [float('nan'), float('nan')], [11, 12]],
391
                'min': [[1, 2], [5, 6], [float('inf'), float('inf')], [11, 12]],
392
                'max': [[3, 4], [9, 10], [-float('inf'), -float('inf')], [11, 12]],
393
            },
394
            {
395
                'src': [[1, 3, 5, 7, 9, 11], [2, 4, 6, 8, 10, 12]],
396
                'index': [[0, 0, 1, 1, 1, 3], [0, 0, 0, 1, 1, 2]],
397
                'indptr': [[0, 2, 5, 5, 6], [0, 3, 5, 6, 6]],
398
                'sum': [[4, 21, 0, 11], [12, 18, 12, 0]],
399
                'prod': [[3, 315, 1, 11], [48, 80, 12, 1]],
400
                'mean': [[2, 7, float('nan'), 11], [4, 9, 12, float('nan')]],
401
                'min': [[1, 5, float('inf'), 11], [2, 8, 12, float('inf')]],
402
                'max': [[3, 9, -float('inf'), 11], [6, 10, 12, -float('inf')]],
403
            },
404
            {
405
                'src': [[[1, 2], [3, 4], [5, 6]], [[7, 9], [10, 11], [12, 13]]],
406
                'index': [[0, 0, 1], [0, 2, 2]],
407
                'indptr': [[0, 2, 3, 3], [0, 1, 1, 3]],
408
                'sum': [[[4, 6], [5, 6], [0, 0]], [[7, 9], [0, 0], [22, 24]]],
409
                'prod': [[[3, 8], [5, 6], [1, 1]], [[7, 9], [1, 1], [120, 143]]],
410
                'mean': [[[2, 3], [5, 6], [float('nan'), float('nan')]],
411
                         [[7, 9], [float('nan'), float('nan')], [11, 12]]],
412
                'min': [[[1, 2], [5, 6], [float('inf'), float('inf')]],
413
                        [[7, 9], [float('inf'), float('inf')], [10, 11]]],
414
                'max': [[[3, 4], [5, 6], [-float('inf'), -float('inf')]],
415
                        [[7, 9], [-float('inf'), -float('inf')], [12, 13]]],
416
            },
417
            {
418
                'src': [[1, 3], [2, 4]],
419
                'index': [[0, 0], [0, 0]],
420
                'indptr': [[0, 2], [0, 2]],
421
                'sum': [[4], [6]],
422
                'prod': [[3], [8]],
423
                'mean': [[2], [3]],
424
                'min': [[1], [2]],
425
                'max': [[3], [4]],
426
            },
427
            {
428
                'src': [[[1, 1], [3, 3]], [[2, 2], [4, 4]]],
429
                'index': [[0, 0], [0, 0]],
430
                'indptr': [[0, 2], [0, 2]],
431
                'sum': [[[4, 4]], [[6, 6]]],
432
                'prod': [[[3, 3]], [[8, 8]]],
433
                'mean': [[[2, 2]], [[3, 3]]],
434
                'min': [[[1, 1]], [[2, 2]]],
435
                'max': [[[3, 3]], [[4, 4]]],
436
            },
437
        ]
438
        for test in tests:
439
            data = torch.tensor(test['src'], dtype=val_dtype, device=device, requires_grad=True)
440
            indptr = torch.tensor(test['indptr'], dtype=length_dtype, device=device)
441
            dim = indptr.ndim - 1
442
            # calculate lengths from indptr
443
            lengths = torch.diff(indptr, dim=dim)
444
            expected = torch.tensor(test[reduce], dtype=val_dtype, device=device)
445

446
            actual_result = torch._segment_reduce(
447
                data=data,
448
                reduce=reduce,
449
                lengths=lengths,
450
                axis=dim,
451
                unsafe=True,
452
            )
453
            self.assertEqual(actual_result, expected)
454

455
            # test offsets
456
            actual_result = torch._segment_reduce(
457
                data=data,
458
                reduce=reduce,
459
                offsets=indptr,
460
                axis=dim,
461
                unsafe=True,
462
            )
463
            self.assertEqual(actual_result, expected)
464

465
            if val_dtype == torch.float64:
466
                def fn(x, mode='lengths'):
467
                    initial = 1
468
                    # supply initial values to prevent gradcheck from failing for 0 length segments
469
                    # where nan/inf are reduction identities that produce nans when calculating the numerical jacobian
470
                    if reduce == 'min':
471
                        initial = 1000
472
                    elif reduce == 'max':
473
                        initial = -1000
474
                    segment_reduce_args = {x, reduce}
475
                    segment_reduce_kwargs = dict(axis=dim, unsafe=True, initial=initial)
476
                    if mode == 'lengths':
477
                        segment_reduce_kwargs[mode] = lengths
478
                    elif mode == 'offsets':
479
                        segment_reduce_kwargs[mode] = indptr
480
                    return torch._segment_reduce(*segment_reduce_args, **segment_reduce_kwargs)
481
                self.assertTrue(gradcheck(partial(fn, mode='lengths'), (data.clone().detach().requires_grad_(True))))
482
                self.assertTrue(gradcheck(partial(fn, mode='offsets'), (data.clone().detach().requires_grad_(True))))
483

484

485
    @dtypes(
486
        *product(
487
            (torch.half, torch.bfloat16, torch.float, torch.double),
488
            (torch.int, torch.int64),
489
        )
490
    )
491
    def test_multi_d(self, device, dtypes):
492
        val_dtype, length_type = dtypes
493
        axis = 0
494
        lengths = [0, 2, 3, 0]
495
        data = np.arange(50).reshape(5, 2, 5).tolist()
496
        expected_grad = []
497

498
        # TODO: calculate grad and check correctness
499
        check_backward = False
500

501
        for reduction in reductions:
502
            initial_value = 0
503
            if reduction == "max":
504
                expected_result = [
505
                    np.full((2, 5), initial_value).tolist(),
506
                    np.max(data[:2], axis=0).tolist(),
507
                    np.max(data[2:], axis=0).tolist(),
508
                    np.full((2, 5), initial_value).tolist(),
509
                ]
510
            elif reduction == "mean":
511
                expected_result = [
512
                    np.full((2, 5), initial_value).tolist(),
513
                    np.mean(data[:2], axis=0).tolist(),
514
                    np.mean(data[2:], axis=0).tolist(),
515
                    np.full((2, 5), initial_value).tolist(),
516
                ]
517
            elif reduction == "min":
518
                initial_value = 1000  # some high number
519
                expected_result = [
520
                    np.full((2, 5), initial_value).tolist(),
521
                    np.min(data[:2], axis=0).tolist(),
522
                    np.min(data[2:], axis=0).tolist(),
523
                    np.full((2, 5), initial_value).tolist(),
524
                ]
525
            elif reduction == "sum":
526
                expected_result = [
527
                    np.full((2, 5), initial_value).tolist(),
528
                    np.sum(data[:2], axis=0).tolist(),
529
                    np.sum(data[2:], axis=0).tolist(),
530
                    np.full((2, 5), initial_value).tolist(),
531
                ]
532
            elif reduction == "prod":
533
                initial_value = 1
534
                expected_result = [
535
                    np.full((2, 5), initial_value).tolist(),
536
                    np.prod(data[:2], axis=0).tolist(),
537
                    np.prod(data[2:], axis=0).tolist(),
538
                    np.full((2, 5), initial_value).tolist(),
539
                ]
540
            for unsafe in [True, False]:
541
                self._test_common(
542
                    reduction,
543
                    device,
544
                    val_dtype,
545
                    unsafe,
546
                    axis,
547
                    initial_value,
548
                    data,
549
                    lengths,
550
                    expected_result,
551
                    expected_grad,
552
                    check_backward,
553
                )
554

555
    @dtypes(torch.int, torch.int64)
556
    def test_unsafe_flag(self, device, dtype):
557
        length_type = dtype
558
        lengths = torch.tensor([0, 2, 3, 0], device=device, dtype=length_type)
559
        data = torch.arange(6, dtype=torch.float, device=device)
560

561
        # test for error on 1-D lenghts
562
        with self.assertRaisesRegex(RuntimeError, "Expected all rows of lengths along axis"):
563
            torch._segment_reduce(data, 'sum', lengths=lengths, axis=0, unsafe=False)
564

565
        # test for error on multi-D lengths
566
        nd_lengths = torch.tensor([[0, 3, 3, 0], [2, 3, 0, 0]], dtype=length_type, device=device)
567
        nd_data = torch.arange(12, dtype=torch.float, device=device).reshape(2, 6)
568
        with self.assertRaisesRegex(RuntimeError, "Expected all rows of lengths along axis"):
569
            torch._segment_reduce(nd_data, 'sum', lengths=nd_lengths, axis=1, unsafe=False)
570

571

572

573

574
instantiate_device_type_tests(TestSegmentReductions, globals())
575

576
if __name__ == "__main__":
577
    run_tests()
578

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.