pytorch

Форк
0
/
aten_test.py 
131 строка · 3.7 Кб
1
from caffe2.python import core
2
from hypothesis import given
3

4
import caffe2.python.hypothesis_test_util as hu
5
import hypothesis.strategies as st
6
import numpy as np
7

8

9
class TestATen(hu.HypothesisTestCase):
10

11
    @given(inputs=hu.tensors(n=2), **hu.gcs)
12
    def test_add(self, inputs, gc, dc):
13
        op = core.CreateOperator(
14
            "ATen",
15
            ["X", "Y"],
16
            ["Z"],
17
            operator="add")
18

19
        def ref(X, Y):
20
            return [X + Y]
21
        self.assertReferenceChecks(gc, op, inputs, ref)
22

23
    @given(inputs=hu.tensors(n=2, dtype=np.float16), **hu.gcs_gpu_only)
24
    def test_add_half(self, inputs, gc, dc):
25
        op = core.CreateOperator(
26
            "ATen",
27
            ["X", "Y"],
28
            ["Z"],
29
            operator="add")
30

31
        def ref(X, Y):
32
            return [X + Y]
33
        self.assertReferenceChecks(gc, op, inputs, ref)
34

35
    @given(inputs=hu.tensors(n=1), **hu.gcs)
36
    def test_pow(self, inputs, gc, dc):
37
        op = core.CreateOperator(
38
            "ATen",
39
            ["S"],
40
            ["Z"],
41
            operator="pow", exponent=2.0)
42

43
        def ref(X):
44
            return [np.square(X)]
45

46
        self.assertReferenceChecks(gc, op, inputs, ref)
47

48
    @given(x=st.integers(min_value=2, max_value=8), **hu.gcs)
49
    def test_sort(self, x, gc, dc):
50
        inputs = [np.random.permutation(x)]
51
        op = core.CreateOperator(
52
            "ATen",
53
            ["S"],
54
            ["Z", "I"],
55
            operator="sort")
56

57
        def ref(X):
58
            return [np.sort(X), np.argsort(X)]
59
        self.assertReferenceChecks(gc, op, inputs, ref)
60

61
    @given(inputs=hu.tensors(n=1), **hu.gcs)
62
    def test_sum(self, inputs, gc, dc):
63
        op = core.CreateOperator(
64
            "ATen",
65
            ["S"],
66
            ["Z"],
67
            operator="sum")
68

69
        def ref(X):
70
            return [np.sum(X)]
71

72
        self.assertReferenceChecks(gc, op, inputs, ref)
73

74
    @given(**hu.gcs)
75
    def test_index_uint8(self, gc, dc):
76
        # Indexing with uint8 is deprecated, but we need to provide backward compatibility for some old models exported through ONNX
77
        op = core.CreateOperator(
78
            "ATen",
79
            ['self', 'mask'],
80
            ["Z"],
81
            operator="index")
82

83
        def ref(self, mask):
84
            return (self[mask.astype(np.bool_)],)
85

86
        tensor = np.random.randn(2, 3, 4).astype(np.float32)
87
        mask = np.array([[1, 0, 0], [1, 1, 0]]).astype(np.uint8)
88

89
        self.assertReferenceChecks(gc, op, [tensor, mask], ref)
90

91
    @given(**hu.gcs)
92
    def test_index_put(self, gc, dc):
93
        op = core.CreateOperator(
94
            "ATen",
95
            ['self', 'indices', 'values'],
96
            ["Z"],
97
            operator="index_put")
98

99
        def ref(self, indices, values):
100
            self[indices] = values
101
            return (self,)
102

103
        tensor = np.random.randn(3, 3).astype(np.float32)
104
        mask = np.array([[True, True, True], [True, False, False], [True, True, False]])
105
        values = np.random.randn(6).astype(np.float32)
106

107
        self.assertReferenceChecks(gc, op, [tensor, mask, values], ref)
108

109
    @given(**hu.gcs)
110
    def test_unique(self, gc, dc):
111
        op = core.CreateOperator(
112
            "ATen",
113
            ['self'],
114
            ["output"],
115
            sorted=True,
116
            return_inverse=True,
117
            # return_counts=False,
118
            operator="_unique")
119

120
        def ref(self):
121
            index, _ = np.unique(self, return_index=False, return_inverse=True, return_counts=False)
122
            return (index,)
123

124
        tensor = np.array([1, 2, 6, 4, 2, 3, 2])
125
        print(ref(tensor))
126
        self.assertReferenceChecks(gc, op, [tensor], ref)
127

128

129
if __name__ == "__main__":
130
    import unittest
131
    unittest.main()
132

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.