pytorch
79 строк · 2.7 Кб
1
2
3
4
5
6from caffe2.python import core
7from hypothesis import given
8import caffe2.python.hypothesis_test_util as hu
9import hypothesis.strategies as st
10import numpy as np
11
12
13def calculate_ap(predictions, labels):
14N, D = predictions.shape
15ap = np.zeros(D)
16num_range = np.arange((N), dtype=np.float32) + 1
17for k in range(D):
18scores = predictions[:N, k]
19label = labels[:N, k]
20sortind = np.argsort(-scores, kind='mergesort')
21truth = label[sortind]
22precision = np.cumsum(truth) / num_range
23ap[k] = precision[truth.astype(bool)].sum() / max(1, truth.sum())
24return ap
25
26
27class TestAPMeterOps(hu.HypothesisTestCase):
28@given(predictions=hu.arrays(dims=[10, 3],
29elements=hu.floats(allow_nan=False,
30allow_infinity=False,
31min_value=0.1,
32max_value=1)),
33labels=hu.arrays(dims=[10, 3],
34dtype=np.int32,
35elements=st.integers(min_value=0,
36max_value=1)),
37**hu.gcs_cpu_only)
38def test_average_precision(self, predictions, labels, gc, dc):
39op = core.CreateOperator(
40"APMeter",
41["predictions", "labels"],
42["AP"],
43buffer_size=10,
44)
45
46def op_ref(predictions, labels):
47ap = calculate_ap(predictions, labels)
48return (ap, )
49
50self.assertReferenceChecks(
51device_option=gc,
52op=op,
53inputs=[predictions, labels],
54reference=op_ref)
55
56@given(predictions=hu.arrays(dims=[10, 3],
57elements=hu.floats(allow_nan=False,
58allow_infinity=False,
59min_value=0.1,
60max_value=1)),
61labels=hu.arrays(dims=[10, 3],
62dtype=np.int32,
63elements=st.integers(min_value=0,
64max_value=1)),
65**hu.gcs_cpu_only)
66def test_average_precision_small_buffer(self, predictions, labels, gc, dc):
67op_small_buffer = core.CreateOperator(
68"APMeter",
69["predictions", "labels"],
70["AP"],
71buffer_size=5,
72)
73
74def op_ref(predictions, labels):
75# We can only hold the last 5 in the buffer
76ap = calculate_ap(predictions[5:], labels[5:])
77return (ap, )
78
79self.assertReferenceChecks(
80device_option=gc,
81op=op_small_buffer,
82inputs=[predictions, labels],
83reference=op_ref
84)
85