9
from caffe2.python import cnn, workspace, core
11
from caffe2.python.predictor_constants import predictor_constants as pc
12
import caffe2.python.predictor.predictor_exporter as pe
13
import caffe2.python.predictor.predictor_py_utils as pred_utils
14
from caffe2.proto import caffe2_pb2, metanet_pb2
17
class MetaNetDefTest(unittest.TestCase):
18
def test_minimal(self):
20
Tests that a NetsMap message can be created with a NetDef message
23
metanet_pb2.NetsMap(key="test_key", value=caffe2_pb2.NetDef())
25
def test_adding_net(self):
27
Tests that NetDefs can be added to MetaNetDefs
29
meta_net_def = metanet_pb2.MetaNetDef()
30
net_def = caffe2_pb2.NetDef()
31
meta_net_def.nets.add(key="test_key", value=net_def)
33
def test_replace_blobs(self):
35
Tests that NetDefs can be added to MetaNetDefs
37
meta_net_def = metanet_pb2.MetaNetDef()
41
replaced_blob_def = ["CC"]
42
pred_utils.AddBlobs(meta_net_def, blob_name, blob_def)
43
self.assertEqual(blob_def, pred_utils.GetBlobs(meta_net_def, blob_name))
44
pred_utils.AddBlobs(meta_net_def, blob_name, blob_def2)
45
self.assertEqual(blob_def + blob_def2, pred_utils.GetBlobs(meta_net_def, blob_name))
47
pred_utils.ReplaceBlobs(meta_net_def, blob_name, replaced_blob_def)
48
self.assertEqual(replaced_blob_def, pred_utils.GetBlobs(meta_net_def, blob_name))
51
class PredictorExporterTest(unittest.TestCase):
52
def _create_model(self):
53
m = cnn.CNNModelHelper()
56
weight_init=m.XavierInit,
57
bias_init=m.XavierInit)
62
m = self._create_model()
64
self.predictor_export_meta = pe.PredictorExportMeta(
65
predict_net=m.net.Proto(),
66
parameters=[str(b) for b in m.params],
69
shapes={"y": (1, 10), "data": (1, 5)},
71
workspace.RunNetOnce(m.param_init_net)
74
param: workspace.FetchBlob(param)
75
for param in self.predictor_export_meta.parameters}
77
workspace.ResetWorkspace()
79
def test_meta_constructor(self):
81
Test that passing net itself instead of proto works
83
m = self._create_model()
84
pe.PredictorExportMeta(
89
shapes={"y": (1, 10), "data": (1, 5)},
92
def test_param_intersection(self):
94
Test that passes intersecting parameters and input/output blobs
96
m = self._create_model()
97
with self.assertRaises(Exception):
98
pe.PredictorExportMeta(
101
inputs=["data"] + m.params,
103
shapes={"y": (1, 10), "data": (1, 5)},
105
with self.assertRaises(Exception):
106
pe.PredictorExportMeta(
110
outputs=["y"] + m.params,
111
shapes={"y": (1, 10), "data": (1, 5)},
114
def test_meta_net_def_net_runs(self):
115
for param, value in self.params.items():
116
workspace.FeedBlob(param, value)
118
extra_init_net = core.Net('extra_init')
119
extra_init_net.ConstantFill('data', 'data', value=1.0)
121
global_init_net = core.Net('global_init')
122
global_init_net.ConstantFill(
127
dtype=core.DataType.FLOAT
129
pem = pe.PredictorExportMeta(
130
predict_net=self.predictor_export_meta.predict_net,
131
parameters=self.predictor_export_meta.parameters,
132
inputs=self.predictor_export_meta.inputs,
133
outputs=self.predictor_export_meta.outputs,
134
shapes=self.predictor_export_meta.shapes,
135
extra_init_net=extra_init_net,
136
global_init_net=global_init_net,
141
db_file = tempfile.NamedTemporaryFile(
142
delete=False, suffix=".{}".format(db_type))
145
db_destination=db_file.name,
146
predictor_export_meta=pem)
148
workspace.ResetWorkspace()
150
meta_net_def = pe.load_from_db(
152
filename=db_file.name,
155
self.assertTrue("data" not in workspace.Blobs())
156
self.assertTrue("y" not in workspace.Blobs())
158
init_net = pred_utils.GetNet(meta_net_def, pc.PREDICT_INIT_NET_TYPE)
161
workspace.RunNetOnce(init_net)
163
self.assertTrue("data" in workspace.Blobs())
164
self.assertTrue("y" in workspace.Blobs())
166
print(workspace.FetchBlob("data"))
167
np.testing.assert_array_equal(
168
workspace.FetchBlob("data"), np.ones(shape=(1, 5)))
169
np.testing.assert_array_equal(
170
workspace.FetchBlob("y"), np.zeros(shape=(1, 10)))
172
self.assertTrue("global_init_blob" not in workspace.Blobs())
174
global_init_net = pred_utils.GetNet(meta_net_def,
175
pc.GLOBAL_INIT_NET_TYPE)
176
workspace.RunNetOnce(global_init_net)
179
self.assertTrue(workspace.HasBlob('global_init_blob'))
180
np.testing.assert_array_equal(
181
workspace.FetchBlob("global_init_blob"), np.ones(shape=(1, 5)))
185
workspace.FeedBlob("data", np.random.randn(2, 5).astype(np.float32))
186
predict_net = pred_utils.GetNet(meta_net_def, pc.PREDICT_NET_TYPE)
187
self.assertEqual(predict_net.type, 'dag')
188
workspace.RunNetOnce(predict_net)
189
np.testing.assert_array_almost_equal(
190
workspace.FetchBlob("y"),
191
workspace.FetchBlob("data").dot(self.params["y_w"].T) +
194
def test_load_device_scope(self):
195
for param, value in self.params.items():
196
workspace.FeedBlob(param, value)
198
pem = pe.PredictorExportMeta(
199
predict_net=self.predictor_export_meta.predict_net,
200
parameters=self.predictor_export_meta.parameters,
201
inputs=self.predictor_export_meta.inputs,
202
outputs=self.predictor_export_meta.outputs,
203
shapes=self.predictor_export_meta.shapes,
208
db_file = tempfile.NamedTemporaryFile(
209
delete=False, suffix=".{}".format(db_type))
212
db_destination=db_file.name,
213
predictor_export_meta=pem)
215
workspace.ResetWorkspace()
216
with core.DeviceScope(core.DeviceOption(caffe2_pb2.CPU, 1)):
217
meta_net_def = pe.load_from_db(
219
filename=db_file.name,
222
init_net = core.Net(pred_utils.GetNet(meta_net_def,
223
pc.GLOBAL_INIT_NET_TYPE))
224
predict_init_net = core.Net(pred_utils.GetNet(
225
meta_net_def, pc.PREDICT_INIT_NET_TYPE))
228
for op in list(init_net.Proto().op) + list(predict_init_net.Proto().op):
229
self.assertEqual(1, op.device_option.device_id)
230
self.assertEqual(caffe2_pb2.CPU, op.device_option.device_type)
232
def test_db_fails_without_params(self):
233
with self.assertRaises(Exception):
234
for db_type in ["minidb"]:
235
db_file = tempfile.NamedTemporaryFile(
236
delete=False, suffix=".{}".format(db_type))
239
db_destination=db_file.name,
240
predictor_export_meta=self.predictor_export_meta)