pytorch

Форк
0
/
predictor_exporter_test.py 
235 строк · 8.5 Кб
1

2

3

4

5

6
import tempfile
7
import unittest
8
import numpy as np
9
from caffe2.python import cnn, workspace, core
10

11
from caffe2.python.predictor_constants import predictor_constants as pc
12
import caffe2.python.predictor.predictor_exporter as pe
13
import caffe2.python.predictor.predictor_py_utils as pred_utils
14
from caffe2.proto import caffe2_pb2, metanet_pb2
15

16

17
class MetaNetDefTest(unittest.TestCase):
18
    def test_minimal(self):
19
        '''
20
        Tests that a NetsMap message can be created with a NetDef message
21
        '''
22
        # This calls the constructor for a metanet_pb2.NetsMap
23
        metanet_pb2.NetsMap(key="test_key", value=caffe2_pb2.NetDef())
24

25
    def test_adding_net(self):
26
        '''
27
        Tests that NetDefs can be added to MetaNetDefs
28
        '''
29
        meta_net_def = metanet_pb2.MetaNetDef()
30
        net_def = caffe2_pb2.NetDef()
31
        meta_net_def.nets.add(key="test_key", value=net_def)
32

33
    def test_replace_blobs(self):
34
        '''
35
        Tests that NetDefs can be added to MetaNetDefs
36
        '''
37
        meta_net_def = metanet_pb2.MetaNetDef()
38
        blob_name = "Test"
39
        blob_def = ["AA"]
40
        blob_def2 = ["BB"]
41
        replaced_blob_def = ["CC"]
42
        pred_utils.AddBlobs(meta_net_def, blob_name, blob_def)
43
        self.assertEqual(blob_def, pred_utils.GetBlobs(meta_net_def, blob_name))
44
        pred_utils.AddBlobs(meta_net_def, blob_name, blob_def2)
45
        self.assertEqual(blob_def + blob_def2, pred_utils.GetBlobs(meta_net_def, blob_name))
46

47
        pred_utils.ReplaceBlobs(meta_net_def, blob_name, replaced_blob_def)
48
        self.assertEqual(replaced_blob_def, pred_utils.GetBlobs(meta_net_def, blob_name))
49

50

51
class PredictorExporterTest(unittest.TestCase):
52
    def _create_model(self):
53
        m = cnn.CNNModelHelper()
54
        m.FC("data", "y",
55
             dim_in=5, dim_out=10,
56
             weight_init=m.XavierInit,
57
             bias_init=m.XavierInit)
58
        return m
59

60
    def setUp(self):
61
        np.random.seed(1)
62
        m = self._create_model()
63

64
        self.predictor_export_meta = pe.PredictorExportMeta(
65
            predict_net=m.net.Proto(),
66
            parameters=[str(b) for b in m.params],
67
            inputs=["data"],
68
            outputs=["y"],
69
            shapes={"y": (1, 10), "data": (1, 5)},
70
        )
71
        workspace.RunNetOnce(m.param_init_net)
72

73
        self.params = {
74
            param: workspace.FetchBlob(param)
75
            for param in self.predictor_export_meta.parameters}
76
        # Reset the workspace, to ensure net creation proceeds as expected.
77
        workspace.ResetWorkspace()
78

79
    def test_meta_constructor(self):
80
        '''
81
        Test that passing net itself instead of proto works
82
        '''
83
        m = self._create_model()
84
        pe.PredictorExportMeta(
85
            predict_net=m.net,
86
            parameters=m.params,
87
            inputs=["data"],
88
            outputs=["y"],
89
            shapes={"y": (1, 10), "data": (1, 5)},
90
        )
91

92
    def test_param_intersection(self):
93
        '''
94
        Test that passes intersecting parameters and input/output blobs
95
        '''
96
        m = self._create_model()
97
        with self.assertRaises(Exception):
98
            pe.PredictorExportMeta(
99
                predict_net=m.net,
100
                parameters=m.params,
101
                inputs=["data"] + m.params,
102
                outputs=["y"],
103
                shapes={"y": (1, 10), "data": (1, 5)},
104
            )
105
        with self.assertRaises(Exception):
106
            pe.PredictorExportMeta(
107
                predict_net=m.net,
108
                parameters=m.params,
109
                inputs=["data"],
110
                outputs=["y"] + m.params,
111
                shapes={"y": (1, 10), "data": (1, 5)},
112
            )
113

114
    def test_meta_net_def_net_runs(self):
115
        for param, value in self.params.items():
116
            workspace.FeedBlob(param, value)
117

118
        extra_init_net = core.Net('extra_init')
119
        extra_init_net.ConstantFill('data', 'data', value=1.0)
120

121
        global_init_net = core.Net('global_init')
122
        global_init_net.ConstantFill(
123
            [],
124
            'global_init_blob',
125
            value=1.0,
126
            shape=[1, 5],
127
            dtype=core.DataType.FLOAT
128
        )
