pytorch

Форк
0
/
load_save_test.py 
868 строк · 32.4 Кб
1
import hypothesis.strategies as st
2
from hypothesis import given, assume, settings
3
import io
4
import math
5
import numpy as np
6
import os
7
import struct
8
import unittest
9
from pathlib import Path
10
from typing import Dict, Generator, List, NamedTuple, Optional, Tuple, Type
11
from caffe2.proto import caffe2_pb2
12
from caffe2.proto.caffe2_pb2 import BlobSerializationOptions
13
from caffe2.python import core, test_util, workspace
14

15
if workspace.has_gpu_support:
16
    DEVICES = [caffe2_pb2.CPU, workspace.GpuDeviceType]
17
    max_gpuid = workspace.NumGpuDevices() - 1
18
else:
19
    DEVICES = [caffe2_pb2.CPU]
20
    max_gpuid = 0
21

22

23
class MiniDBEntry(NamedTuple):
24
    key: str
25
    value_size: int
26

27

28
# Utility class for other loading tests, don't add test functions here
29
# Inherit from this test instead. If you add a test here,
30
# each derived class will inherit it as well and cause test duplication
31
class TestLoadSaveBase(test_util.TestCase):
32

33
    def __init__(self, methodName, db_type='minidb'):
34
        super().__init__(methodName)
35
        self._db_type = db_type
36

37
    @settings(deadline=None)
38
    @given(src_device_type=st.sampled_from(DEVICES),
39
           src_gpu_id=st.integers(min_value=0, max_value=max_gpuid),
40
           dst_device_type=st.sampled_from(DEVICES),
41
           dst_gpu_id=st.integers(min_value=0, max_value=max_gpuid))
42
    def load_save(self, src_device_type, src_gpu_id,
43
                  dst_device_type, dst_gpu_id):
44
        workspace.ResetWorkspace()
45
        dtypes = [np.float16, np.float32, np.float64, bool, np.int8,
46
                  np.int16, np.int32, np.int64, np.uint8, np.uint16]
47
        arrays = [np.random.permutation(6).reshape(2, 3).astype(T)
48
                  for T in dtypes]
49
        assume(core.IsGPUDeviceType(src_device_type) or src_gpu_id == 0)
50
        assume(core.IsGPUDeviceType(dst_device_type) or dst_gpu_id == 0)
51
        src_device_option = core.DeviceOption(
52
            src_device_type, src_gpu_id)
53
        dst_device_option = core.DeviceOption(
54
            dst_device_type, dst_gpu_id)
55

56
        for i, arr in enumerate(arrays):
57
            self.assertTrue(workspace.FeedBlob(str(i), arr, src_device_option))
58
            self.assertTrue(workspace.HasBlob(str(i)))
59

60
        # Saves the blobs to a local db.
61
        tmp_folder = self.make_tempdir()
62
        op = core.CreateOperator(
63
            "Save",
64
            [str(i) for i in range(len(arrays))], [],
65
            absolute_path=1,
66
            db=str(tmp_folder / "db"), db_type=self._db_type)
67
        self.assertTrue(workspace.RunOperatorOnce(op))
68

69
        # Reset the workspace so that anything we load is surely loaded
70
        # from the serialized proto.
71
        workspace.ResetWorkspace()
72
        self.assertEqual(len(workspace.Blobs()), 0)
73

74
        def _LoadTest(keep_device, device_type, gpu_id, blobs, loadAll):
75
            """A helper subfunction to test keep and not keep."""
76
            op = core.CreateOperator(
77
                "Load",
78
                [], blobs,
79
                absolute_path=1,
80
                db=str(tmp_folder / "db"), db_type=self._db_type,
81
                device_option=dst_device_option,
82
                keep_device=keep_device,
83
                load_all=loadAll)
84
            self.assertTrue(workspace.RunOperatorOnce(op))
85
            for i, arr in enumerate(arrays):
86
                self.assertTrue(workspace.HasBlob(str(i)))
87
                fetched = workspace.FetchBlob(str(i))
88
                self.assertEqual(fetched.dtype, arr.dtype)
89
                np.testing.assert_array_equal(
90
                    workspace.FetchBlob(str(i)), arr)
91
                proto = caffe2_pb2.BlobProto()
92
                proto.ParseFromString(workspace.SerializeBlob(str(i)))
93
                self.assertTrue(proto.HasField('tensor'))
94
                self.assertEqual(proto.tensor.device_detail.device_type,
95
                                 device_type)
96
                if core.IsGPUDeviceType(device_type):
97
                    self.assertEqual(proto.tensor.device_detail.device_id,
98
                                     gpu_id)
