pytorch

Форк
0
/
im2col_col2im_test.py 
133 строки · 4.2 Кб
1

2

3

4

5

6
from caffe2.python import core
7
from hypothesis import assume, given, settings
8

9
import caffe2.python.hypothesis_test_util as hu
10
import hypothesis.strategies as st
11
import numpy as np
12

13

14
class TestReduceFrontSum(hu.HypothesisTestCase):
15
    @given(batch_size=st.integers(1, 3),
16
           stride=st.integers(1, 3),
17
           pad=st.integers(0, 3),
18
           kernel=st.integers(1, 5),
19
           dilation=st.integers(1, 3),
20
           size=st.integers(7, 10),
21
           channels=st.integers(1, 8),
22
           **hu.gcs)
23
    def test_im2col_layout(self, batch_size, stride, pad, kernel, dilation,
24
                           size, channels, gc, dc):
25

26
        dkernel = (dilation * (kernel - 1) + 1)
27
        assume(size >= dkernel)
28

29
        NCHW_TO_NHWC = (0, 2, 3, 1)
30
        NHWC_TO_NCHW = (0, 3, 1, 2)
31
        COL_NHWC_TO_NCHW = (4, 2, 3, 0, 1)
32

33
        N = batch_size
34
        C = channels
35
        H = size
36
        W = size
37

38
        out_h = int((H + (2 * pad) - dkernel) / stride + 1)
39
        out_w = int((W + (2 * pad) - dkernel) / stride + 1)
40

41
        im_nchw = np.random.rand(N, C, H, W).astype(np.float32) - 0.5
42
        im_nhwc = im_nchw.transpose(NCHW_TO_NHWC)
43

44
        op_im2col_nchw = core.CreateOperator(
45
            "Im2Col",
46
            ["im_nchw"], ["col_nchw"],
47
            stride=stride,
48
            kernel=kernel,
49
            dilation=dilation,
50
            pad=pad,
51
            order="NCHW",
52
            device_option=gc)
53

54
        op_im2col_nhwc = core.CreateOperator(
55
            "Im2Col",
56
            ["im_nhwc"], ["col_nhwc"],
57
            stride=stride,
58
            kernel=kernel,
59
            dilation=dilation,
60
            pad=pad,
61
            order="NHWC",
62
            device_option=gc)
63

64
        self.ws.create_blob("im_nchw").feed(im_nchw, device_option=gc)
65
        self.ws.create_blob("im_nhwc").feed(im_nhwc, device_option=gc)
66
        self.ws.run(op_im2col_nchw)
67
        self.ws.run(op_im2col_nhwc)
68

69
        # there is probably a clever way to spell this in np
70
        col_nchw = self.ws.blobs["col_nchw"].fetch()
71
        col_nhwc = self.ws.blobs["col_nhwc"].fetch()
72
        col_nchw_ = col_nchw.reshape(N, C, kernel, kernel, out_h, out_w)
73
        col_nhwc_ = col_nhwc.reshape(N, out_h, out_w, kernel, kernel, C)
74
        for i in range(0, N):
75
            np.testing.assert_allclose(
76
                col_nchw_[i],
77
                col_nhwc_[i].transpose(COL_NHWC_TO_NCHW),
78
                atol=1e-4,
79
                rtol=1e-4)
80

81
        op_col2im_nchw = core.CreateOperator(
82
            "Col2Im",
83
            ["col_nchw", "im_nchw"],
84
            ["out_nchw"],
85
            stride=stride,
86
            kernel=kernel,
87
            dilation=dilation,
88
            pad=pad,
89
            order="NCHW",
90
            device_option=gc)
91

92
        op_col2im_nhwc = core.CreateOperator(
93
            "Col2Im",
94
            ["col_nhwc", "im_nhwc"],
95
            ["out_nhwc"],
96
            stride=stride,
97
            kernel=kernel,
98
            dilation=dilation,
99
            pad=pad,
100
            order="NHWC",
101
            device_option=gc)
102

103
        self.ws.run(op_col2im_nchw)
104
        self.ws.run(op_col2im_nhwc)
105

106
        out_nchw = self.ws.blobs["out_nchw"].fetch()
107
        out_nhwc = self.ws.blobs["out_nhwc"].fetch()
108
        np.testing.assert_allclose(
109
            out_nchw,
110
            out_nhwc.transpose(NHWC_TO_NCHW),
111
            atol=1e-4,
112
            rtol=1e-4)
113

114
    @given(batch_size=st.integers(1, 3),
115
           stride=st.integers(1, 3),
116
           pad=st.integers(0, 3),
117
           kernel=st.integers(1, 5),
118
           dilation=st.integers(1, 3),
119
           size=st.integers(7, 10),
120
           channels=st.integers(1, 8),
121
           order=st.sampled_from(["NCHW"]),
122
           **hu.gcs)
123
    @settings(deadline=10000)
124
    def test_col2im_gradients(self, batch_size, stride, pad, kernel,
125
                              dilation, size, channels, order, gc, dc):
126
        assume(size >= dilation * (kernel - 1) + 1)
127
        op = core.CreateOperator(
128
            "Im2Col",
129
            ["X"], ["Y"],
130
            stride=stride,
131
            kernel=kernel,
132
            dilation=dilation,
133
            pad=pad,
134
            order=order,
135
            device_option=gc)
136
        X = np.random.rand(batch_size, channels, size, size).astype(np.float32)
137
        self.assertGradientChecks(gc, op, [X], 0, [0])
138
        return
139

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.