pytorch

Форк
0
40 строк · 1.2 Кб
1

2

3

4

5

6
from caffe2.python import core
7
import caffe2.python.hypothesis_test_util as hu
8
import caffe2.python.serialized_test.serialized_test_util as serial
9
from hypothesis import given, settings
10
import hypothesis.strategies as st
11
import numpy as np
12

13
import unittest
14

15

16
@st.composite
17
def _glu_old_input(draw):
18
    dims = draw(st.lists(st.integers(min_value=1, max_value=5), min_size=1, max_size=3))
19
    axis = draw(st.integers(min_value=0, max_value=len(dims)))
20
    # The axis dimension must be divisible by two
21
    axis_dim = 2 * draw(st.integers(min_value=1, max_value=2))
22
    dims.insert(axis, axis_dim)
23
    X = draw(hu.arrays(dims, np.float32, None))
24
    return (X, axis)
25

26

27
class TestGlu(serial.SerializedTestCase):
28
    @given(
29
        X_axis=_glu_old_input(),
30
        **hu.gcs
31
    )
32
    @settings(deadline=10000)
33
    def test_glu_old(self, X_axis, gc, dc):
34
        X, axis = X_axis
35

36
        def glu_ref(X):
37
            x1, x2 = np.split(X, [X.shape[axis] // 2], axis=axis)
38
            Y = x1 * (1. / (1. + np.exp(-x2)))
39
            return [Y]
40

41
        op = core.CreateOperator("Glu", ["X"], ["Y"], dim=axis)
42
        self.assertReferenceChecks(gc, op, [X], glu_ref)
43

44
if __name__ == "__main__":
45
    unittest.main()
46

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.