99

100
        blobs = [str(i) for i in range(len(arrays))]
101
        # Load using device option stored in the proto, i.e.
102
        # src_device_option
103
        _LoadTest(1, src_device_type, src_gpu_id, blobs, 0)
104
        # Load again, but this time load into dst_device_option.
105
        _LoadTest(0, dst_device_type, dst_gpu_id, blobs, 0)
106
        # Load back to the src_device_option to see if both paths are able
107
        # to reallocate memory.
108
        _LoadTest(1, src_device_type, src_gpu_id, blobs, 0)
109
        # Reset the workspace, and load directly into the dst_device_option.
110
        workspace.ResetWorkspace()
111
        _LoadTest(0, dst_device_type, dst_gpu_id, blobs, 0)
112

113
        # Test load all which loads all blobs in the db into the workspace.
114
        workspace.ResetWorkspace()
115
        _LoadTest(1, src_device_type, src_gpu_id, [], 1)
116
        # Load again making sure that overwrite functionality works.
117
        _LoadTest(1, src_device_type, src_gpu_id, [], 1)
118
        # Load again with different device.
119
        _LoadTest(0, dst_device_type, dst_gpu_id, [], 1)
120
        workspace.ResetWorkspace()
121
        _LoadTest(0, dst_device_type, dst_gpu_id, [], 1)
122
        workspace.ResetWorkspace()
123
        _LoadTest(1, src_device_type, src_gpu_id, blobs, 1)
124
        workspace.ResetWorkspace()
125
        _LoadTest(0, dst_device_type, dst_gpu_id, blobs, 1)
126

127
    def saveFile(
128
        self, tmp_folder: Path, db_name: str, db_type: str, start_blob_id: int
129
    ) -> Tuple[str, List[np.ndarray]]:
130
        dtypes = [np.float16, np.float32, np.float64, bool, np.int8,
131
                  np.int16, np.int32, np.int64, np.uint8, np.uint16]
132
        arrays = [np.random.permutation(6).reshape(2, 3).astype(T)
133
                  for T in dtypes]
134

135
        for i, arr in enumerate(arrays):
136
            self.assertTrue(workspace.FeedBlob(str(i + start_blob_id), arr))
137
            self.assertTrue(workspace.HasBlob(str(i + start_blob_id)))
138

139
        # Saves the blobs to a local db.
140
        tmp_file = str(tmp_folder / db_name)
141
        op = core.CreateOperator(
142
            "Save",
143
            [str(i + start_blob_id) for i in range(len(arrays))], [],
144
            absolute_path=1,
145
            db=tmp_file, db_type=db_type)
146
        workspace.RunOperatorOnce(op)
147
        return tmp_file, arrays
148

149

150
class TestLoadSave(TestLoadSaveBase):
151

152
    def testLoadSave(self):
153
        self.load_save()
154

155
    def testRepeatedArgs(self):
156
        dtypes = [np.float16, np.float32, np.float64, bool, np.int8,
157
                  np.int16, np.int32, np.int64, np.uint8, np.uint16]
158
        arrays = [np.random.permutation(6).reshape(2, 3).astype(T)
159
                  for T in dtypes]
160

161
        for i, arr in enumerate(arrays):
162
            self.assertTrue(workspace.FeedBlob(str(i), arr))
163
            self.assertTrue(workspace.HasBlob(str(i)))
164

165
        # Saves the blobs to a local db.
166
        tmp_folder = self.make_tempdir()
167
        op = core.CreateOperator(
168
            "Save",
169
            [str(i) for i in range(len(arrays))] * 2, [],
170
            absolute_path=1,
171
            db=str(tmp_folder / "db"), db_type=self._db_type)
172
        with self.assertRaises(RuntimeError):
173
            workspace.RunOperatorOnce(op)
174

175
    def testLoadExcessblobs(self):
176
        tmp_folder = self.make_tempdir()
177
        tmp_file, arrays = self.saveFile(tmp_folder, "db", self._db_type, 0)
178

179
        op = core.CreateOperator(
180
            "Load",
181
            [], [str(i) for i in range(len(arrays))] * 2,
182
            absolute_path=1,
183
            db=tmp_file, db_type=self._db_type,
184
            load_all=False)
185
        with self.assertRaises(RuntimeError):
186
            workspace.RunOperatorOnce(op)
187

188
        op = core.CreateOperator(
189
            "Load",
190
            [], [str(len(arrays) + i) for i in [-1, 0]],
191
            absolute_path=1,
192
            db=tmp_file, db_type=self._db_type,
193
            load_all=True)
194
        with self.assertRaises(RuntimeError):
