8
from hypothesis import given, settings
9
import hypothesis.strategies as st
10
from multiprocessing import Process
16
import caffe2.python.hypothesis_test_util as hu
21
class TemporaryDirectory:
23
self.tmpdir = tempfile.mkdtemp()
26
def __exit__(self, type, value, traceback):
27
shutil.rmtree(self.tmpdir)
30
def allcompare_process(filestore_dir, process_id, data, num_procs):
31
from caffe2.python import core, data_parallel_model, workspace, lazy_dyndep
32
from caffe2.python.model_helper import ModelHelper
33
from caffe2.proto import caffe2_pb2
34
lazy_dyndep.RegisterOpsLibrary("@/caffe2/caffe2/distributed:file_store_handler_ops")
36
workspace.RunOperatorOnce(
38
"FileStoreHandlerCreate", [], ["store_handler"], path=filestore_dir
42
kv_handler="store_handler",
50
model._rendezvous = rendezvous
52
workspace.FeedBlob("test_data", data)
54
data_parallel_model._RunComparison(
55
model, "test_data", core.DeviceOption(caffe2_pb2.CPU, 0)
59
class TestLazyDynDepAllCompare(hu.HypothesisTestCase):
61
d=st.integers(1, 5), n=st.integers(2, 11), num_procs=st.integers(1, 8)
63
@settings(deadline=None)
64
def test_allcompare(self, d, n, num_procs):
67
dims.append(np.random.randint(1, high=n))
68
test_data = np.random.ranf(size=tuple(dims)).astype(np.float32)
70
with TemporaryDirectory() as tempdir:
72
for idx in range(num_procs):
74
target=allcompare_process,
75
args=(tempdir, idx, test_data, num_procs)
77
processes.append(process)
80
while len(processes) > 0:
81
process = processes.pop()
84
class TestLazyDynDepError(unittest.TestCase):
85
def test_errorhandler(self):
86
from caffe2.python import core, lazy_dyndep
89
with tempfile.NamedTemporaryFile() as f:
90
lazy_dyndep.RegisterOpsLibrary(f.name)
93
raise ValueError("test")
94
lazy_dyndep.SetErrorHandler(handler)
95
with self.assertRaises(ValueError, msg="test"):
96
core.RefreshRegisteredOperators()
98
def test_importaftererror(self):
99
from caffe2.python import core, lazy_dyndep
102
with tempfile.NamedTemporaryFile() as f:
103
lazy_dyndep.RegisterOpsLibrary(f.name)
106
raise ValueError("test")
107
lazy_dyndep.SetErrorHandler(handler)
108
with self.assertRaises(ValueError):
109
core.RefreshRegisteredOperators()
113
lazy_dyndep.SetErrorHandler(handlernoop)
114
lazy_dyndep.RegisterOpsLibrary("@/caffe2/caffe2/distributed:file_store_handler_ops")
115
core.RefreshRegisteredOperators()
117
def test_workspacecreatenet(self):
118
from caffe2.python import workspace, lazy_dyndep
121
with tempfile.NamedTemporaryFile() as f:
122
lazy_dyndep.RegisterOpsLibrary(f.name)
126
raise ValueError("test")
127
lazy_dyndep.SetErrorHandler(handler)
128
with self.assertRaises(ValueError, msg="test"):
129
workspace.CreateNet("fake")
132
if __name__ == "__main__":