pytorch

Форк
0
/
quantile_test.py 
84 строки · 3.2 Кб
1

2

3
import unittest
4

5
import caffe2.python.hypothesis_test_util as hu
6
import numpy as np
7
from caffe2.python import core, workspace
8

9

10
class TestQuantile(hu.HypothesisTestCase):
11
    def _test_quantile(self, inputs, quantile, abs, tol):
12
        net = core.Net("test_net")
13
        net.Proto().type = "dag"
14
        input_tensors = []
15
        for i, input in enumerate(inputs):
16
            workspace.FeedBlob("t_{}".format(i), input)
17
            input_tensors.append("t_{}".format(i))
18
        net.Quantile(
19
            input_tensors, ["quantile_value"], quantile=quantile, abs=abs, tol=tol
20
        )
21
        workspace.RunNetOnce(net)
22
        quantile_value_blob = workspace.FetchBlob("quantile_value")
23
        assert np.size(quantile_value_blob) == 1
24
        quantile_value = quantile_value_blob[0]
25

26
        input_cat = np.concatenate([input.flatten() for input in inputs])
27
        input_cat = np.abs(input_cat) if abs == 1 else input_cat
28
        target_cnt = np.ceil(np.size(input_cat) * quantile)
29
        actual_cnt = np.sum(input_cat <= quantile_value)
30
        # prune with return value will remove no less than
31
        # "quantile" portion of  elements
32
        assert actual_cnt >= target_cnt
33
        # Expect that (hi-lo) < tol * (|lo| + |hi|)
34
        # if tol < 1.0 -> hi * lo > 0, then we are expecting
35
        # 1. if hi >0,
36
        #           |hi|-|lo| < tol * (|lo| + |hi|)
37
        #          hi - lo  < (2 tol) /(1 + tol)  |hi| < 2 tol |hi|
38
        # 2. if hi < 0,
39
        #           |lo|- |hi| < tol * (|lo| + |hi|)
40
        #          hi - lo  < (2 tol) /(1 - tol)  |hi| < 2.5 tol |hi| if tol < 0.2
41
        quantile_value_lo = quantile_value - 2.5 * tol * np.abs(quantile_value)
42
        lo_cnt = np.sum(input_cat <= quantile_value_lo)
43
        # prune with a slightly smaller value will remove
44
        # less than "quantile" portion of elements
45
        assert lo_cnt <= target_cnt
46

47
    def test_quantile_1(self):
48
        inputs = []
49
        num_tensors = 5
50
        for i in range(num_tensors):
51
            dim = np.random.randint(5, 100)
52
            inputs.append(np.random.rand(dim))
53
        self._test_quantile(inputs=inputs, quantile=0.2, abs=1, tol=1e-4)
54

55
    def test_quantile_2(self):
56
        inputs = []
57
        num_tensors = 5
58
        for i in range(num_tensors):
59
            dim = np.random.randint(5, 100)
60
            inputs.append(np.random.rand(dim))
61
        self._test_quantile(inputs=inputs, quantile=1e-6, abs=0, tol=1e-3)
62

63
    def test_quantile_3(self):
64
        inputs = []
65
        num_tensors = 5
66
        for i in range(num_tensors):
67
            dim1 = np.random.randint(5, 100)
68
            dim2 = np.random.randint(5, 100)
69
            inputs.append(np.random.rand(dim1, dim2))
70
        self._test_quantile(inputs=inputs, quantile=1 - 1e-6, abs=1, tol=1e-5)
71

72
    def test_quantile_4(self):
73
        inputs = []
74
        num_tensors = 5
75
        for i in range(num_tensors):
76
            dim1 = np.random.randint(5, 100)
77
            dim2 = np.random.randint(5, 100)
78
            inputs.append(np.random.rand(dim1, dim2))
79
            inputs.append(np.random.rand(dim1))
80
        self._test_quantile(inputs=inputs, quantile=0.168, abs=1, tol=1e-4)
81

82

83
if __name__ == "__main__":
84
    global_options = ["caffe2"]
85
    core.GlobalInit(global_options)
86
    unittest.main()
87

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.