195
            workspace.ResetWorkspace()
196
            workspace.RunOperatorOnce(op)
197

198
        op = core.CreateOperator(
199
            "Load",
200
            [], [str(len(arrays) + i) for i in range(2)],
201
            absolute_path=1,
202
            db=tmp_file, db_type=self._db_type,
203
            load_all=True)
204
        with self.assertRaises(RuntimeError):
205
            workspace.ResetWorkspace()
206
            workspace.RunOperatorOnce(op)
207

208
    def testTruncatedFile(self):
209
        tmp_folder = self.make_tempdir()
210
        tmp_file, arrays = self.saveFile(tmp_folder, "db", self._db_type, 0)
211

212
        with open(tmp_file, 'wb+') as fdest:
213
            fdest.seek(20, os.SEEK_END)
214
            fdest.truncate()
215

216
        op = core.CreateOperator(
217
            "Load",
218
            [], [str(i) for i in range(len(arrays))],
219
            absolute_path=1,
220
            db=tmp_file, db_type=self._db_type,
221
            load_all=False)
222
        with self.assertRaises(RuntimeError):
223
            workspace.RunOperatorOnce(op)
224

225
        op = core.CreateOperator(
226
            "Load",
227
            [], [],
228
            absolute_path=1,
229
            db=tmp_file, db_type=self._db_type,
230
            load_all=True)
231
        with self.assertRaises(RuntimeError):
232
            workspace.RunOperatorOnce(op)
233

234
    def testBlobNameOverrides(self):
235
        original_names = ['blob_a', 'blob_b', 'blob_c']
236
        new_names = ['x', 'y', 'z']
237
        blobs = [np.random.permutation(6) for i in range(3)]
238
        for i, blob in enumerate(blobs):
239
            self.assertTrue(workspace.FeedBlob(original_names[i], blob))
240
            self.assertTrue(workspace.HasBlob(original_names[i]))
241
        self.assertEqual(len(workspace.Blobs()), 3)
242

243
        # Saves the blobs to a local db.
244
        tmp_folder = self.make_tempdir()
245
        with self.assertRaises(RuntimeError):
246
            workspace.RunOperatorOnce(
247
                core.CreateOperator(
248
                    "Save", original_names, [],
249
                    absolute_path=1,
250
                    strip_prefix='.temp',
251
                    blob_name_overrides=new_names,
252
                    db=str(tmp_folder / "db"),
253
                    db_type=self._db_type
254
                )
255
            )
256
        self.assertTrue(
257
            workspace.RunOperatorOnce(
258
                core.CreateOperator(
259
                    "Save", original_names, [],
260
                    absolute_path=1,
261
                    blob_name_overrides=new_names,
262
                    db=str(tmp_folder / "db"),
263
                    db_type=self._db_type
264
                )
265
            )
266
        )
267
        self.assertTrue(workspace.ResetWorkspace())
268
        self.assertEqual(len(workspace.Blobs()), 0)
269
        self.assertTrue(
270
            workspace.RunOperatorOnce(
271
                core.CreateOperator(
272
                    "Load", [], [],
273
                    absolute_path=1,
274
                    db=str(tmp_folder / "db"),
275
                    db_type=self._db_type,
276
                    load_all=1
277
                )
278
            )
279
        )
280
        self.assertEqual(len(workspace.Blobs()), 3)
281
        for i, name in enumerate(new_names):
282
            self.assertTrue(workspace.HasBlob(name))
283
            self.assertTrue((workspace.FetchBlob(name) == blobs[i]).all())
284
        # moved here per @cxj's suggestion
285
        load_new_names = ['blob_x', 'blob_y', 'blob_z']
286
        # load 'x' into 'blob_x'
287
        self.assertTrue(
288
            workspace.RunOperatorOnce(
289
                core.CreateOperator(
290
                    "Load", [], load_new_names[0:1],
291
                    absolute_path=1,
292
                    db=str(tmp_folder / "db"),
293
                    db_type=self._db_type,
294
                    source_blob_names=new_names[0:1]
295
                )
296
            )
297
        )
298
        # we should have 'blob_a/b/c/' and 'blob_x' now
299
        self.assertEqual(len(workspace.Blobs()), 4)
300
        for i, name in enumerate(load_new_names[0:1]):
301
            self.assertTrue(workspace.HasBlob(name))
302
            self.assertTrue((workspace.FetchBlob(name) == blobs[i]).all())
