pytorch

Форк
0
79 строк · 2.7 Кб
1

2

3

4

5

6
from caffe2.python import core
7
from hypothesis import given
8
import caffe2.python.hypothesis_test_util as hu
9
import hypothesis.strategies as st
10
import numpy as np
11

12

13
def calculate_ap(predictions, labels):
14
    N, D = predictions.shape
15
    ap = np.zeros(D)
16
    num_range = np.arange((N), dtype=np.float32) + 1
17
    for k in range(D):
18
        scores = predictions[:N, k]
19
        label = labels[:N, k]
20
        sortind = np.argsort(-scores, kind='mergesort')
21
        truth = label[sortind]
22
        precision = np.cumsum(truth) / num_range
23
        ap[k] = precision[truth.astype(bool)].sum() / max(1, truth.sum())
24
    return ap
25

26

27
class TestAPMeterOps(hu.HypothesisTestCase):
28
    @given(predictions=hu.arrays(dims=[10, 3],
29
           elements=hu.floats(allow_nan=False,
30
                              allow_infinity=False,
31
                              min_value=0.1,
32
                              max_value=1)),
33
           labels=hu.arrays(dims=[10, 3],
34
                            dtype=np.int32,
35
                            elements=st.integers(min_value=0,
36
                                                 max_value=1)),
37
           **hu.gcs_cpu_only)
38
    def test_average_precision(self, predictions, labels, gc, dc):
39
        op = core.CreateOperator(
40
            "APMeter",
41
            ["predictions", "labels"],
42
            ["AP"],
43
            buffer_size=10,
44
        )
45

46
        def op_ref(predictions, labels):
47
            ap = calculate_ap(predictions, labels)
48
            return (ap, )
49

50
        self.assertReferenceChecks(
51
            device_option=gc,
52
            op=op,
53
            inputs=[predictions, labels],
54
            reference=op_ref)
55

56
    @given(predictions=hu.arrays(dims=[10, 3],
57
           elements=hu.floats(allow_nan=False,
58
                              allow_infinity=False,
59
                              min_value=0.1,
60
                              max_value=1)),
61
           labels=hu.arrays(dims=[10, 3],
62
                            dtype=np.int32,
63
                            elements=st.integers(min_value=0,
64
                                                 max_value=1)),
65
           **hu.gcs_cpu_only)
66
    def test_average_precision_small_buffer(self, predictions, labels, gc, dc):
67
        op_small_buffer = core.CreateOperator(
68
            "APMeter",
69
            ["predictions", "labels"],
70
            ["AP"],
71
            buffer_size=5,
72
        )
73

74
        def op_ref(predictions, labels):
75
            # We can only hold the last 5 in the buffer
76
            ap = calculate_ap(predictions[5:], labels[5:])
77
            return (ap, )
78

79
        self.assertReferenceChecks(
80
            device_option=gc,
81
            op=op_small_buffer,
82
            inputs=[predictions, labels],
83
            reference=op_ref
84
        )
85

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.