129
        pem = pe.PredictorExportMeta(
130
            predict_net=self.predictor_export_meta.predict_net,
131
            parameters=self.predictor_export_meta.parameters,
132
            inputs=self.predictor_export_meta.inputs,
133
            outputs=self.predictor_export_meta.outputs,
134
            shapes=self.predictor_export_meta.shapes,
135
            extra_init_net=extra_init_net,
136
            global_init_net=global_init_net,
137
            net_type='dag',
138
        )
139

140
        db_type = 'minidb'
141
        db_file = tempfile.NamedTemporaryFile(
142
            delete=False, suffix=".{}".format(db_type))
143
        pe.save_to_db(
144
            db_type=db_type,
145
            db_destination=db_file.name,
146
            predictor_export_meta=pem)
147

148
        workspace.ResetWorkspace()
149

150
        meta_net_def = pe.load_from_db(
151
            db_type=db_type,
152
            filename=db_file.name,
153
        )
154

155
        self.assertTrue("data" not in workspace.Blobs())
156
        self.assertTrue("y" not in workspace.Blobs())
157

158
        init_net = pred_utils.GetNet(meta_net_def, pc.PREDICT_INIT_NET_TYPE)
159

160
        # 0-fills externalblobs blobs and runs extra_init_net
161
        workspace.RunNetOnce(init_net)
162

163
        self.assertTrue("data" in workspace.Blobs())
164
        self.assertTrue("y" in workspace.Blobs())
165

166
        print(workspace.FetchBlob("data"))
167
        np.testing.assert_array_equal(
168
            workspace.FetchBlob("data"), np.ones(shape=(1, 5)))
169
        np.testing.assert_array_equal(
170
            workspace.FetchBlob("y"), np.zeros(shape=(1, 10)))
171

172
        self.assertTrue("global_init_blob" not in workspace.Blobs())
173
        # Load parameters from DB
174
        global_init_net = pred_utils.GetNet(meta_net_def,
175
                                            pc.GLOBAL_INIT_NET_TYPE)
176
        workspace.RunNetOnce(global_init_net)
177

178
        # make sure the extra global_init_net is running
179
        self.assertTrue(workspace.HasBlob('global_init_blob'))
180
        np.testing.assert_array_equal(
181
            workspace.FetchBlob("global_init_blob"), np.ones(shape=(1, 5)))
182

183
        # Run the net with a reshaped input and verify we are
184
        # producing good numbers (with our custom implementation)
185
        workspace.FeedBlob("data", np.random.randn(2, 5).astype(np.float32))
186
        predict_net = pred_utils.GetNet(meta_net_def, pc.PREDICT_NET_TYPE)
187
        self.assertEqual(predict_net.type, 'dag')
188
        workspace.RunNetOnce(predict_net)
189
        np.testing.assert_array_almost_equal(
190
            workspace.FetchBlob("y"),
191
            workspace.FetchBlob("data").dot(self.params["y_w"].T) +
192
            self.params["y_b"])
193

194
    def test_load_device_scope(self):
195
        for param, value in self.params.items():
196
            workspace.FeedBlob(param, value)
197

198
        pem = pe.PredictorExportMeta(
199
            predict_net=self.predictor_export_meta.predict_net,
200
            parameters=self.predictor_export_meta.parameters,
201
            inputs=self.predictor_export_meta.inputs,
202
            outputs=self.predictor_export_meta.outputs,
203
            shapes=self.predictor_export_meta.shapes,
204
            net_type='dag',
205
        )
206

207
        db_type = 'minidb'
208
        db_file = tempfile.NamedTemporaryFile(
209
            delete=False, suffix=".{}".format(db_type))
210
        pe.save_to_db(
211
            db_type=db_type,
212
            db_destination=db_file.name,
213
            predictor_export_meta=pem)
214

215
        workspace.ResetWorkspace()
216
        with core.DeviceScope(core.DeviceOption(caffe2_pb2.CPU, 1)):
217
            meta_net_def = pe.load_from_db(
218
                db_type=db_type,
219
                filename=db_file.name,
220
            )
221

222
        init_net = core.Net(pred_utils.GetNet(meta_net_def,
223
                            pc.GLOBAL_INIT_NET_TYPE))
224
        predict_init_net = core.Net(pred_utils.GetNet(
225
            meta_net_def, pc.PREDICT_INIT_NET_TYPE))
226

227
        # check device options
228
        for op in list(init_net.Proto().op) + list(predict_init_net.Proto().op):
229
            self.assertEqual(1, op.device_option.device_id)
230
            self.assertEqual(caffe2_pb2.CPU, op.device_option.device_type)
231

232
    def test_db_fails_without_params(self):
233
        with self.assertRaises(Exception):
234
            for db_type in ["minidb"]:
235
                db_file = tempfile.NamedTemporaryFile(
236
                    delete=False, suffix=".{}".format(db_type))
237
                pe.save_to_db(
238
                    db_type=db_type,
239
                    db_destination=db_file.name,
240
                    predictor_export_meta=self.predictor_export_meta)
241

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.