303
        self.assertTrue(
304
            workspace.RunOperatorOnce(
305
                core.CreateOperator(
306
                    "Load", [], load_new_names[0:3],
307
                    absolute_path=1,
308
                    db=str(tmp_folder / "db"),
309
                    db_type=self._db_type,
310
                    source_blob_names=new_names[0:3]
311
                )
312
            )
313
        )
314
        # we should have 'blob_a/b/c/' and 'blob_x/y/z' now
315
        self.assertEqual(len(workspace.Blobs()), 6)
316
        for i, name in enumerate(load_new_names[0:3]):
317
            self.assertTrue(workspace.HasBlob(name))
318
            self.assertTrue((workspace.FetchBlob(name) == blobs[i]).all())
319

320
    def testMissingFile(self):
321
        tmp_folder = self.make_tempdir()
322
        tmp_file = tmp_folder / "missing_db"
323

324
        op = core.CreateOperator(
325
            "Load",
326
            [], [],
327
            absolute_path=1,
328
            db=str(tmp_file), db_type=self._db_type,
329
            load_all=True)
330
        with self.assertRaises(RuntimeError):
331
            try:
332
                workspace.RunOperatorOnce(op)
333
            except RuntimeError as e:
334
                print(e)
335
                raise
336

337
    def testLoadMultipleFilesGivenSourceBlobNames(self):
338
        tmp_folder = self.make_tempdir()
339
        db_file_1, arrays_1 = self.saveFile(tmp_folder, "db1", self._db_type, 0)
340
        db_file_2, arrays_2 = self.saveFile(
341
            tmp_folder, "db2", self._db_type, len(arrays_1)
342
        )
343
        db_files = [db_file_1, db_file_2]
344
        blobs_names = [str(i) for i in range(len(arrays_1) + len(arrays_2))]
345

346
        workspace.ResetWorkspace()
347
        self.assertEqual(len(workspace.Blobs()), 0)
348
        self.assertTrue(
349
            workspace.RunOperatorOnce(
350
                core.CreateOperator(
351
                    "Load",
352
                    [], blobs_names,
353
                    absolute_path=1,
354
                    dbs=db_files, db_type=self._db_type,
355
                    source_blob_names=blobs_names
356
                )
357
            )
358
        )
359
        self.assertEqual(len(workspace.Blobs()), len(blobs_names))
360
        for i in range(len(arrays_1)):
361
            np.testing.assert_array_equal(
362
                workspace.FetchBlob(str(i)), arrays_1[i]
363
            )
364
        for i in range(len(arrays_2)):
365
            np.testing.assert_array_equal(
366
                workspace.FetchBlob(str(i + len(arrays_1))), arrays_2[i]
367
            )
368

369
    def testLoadAllMultipleFiles(self):
370
        tmp_folder = self.make_tempdir()
371
        db_file_1, arrays_1 = self.saveFile(tmp_folder, "db1", self._db_type, 0)
372
        db_file_2, arrays_2 = self.saveFile(
373
            tmp_folder, "db2", self._db_type, len(arrays_1)
374
        )
375
        db_files = [db_file_1, db_file_2]
376

377
        workspace.ResetWorkspace()
378
        self.assertEqual(len(workspace.Blobs()), 0)
379
        self.assertTrue(
380
            workspace.RunOperatorOnce(
381
                core.CreateOperator(
382
                    "Load",
383
                    [], [],
384
                    absolute_path=1,
385
                    dbs=db_files, db_type=self._db_type,
386
                    load_all=True
387
                )
388
            )
389
        )
390
        self.assertEqual(len(workspace.Blobs()), len(arrays_1) + len(arrays_2))
391
        for i in range(len(arrays_1)):
392
            np.testing.assert_array_equal(
393
                workspace.FetchBlob(str(i)), arrays_1[i]
394
            )
395
        for i in range(len(arrays_2)):
396
            np.testing.assert_array_equal(
397
                workspace.FetchBlob(str(i + len(arrays_1))), arrays_2[i]
398
            )
399

400
    def testLoadAllMultipleFilesWithSameKey(self):
401
        tmp_folder = self.make_tempdir()
402
        db_file_1, arrays_1 = self.saveFile(tmp_folder, "db1", self._db_type, 0)
403
        db_file_2, arrays_2 = self.saveFile(tmp_folder, "db2", self._db_type, 0)
404

405
        db_files = [db_file_1, db_file_2]
406
        workspace.ResetWorkspace()
407
        self.assertEqual(len(workspace.Blobs()), 0)
408
        op = core.CreateOperator(
409
            "Load",
410
            [], [],
411
            absolute_path=1,
412
            dbs=db_files, db_type=self._db_type,
413
            load_all=True)
414
        with self.assertRaises(RuntimeError):
