pytorch
39 строк · 1.5 Кб
1
2
3
4
5
6from caffe2.python import core, workspace, test_util
7import os
8import shutil
9import tempfile
10import unittest
11
12
13class CheckpointTest(test_util.TestCase):
14"""A simple test case to make sure that the checkpoint behavior is correct.
15"""
16
17@unittest.skipIf("LevelDB" not in core.C.registered_dbs(), "Need LevelDB")
18def testCheckpoint(self):
19temp_root = tempfile.mkdtemp()
20net = core.Net("test_checkpoint")
21# Note(jiayq): I am being a bit lazy here and am using the old iter
22# convention that does not have an input. Optionally change it to the
23# new style if needed.
24net.Iter([], "iter")
25net.ConstantFill([], "value", shape=[1, 2, 3])
26net.Checkpoint(["iter", "value"], [],
27db=os.path.join(temp_root, "test_checkpoint_at_%05d"),
28db_type="leveldb", every=10, absolute_path=True)
29self.assertTrue(workspace.CreateNet(net))
30for i in range(100):
31self.assertTrue(workspace.RunNet("test_checkpoint"))
32for i in range(1, 10):
33# Print statements are only for debugging purposes.
34# print("Asserting %d" % i)
35# print(os.path.join(temp_root, "test_checkpoint_at_%05d" % (i * 10)))
36self.assertTrue(os.path.exists(
37os.path.join(temp_root, "test_checkpoint_at_%05d" % (i * 10))))
38
39# Finally, clean up.
40shutil.rmtree(temp_root)
41
42
43if __name__ == "__main__":
44unittest.main()
45