pytorch

Форк
0
/
sparse_ops_test.py 
84 строки · 3.4 Кб
1

2

3

4

5

6
from caffe2.python import core
7
from caffe2.python.test_util import rand_array
8
import caffe2.python.hypothesis_test_util as hu
9
import caffe2.python.serialized_test.serialized_test_util as serial
10
from hypothesis import given, settings
11
import hypothesis.strategies as st
12
import numpy as np
13

14
class TestScatterOps(serial.SerializedTestCase):
15
    # TODO(dzhulgakov): add test cases for failure scenarios
16
    @given(num_args=st.integers(1, 5),
17
           first_dim=st.integers(1, 20),
18
           index_dim=st.integers(1, 10),
19
           extra_dims=st.lists(st.integers(1, 4), min_size=0, max_size=3),
20
           ind_type=st.sampled_from([np.int32, np.int64]),
21
           data_type=st.sampled_from([np.float32, np.float64]),
22
           **hu.gcs)
23
    @settings(deadline=10000)
24
    def testScatterWeightedSum(
25
            self, num_args, first_dim, index_dim, extra_dims, ind_type, data_type, gc, dc):
26
        ins = ['data', 'w0', 'indices']
27
        for i in range(1, num_args + 1):
28
            ins.extend(['x' + str(i), 'w' + str(i)])
29
        op = core.CreateOperator(
30
            'ScatterWeightedSum',
31
            ins,
32
            ['data'],
33
            device_option=gc)
34
        def ref(d, w0, ind, *args):
35
            r = d.copy()
36
            for i in ind:
37
                r[i] *= w0
38
            for i in range(0, len(args), 2):
39
                x = args[i]
40
                w = args[i+1]
41
                for i, j in enumerate(ind):
42
                    r[j] += w * x[i]
43
            return [r]
44

45
        d = rand_array(first_dim, *extra_dims)
46
        ind = np.random.randint(0, first_dim, index_dim).astype(ind_type)
47
        # ScatterWeightedSumOp only supports w0=1.0 in CUDAContext
48
        # And it only support float32 data in CUDAContext
49
        if(gc == hu.gpu_do or gc == hu.hip_do):
50
            w0 = np.array(1.0).astype(np.float32)
51
            data_type = np.float32
52
        else:
53
            w0 = rand_array()
54
        d = d.astype(data_type)
55
        inputs = [d, w0, ind]
56
        for _ in range(1, num_args + 1):
57
            x = rand_array(index_dim, *extra_dims).astype(data_type)
58
            w = rand_array()
59
            inputs.extend([x,w])
60
        self.assertReferenceChecks(gc, op, inputs, ref, threshold=1e-3)
61

62
    @given(first_dim=st.integers(1, 20),
63
           index_dim=st.integers(1, 10),
64
           extra_dims=st.lists(st.integers(1, 4), min_size=0, max_size=3),
65
           data_type=st.sampled_from([np.float16, np.float32, np.int32, np.int64]),
66
           ind_type=st.sampled_from([np.int32, np.int64]),
67
           **hu.gcs)
68
    @settings(deadline=10000)
69
    def testScatterAssign(
70
            self, first_dim, index_dim, extra_dims, data_type, ind_type, gc, dc):
71
        op = core.CreateOperator('ScatterAssign',
72
                                 ['data', 'indices', 'slices'], ['data'])
73
        def ref(d, ind, x):
74
            r = d.copy()
75
            r[ind] = x
76
            return [r]
77

78
        # let's have indices unique
79
        if first_dim < index_dim:
80
            first_dim, index_dim = index_dim, first_dim
81
        d = (rand_array(first_dim, *extra_dims) * 10).astype(data_type)
82
        ind = np.random.choice(first_dim, index_dim,
83
                               replace=False).astype(ind_type)
84
        x = (rand_array(index_dim, *extra_dims) * 10).astype(data_type)
85
        self.assertReferenceChecks(gc, op, [d, ind, x], ref, threshold=1e-3, ensure_outputs_are_inferred=True)
86

87
if __name__ == "__main__":
88
    import unittest
89
    unittest.main()
90

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.