415
            workspace.RunOperatorOnce(op)
416

417
    def testLoadRepeatedFiles(self):
418
        tmp_folder = self.make_tempdir()
419
        tmp_file, arrays = self.saveFile(tmp_folder, "db", self._db_type, 0)
420

421
        db_files = [tmp_file, tmp_file]
422
        workspace.ResetWorkspace()
423
        self.assertEqual(len(workspace.Blobs()), 0)
424
        op = core.CreateOperator(
425
            "Load",
426
            [], [str(i) for i in range(len(arrays))],
427
            absolute_path=1,
428
            dbs=db_files, db_type=self._db_type,
429
            load_all=False)
430
        with self.assertRaises(RuntimeError):
431
            workspace.RunOperatorOnce(op)
432

433
    def testLoadWithDBOptions(self) -> None:
434
        tmp_folder = self.make_tempdir()
435
        tmp_file, arrays = self.saveFile(tmp_folder, "db", self._db_type, 0)
436

437
        db_files = [tmp_file, tmp_file]
438
        workspace.ResetWorkspace()
439
        self.assertEqual(len(workspace.Blobs()), 0)
440

441
        db_options = b"test_db_options"
442
        op = core.CreateOperator(
443
            "Load",
444
            [], [str(i) for i in range(len(arrays))],
445
            absolute_path=1,
446
            dbs=db_files, db_type=self._db_type,
447
            load_all=False,
448
            db_options=db_options,
449
        )
450
        with self.assertRaises(RuntimeError):
451
            workspace.RunOperatorOnce(op)
452

453
    def create_test_blobs(
454
        self, size: int = 1234, feed: bool = True
455
    ) -> List[Tuple[str, np.ndarray]]:
456
        def int_array(dtype: Type[np.integer], size: int) -> np.ndarray:
457
            info = np.iinfo(dtype)
458
            return np.random.randint(info.min, info.max, size, dtype=dtype)
459

460
        def float_array(dtype: Type[np.floating], size: int) -> np.ndarray:
461
            return np.random.random_sample(size).astype(dtype)
462

463
        blobs = [
464
            ("int8_data", int_array(np.int8, size)),
465
            ("int16_data", int_array(np.int16, size)),
466
            ("int32_data", int_array(np.int32, size)),
467
            ("int64_data", int_array(np.int64, size)),
468
            ("uint8_data", int_array(np.uint8, size)),
469
            ("uint16_data", int_array(np.uint16, size)),
470
            ("float16_data", float_array(np.float16, size)),
471
            ("float32_data", float_array(np.float32, size)),
472
            ("float64_data", float_array(np.float64, size)),
473
        ]
474

475
        if feed:
476
            for name, data in blobs:
477
                workspace.FeedBlob(name, data)
478

479
        return blobs
480

481
    def load_blobs(
482
        self,
483
        blob_names: List[str],
484
        dbs: List[str],
485
        db_type: Optional[str] = None
486
    ) -> None:
487
        workspace.ResetWorkspace()
488
        self.assertEqual(len(workspace.Blobs()), 0)
489
        load_op = core.CreateOperator(
490
            "Load",
491
            [],
492
            blob_names,
493
            absolute_path=1,
494
            dbs=dbs,
495
            db_type=db_type or self._db_type,
496
        )
497
        self.assertTrue(workspace.RunOperatorOnce(load_op))
498
        self.assertEqual(len(workspace.Blobs()), len(blob_names))
499

500
    def load_and_check_blobs(
501
        self,
502
        blobs: List[Tuple[str, np.ndarray]],
503
        dbs: List[str],
504
        db_type: Optional[str] = None
505
    ) -> None:
506
        self.load_blobs([name for name, data in blobs], dbs, db_type)
507
        for name, data in blobs:
508
            np.testing.assert_array_equal(workspace.FetchBlob(name), data)
509

510
    def _read_minidb_entries(
511
        self, path: Path
512
    ) -> Generator[MiniDBEntry, None, None]:
513
        """Read the entry information out of a minidb file.
514
        """
515
        header = struct.Struct("=ii")
516
        with path.open("rb") as f:
517
            while True:
518
                buf = f.read(header.size)
519
                if not buf:
520
                    break
521
                if len(buf) < header.size:
522
                    raise Exception("early EOF in minidb header")
523
                (key_len, value_len) = header.unpack(buf)
524
                if key_len < 0 or value_len < 0:
525
                    raise Exception(
526
                        f"invalid minidb header: ({key_len}, {value_len})"
527
                    )
528
                key = f.read(key_len)
529
                if len(key) < key_len:
530
                    raise Exception("early EOF in minidb key")
