pytorch

Форк
0
/
leaky_relu_test.py 
173 строки · 5.5 Кб
1

2

3

4

5
import numpy as np
6
from hypothesis import given, assume
7
import hypothesis.strategies as st
8

9
from caffe2.python import core, model_helper, utils
10
import caffe2.python.hypothesis_test_util as hu
11

12

13
class TestLeakyRelu(hu.HypothesisTestCase):
14

15
    def _get_inputs(self, N, C, H, W, order):
16
        input_data = np.random.rand(N, C, H, W).astype(np.float32) - 0.5
17

18
        # default step size is 0.05
19
        input_data[np.logical_and(
20
            input_data >= 0, input_data <= 0.051)] = 0.051
21
        input_data[np.logical_and(
22
            input_data <= 0, input_data >= -0.051)] = -0.051
23

24
        if order == 'NHWC':
25
            input_data = utils.NCHW2NHWC(input_data)
26

27
        return input_data,
28

29
    def _get_op(self, device_option, alpha, order, inplace=False):
30
        outputs = ['output' if not inplace else "input"]
31
        op = core.CreateOperator(
32
            'LeakyRelu',
33
            ['input'],
34
            outputs,
35
            alpha=alpha,
36
            device_option=device_option)
37
        return op
38

39
    def _feed_inputs(self, input_blobs, device_option):
40
        names = ['input', 'scale', 'bias']
41
        for name, blob in zip(names, input_blobs):
42
            self.ws.create_blob(name).feed(blob, device_option=device_option)
43

44
    @given(gc=hu.gcs['gc'],
45
           dc=hu.gcs['dc'],
46
           N=st.integers(2, 3),
47
           C=st.integers(2, 3),
48
           H=st.integers(2, 3),
49
           W=st.integers(2, 3),
50
           alpha=st.floats(0, 1),
51
           order=st.sampled_from(['NCHW', 'NHWC']),
52
           seed=st.integers(0, 1000))
53
    def test_leaky_relu_gradients(self, gc, dc, N, C, H, W, order, alpha, seed):
54
        np.random.seed(seed)
55

56
        op = self._get_op(
57
            device_option=gc,
58
            alpha=alpha,
59
            order=order)
60
        input_blobs = self._get_inputs(N, C, H, W, order)
61

62
        self.assertDeviceChecks(dc, op, input_blobs, [0])
63
        self.assertGradientChecks(gc, op, input_blobs, 0, [0])
64

65
    @given(gc=hu.gcs['gc'],
66
           dc=hu.gcs['dc'],
67
           N=st.integers(2, 10),
68
           C=st.integers(3, 10),
69
           H=st.integers(5, 10),
70
           W=st.integers(7, 10),
71
           alpha=st.floats(0, 1),
72
           seed=st.integers(0, 1000))
73
    def test_leaky_relu_layout(self, gc, dc, N, C, H, W, alpha, seed):
74
        outputs = {}
75
        for order in ('NCHW', 'NHWC'):
76
            np.random.seed(seed)
77
            input_blobs = self._get_inputs(N, C, H, W, order)
78
            self._feed_inputs(input_blobs, device_option=gc)
79
            op = self._get_op(
80
                device_option=gc,
81
                alpha=alpha,
82
                order=order)
83
            self.ws.run(op)
84
            outputs[order] = self.ws.blobs['output'].fetch()
85
        np.testing.assert_allclose(
86
            outputs['NCHW'],
87
            utils.NHWC2NCHW(outputs["NHWC"]),
88
            atol=1e-4,
89
            rtol=1e-4)
90

91
    @given(gc=hu.gcs['gc'],
92
           dc=hu.gcs['dc'],
93
           N=st.integers(2, 10),
94
           C=st.integers(3, 10),
95
           H=st.integers(5, 10),
96
           W=st.integers(7, 10),
97
           order=st.sampled_from(['NCHW', 'NHWC']),
98
           alpha=st.floats(0, 1),
99
           seed=st.integers(0, 1000),
100
           inplace=st.booleans())
101
    def test_leaky_relu_reference_check(self, gc, dc, N, C, H, W, order, alpha,
102
                                        seed, inplace):
103
        np.random.seed(seed)
104

105
        if order != "NCHW":
106
            assume(not inplace)
107

108
        inputs = self._get_inputs(N, C, H, W, order)
109
        op = self._get_op(
110
            device_option=gc,
111
            alpha=alpha,
112
            order=order,
113
            inplace=inplace)
114

115
        def ref(input_blob):
116
            result = input_blob.copy()
117
            result[result < 0] *= alpha
118
            return result,
119

120
        self.assertReferenceChecks(gc, op, inputs, ref)
121

122
    @given(gc=hu.gcs['gc'],
123
           dc=hu.gcs['dc'],
124
           N=st.integers(2, 10),
125
           C=st.integers(3, 10),
126
           H=st.integers(5, 10),
127
           W=st.integers(7, 10),
128
           order=st.sampled_from(['NCHW', 'NHWC']),
129
           alpha=st.floats(0, 1),
130
           seed=st.integers(0, 1000))
131
    def test_leaky_relu_device_check(self, gc, dc, N, C, H, W, order, alpha,
132
                                     seed):
133
        np.random.seed(seed)
134

135
        inputs = self._get_inputs(N, C, H, W, order)
136
        op = self._get_op(
137
            device_option=gc,
138
            alpha=alpha,
139
            order=order)
140

141
        self.assertDeviceChecks(dc, op, inputs, [0])
142

143
    @given(N=st.integers(2, 10),
144
           C=st.integers(3, 10),
145
           H=st.integers(5, 10),
146
           W=st.integers(7, 10),
147
           order=st.sampled_from(['NCHW', 'NHWC']),
148
           alpha=st.floats(0, 1),
149
           seed=st.integers(0, 1000))
150
    def test_leaky_relu_model_helper_helper(self, N, C, H, W, order, alpha, seed):
151
        np.random.seed(seed)
152
        arg_scope = {'order': order}
153
        model = model_helper.ModelHelper(name="test_model", arg_scope=arg_scope)
154
        model.LeakyRelu(
155
            'input',
156
            'output',
157
            alpha=alpha)
158

159
        input_blob = np.random.rand(N, C, H, W).astype(np.float32)
160
        if order == 'NHWC':
161
            input_blob = utils.NCHW2NHWC(input_blob)
162

163
        self.ws.create_blob('input').feed(input_blob)
164

165
        self.ws.create_net(model.param_init_net).run()
166
        self.ws.create_net(model.net).run()
167

168
        output_blob = self.ws.blobs['output'].fetch()
169
        if order == 'NHWC':
170
            output_blob = utils.NHWC2NCHW(output_blob)
171

172
        assert output_blob.shape == (N, C, H, W)
173

174

175
if __name__ == '__main__':
176
    import unittest
177
    unittest.main()
178

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.