pytorch

Форк
0
/
upsample_op_test.py 
197 строк · 7.1 Кб
1
# Copyright (c) 2016-present, Facebook, Inc.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14
##############################################################################
15

16

17

18

19

20
from caffe2.python import core
21
import caffe2.python.hypothesis_test_util as hu
22
import caffe2.python.serialized_test.serialized_test_util as serial
23
from hypothesis import given, settings
24
import hypothesis.strategies as st
25
import numpy as np
26
import unittest
27

28

29
class TestUpSample(serial.SerializedTestCase):
30
    @given(height_scale=st.floats(1.0, 4.0) | st.just(2.0),
31
           width_scale=st.floats(1.0, 4.0) | st.just(2.0),
32
           height=st.integers(4, 32),
33
           width=st.integers(4, 32),
34
           num_channels=st.integers(1, 4),
35
           batch_size=st.integers(1, 4),
36
           seed=st.integers(0, 65535),
37
           **hu.gcs)
38
    @settings(max_examples=50, deadline=None)
39
    def test_upsample(self, height_scale, width_scale, height, width,
40
                     num_channels, batch_size, seed,
41
                     gc, dc):
42

43
        np.random.seed(seed)
44

45
        X = np.random.rand(
46
            batch_size, num_channels, height, width).astype(np.float32)
47
        scales = np.array([height_scale, width_scale]).astype(np.float32)
48

49
        ops = [
50
            (
51
                core.CreateOperator(
52
                    "UpsampleBilinear",
53
                    ["X"],
54
                    ["Y"],
55
                    width_scale=width_scale,
56
                    height_scale=height_scale,
57
                ),
58
                [X],
59
            ),
60
            (
61
                core.CreateOperator(
62
                    "UpsampleBilinear",
63
                    ["X", "scales"],
64
                    ["Y"],
65
                ),
66
                [X, scales],
67
            ),
68
        ]
69

70
        for op, inputs in ops:
71
            def ref(X, scales=None):
72
                output_height = np.int32(height * height_scale)
73
                output_width = np.int32(width * width_scale)
74

75
                Y = np.random.rand(
76
                    batch_size, num_channels, output_height,
77
                    output_width).astype(np.float32)
78

79
                rheight = ((height - 1) / (output_height - 1)
80
                        if output_height > 1
81
                        else float(0))
82
                rwidth = ((width - 1) / (output_width - 1)
83
                        if output_width > 1
84
                        else float(0))
85

86
                for i in range(output_height):
87
                    h1r = rheight * i
88
                    h1 = int(h1r)
89
                    h1p = 1 if h1 < height - 1 else 0
90
                    h1lambda = h1r - h1
91
                    h0lambda = float(1) - h1lambda
92
                    for j in range(output_width):
93
                        w1r = rwidth * j
94
                        w1 = int(w1r)
95
                        w1p = 1 if w1 < width - 1 else 0
96
                        w1lambda = w1r - w1
97
                        w0lambda = float(1) - w1lambda
98
                        Y[:, :, i, j] = (h0lambda * (
99
                            w0lambda * X[:, :, h1, w1] +
100
                            w1lambda * X[:, :, h1, w1 + w1p]) +
101
                            h1lambda * (w0lambda * X[:, :, h1 + h1p, w1] +
102
                            w1lambda * X[:, :, h1 + h1p, w1 + w1p]))
103

104
                return Y,
105

106
            self.assertReferenceChecks(gc, op, inputs, ref)
107
            self.assertDeviceChecks(dc, op, inputs, [0])
108
            self.assertGradientChecks(gc, op, inputs, 0, [0], stepsize=0.1,
109
                                      threshold=1e-2)
110

111
    @given(height_scale=st.floats(1.0, 4.0) | st.just(2.0),
112
           width_scale=st.floats(1.0, 4.0) | st.just(2.0),
113
           height=st.integers(4, 32),
114
           width=st.integers(4, 32),
115
           num_channels=st.integers(1, 4),
116
           batch_size=st.integers(1, 4),
117
           seed=st.integers(0, 65535),
118
           **hu.gcs)
119
    @settings(deadline=10000)
120
    def test_upsample_grad(self, height_scale, width_scale, height, width,
121
                          num_channels, batch_size, seed, gc, dc):
122

123
        np.random.seed(seed)
124

125
        output_height = np.int32(height * height_scale)
126
        output_width = np.int32(width * width_scale)
127
        X = np.random.rand(batch_size,
128
                           num_channels,
129
                           height,
130
                           width).astype(np.float32)
131
        dY = np.random.rand(batch_size,
132
                            num_channels,
133
                            output_height,
134
                            output_width).astype(np.float32)
135
        scales = np.array([height_scale, width_scale]).astype(np.float32)
136

137
        ops = [
138
            (
139
                core.CreateOperator(
140
                    "UpsampleBilinearGradient",
141
                    ["dY", "X"],
142
                    ["dX"],
143
                    width_scale=width_scale,
144
                    height_scale=height_scale,
145
                ),
146
                [dY, X],
147
            ),
148
            (
149
                core.CreateOperator(
150
                    "UpsampleBilinearGradient",
151
                    ["dY", "X", "scales"],
152
                    ["dX"],
153
                ),
154
                [dY, X, scales],
155
            ),
156
        ]
157

158
        for op, inputs in ops:
159
            def ref(dY, X, scales=None):
160
                dX = np.zeros_like(X)
161

162
                rheight = ((height - 1) / (output_height - 1)
163
                        if output_height > 1
164
                        else float(0))
165
                rwidth = ((width - 1) / (output_width - 1)
166
                        if output_width > 1
167
                        else float(0))
168

169
                for i in range(output_height):
170
                    h1r = rheight * i
171
                    h1 = int(h1r)
172
                    h1p = 1 if h1 < height - 1 else 0
173
                    h1lambda = h1r - h1
174
                    h0lambda = float(1) - h1lambda
175
                    for j in range(output_width):
176
                        w1r = rwidth * j
177
                        w1 = int(w1r)
178
                        w1p = 1 if w1 < width - 1 else 0
179
                        w1lambda = w1r - w1
180
                        w0lambda = float(1) - w1lambda
181
                        dX[:, :, h1, w1] += (
182
                            h0lambda * w0lambda * dY[:, :, i, j])
183
                        dX[:, :, h1, w1 + w1p] += (
184
                            h0lambda * w1lambda * dY[:, :, i, j])
185
                        dX[:, :, h1 + h1p, w1] += (
186
                            h1lambda * w0lambda * dY[:, :, i, j])
187
                        dX[:, :, h1 + h1p, w1 + w1p] += (
188
                            h1lambda * w1lambda * dY[:, :, i, j])
189

190
                return dX,
191

192
            self.assertDeviceChecks(dc, op, inputs, [0])
193
            self.assertReferenceChecks(gc, op, inputs, ref)
194

195

196
if __name__ == "__main__":
197
    unittest.main()
198

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.