531
                f.seek(value_len, io.SEEK_CUR)
532
                yield MiniDBEntry(key=key.decode("utf-8"), value_size=value_len)
533

534
    def _read_chunk_info(self, path: Path) -> Dict[str, List[MiniDBEntry]]:
535
        """Read a minidb file and return the names of each blob and how many
536
        chunks are stored for that blob.
537
        """
538
        chunk_id_separator = "#%"
539
        results: Dict[str, List[MiniDBEntry]] = {}
540
        for entry in self._read_minidb_entries(path):
541
            parts = entry.key.rsplit(chunk_id_separator, 1)
542
            if len(parts) == 0:
543
                assert entry.key not in results
544
                results[entry.key] = [entry]
545
            else:
546
                blob_name = parts[0]
547
                results.setdefault(blob_name, [])
548
                results[blob_name].append(entry)
549

550
        return results
551

552
    def _test_save_with_chunk_size(
553
        self, num_elems: int, chunk_size: int, expected_num_chunks: int,
554
    ) -> None:
555
        tmp_folder = self.make_tempdir()
556
        tmp_file = str(tmp_folder / "save.output")
557

558
        blobs = self.create_test_blobs(num_elems)
559

560
        # Saves the blobs to a local db.
561
        save_op = core.CreateOperator(
562
            "Save",
563
            [name for name, data in blobs],
564
            [],
565
            absolute_path=1,
566
            db=tmp_file,
567
            db_type=self._db_type,
568
            chunk_size=chunk_size,
569
        )
570
        self.assertTrue(workspace.RunOperatorOnce(save_op))
571

572
        self.load_and_check_blobs(blobs, [tmp_file])
573

574
        blob_chunks = self._read_chunk_info(Path(tmp_file))
575
        for blob_name, chunks in blob_chunks.items():
576
            self.assertEqual(len(chunks), expected_num_chunks)
577

578
    def testSaveWithChunkSize(self) -> None:
579
        num_elems = 1234
580
        chunk_size = 32
581
        expected_num_chunks = math.ceil(num_elems / chunk_size)
582
        self._test_save_with_chunk_size(
583
            num_elems=num_elems,
584
            chunk_size=chunk_size,
585
            expected_num_chunks=expected_num_chunks,
586
        )
587

588
    def testSaveWithDefaultChunkSize(self) -> None:
589
        # This is the default value of the --caffe2_tensor_chunk_size flag from
590
        # core/blob_serialization.cc
591
        #
592
        # Test with just slightly more than this to ensure that 2 chunks are
593
        # used.
594
        default_chunk_size = 1000000
595
        self._test_save_with_chunk_size(
596
            num_elems=default_chunk_size + 10,
597
            chunk_size=-1,
598
            expected_num_chunks=2,
599
        )
600

601
    def testSaveWithNoChunking(self) -> None:
602
        default_chunk_size = 1000000
603
        self._test_save_with_chunk_size(
604
            num_elems=default_chunk_size + 10,
605
            chunk_size=0,
606
            expected_num_chunks=1,
607
        )
608

609
    def testSaveWithOptions(self) -> None:
610
        tmp_folder = self.make_tempdir()
611
        tmp_file = str(tmp_folder / "save.output")
612

613
        num_elems = 1234
614
        blobs = self.create_test_blobs(num_elems)
615

616
        # Saves the blobs to a local db.
617
        save_op = core.CreateOperator(
618
            "Save",
619
            [name for name, data in blobs],
620
            [],
621
            absolute_path=1,
622
            db=tmp_file,
623
            db_type=self._db_type,
624
            chunk_size=40,
625
            options=caffe2_pb2.SerializationOptions(
626
                options=[
627
                    BlobSerializationOptions(
628
                        blob_name_regex="int16_data", chunk_size=10
629
                    ),
630
                    BlobSerializationOptions(
631
                        blob_name_regex=".*16_data", chunk_size=20
632
                    ),
633
                    BlobSerializationOptions(
634
                        blob_name_regex="float16_data", chunk_size=30
635
                    ),
636
                ],
637
            ),
638
        )
639
        self.assertTrue(workspace.RunOperatorOnce(save_op))
640

641
        self.load_and_check_blobs(blobs, [tmp_file])
642

643
        blob_chunks = self._read_chunk_info(Path(tmp_file))
644
        # We explicitly set a chunk_size of 10 for int16_data
645
        self.assertEqual(
646
            len(blob_chunks["int16_data"]), math.ceil(num_elems / 10)
647
        )
648
        # uint16_data should match the .*16_data pattern, and get a size of 20
