pytorch

Форк
0
/
batch_bucketize_op_test.py 
86 строк · 3.6 Кб
1

2

3

4

5

6
import numpy as np
7

8
from caffe2.python import core
9
import caffe2.python.hypothesis_test_util as hu
10
import caffe2.python.serialized_test.serialized_test_util as serial
11
from hypothesis import given
12
import hypothesis.strategies as st
13

14

15
class TestBatchBucketize(serial.SerializedTestCase):
16
    @serial.given(**hu.gcs_cpu_only)
17
    def test_batch_bucketize_example(self, gc, dc):
18
        op = core.CreateOperator('BatchBucketize',
19
                                 ["FEATURE", "INDICES", "BOUNDARIES", "LENGTHS"],
20
                                 ["O"])
21
        float_feature = np.array([[1.42, 2.07, 3.19, 0.55, 4.32],
22
                                  [4.57, 2.30, 0.84, 4.48, 3.09],
23
                                  [0.89, 0.26, 2.41, 0.47, 1.05],
24
                                  [0.03, 2.97, 2.43, 4.36, 3.11],
25
                                  [2.74, 5.77, 0.90, 2.63, 0.38]], dtype=np.float32)
26
        indices = np.array([0, 1, 4], dtype=np.int32)
27
        lengths = np.array([2, 3, 1], dtype=np.int32)
28
        boundaries = np.array([0.5, 1.0, 1.5, 2.5, 3.5, 2.5], dtype=np.float32)
29

30
        def ref(float_feature, indices, boundaries, lengths):
31
            output = np.array([[2, 1, 1],
32
                               [2, 1, 1],
33
                               [1, 0, 0],
34
                               [0, 2, 1],
35
                               [2, 3, 0]], dtype=np.int32)
36
            return (output,)
37

38
        self.assertReferenceChecks(gc, op,
39
                                   [float_feature, indices, boundaries, lengths],
40
                                   ref)
41

42
    @given(
43
        x=hu.tensor(
44
            min_dim=2, max_dim=2, dtype=np.float32,
45
            elements=hu.floats(min_value=0, max_value=5),
46
            min_value=5),
47
        seed=st.integers(min_value=2, max_value=1000),
48
        **hu.gcs_cpu_only)
49
    def test_batch_bucketize(self, x, seed, gc, dc):
50
        op = core.CreateOperator('BatchBucketize',
51
                                 ["FEATURE", "INDICES", "BOUNDARIES", "LENGTHS"],
52
                                 ['O'])
53
        np.random.seed(seed)
54
        d = x.shape[1]
55
        lens = np.random.randint(low=1, high=3, size=d - 3)
56
        indices = np.random.choice(range(d), d - 3, replace=False)
57
        indices.sort()
58
        boundaries = []
59
        for i in range(d - 3):
60
            # add [0, 0] as duplicated boundary for duplicated bucketization
61
            if lens[i] > 2:
62
                cur_boundary = np.append(
63
                    np.random.randn(lens[i] - 2) * 5, [0, 0])
64
            else:
65
                cur_boundary = np.random.randn(lens[i]) * 5
66
            cur_boundary.sort()
67
            boundaries += cur_boundary.tolist()
68

69
        lens = np.array(lens, dtype=np.int32)
70
        boundaries = np.array(boundaries, dtype=np.float32)
71
        indices = np.array(indices, dtype=np.int32)
72

73
        def ref(x, indices, boundaries, lens):
74
            output_dim = indices.shape[0]
75
            ret = np.zeros((x.shape[0], output_dim)).astype(np.int32)
76
            boundary_offset = 0
77
            for i, l in enumerate(indices):
78
                temp_bound = boundaries[boundary_offset : lens[i] + boundary_offset]
79
                for j in range(x.shape[0]):
80
                    for k, bound_val in enumerate(temp_bound):
81
                        if k == len(temp_bound) - 1 and x[j, l] > bound_val:
82
                            ret[j, i] = k + 1
83
                        elif x[j, l] > bound_val:
84
                            continue
85
                        else:
86
                            ret[j, i] = k
87
                            break
88
                boundary_offset += lens[i]
89
            return (ret,)
90

91
        self.assertReferenceChecks(gc, op, [x, indices, boundaries, lens], ref)
92

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.