pytorch

Форк
0
/
allcompare_test.py 
87 строк · 2.2 Кб
1
#!/usr/bin/env python3
2

3

4

5

6

7

8
from hypothesis import given, settings
9
import hypothesis.strategies as st
10
from multiprocessing import Process
11

12
import numpy as np
13
import tempfile
14
import shutil
15

16
import caffe2.python.hypothesis_test_util as hu
17

18
op_engine = 'GLOO'
19

20

21
class TemporaryDirectory:
22
    def __enter__(self):
23
        self.tmpdir = tempfile.mkdtemp()
24
        return self.tmpdir
25

26
    def __exit__(self, type, value, traceback):
27
        shutil.rmtree(self.tmpdir)
28

29

30
def allcompare_process(filestore_dir, process_id, data, num_procs):
31
    from caffe2.python import core, data_parallel_model, workspace, dyndep
32
    from caffe2.python.model_helper import ModelHelper
33
    from caffe2.proto import caffe2_pb2
34
    dyndep.InitOpsLibrary("@/caffe2/caffe2/distributed:file_store_handler_ops")
35

36
    workspace.RunOperatorOnce(
37
        core.CreateOperator(
38
            "FileStoreHandlerCreate", [], ["store_handler"], path=filestore_dir
39
        )
40
    )
41
    rendezvous = dict(
42
        kv_handler="store_handler",
43
        shard_id=process_id,
44
        num_shards=num_procs,
45
        engine=op_engine,
46
        exit_nets=None
47
    )
48

49
    model = ModelHelper()
50
    model._rendezvous = rendezvous
51

52
    workspace.FeedBlob("test_data", data)
53

54
    data_parallel_model._RunComparison(
55
        model, "test_data", core.DeviceOption(caffe2_pb2.CPU, 0)
56
    )
57

58

59
class TestAllCompare(hu.HypothesisTestCase):
60
    @given(
61
        d=st.integers(1, 5), n=st.integers(2, 11), num_procs=st.integers(1, 8)
62
    )
63
    @settings(deadline=10000)
64
    def test_allcompare(self, d, n, num_procs):
65
        dims = []
66
        for _ in range(d):
67
            dims.append(np.random.randint(1, high=n))
68
        test_data = np.random.ranf(size=tuple(dims)).astype(np.float32)
69

70
        with TemporaryDirectory() as tempdir:
71
            processes = []
72
            for idx in range(num_procs):
73
                process = Process(
74
                    target=allcompare_process,
75
                    args=(tempdir, idx, test_data, num_procs)
76
                )
77
                processes.append(process)
78
                process.start()
79

80
            while len(processes) > 0:
81
                process = processes.pop()
82
                process.join()
83

84

85
if __name__ == "__main__":
86
    import unittest
87
    unittest.main()
88

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.