pytorch
94 строки · 3.0 Кб
1
2
3
4
5from caffe2.proto import caffe2_pb26from caffe2.python import core7import caffe2.python.hypothesis_test_util as hu8import caffe2.python.serialized_test.serialized_test_util as serial9from hypothesis import given, settings10import hypothesis.strategies as st11import numpy as np12import unittest13
14
15class TestONNXWhile(serial.SerializedTestCase):16@given(17condition=st.booleans(),18max_trip_count=st.integers(0, 100),19save_scopes=st.booleans(),20disable_scopes=st.booleans(),21seed=st.integers(0, 65535),22**hu.gcs_cpu_only)23@settings(deadline=10000)24def test_onnx_while_fibb(25self, condition, max_trip_count, save_scopes, disable_scopes, seed, gc, dc):26np.random.seed(seed)27if disable_scopes:28save_scopes = False29
30# Create body net31body_net = caffe2_pb2.NetDef()32# Two loop carried dependencies: first and second33body_net.external_input.extend(['i', 'cond', 'first', 'second'])34body_net.external_output.extend(['cond_new', 'second', 'third', 'third'])35add_op = core.CreateOperator(36'Add',37['first', 'second'],38['third'],39)40print3 = core.CreateOperator(41'Print',42['third'],43[],44)45limit_const = core.CreateOperator(46'ConstantFill',47[],48['limit_const'],49shape=[1],50dtype=caffe2_pb2.TensorProto.FLOAT,51value=100.0,52)53cond = core.CreateOperator(54'LT',55['third', 'limit_const'],56['cond_new'],57)58body_net.op.extend([add_op, print3, limit_const, cond])59
60while_op = core.CreateOperator(61'ONNXWhile',62['max_trip_count', 'condition', 'first_init', 'second_init'],63['first_a', 'second_a', 'third_a'],64body=body_net,65has_cond=True,66has_trip_count=True,67save_scopes=save_scopes,68disable_scopes=disable_scopes,69)70
71condition_arr = np.array(condition).astype(bool)72max_trip_count_arr = np.array(max_trip_count).astype(np.int64)73first_init = np.array([1]).astype(np.float32)74second_init = np.array([1]).astype(np.float32)75
76def ref(max_trip_count, condition, first_init, second_init):77first = 178second = 179results = []80if condition:81for _ in range(max_trip_count):82third = first + second83first = second84second = third85results.append(third)86if third > 100:87break88return (first, second, np.array(results).astype(np.float32))89
90self.assertReferenceChecks(91gc,92while_op,93[max_trip_count_arr, condition_arr, first_init, second_init],94ref,95)96
97if __name__ == "__main__":98unittest.main()99