5
from torch.package import Importer, OrderedImporter, PackageImporter, sys_importer
6
from torch.package._package_pickler import create_pickler
7
from torch.package._package_unpickler import PackageUnpickler
8
from torch.serialization import _maybe_decode_ascii
11
def _save_storages(importer, obj):
12
serialized_storages = []
13
serialized_dtypes = []
15
importer = importer if isinstance(importer, torch.package.PackageImporter) else None
17
if importer is not None:
18
importers = OrderedImporter(importer, sys_importer)
20
importers = sys_importer
22
def persistent_id(obj):
23
if torch.is_storage(obj) or isinstance(obj, torch.storage.TypedStorage):
24
if isinstance(obj, torch.storage.TypedStorage):
31
serialized_storages.append(obj)
32
serialized_dtypes.append(dtype)
33
return ("storage", len(serialized_storages) - 1)
35
if hasattr(obj, "__reduce_deploy__"):
36
if _serialized_reduces.get(id(obj)) is None:
37
_serialized_reduces[id(obj)] = (
40
*obj.__reduce_deploy__(importers),
42
return _serialized_reduces[id(obj)]
47
data_buf = io.BytesIO()
48
pickler = create_pickler(data_buf, importers)
49
pickler.persistent_id = persistent_id
51
data_value = data_buf.getvalue()
56
importer.zip_reader if importer else None,
60
def _load_storages(id, zip_reader, obj_bytes, serialized_storages, serialized_dtypes):
61
def persistent_load(saved_id):
62
assert isinstance(saved_id, tuple)
63
typename = _maybe_decode_ascii(saved_id[0])
66
if typename == "storage":
69
storage = serialized_storages[data[0]]
70
dtype = serialized_dtypes[data[0]]
71
return torch.storage.TypedStorage(
72
wrap_storage=storage.untyped(), dtype=dtype
75
if typename == "reduce_deploy":
76
reduce_id, func, args = data
77
if reduce_id not in _loaded_reduces:
78
_loaded_reduces[reduce_id] = func(_raw_packages[zip_reader], *args)
79
return _loaded_reduces[reduce_id]
84
if zip_reader is not None:
85
importer = OrderedImporter(_get_package(zip_reader), sys_importer)
87
importer = sys_importer
89
unpickler = PackageUnpickler(importer, io.BytesIO(obj_bytes))
90
unpickler.persistent_load = persistent_load
91
result = _deploy_objects[id] = unpickler.load()
95
def _get_package(zip_reader):
96
if zip_reader not in _raw_packages:
97
_raw_packages[zip_reader] = PackageImporter(zip_reader)
98
return _raw_packages[zip_reader]
101
_raw_packages: dict = {}
102
_deploy_objects: dict = {}
103
_serialized_reduces: dict = {}
104
_loaded_reduces: dict = {}