pytorch
69 строк · 2.8 Кб
1
2
3
4
5
6from caffe2.python import core, workspace7import caffe2.python.hypothesis_test_util as hu8import caffe2.python.serialized_test.serialized_test_util as serial9import hypothesis.strategies as st10from hypothesis import given, settings11import numpy as np12
13
14class TestIndexHashOps(serial.SerializedTestCase):15@given(16indices=st.sampled_from([17np.int32, np.int6418]).flatmap(lambda dtype: hu.tensor(min_dim=1, max_dim=1, dtype=dtype)),19seed=st.integers(min_value=0, max_value=10),20modulo=st.integers(min_value=100000, max_value=200000),21**hu.gcs_cpu_only22)23@settings(deadline=10000)24def test_index_hash_ops(self, indices, seed, modulo, gc, dc):25def index_hash(indices):26dtype = np.array(indices).dtype27assert dtype == np.int32 or dtype == np.int6428hashed_indices = []29for index in indices:30hashed = dtype.type(0xDEADBEEF * seed)31indices_bytes = np.array([index], dtype).view(np.int8)32for b in indices_bytes:33hashed = dtype.type(hashed * 65537 + b)34hashed = (modulo + hashed % modulo) % modulo35hashed_indices.append(hashed)36return [hashed_indices]37
38op = core.CreateOperator("IndexHash",39["indices"], ["hashed_indices"],40seed=seed, modulo=modulo)41
42self.assertDeviceChecks(dc, op, [indices], [0])43self.assertReferenceChecks(gc, op, [indices], index_hash)44
45# In-place update46op = core.CreateOperator("IndexHash",47["indices"], ["indices"],48seed=seed, modulo=modulo)49
50self.assertDeviceChecks(dc, op, [indices], [0])51self.assertReferenceChecks(gc, op, [indices], index_hash)52
53def test_shape_and_type_inference(self):54with hu.temp_workspace("shape_type_inf_int64"):55net = core.Net('test_net')56net.ConstantFill(57[], "values", shape=[64], dtype=core.DataType.INT64,58)59net.IndexHash(['values'], ['values_output'])60(shapes, types) = workspace.InferShapesAndTypes([net], {})61
62self.assertEqual(shapes["values_output"], [64])63self.assertEqual(types["values_output"], core.DataType.INT64)64
65with hu.temp_workspace("shape_type_inf_int32"):66net = core.Net('test_net')67net.ConstantFill(68[], "values", shape=[2, 32], dtype=core.DataType.INT32,69)70net.IndexHash(['values'], ['values_output'])71(shapes, types) = workspace.InferShapesAndTypes([net], {})72
73self.assertEqual(shapes["values_output"], [2, 32])74self.assertEqual(types["values_output"], core.DataType.INT32)75