pytorch
42 строки · 1.6 Кб
1
2
3
4
5
6
7from caffe2.python import core, workspace8import caffe2.python.hypothesis_test_util as hu9
10from hypothesis import given11import numpy as np12
13
14class TestCastOp(hu.HypothesisTestCase):15
16@given(**hu.gcs)17def test_cast_int_float(self, gc, dc):18data = np.random.rand(5, 5).astype(np.int32)19# from int to float20op = core.CreateOperator('Cast', 'data', 'data_cast', to=1, from_type=2)21self.assertDeviceChecks(dc, op, [data], [0])22# This is actually 023self.assertGradientChecks(gc, op, [data], 0, [0])24
25@given(**hu.gcs)26def test_cast_int_float_empty(self, gc, dc):27data = np.random.rand(0).astype(np.int32)28# from int to float29op = core.CreateOperator('Cast', 'data', 'data_cast', to=1, from_type=2)30self.assertDeviceChecks(dc, op, [data], [0])31# This is actually 032self.assertGradientChecks(gc, op, [data], 0, [0])33
34@given(data=hu.tensor(dtype=np.int32), **hu.gcs_cpu_only)35def test_cast_int_to_string(self, data, gc, dc):36op = core.CreateOperator(37'Cast', 'data', 'data_cast', to=core.DataType.STRING)38
39def ref(data):40ret = data.astype(dtype=str)41# the string blob will be fetched as object, we feed and re-fetch42# to mimic this.43with hu.temp_workspace('tmp_ref_int_to_string'):44workspace.FeedBlob('tmp_blob', ret)45fetched_ret = workspace.FetchBlob('tmp_blob')46return (fetched_ret, )47
48self.assertReferenceChecks(gc, op, inputs=[data], reference=ref)49