pytorch
31 строка · 945.0 Байт
1#!/usr/bin/env python3
2
3import caffe2.python.hypothesis_test_util as hu4import hypothesis.strategies as st5import numpy as np6from caffe2.python import core7from hypothesis import given8
9
10class TestAsyncNetBarrierOp(hu.HypothesisTestCase):11@given(12n=st.integers(1, 5),13shape=st.lists(st.integers(0, 5), min_size=1, max_size=3),14**hu.gcs15)16def test_async_net_barrier_op(self, n, shape, dc, gc):17test_inputs = [(100 * np.random.random(shape)).astype(np.float32) for _ in range(n)]18test_input_blobs = ["x_{}".format(i) for i in range(n)]19
20barrier_op = core.CreateOperator(21"AsyncNetBarrier",22test_input_blobs,23test_input_blobs,24device_option=gc,25)26
27def reference_func(*args):28self.assertEqual(len(args), n)29return args30
31self.assertReferenceChecks(gc, barrier_op, test_inputs, reference_func)32