pytorch

Форк
0
/
onnx_while_test.py 
94 строки · 3.0 Кб
1

2

3

4

5
from caffe2.proto import caffe2_pb2
6
from caffe2.python import core
7
import caffe2.python.hypothesis_test_util as hu
8
import caffe2.python.serialized_test.serialized_test_util as serial
9
from hypothesis import given, settings
10
import hypothesis.strategies as st
11
import numpy as np
12
import unittest
13

14

15
class TestONNXWhile(serial.SerializedTestCase):
16
    @given(
17
        condition=st.booleans(),
18
        max_trip_count=st.integers(0, 100),
19
        save_scopes=st.booleans(),
20
        disable_scopes=st.booleans(),
21
        seed=st.integers(0, 65535),
22
        **hu.gcs_cpu_only)
23
    @settings(deadline=10000)
24
    def test_onnx_while_fibb(
25
            self, condition, max_trip_count, save_scopes, disable_scopes, seed, gc, dc):
26
        np.random.seed(seed)
27
        if disable_scopes:
28
            save_scopes = False
29

30
        # Create body net
31
        body_net = caffe2_pb2.NetDef()
32
        # Two loop carried dependencies: first and second
33
        body_net.external_input.extend(['i', 'cond', 'first', 'second'])
34
        body_net.external_output.extend(['cond_new', 'second', 'third', 'third'])
35
        add_op = core.CreateOperator(
36
            'Add',
37
            ['first', 'second'],
38
            ['third'],
39
        )
40
        print3 = core.CreateOperator(
41
            'Print',
42
            ['third'],
43
            [],
44
        )
45
        limit_const = core.CreateOperator(
46
            'ConstantFill',
47
            [],
48
            ['limit_const'],
49
            shape=[1],
50
            dtype=caffe2_pb2.TensorProto.FLOAT,
51
            value=100.0,
52
        )
53
        cond = core.CreateOperator(
54
            'LT',
55
            ['third', 'limit_const'],
56
            ['cond_new'],
57
        )
58
        body_net.op.extend([add_op, print3, limit_const, cond])
59

60
        while_op = core.CreateOperator(
61
            'ONNXWhile',
62
            ['max_trip_count', 'condition', 'first_init', 'second_init'],
63
            ['first_a', 'second_a', 'third_a'],
64
            body=body_net,
65
            has_cond=True,
66
            has_trip_count=True,
67
            save_scopes=save_scopes,
68
            disable_scopes=disable_scopes,
69
        )
70

71
        condition_arr = np.array(condition).astype(bool)
72
        max_trip_count_arr = np.array(max_trip_count).astype(np.int64)
73
        first_init = np.array([1]).astype(np.float32)
74
        second_init = np.array([1]).astype(np.float32)
75

76
        def ref(max_trip_count, condition, first_init, second_init):
77
            first = 1
78
            second = 1
79
            results = []
80
            if condition:
81
                for _ in range(max_trip_count):
82
                    third = first + second
83
                    first = second
84
                    second = third
85
                    results.append(third)
86
                    if third > 100:
87
                        break
88
            return (first, second, np.array(results).astype(np.float32))
89

90
        self.assertReferenceChecks(
91
            gc,
92
            while_op,
93
            [max_trip_count_arr, condition_arr, first_init, second_init],
94
            ref,
95
        )
96

97
if __name__ == "__main__":
98
    unittest.main()
99

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.