pytorch

Форк
0
/
_deploy.py 
105 строк · 3.4 Кб
1
import io
2

3
import torch
4
from torch.package import Importer, OrderedImporter, PackageImporter, sys_importer
5
from torch.package._package_pickler import create_pickler
6
from torch.package._package_unpickler import PackageUnpickler
7
from torch.serialization import _maybe_decode_ascii
8

9

10
def _save_storages(importer, obj):
11
    serialized_storages = []
12
    serialized_dtypes = []
13

14
    importer = importer if isinstance(importer, torch.package.PackageImporter) else None
15
    importers: Importer
16
    if importer is not None:
17
        importers = OrderedImporter(importer, sys_importer)
18
    else:
19
        importers = sys_importer
20

21
    def persistent_id(obj):
22
        if torch.is_storage(obj) or isinstance(obj, torch.storage.TypedStorage):
23
            if isinstance(obj, torch.storage.TypedStorage):
24
                # TODO: Once we decide to break serialization FC, we can
25
                # remove this case
26
                storage = obj._untyped_storage
27
                dtype = obj.dtype
28
            else:
29
                storage = obj
30
                dtype = torch.uint8
31

32
            serialized_storages.append(obj)
33
            serialized_dtypes.append(dtype)
34
            return ("storage", len(serialized_storages) - 1)
35

36
        if hasattr(obj, "__reduce_deploy__"):
37
            if _serialized_reduces.get(id(obj)) is None:
38
                _serialized_reduces[id(obj)] = (
39
                    "reduce_deploy",
40
                    id(obj),
41
                    *obj.__reduce_deploy__(importers),
42
                )
43
            return _serialized_reduces[id(obj)]
44

45
        return None
46

47
    # Write the pickle data for `obj`
48
    data_buf = io.BytesIO()
49
    pickler = create_pickler(data_buf, importers)
50
    pickler.persistent_id = persistent_id
51
    pickler.dump(obj)
52
    data_value = data_buf.getvalue()
53
    return (
54
        data_value,
55
        serialized_storages,
56
        serialized_dtypes,
57
        importer.zip_reader if importer else None,
58
    )
59

60

61
def _load_storages(id, zip_reader, obj_bytes, serialized_storages, serialized_dtypes):
62
    def persistent_load(saved_id):
63
        assert isinstance(saved_id, tuple)
64
        typename = _maybe_decode_ascii(saved_id[0])
65
        data = saved_id[1:]
66

67
        if typename == "storage":
68
            # TODO: Once we decide to break serialization FC, we can
69
            # stop wrapping with TypedStorage
70
            storage = serialized_storages[data[0]]
71
            dtype = serialized_dtypes[data[0]]
72
            return torch.storage.TypedStorage(
73
                wrap_storage=storage.untyped(), dtype=dtype
74
            )
75

76
        if typename == "reduce_deploy":
77
            reduce_id, func, args = data
78
            if reduce_id not in _loaded_reduces:
79
                _loaded_reduces[reduce_id] = func(_raw_packages[zip_reader], *args)
80
            return _loaded_reduces[reduce_id]
81

82
        return None
83

84
    importer: Importer
85
    if zip_reader is not None:
86
        importer = OrderedImporter(_get_package(zip_reader), sys_importer)
87
    else:
88
        importer = sys_importer
89

90
    unpickler = PackageUnpickler(importer, io.BytesIO(obj_bytes))
91
    unpickler.persistent_load = persistent_load  # type: ignore[method-assign]
92
    result = _deploy_objects[id] = unpickler.load()
93
    return result
94

95

96
def _get_package(zip_reader):
97
    if zip_reader not in _raw_packages:
98
        _raw_packages[zip_reader] = PackageImporter(zip_reader)
99
    return _raw_packages[zip_reader]
100

101

102
_raw_packages: dict = {}
103
_deploy_objects: dict = {}
104
_serialized_reduces: dict = {}
105
_loaded_reduces: dict = {}
106

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.