pytorch

Форк
0
/
lazy_dyndep_test.py 
133 строки · 3.8 Кб
1
#!/usr/bin/env python3
2

3

4

5

6

7

8
from hypothesis import given, settings
9
import hypothesis.strategies as st
10
from multiprocessing import Process
11

12
import numpy as np
13
import tempfile
14
import shutil
15

16
import caffe2.python.hypothesis_test_util as hu
17
import unittest
18

19
op_engine = 'GLOO'
20

21
class TemporaryDirectory:
22
    def __enter__(self):
23
        self.tmpdir = tempfile.mkdtemp()
24
        return self.tmpdir
25

26
    def __exit__(self, type, value, traceback):
27
        shutil.rmtree(self.tmpdir)
28

29

30
def allcompare_process(filestore_dir, process_id, data, num_procs):
31
    from caffe2.python import core, data_parallel_model, workspace, lazy_dyndep
32
    from caffe2.python.model_helper import ModelHelper
33
    from caffe2.proto import caffe2_pb2
34
    lazy_dyndep.RegisterOpsLibrary("@/caffe2/caffe2/distributed:file_store_handler_ops")
35

36
    workspace.RunOperatorOnce(
37
        core.CreateOperator(
38
            "FileStoreHandlerCreate", [], ["store_handler"], path=filestore_dir
39
        )
40
    )
41
    rendezvous = dict(
42
        kv_handler="store_handler",
43
        shard_id=process_id,
44
        num_shards=num_procs,
45
        engine=op_engine,
46
        exit_nets=None
47
    )
48

49
    model = ModelHelper()
50
    model._rendezvous = rendezvous
51

52
    workspace.FeedBlob("test_data", data)
53

54
    data_parallel_model._RunComparison(
55
        model, "test_data", core.DeviceOption(caffe2_pb2.CPU, 0)
56
    )
57

58

59
class TestLazyDynDepAllCompare(hu.HypothesisTestCase):
60
    @given(
61
        d=st.integers(1, 5), n=st.integers(2, 11), num_procs=st.integers(1, 8)
62
    )
63
    @settings(deadline=None)
64
    def test_allcompare(self, d, n, num_procs):
65
        dims = []
66
        for _ in range(d):
67
            dims.append(np.random.randint(1, high=n))
68
        test_data = np.random.ranf(size=tuple(dims)).astype(np.float32)
69

70
        with TemporaryDirectory() as tempdir:
71
            processes = []
72
            for idx in range(num_procs):
73
                process = Process(
74
                    target=allcompare_process,
75
                    args=(tempdir, idx, test_data, num_procs)
76
                )
77
                processes.append(process)
78
                process.start()
79

80
            while len(processes) > 0:
81
                process = processes.pop()
82
                process.join()
83

84
class TestLazyDynDepError(unittest.TestCase):
85
    def test_errorhandler(self):
86
        from caffe2.python import core, lazy_dyndep
87
        import tempfile
88

89
        with tempfile.NamedTemporaryFile() as f:
90
            lazy_dyndep.RegisterOpsLibrary(f.name)
91

92
            def handler(e):
93
                raise ValueError("test")
94
            lazy_dyndep.SetErrorHandler(handler)
95
            with self.assertRaises(ValueError, msg="test"):
96
                core.RefreshRegisteredOperators()
97

98
    def test_importaftererror(self):
99
        from caffe2.python import core, lazy_dyndep
100
        import tempfile
101

102
        with tempfile.NamedTemporaryFile() as f:
103
            lazy_dyndep.RegisterOpsLibrary(f.name)
104

105
            def handler(e):
106
                raise ValueError("test")
107
            lazy_dyndep.SetErrorHandler(handler)
108
            with self.assertRaises(ValueError):
109
                core.RefreshRegisteredOperators()
110

111
            def handlernoop(e):
112
                raise
113
            lazy_dyndep.SetErrorHandler(handlernoop)
114
            lazy_dyndep.RegisterOpsLibrary("@/caffe2/caffe2/distributed:file_store_handler_ops")
115
            core.RefreshRegisteredOperators()
116

117
    def test_workspacecreatenet(self):
118
        from caffe2.python import workspace, lazy_dyndep
119
        import tempfile
120

121
        with tempfile.NamedTemporaryFile() as f:
122
            lazy_dyndep.RegisterOpsLibrary(f.name)
123
            called = False
124

125
            def handler(e):
126
                raise ValueError("test")
127
            lazy_dyndep.SetErrorHandler(handler)
128
            with self.assertRaises(ValueError, msg="test"):
129
                workspace.CreateNet("fake")
130

131

132
if __name__ == "__main__":
133
    unittest.main()
134

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.