pytorch
79 строк · 3.3 Кб
1
2
3
4
5
6from caffe2.python import core, workspace7from caffe2.python.test_util import TestCase8import tempfile9
10
11class TestCounterOps(TestCase):12
13def test_counter_ops(self):14workspace.RunOperatorOnce(core.CreateOperator(15'CreateCounter', [], ['c'], init_count=1))16
17workspace.RunOperatorOnce(core.CreateOperator(18'CountDown', ['c'], ['t1'])) # 1 -> 019assert not workspace.FetchBlob('t1')20
21workspace.RunOperatorOnce(core.CreateOperator(22'CountDown', ['c'], ['t2'])) # 0 -> -123assert workspace.FetchBlob('t2')24
25workspace.RunOperatorOnce(core.CreateOperator(26'CountUp', ['c'], ['t21'])) # -1 -> 027assert workspace.FetchBlob('t21') == -128workspace.RunOperatorOnce(core.CreateOperator(29'RetrieveCount', ['c'], ['t22']))30assert workspace.FetchBlob('t22') == 031
32workspace.RunOperatorOnce(core.CreateOperator(33'ResetCounter', ['c'], [], init_count=1)) # -> 134workspace.RunOperatorOnce(core.CreateOperator(35'CountDown', ['c'], ['t3'])) # 1 -> 036assert not workspace.FetchBlob('t3')37
38workspace.RunOperatorOnce(core.CreateOperator(39'ResetCounter', ['c'], ['t31'], init_count=5)) # 0 -> 540assert workspace.FetchBlob('t31') == 041workspace.RunOperatorOnce(core.CreateOperator(42'ResetCounter', ['c'], ['t32'])) # 5 -> 043assert workspace.FetchBlob('t32') == 544
45workspace.RunOperatorOnce(core.CreateOperator(46'ConstantFill', [], ['t4'], value=False, shape=[],47dtype=core.DataType.BOOL))48assert workspace.FetchBlob('t4') == workspace.FetchBlob('t1')49
50workspace.RunOperatorOnce(core.CreateOperator(51'ConstantFill', [], ['t5'], value=True, shape=[],52dtype=core.DataType.BOOL))53assert workspace.FetchBlob('t5') == workspace.FetchBlob('t2')54
55assert workspace.RunOperatorOnce(core.CreateOperator(56'And', ['t1', 't2'], ['t6']))57assert not workspace.FetchBlob('t6') # True && False58
59assert workspace.RunOperatorOnce(core.CreateOperator(60'And', ['t2', 't5'], ['t7']))61assert workspace.FetchBlob('t7') # True && True62
63workspace.RunOperatorOnce(core.CreateOperator(64'CreateCounter', [], ['serialized_c'], init_count=22))65with tempfile.NamedTemporaryFile() as tmp:66workspace.RunOperatorOnce(core.CreateOperator(67'Save', ['serialized_c'], [], absolute_path=1,68db_type='minidb', db=tmp.name))69for i in range(10):70workspace.RunOperatorOnce(core.CreateOperator(71'CountDown', ['serialized_c'], ['t8']))72workspace.RunOperatorOnce(core.CreateOperator(73'RetrieveCount', ['serialized_c'], ['t8']))74assert workspace.FetchBlob('t8') == 1275workspace.RunOperatorOnce(core.CreateOperator(76'Load', [], ['serialized_c'], absolute_path=1,77db_type='minidb', db=tmp.name))78workspace.RunOperatorOnce(core.CreateOperator(79'RetrieveCount', ['serialized_c'], ['t8']))80assert workspace.FetchBlob('t8') == 2281
82if __name__ == "__main__":83import unittest84unittest.main()85