pytorch

Форк
0
/
elementwise_logical_ops_test.py 
132 строки · 4.5 Кб
1

2

3

4

5

6
from caffe2.python import core
7
import caffe2.python.hypothesis_test_util as hu
8
import caffe2.python.serialized_test.serialized_test_util as serial
9
from hypothesis import given, settings
10
import hypothesis.strategies as st
11
import numpy as np
12
import unittest
13

14

15
def mux(select, left, right):
16
    return [np.vectorize(lambda c, x, y: x if c else y)(select, left, right)]
17

18

19
def rowmux(select_vec, left, right):
20
    select = [[s] * len(left) for s in select_vec]
21
    return mux(select, left, right)
22

23

24
class TestWhere(serial.SerializedTestCase):
25

26
    def test_reference(self):
27
        self.assertTrue((
28
            np.array([1, 4]) == mux([True, False],
29
                                    [1, 2],
30
                                    [3, 4])[0]
31
        ).all())
32
        self.assertTrue((
33
            np.array([[1], [4]]) == mux([[True], [False]],
34
                                        [[1], [2]],
35
                                        [[3], [4]])[0]
36
        ).all())
37

38
    @given(N=st.integers(min_value=1, max_value=10),
39
           engine=st.sampled_from(["", "CUDNN"]),
40
           **hu.gcs_cpu_only)
41
    @settings(deadline=10000)
42
    def test_where(self, N, gc, dc, engine):
43
        C = np.random.rand(N).astype(bool)
44
        X = np.random.rand(N).astype(np.float32)
45
        Y = np.random.rand(N).astype(np.float32)
46
        op = core.CreateOperator("Where", ["C", "X", "Y"], ["Z"], engine=engine)
47
        self.assertDeviceChecks(dc, op, [C, X, Y], [0])
48
        self.assertReferenceChecks(gc, op, [C, X, Y], mux)
49

50
    @given(N=st.integers(min_value=1, max_value=10),
51
           engine=st.sampled_from(["", "CUDNN"]),
52
           **hu.gcs_cpu_only)
53
    @settings(deadline=10000)
54
    def test_where_dim2(self, N, gc, dc, engine):
55
        C = np.random.rand(N, N).astype(bool)
56
        X = np.random.rand(N, N).astype(np.float32)
57
        Y = np.random.rand(N, N).astype(np.float32)
58
        op = core.CreateOperator("Where", ["C", "X", "Y"], ["Z"], engine=engine)
59
        self.assertDeviceChecks(dc, op, [C, X, Y], [0])
60
        self.assertReferenceChecks(gc, op, [C, X, Y], mux)
61

62

63
class TestRowWhere(hu.HypothesisTestCase):
64

65
    def test_reference(self):
66
        self.assertTrue((
67
            np.array([1, 2]) == rowmux([True],
68
                                       [1, 2],
69
                                       [3, 4])[0]
70
        ).all())
71
        self.assertTrue((
72
            np.array([[1, 2], [7, 8]]) == rowmux([True, False],
73
                                                 [[1, 2], [3, 4]],
74
                                                 [[5, 6], [7, 8]])[0]
75
        ).all())
76

77
    @given(N=st.integers(min_value=1, max_value=10),
78
           engine=st.sampled_from(["", "CUDNN"]),
79
           **hu.gcs_cpu_only)
80
    def test_rowwhere(self, N, gc, dc, engine):
81
        C = np.random.rand(N).astype(bool)
82
        X = np.random.rand(N).astype(np.float32)
83
        Y = np.random.rand(N).astype(np.float32)
84
        op = core.CreateOperator(
85
            "Where",
86
            ["C", "X", "Y"],
87
            ["Z"],
88
            broadcast_on_rows=True,
89
            engine=engine,
90
        )
91
        self.assertDeviceChecks(dc, op, [C, X, Y], [0])
92
        self.assertReferenceChecks(gc, op, [C, X, Y], mux)
93

94
    @given(N=st.integers(min_value=1, max_value=10),
95
           engine=st.sampled_from(["", "CUDNN"]),
96
           **hu.gcs_cpu_only)
97
    def test_rowwhere_dim2(self, N, gc, dc, engine):
98
        C = np.random.rand(N).astype(bool)
99
        X = np.random.rand(N, N).astype(np.float32)
100
        Y = np.random.rand(N, N).astype(np.float32)
101
        op = core.CreateOperator(
102
            "Where",
103
            ["C", "X", "Y"],
104
            ["Z"],
105
            broadcast_on_rows=True,
106
            engine=engine,
107
        )
108
        self.assertDeviceChecks(dc, op, [C, X, Y], [0])
109
        self.assertReferenceChecks(gc, op, [C, X, Y], rowmux)
110

111

112
class TestIsMemberOf(serial.SerializedTestCase):
113

114
    @given(N=st.integers(min_value=1, max_value=10),
115
           engine=st.sampled_from(["", "CUDNN"]),
116
           **hu.gcs_cpu_only)
117
    @settings(deadline=10000)
118
    def test_is_member_of(self, N, gc, dc, engine):
119
        X = np.random.randint(10, size=N).astype(np.int64)
120
        values = [0, 3, 4, 6, 8]
121
        op = core.CreateOperator(
122
            "IsMemberOf",
123
            ["X"],
124
            ["Y"],
125
            value=np.array(values),
126
            engine=engine,
127
        )
128
        self.assertDeviceChecks(dc, op, [X], [0])
129
        values = set(values)
130

131
        def test(x):
132
            return [np.vectorize(lambda x: x in values)(x)]
133
        self.assertReferenceChecks(gc, op, [X], test)
134

135

136
if __name__ == "__main__":
137
    unittest.main()
138

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.