649
        self.assertEqual(
650
            len(blob_chunks["uint16_data"]), math.ceil(num_elems / 20)
651
        )
652
        # float16_data should also match the .*16_data pattern, and get a size
653
        # of 20.  The explicitly float16_data rule came after the .*16_data
654
        # pattern, so it has lower precedence and will be ignored.
655
        self.assertEqual(
656
            len(blob_chunks["float16_data"]), math.ceil(num_elems / 20)
657
        )
658
        # int64_data will get the default chunk_size of 40
659
        self.assertEqual(
660
            len(blob_chunks["int64_data"]), math.ceil(num_elems / 40)
661
        )
662

663

664
    def testSaveWithDBOptions(self) -> None:
665
        num_elems = 1234
666
        chunk_size = 32
667
        expected_num_chunks = math.ceil(num_elems / chunk_size)
668

669
        tmp_folder = self.make_tempdir()
670
        tmp_file = str(tmp_folder / "save.output")
671

672
        blobs = self.create_test_blobs(num_elems)
673

674
        db_options = b"test_db_options"
675
        # Saves the blobs to a local db.
676
        save_op = core.CreateOperator(
677
            "Save",
678
            [name for name, data in blobs],
679
            [],
680
            absolute_path=1,
681
            db=tmp_file,
682
            db_type=self._db_type,
683
            chunk_size=chunk_size,
684
            db_options=db_options,
685
        )
686
        self.assertTrue(workspace.RunOperatorOnce(save_op))
687

688
        self.load_and_check_blobs(blobs, [tmp_file])
689

690
        blob_chunks = self._read_chunk_info(Path(tmp_file))
691
        for blob_name, chunks in blob_chunks.items():
692
            self.assertEqual(len(chunks), expected_num_chunks)
693

694
    def testSaveFloatToBfloat16(self) -> None:
695
        tmp_folder = self.make_tempdir()
696
        tmp_file = str(tmp_folder / "save.output")
697

698
        # Create 2 blobs with the same float data
699
        float_data = np.random.random_sample(4000).astype(np.float32)
700
        workspace.FeedBlob("float1", float_data)
701
        workspace.FeedBlob("float2", float_data)
702
        blob_names = ["float1", "float2"]
703

704
        # Serialize the data, using bfloat16 serialization for one of the blobs
705
        save_op = core.CreateOperator(
706
            "Save",
707
            blob_names,
708
            [],
709
            absolute_path=1,
710
            db=tmp_file,
711
            db_type=self._db_type,
712
            options=caffe2_pb2.SerializationOptions(
713
                options=[
714
                    BlobSerializationOptions(
715
                        blob_name_regex="float1",
716
                        float_format=BlobSerializationOptions.FLOAT_BFLOAT16,
717
                    ),
718
                ],
719
            ),
720
        )
721
        self.assertTrue(workspace.RunOperatorOnce(save_op))
722

723
        # As long as fbgemm was available for us to perform bfloat16 conversion,
724
        # the serialized data for float1 should be almost half the size of float2
725
        if workspace.has_fbgemm:
726
            blob_chunks = self._read_chunk_info(Path(tmp_file))
727
            self.assertEqual(len(blob_chunks["float1"]), 1, blob_chunks["float1"])
728
            self.assertEqual(len(blob_chunks["float2"]), 1, blob_chunks["float2"])
729
            self.assertLess(
730
                blob_chunks["float1"][0].value_size,
731
                0.6 * blob_chunks["float2"][0].value_size
732
            )
733

734
        self.load_blobs(blob_names, [tmp_file])
735

736
        # float2 should be exactly the same as the input data
737
        np.testing.assert_array_equal(workspace.FetchBlob("float2"), float_data)
738
        # float2 should be close-ish to the input data
739
        np.testing.assert_array_almost_equal(
740
            workspace.FetchBlob("float1"), float_data, decimal=2
741
        )
742

743
    def testEstimateBlobSizes(self) -> None:
744
        # Create some blobs to test with
745
        float_data = np.random.random_sample(4000).astype(np.float32)
746
        workspace.FeedBlob("float1", float_data)
747
        workspace.FeedBlob("float2", float_data)
748
        workspace.FeedBlob(
749
            "float3", np.random.random_sample(2).astype(np.float32)
750
        )
751
        workspace.FeedBlob(
752
            "ui16", np.random.randint(0, 0xffff, size=1024, dtype=np.uint16)
753
        )
754

755
        # Estimate the serialized size of the data.
756
        # Request bfloat16 serialization for one of the float blobs, just to
757
        # exercise size estimation when using this option.
