pytorch

Форк
0
/
depthwise_3x3_conv_test.py 
51 строка · 1.8 Кб
1

2

3

4

5

6
import numpy as np
7
import caffe2.python.hypothesis_test_util as hu
8
from caffe2.python import core, utils
9
from hypothesis import given, settings
10
import hypothesis.strategies as st
11

12

13
class Depthwise3x3ConvOpsTest(hu.HypothesisTestCase):
14
    @given(pad=st.integers(0, 1),
15
           kernel=st.integers(3, 3),
16
           size=st.integers(4, 8),
17
           channels=st.integers(2, 4),
18
           batch_size=st.integers(1, 1),
19
           order=st.sampled_from(["NCHW"]),
20
           engine=st.sampled_from(["DEPTHWISE_3x3"]),
21
           use_bias=st.booleans(),
22
           **hu.gcs)
23
    @settings(deadline=10000)
24
    def test_convolution_gradients(self, pad, kernel, size,
25
                                   channels, batch_size,
26
                                   order, engine, use_bias, gc, dc):
27
        op = core.CreateOperator(
28
            "Conv",
29
            ["X", "w", "b"] if use_bias else ["X", "w"],
30
            ["Y"],
31
            kernel=kernel,
32
            pad=pad,
33
            group=channels,
34
            order=order,
35
            engine=engine,
36
        )
37
        X = np.random.rand(
38
            batch_size, size, size, channels).astype(np.float32) - 0.5
39
        w = np.random.rand(
40
            channels, kernel, kernel, 1).astype(np.float32)\
41
            - 0.5
42
        b = np.random.rand(channels).astype(np.float32) - 0.5
43
        if order == "NCHW":
44
            X = utils.NHWC2NCHW(X)
45
            w = utils.NHWC2NCHW(w)
46

47
        inputs = [X, w, b] if use_bias else [X, w]
48
        # Error handling path.
49
        if size + pad + pad < kernel or size + pad + pad < kernel:
50
            with self.assertRaises(RuntimeError):
51
                self.assertDeviceChecks(dc, op, inputs, [0])
52
            return
53

54
        self.assertDeviceChecks(dc, op, inputs, [0])
55
        for i in range(len(inputs)):
56
            self.assertGradientChecks(gc, op, inputs, i, [0])
57

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.