pytorch

Форк
0
/
flatten_op_test.py 
33 строки · 922.0 Байт
1

2

3

4

5

6
from hypothesis import given
7
import numpy as np
8

9
from caffe2.python import core
10
import caffe2.python.hypothesis_test_util as hu
11

12

13
class TestFlatten(hu.HypothesisTestCase):
14
    @given(X=hu.tensor(min_dim=2, max_dim=4),
15
           **hu.gcs)
16
    def test_flatten(self, X, gc, dc):
17
        for axis in range(X.ndim + 1):
18
            op = core.CreateOperator(
19
                "Flatten",
20
                ["X"],
21
                ["Y"],
22
                axis=axis)
23

24
            def flatten_ref(X):
25
                shape = X.shape
26
                outer = np.prod(shape[:axis]).astype(int)
27
                inner = np.prod(shape[axis:]).astype(int)
28
                return np.copy(X).reshape(outer, inner),
29

30
            self.assertReferenceChecks(gc, op, [X], flatten_ref)
31

32
        # Check over multiple devices
33
        self.assertDeviceChecks(dc, op, [X], [0])
34

35

36
if __name__ == "__main__":
37
    import unittest
38
    unittest.main()
39

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.