pytorch
46 строк · 1.7 Кб
1
2
3
4
5
6from caffe2.python import core, workspace
7from caffe2.python.test_util import TestCase
8import numpy as np
9
10
11class TestCounterOps(TestCase):
12
13def test_stats_ops(self):
14# The global StatRegistry isn't reset when the workspace is reset,
15# so there may be existing stats from a previous test
16workspace.RunOperatorOnce(core.CreateOperator(
17'StatRegistryExport', [], ['prev_k', 'prev_v', 'prev_ts']))
18previous_keys = workspace.FetchBlob('prev_k')
19existing = len(previous_keys)
20
21prefix = '/'.join([__name__, 'TestCounterOps', 'test_stats_ops'])
22keys = [
23(prefix + '/key1').encode('ascii'),
24(prefix + '/key2').encode('ascii')
25]
26values = [34, 45]
27workspace.FeedBlob('k', np.array(keys, dtype=str))
28workspace.FeedBlob('v', np.array(values, dtype=np.int64))
29for _ in range(2):
30workspace.RunOperatorOnce(core.CreateOperator(
31'StatRegistryUpdate', ['k', 'v'], []))
32workspace.RunOperatorOnce(core.CreateOperator(
33'StatRegistryExport', [], ['k2', 'v2', 't2']))
34
35workspace.RunOperatorOnce(core.CreateOperator(
36'StatRegistryCreate', [], ['reg']))
37workspace.RunOperatorOnce(core.CreateOperator(
38'StatRegistryUpdate', ['k2', 'v2', 'reg'], []))
39
40workspace.RunOperatorOnce(core.CreateOperator(
41'StatRegistryExport', ['reg'], ['k3', 'v3', 't3']))
42
43k3 = workspace.FetchBlob('k3')
44v3 = workspace.FetchBlob('v3')
45t3 = workspace.FetchBlob('t3')
46
47self.assertEqual(len(k3) - existing, 2)
48self.assertEqual(len(v3), len(k3))
49self.assertEqual(len(t3), len(k3))
50for key in keys:
51self.assertIn(key, k3)
52