758
        options = caffe2_pb2.SerializationOptions(
759
            options=[
760
                BlobSerializationOptions(
761
                    blob_name_regex="float1",
762
                    float_format=BlobSerializationOptions.FLOAT_BFLOAT16,
763
                    chunk_size=500,
764
                ),
765
            ],
766
        )
767
        get_blobs_op = core.CreateOperator(
768
            "EstimateAllBlobSizes",
769
            [],
770
            ["blob_names", "blob_sizes"],
771
            options=options,
772
        )
773
        self.assertTrue(workspace.RunOperatorOnce(get_blobs_op))
774
        blob_names = workspace.FetchBlob("blob_names")
775
        blob_sizes = workspace.FetchBlob("blob_sizes")
776

777
        sizes_by_name: Dict[str, int] = {}
778
        for idx, name in enumerate(blob_names):
779
            sizes_by_name[name.decode("utf-8")] = blob_sizes[idx]
780

781
        # Note that the output blob list will include our output blob names.
782
        expected_blobs = [
783
            "float1", "float2", "float3", "ui16",
784
            "blob_names", "blob_sizes"
785
        ]
786
        self.assertEqual(set(sizes_by_name.keys()), set(expected_blobs))
787

788
        def check_expected_blob_size(
789
            name: str, num_elems: int, elem_size: int, num_chunks: int = 1
790
        ) -> None:
791
            # The estimation code applies a fixed 40 byte per-chunk overhead to
792
            # account for the extra space required for other fixed TensorProto
793
            # message fields.
794
            per_chunk_overhead = 50
795
            expected_size = (
796
                (num_chunks * (len(name) + per_chunk_overhead))
797
                + (num_elems * elem_size)
798
            )
799
            self.assertEqual(
800
                sizes_by_name[name],
801
                expected_size,
802
                f"expected size mismatch for {name}"
803
            )
804

805
        check_expected_blob_size("ui16", 1024, 3)
806
        check_expected_blob_size("float2", 4000, 4)
807
        check_expected_blob_size("float3", 2, 4)
808

809
        # Our serialization options request to split float1 into 500-element
810
        # chunks when saving it.  If fbgemm is available then the float1 blob
811
        # will be serialized using 2 bytes per element instead of 4 bytes.
812
        float1_num_chunks = 4000 // 500
813
        if workspace.has_fbgemm:
814
            check_expected_blob_size("float1", 4000, 2, float1_num_chunks)
815
        else:
816
            check_expected_blob_size("float1", 4000, 4, float1_num_chunks)
817

818
        check_expected_blob_size("blob_names", len(expected_blobs), 50)
819
        check_expected_blob_size("blob_sizes", len(expected_blobs), 8)
820

821
        # Now actually save the blobs so we can compare our estimates
822
        # to how big the serialized data actually is.
823
        tmp_folder = self.make_tempdir()
824
        tmp_file = str(tmp_folder / "save.output")
825
        save_op = core.CreateOperator(
826
            "Save",
827
            list(sizes_by_name.keys()),
828
            [],
829
            absolute_path=1,
830
            db=tmp_file,
831
            db_type=self._db_type,
832
            options=options,
833
        )
834
        self.assertTrue(workspace.RunOperatorOnce(save_op))
835

836
        blob_chunks = self._read_chunk_info(Path(tmp_file))
837
        saved_sizes: Dict[str, int] = {}
838
        for blob_name, chunks in blob_chunks.items():
839
            total_size = sum(chunk.value_size for chunk in chunks)
840
            saved_sizes[blob_name] = total_size
841

842
        # For sanity checking, ensure that our estimates aren't
843
        # extremely far off
844
        for name in expected_blobs:
845
            estimated_size = sizes_by_name[name]
846
            saved_size = saved_sizes[name]
847
            difference = abs(estimated_size - saved_size)
848
            error_pct = 100.0 * (difference / saved_size)
849
            print(
850
                f"{name}: estimated={estimated_size} actual={saved_size} "
851
                f"error={error_pct:.2f}%"
852
            )
853
            # Don't check the blob_names blob.  It is a string tensor, and we
854
            # can't estimate string tensor sizes very well without knowing the
855
            # individual string lengths.  (Currently it requires 102 bytes to
856
            # save, but we estimate 360).
857
            if name == "blob_names":
858
                continue
859
            # Check that we are within 100 bytes, or within 25%
860
            # We are generally quite close for tensors with fixed-width fields
861
            # (like float), but a little farther off for tensors that use varint
862
            # encoding.
863
            if difference > 100:
864
                self.assertLess(error_pct, 25.0)
865

866

867
if __name__ == '__main__':
868
    unittest.main()
869

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.