pytorch

Форк
0
/
index_hash_ops_test.py 
69 строк · 2.8 Кб
1

2

3

4

5

6
from caffe2.python import core, workspace
7
import caffe2.python.hypothesis_test_util as hu
8
import caffe2.python.serialized_test.serialized_test_util as serial
9
import hypothesis.strategies as st
10
from hypothesis import given, settings
11
import numpy as np
12

13

14
class TestIndexHashOps(serial.SerializedTestCase):
15
    @given(
16
        indices=st.sampled_from([
17
            np.int32, np.int64
18
        ]).flatmap(lambda dtype: hu.tensor(min_dim=1, max_dim=1, dtype=dtype)),
19
        seed=st.integers(min_value=0, max_value=10),
20
        modulo=st.integers(min_value=100000, max_value=200000),
21
        **hu.gcs_cpu_only
22
    )
23
    @settings(deadline=10000)
24
    def test_index_hash_ops(self, indices, seed, modulo, gc, dc):
25
        def index_hash(indices):
26
            dtype = np.array(indices).dtype
27
            assert dtype == np.int32 or dtype == np.int64
28
            hashed_indices = []
29
            for index in indices:
30
                hashed = dtype.type(0xDEADBEEF * seed)
31
                indices_bytes = np.array([index], dtype).view(np.int8)
32
                for b in indices_bytes:
33
                    hashed = dtype.type(hashed * 65537 + b)
34
                hashed = (modulo + hashed % modulo) % modulo
35
                hashed_indices.append(hashed)
36
            return [hashed_indices]
37

38
        op = core.CreateOperator("IndexHash",
39
                                 ["indices"], ["hashed_indices"],
40
                                 seed=seed, modulo=modulo)
41

42
        self.assertDeviceChecks(dc, op, [indices], [0])
43
        self.assertReferenceChecks(gc, op, [indices], index_hash)
44

45
        # In-place update
46
        op = core.CreateOperator("IndexHash",
47
                                 ["indices"], ["indices"],
48
                                 seed=seed, modulo=modulo)
49

50
        self.assertDeviceChecks(dc, op, [indices], [0])
51
        self.assertReferenceChecks(gc, op, [indices], index_hash)
52

53
    def test_shape_and_type_inference(self):
54
        with hu.temp_workspace("shape_type_inf_int64"):
55
            net = core.Net('test_net')
56
            net.ConstantFill(
57
                [], "values", shape=[64], dtype=core.DataType.INT64,
58
            )
59
            net.IndexHash(['values'], ['values_output'])
60
            (shapes, types) = workspace.InferShapesAndTypes([net], {})
61

62
            self.assertEqual(shapes["values_output"], [64])
63
            self.assertEqual(types["values_output"], core.DataType.INT64)
64

65
        with hu.temp_workspace("shape_type_inf_int32"):
66
            net = core.Net('test_net')
67
            net.ConstantFill(
68
                [], "values", shape=[2, 32], dtype=core.DataType.INT32,
69
            )
70
            net.IndexHash(['values'], ['values_output'])
71
            (shapes, types) = workspace.InferShapesAndTypes([net], {})
72

73
            self.assertEqual(shapes["values_output"], [2, 32])
74
            self.assertEqual(types["values_output"], core.DataType.INT32)
75

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.