11
from contextlib import closing, contextmanager
13
from ._utils import _import_dotted_name
14
from torch._sources import get_source_lines_and_file
15
from torch.types import Storage
16
from torch.storage import _get_dtype_from_pickle_storage_type
17
from typing import Any, BinaryIO, Callable, cast, Dict, Optional, Type, Tuple, Union, IO, List
18
from typing_extensions import TypeAlias, TypeGuard # Python 3.10+
21
import torch._weights_only_unpickler as _weights_only_unpickler
25
LONG_SIZE = struct.Struct('=l').size
26
INT_SIZE = struct.Struct('=i').size
27
SHORT_SIZE = struct.Struct('=h').size
29
MAGIC_NUMBER = 0x1950a86a20f9469cfc6c
30
PROTOCOL_VERSION = 1001
31
STORAGE_KEY_SEPARATOR = ','
33
FILE_LIKE: TypeAlias = Union[str, os.PathLike, BinaryIO, IO[bytes]]
34
MAP_LOCATION: TypeAlias = Optional[Union[Callable[[torch.Tensor, str], torch.Tensor], torch.device, str, Dict[str, str]]]
35
STORAGE: TypeAlias = Union[Storage, torch.storage.TypedStorage, torch.UntypedStorage]
38
'SourceChangeWarning',
41
'check_module_version_greater_or_equal',
42
'validate_cuda_device',
43
'validate_hpu_device',
45
'default_restore_location',
46
'normalize_storage_type',
47
'storage_to_tensor_type',
52
'get_default_load_endianness',
53
'set_default_load_endianness',
57
class SourceChangeWarning(Warning):
63
path = tempfile.mkdtemp()
70
_package_registry: List[Tuple[int, Callable[[STORAGE], Optional[str]], Callable[[STORAGE, str], Optional[STORAGE]]]] = []
72
class LoadEndianness(Enum):
77
_default_load_endian: Optional[LoadEndianness] = None
79
def get_default_load_endianness() -> Optional[LoadEndianness]:
81
Get fallback byte order for loading files
83
If byteorder mark is not present in saved checkpoint,
84
this byte order is used as fallback.
85
By default, it's "native" byte order.
88
default_load_endian: Optional[LoadEndianness]
90
return _default_load_endian
92
def set_default_load_endianness(endianness):
94
Set fallback byte order for loading files
96
If byteorder mark is not present in saved checkpoint,
97
this byte order is used as fallback.
98
By default, it's "native" byte order.
101
endianness: the new fallback byte order
103
global _default_load_endian
104
if not isinstance(endianness, LoadEndianness) and endianness is not None:
105
raise TypeError("Invalid argument type in function set_default_load_endianness")
106
_default_load_endian = endianness
108
def _is_zipfile(f) -> bool:
109
# This is a stricter implementation than zipfile.is_zipfile().
110
# zipfile.is_zipfile() is True if the magic number appears anywhere in the
111
# binary. Since we expect the files here to be generated by torch.save or
112
# torch.jit.save, it's safe to only check the start bytes and avoid
113
# collisions and assume the zip has only 1 file.
114
# See bugs.python.org/issue28494.
117
# Read the first few bytes and match against the ZIP file signature
118
local_header_magic_number = b'PK\x03\x04'
119
read_bytes = f.read(len(local_header_magic_number))
121
return read_bytes == local_header_magic_number
126
tagger: Callable[[STORAGE], Optional[str]],
127
deserializer: Callable[[STORAGE, str], Optional[STORAGE]]
130
Registers callables for tagging and deserializing storage objects with an associated priority.
131
Tagging associates a device with a storage object at save time while deserializing moves a
132
storage object to an appropriate device at load time. :attr:`tagger` and :attr:`deserializer`
133
are run in the order given by their :attr:`priority` until a tagger/deserializer returns a
134
value that is not `None`.
136
To override the deserialization behavior for a device in the global registry, one can register a
137
tagger with a higher priority than the existing tagger.
139
This function can also be used to register a tagger and deserializer for new devices.
142
priority: Indicates the priority associated with the tagger and deserializer, where a lower
143
value indicates higher priority.
144
tagger: Callable that takes in a storage object and returns its tagged device as a string
146
deserializer: Callable that takes in storage object and a device string and returns a storage
147
object on the appropriate device or None.
153
>>> def ipu_tag(obj):
154
>>> if obj.device.type == 'ipu':
156
>>> def ipu_deserialize(obj, location):
157
>>> if location.startswith('ipu'):
158
>>> ipu = getattr(torch, "ipu", None)
159
>>> assert ipu is not None, "IPU device module is not loaded"
160
>>> assert torch.ipu.is_available(), "ipu is not available"
161
>>> return obj.ipu(location)
162
>>> torch.serialization.register_package(11, ipu_tag, ipu_deserialize)
164
queue_elem = (priority, tagger, deserializer)
165
_package_registry.append(queue_elem)
166
_package_registry.sort()
169
def check_module_version_greater_or_equal(module, req_version_tuple, error_if_malformed=True):
171
Check if a module's version satisfies requirements
173
Usually, a module's version string will be like 'x.y.z', which would be represented
174
as a tuple (x, y, z), but sometimes it could be an unexpected format. If the version
175
string does not match the given tuple's format up to the length of the tuple, then
176
error and exit or emit a warning.
179
module: the module to check the version of
180
req_version_tuple: tuple (usually of ints) representing the required version
181
error_if_malformed: whether we should exit if module version string is malformed
184
requirement_is_met: bool
187
version_strs = module.__version__.split('.')
188
# Cast module version fields to match the types of the required version
189
module_version = tuple(
190
type(req_field)(version_strs[idx]) for idx, req_field in enumerate(req_version_tuple)
192
requirement_is_met = module_version >= req_version_tuple
194
except Exception as e:
196
f"'{module.__name__}' module version string is malformed '{module.__version__}' and cannot be compared"
197
f" with tuple {str(req_version_tuple)}"
199
if error_if_malformed:
200
raise RuntimeError(message) from e
202
warnings.warn(message + ', but continuing assuming that requirement is met')
203
requirement_is_met = True
205
return requirement_is_met
209
if obj.device.type == 'cpu':
214
if obj.device.type == 'cuda':
215
return 'cuda:' + str(obj.device.index)
218
if obj.device.type == 'hpu':
219
return 'hpu:' + str(obj.device.index)
222
if obj.device.type == 'mps':
227
if obj.device.type == 'meta':
231
def _privateuse1_tag(obj):
232
backend_name = torch._C._get_privateuse1_backend_name()
233
if obj.device.type == backend_name:
234
if obj.device.index is None:
237
return backend_name + ':' + str(obj.device.index)
240
def _cpu_deserialize(obj, location):
241
if location == 'cpu':
245
def validate_cuda_device(location):
246
device = torch.cuda._utils._get_device_index(location, True)
248
if not torch.cuda.is_available():
249
raise RuntimeError('Attempting to deserialize object on a CUDA '
250
'device but torch.cuda.is_available() is False. '
251
'If you are running on a CPU-only machine, '
252
'please use torch.load with map_location=torch.device(\'cpu\') '
253
'to map your storages to the CPU.')
254
device_count = torch.cuda.device_count()
255
if device >= device_count:
256
raise RuntimeError('Attempting to deserialize object on CUDA device '
257
f'{device} but torch.cuda.device_count() is {device_count}. Please use '
258
'torch.load with map_location to map your storages '
259
'to an existing device.')
263
def _cuda_deserialize(obj, location):
264
if location.startswith('cuda'):
265
device = validate_cuda_device(location)
266
if getattr(obj, "_torch_load_uninitialized", False):
267
with torch.cuda.device(device):
268
return torch.UntypedStorage(obj.nbytes(), device=torch.device(location))
270
return obj.cuda(device)
273
def validate_hpu_device(location):
274
hpu = getattr(torch, "hpu", None)
275
assert hpu is not None, "HPU device module is not loaded"
276
device = hpu._utils._get_device_index(location, optional=True)
278
if not hpu.is_available():
279
raise RuntimeError('Attempting to deserialize object on a HPU '
280
'device but torch.hpu.is_available() is False. '
281
'If you are running on a CPU-only machine, '
282
'please use torch.load with map_location=torch.device(\'cpu\') '
283
'to map your storages to the CPU.')
284
device_count = hpu.device_count()
285
if device >= device_count:
286
raise RuntimeError('Attempting to deserialize object on HPU device '
287
f'{device} but torch.hpu.device_count() is {device_count}. Please use '
288
'torch.load with map_location to map your storages '
289
'to an existing device.')
293
def _hpu_deserialize(obj, location):
294
if location.startswith('hpu'):
295
hpu = getattr(torch, "hpu", None)
296
assert hpu is not None, "HPU device module is not loaded"
297
device = validate_hpu_device(location)
298
if getattr(obj, "_torch_load_uninitialized", False):
299
with hpu.device(device):
300
return torch.UntypedStorage(obj.nbytes(), device=torch.device(location))
302
return obj.hpu(device)
305
def _mps_deserialize(obj, location):
306
if location.startswith('mps'):
310
def _meta_deserialize(obj, location):
311
if location == 'meta':
312
return torch.UntypedStorage(obj.nbytes(), device='meta')
315
def _validate_privateuse1_device(location, backend_name):
317
Check whether the device index of privateuse1 is valid
319
Register a device_module of privateuse1 by torch._register_device_module.
320
Implement the following methods in device_module like cuda:
321
device_module._utils._get_device_index(location, True),
322
device_module.device_count().
325
location: string of device
326
backend_name: the name of privateuse1, which can be renamed
331
if not hasattr(torch, backend_name):
332
raise RuntimeError(f'The {backend_name.upper()} device module is not registered. '
333
'If you are running on a CPU-only machine, '
334
'please use torch.load with map_location=torch.device(\'cpu\') '
335
'to map your storages to the CPU.')
336
device_module = getattr(torch, backend_name)
337
if hasattr(device_module, '_utils') and hasattr(device_module._utils, '_get_device_index'):
338
device_index = device_module._utils._get_device_index(location, True)
340
device = torch.device(location)
341
device_index = device.index if device.index else 0
342
if hasattr(device_module, 'is_available') and not device_module.is_available():
343
raise RuntimeError(f'Attempting to deserialize object on a {backend_name.upper()} '
344
f'device but torch.{backend_name}.is_available() is False. '
345
'If you are running on a CPU-only machine, '
346
'please use torch.load with map_location=torch.device(\'cpu\') '
347
'to map your storages to the CPU.')
348
if hasattr(device_module, 'device_count'):
349
device_count = device_module.device_count()
350
if device_index >= device_count:
351
raise RuntimeError(f'Attempting to deserialize object on {backend_name.upper()} device '
352
f'{device_index} but torch.{backend_name}.device_count() is {device_count}. '
353
'Please use torch.load with map_location to map your storages '
354
'to an existing device.')
358
def _privateuse1_deserialize(obj, location):
359
backend_name = torch._C._get_privateuse1_backend_name()
360
if location.startswith(backend_name):
361
if not hasattr(obj, backend_name):
362
raise RuntimeError(f'Attempting to load the storages to the {backend_name.upper()} device '
363
f'but torch.storage._StorageBase.{backend_name}() or '
364
f'torch.storage.TypedStorage.{backend_name}() is not generated. '
365
'Please use torch.utils.generate_methods_for_privateuse1_backend '
366
f'to generate storage.{backend_name}() method first.')
367
device_index = _validate_privateuse1_device(location, backend_name)
368
return getattr(obj, backend_name)(device_index)
371
register_package(10, _cpu_tag, _cpu_deserialize)
372
register_package(20, _cuda_tag, _cuda_deserialize)
373
register_package(21, _mps_tag, _mps_deserialize)
374
register_package(22, _meta_tag, _meta_deserialize)
375
register_package(23, _privateuse1_tag, _privateuse1_deserialize)
376
register_package(24, _hpu_tag, _hpu_deserialize)
379
def location_tag(storage: Union[Storage, torch.storage.TypedStorage, torch.UntypedStorage]):
380
for _, tagger, _ in _package_registry:
381
location = tagger(storage)
384
raise RuntimeError("don't know how to determine data location of "
385
+ torch.typename(storage))
388
def default_restore_location(storage, location):
389
for _, _, fn in _package_registry:
390
result = fn(storage, location)
391
if result is not None:
393
raise RuntimeError("don't know how to restore data location of "
394
+ torch.typename(storage) + " (tagged with "
398
def normalize_storage_type(storage_type):
399
return getattr(torch, storage_type.__name__)
402
def storage_to_tensor_type(storage):
403
storage_type = type(storage)
404
module = _import_dotted_name(storage_type.__module__)
405
return getattr(module, storage_type.__name__.replace('Storage', 'Tensor'))
408
def _is_path(name_or_buffer) -> TypeGuard[Union[str, os.PathLike]]:
409
return isinstance(name_or_buffer, (str, os.PathLike))
413
def __init__(self, file_like):
414
self.file_like = file_like
417
return self.file_like
419
def __exit__(self, *args):
423
class _open_file(_opener):
424
def __init__(self, name, mode):
425
super().__init__(open(name, mode))
427
def __exit__(self, *args):
428
self.file_like.close()
431
class _open_buffer_reader(_opener):
432
def __init__(self, buffer):
433
super().__init__(buffer)
434
_check_seekable(buffer)
437
class _open_buffer_writer(_opener):
438
def __exit__(self, *args):
439
self.file_like.flush()
442
def _open_file_like(name_or_buffer, mode):
443
if _is_path(name_or_buffer):
444
return _open_file(name_or_buffer, mode)
447
return _open_buffer_writer(name_or_buffer)
449
return _open_buffer_reader(name_or_buffer)
451
raise RuntimeError(f"Expected 'r' or 'w' in mode but got {mode}")
454
class _open_zipfile_reader(_opener):
455
def __init__(self, name_or_buffer) -> None:
456
super().__init__(torch._C.PyTorchFileReader(name_or_buffer))
459
class _open_zipfile_writer_file(_opener):
460
def __init__(self, name) -> None:
461
self.file_stream = None
462
self.name = str(name)
464
self.name.encode('ascii')
465
except UnicodeEncodeError:
466
# PyTorchFileWriter only supports ascii filename.
467
# For filenames with non-ascii characters, we rely on Python
468
# for writing out the file.
469
self.file_stream = io.FileIO(self.name, mode='w')
470
super().__init__(torch._C.PyTorchFileWriter(self.file_stream))
472
super().__init__(torch._C.PyTorchFileWriter(self.name))
474
def __exit__(self, *args) -> None:
475
self.file_like.write_end_of_file()
476
if self.file_stream is not None:
477
self.file_stream.close()
480
class _open_zipfile_writer_buffer(_opener):
481
def __init__(self, buffer) -> None:
482
if not callable(getattr(buffer, "write", None)):
483
msg = f"Buffer of {str(type(buffer)).strip('<>')} has no callable attribute 'write'"
484
if not hasattr(buffer, "write"):
485
raise AttributeError(msg)
488
super().__init__(torch._C.PyTorchFileWriter(buffer))
490
def __exit__(self, *args) -> None:
491
self.file_like.write_end_of_file()
495
def _open_zipfile_writer(name_or_buffer):
496
container: Type[_opener]
497
if _is_path(name_or_buffer):
498
container = _open_zipfile_writer_file
500
container = _open_zipfile_writer_buffer
501
return container(name_or_buffer)
504
def _is_compressed_file(f) -> bool:
505
compress_modules = ['gzip']
507
return f.__module__ in compress_modules
508
except AttributeError:
512
def _should_read_directly(f):
514
Checks if f is a file that should be read directly. It should be read
515
directly if it is backed by a real file (has a fileno) and is not a
516
a compressed file (e.g. gzip)
518
if _is_compressed_file(f):
521
return f.fileno() >= 0
522
except io.UnsupportedOperation:
524
except AttributeError:
528
def _check_seekable(f) -> bool:
530
def raise_err_msg(patterns, e):
533
msg = (str(e) + ". You can only torch.load from a file that is seekable."
534
+ " Please pre-load the data into a buffer like io.BytesIO and"
535
+ " try to load from it instead.")
542
except (io.UnsupportedOperation, AttributeError) as e:
543
raise_err_msg(["seek", "tell"], e)
547
def _check_dill_version(pickle_module) -> None:
548
'''Checks if using dill as the pickle module, and if so, checks if it is the correct version.
549
If dill version is lower than 0.3.1, a ValueError is raised.
552
pickle_module: module used for pickling metadata and objects
555
if pickle_module is not None and pickle_module.__name__ == 'dill':
556
required_dill_version = (0, 3, 1)
557
if not check_module_version_greater_or_equal(pickle_module, required_dill_version, False):
559
"'torch' supports dill >= {}, but you have dill {}."
560
" Please upgrade dill or switch to 'pickle'"
562
'.'.join([str(num) for num in required_dill_version]),
563
pickle_module.__version__
567
def _check_save_filelike(f):
568
if not _is_path(f) and not hasattr(f, 'write'):
569
raise AttributeError(
570
"expected 'f' to be string, path, or a file-like object with "
571
"a 'write' attribute")
577
pickle_module: Any = pickle,
578
pickle_protocol: int = DEFAULT_PROTOCOL,
579
_use_new_zipfile_serialization: bool = True,
580
_disable_byteorder_record: bool = False
582
# Reference: https://github.com/pytorch/pytorch/issues/54354
583
# The first line of this docstring overrides the one Sphinx generates for the
584
# documentation. We need it so that Sphinx doesn't leak `pickle`s path from
585
# the build environment (e.g. `<module 'pickle' from '/leaked/path').
587
"""save(obj, f, pickle_module=pickle, pickle_protocol=DEFAULT_PROTOCOL, _use_new_zipfile_serialization=True)
589
Saves an object to a disk file.
591
See also: :ref:`saving-loading-tensors`
595
f: a file-like object (has to implement write and flush) or a string or
596
os.PathLike object containing a file name
597
pickle_module: module used for pickling metadata and objects
598
pickle_protocol: can be specified to override the default protocol
601
A common PyTorch convention is to save tensors using .pt file extension.
604
PyTorch preserves storage sharing across serialization. See
605
:ref:`preserve-storage-sharing` for more details.
608
The 1.6 release of PyTorch switched ``torch.save`` to use a new
609
zipfile-based file format. ``torch.load`` still retains the ability to
610
load files in the old format. If for any reason you want ``torch.save``
611
to use the old format, pass the kwarg ``_use_new_zipfile_serialization=False``.
614
>>> # xdoctest: +SKIP("makes cwd dirty")
616
>>> x = torch.tensor([0, 1, 2, 3, 4])
617
>>> torch.save(x, 'tensor.pt')
618
>>> # Save to io.BytesIO buffer
619
>>> buffer = io.BytesIO()
620
>>> torch.save(x, buffer)
622
torch._C._log_api_usage_once("torch.save")
623
_check_dill_version(pickle_module)
624
_check_save_filelike(f)
626
if _use_new_zipfile_serialization:
627
with _open_zipfile_writer(f) as opened_zipfile:
628
_save(obj, opened_zipfile, pickle_module, pickle_protocol, _disable_byteorder_record)
631
with _open_file_like(f, 'wb') as opened_file:
632
_legacy_save(obj, opened_file, pickle_module, pickle_protocol)
635
def _legacy_save(obj, f, pickle_module, pickle_protocol) -> None:
636
import torch.nn as nn
637
serialized_container_types = {}
638
serialized_storages = {}
640
# Since loading storages that view the same data with different dtypes is
641
# not supported, we need to keep track of the dtype associated with each
642
# storage data_ptr and throw an error if the dtype is ever different.
643
# TODO: This feature could be added in the future
644
storage_dtypes: Dict[int, torch.dtype] = {}
646
def persistent_id(obj: Any) -> Optional[Tuple]:
647
# FIXME: the docs say that persistent_id should only return a string
648
# but torch store returns tuples. This works only in the binary protocol
650
# https://docs.python.org/2/library/pickle.html#pickling-and-unpickling-external-objects
651
# https://github.com/python/cpython/blob/master/Lib/pickle.py#L527-L537
652
if isinstance(obj, type) and issubclass(obj, nn.Module):
653
if obj in serialized_container_types:
655
serialized_container_types[obj] = True
656
source_file = source = None
658
source_lines, _, source_file = get_source_lines_and_file(obj)
659
source = ''.join(source_lines)
660
except Exception: # saving the source is optional, so we can ignore any errors
661
warnings.warn("Couldn't retrieve source code for container of "
662
"type " + obj.__name__ + ". It won't be checked "
663
"for correctness upon loading.")
664
return ('module', obj, source_file, source)
666
if isinstance(obj, torch.storage.TypedStorage) or torch.is_storage(obj):
667
storage: torch.UntypedStorage
669
if isinstance(obj, torch.storage.TypedStorage):
670
# TODO: Once we decide to break serialization FC, this case
672
storage = obj._untyped_storage
673
storage_dtype = obj.dtype
674
storage_type_str = obj._pickle_storage_type()
675
storage_type = getattr(torch, storage_type_str)
677
storage_numel = obj._size()
679
elif isinstance(obj, torch.UntypedStorage):
681
storage_dtype = torch.uint8
682
storage_type = normalize_storage_type(type(obj))
684
storage_numel = storage.nbytes()
686
raise TypeError(f'type not recognized: {type(obj)}')
688
# If storage is allocated, ensure that any other saved storages
689
# pointing to the same data all have the same dtype. If storage is
690
# not allocated, don't perform this check
691
if storage.data_ptr() != 0:
692
if storage.data_ptr() in storage_dtypes:
693
if storage_dtype != storage_dtypes[storage.data_ptr()]:
695
'Cannot save multiple tensors or storages that '
696
'view the same data as different types')
698
storage_dtypes[storage.data_ptr()] = storage_dtype
700
view_metadata: Optional[Tuple[str, int, int]]
702
# Offset is always 0, but we keep it for backwards compatibility
703
# with the old serialization format (which supported storage views)
705
storage_key = str(storage._cdata)
706
location = location_tag(storage)
708
# TODO: There's an issue here with FC. It might be impossible to
709
# solve, but it's worth noting. Imagine we save a list `[storage,
710
# tensor]`, where `tensor.storage()` is the same as `storage`, and
711
# `tensor.element_size() > 1`. Let's say that `tensor.dtype ==
712
# torch.float`. The storage will be serialized with element size
713
# of 1, since we're choosing to serialize the first occurance of
714
# a duplicate storage. Since this legacy serialization format saves
715
# the numel of the storage, rather than nbytes directly, we'll be
716
# effectively saving nbytes in this case. We'll be able to load it
717
# and the tensor back up with no problems in _this_ and future
718
# versions of pytorch, but in older versions, here's the problem:
719
# the storage will be loaded up as a UntypedStorage, and then the
720
# FloatTensor will loaded and the UntypedStorage will be assigned to
721
# it. Since the storage dtype does not match the tensor dtype, this
722
# will cause an error. If we reverse the list, like `[tensor,
723
# storage]`, then we will save the `tensor.storage()` as a faked
724
# `FloatStorage`, and the saved size will be the correct
725
# dtype-specific numel count that old versions expect. `tensor`
726
# will be able to load up properly in old versions, pointing to
727
# a FloatStorage. However, `storage` is still being translated to
728
# a UntypedStorage, and it will try to resolve to the same
729
# FloatStorage that `tensor` contains. This will also cause an
730
# error. It doesn't seem like there's any way around this.
731
# Probably, we just cannot maintain FC for the legacy format if the
732
# saved list contains both a tensor and a storage that point to the
733
# same data. We should still be able to maintain FC for lists of
734
# just tensors, as long as all views share the same dtype as the
735
# tensor they are viewing.
737
if storage_key not in serialized_storages:
738
serialized_storages[storage_key] = (storage, dtype)
739
is_view = storage._cdata != storage._cdata
741
view_metadata = (str(storage._cdata), offset, storage.nbytes())
755
protocol_version=PROTOCOL_VERSION,
756
little_endian=sys.byteorder == 'little',
764
pickle_module.dump(MAGIC_NUMBER, f, protocol=pickle_protocol)
765
pickle_module.dump(PROTOCOL_VERSION, f, protocol=pickle_protocol)
766
pickle_module.dump(sys_info, f, protocol=pickle_protocol)
767
pickler = pickle_module.Pickler(f, protocol=pickle_protocol)
768
pickler.persistent_id = persistent_id
771
serialized_storage_keys = sorted(serialized_storages.keys())
772
pickle_module.dump(serialized_storage_keys, f, protocol=pickle_protocol)
774
for key in serialized_storage_keys:
775
storage, dtype = serialized_storages[key]
776
storage._write_file(f, _should_read_directly(f), True, torch._utils._element_size(dtype))
779
def _save(obj, zip_file, pickle_module, pickle_protocol, _disable_byteorder_record):
780
serialized_storages = {}
781
id_map: Dict[int, str] = {}
783
# Since loading storages that view the same data with different dtypes is
784
# not supported, we need to keep track of the dtype associated with each
785
# storage data_ptr and throw an error if the dtype is ever different.
786
# TODO: This feature could be added in the future
787
storage_dtypes: Dict[int, torch.dtype] = {}
789
def persistent_id(obj):
790
# FIXME: the docs say that persistent_id should only return a string
791
# but torch store returns tuples. This works only in the binary protocol
793
# https://docs.python.org/2/library/pickle.html#pickling-and-unpickling-external-objects
794
# https://github.com/python/cpython/blob/master/Lib/pickle.py#L527-L537
795
if isinstance(obj, torch.storage.TypedStorage) or torch.is_storage(obj):
797
if isinstance(obj, torch.storage.TypedStorage):
798
# TODO: Once we decide to break serialization FC, this case
800
storage = obj._untyped_storage
801
storage_dtype = obj.dtype
802
storage_type_str = obj._pickle_storage_type()
803
storage_type = getattr(torch, storage_type_str)
804
storage_numel = obj._size()
808
storage_dtype = torch.uint8
809
storage_type = normalize_storage_type(type(obj))
810
storage_numel = storage.nbytes()
812
# If storage is allocated, ensure that any other saved storages
813
# pointing to the same data all have the same dtype. If storage is
814
# not allocated, don't perform this check
815
if storage.data_ptr() != 0:
816
if storage.data_ptr() in storage_dtypes:
817
if storage_dtype != storage_dtypes[storage.data_ptr()]:
819
'Cannot save multiple tensors or storages that '
820
'view the same data as different types')
822
storage_dtypes[storage.data_ptr()] = storage_dtype
824
storage_key = id_map.setdefault(storage._cdata, str(len(id_map)))
825
location = location_tag(storage)
826
serialized_storages[storage_key] = storage
836
# Write the pickle data for `obj`
837
data_buf = io.BytesIO()
838
pickler = pickle_module.Pickler(data_buf, protocol=pickle_protocol)
839
pickler.persistent_id = persistent_id
841
data_value = data_buf.getvalue()
842
zip_file.write_record('data.pkl', data_value, len(data_value))
844
# Write byte order marker
845
if not _disable_byteorder_record:
846
if sys.byteorder not in ['little', 'big']:
847
raise ValueError('Unknown endianness type: ' + sys.byteorder)
849
zip_file.write_record('byteorder', sys.byteorder, len(sys.byteorder))
851
# Write each tensor to a file named tensor/the_tensor_key in the zip archive
852
for key in sorted(serialized_storages.keys()):
854
storage = serialized_storages[key]
855
# given that we copy things around anyway, we might use storage.cpu()
856
# this means to that to get tensors serialized, you need to implement
857
# .cpu() on the underlying Storage
858
if storage.device.type != 'cpu':
859
storage = storage.cpu()
860
# Now that it is on the CPU we can directly copy it into the zip file
861
num_bytes = storage.nbytes()
862
zip_file.write_record(name, storage, num_bytes)
867
map_location: MAP_LOCATION = None,
868
pickle_module: Any = None,
870
weights_only: bool = False,
871
mmap: Optional[bool] = None,
872
**pickle_load_args: Any
874
# Reference: https://github.com/pytorch/pytorch/issues/54354
875
# The first line of this docstring overrides the one Sphinx generates for the
876
# documentation. We need it so that Sphinx doesn't leak `pickle`s path from
877
# the build environment (e.g. `<module 'pickle' from '/leaked/path').
879
"""load(f, map_location=None, pickle_module=pickle, *, weights_only=False, mmap=None, **pickle_load_args)
881
Loads an object saved with :func:`torch.save` from a file.
883
:func:`torch.load` uses Python's unpickling facilities but treats storages,
884
which underlie tensors, specially. They are first deserialized on the
885
CPU and are then moved to the device they were saved from. If this fails
886
(e.g. because the run time system doesn't have certain devices), an exception
887
is raised. However, storages can be dynamically remapped to an alternative
888
set of devices using the :attr:`map_location` argument.
890
If :attr:`map_location` is a callable, it will be called once for each serialized
891
storage with two arguments: storage and location. The storage argument
892
will be the initial deserialization of the storage, residing on the CPU.
893
Each serialized storage has a location tag associated with it which
894
identifies the device it was saved from, and this tag is the second
895
argument passed to :attr:`map_location`. The builtin location tags are ``'cpu'``
896
for CPU tensors and ``'cuda:device_id'`` (e.g. ``'cuda:2'``) for CUDA tensors.
897
:attr:`map_location` should return either ``None`` or a storage. If
898
:attr:`map_location` returns a storage, it will be used as the final deserialized
899
object, already moved to the right device. Otherwise, :func:`torch.load` will
900
fall back to the default behavior, as if :attr:`map_location` wasn't specified.
902
If :attr:`map_location` is a :class:`torch.device` object or a string containing
903
a device tag, it indicates the location where all tensors should be loaded.
905
Otherwise, if :attr:`map_location` is a dict, it will be used to remap location tags
906
appearing in the file (keys), to ones that specify where to put the
909
User extensions can register their own location tags and tagging and
910
deserialization methods using :func:`torch.serialization.register_package`.
913
f: a file-like object (has to implement :meth:`read`, :meth:`readline`, :meth:`tell`, and :meth:`seek`),
914
or a string or os.PathLike object containing a file name
915
map_location: a function, :class:`torch.device`, string or a dict specifying how to remap storage
917
pickle_module: module used for unpickling metadata and objects (has to
918
match the :attr:`pickle_module` used to serialize file)
919
weights_only: Indicates whether unpickler should be restricted to
920
loading only tensors, primitive types and dictionaries
921
mmap: Indicates whether the file should be mmaped rather than loading all the storages into memory.
922
Typically, tensor storages in the file will first be moved from disk to CPU memory, after which they
923
are moved to the location that they were tagged with when saving, or specified by ``map_location``. This
924
second step is a no-op if the final location is CPU. When the ``mmap`` flag is set, instead of copying the
925
tensor storages from disk to CPU memory in the first step, ``f`` is mmaped.
926
pickle_load_args: (Python 3 only) optional keyword arguments passed over to
927
:func:`pickle_module.load` and :func:`pickle_module.Unpickler`, e.g.,
931
:func:`torch.load()` unless `weights_only` parameter is set to `True`,
932
uses ``pickle`` module implicitly, which is known to be insecure.
933
It is possible to construct malicious pickle data which will execute arbitrary code
934
during unpickling. Never load data that could have come from an untrusted
935
source in an unsafe mode, or that could have been tampered with. **Only load data you trust**.
938
When you call :func:`torch.load()` on a file which contains GPU tensors, those tensors
939
will be loaded to GPU by default. You can call ``torch.load(.., map_location='cpu')``
940
and then :meth:`load_state_dict` to avoid GPU RAM surge when loading a model checkpoint.
943
By default, we decode byte strings as ``utf-8``. This is to avoid a common error
944
case ``UnicodeDecodeError: 'ascii' codec can't decode byte 0x...``
945
when loading files saved by Python 2 in Python 3. If this default
946
is incorrect, you may use an extra :attr:`encoding` keyword argument to specify how
947
these objects should be loaded, e.g., :attr:`encoding='latin1'` decodes them
948
to strings using ``latin1`` encoding, and :attr:`encoding='bytes'` keeps them
949
as byte arrays which can be decoded later with ``byte_array.decode(...)``.
952
>>> # xdoctest: +SKIP("undefined filepaths")
953
>>> torch.load('tensors.pt', weights_only=True)
954
# Load all tensors onto the CPU
955
>>> torch.load('tensors.pt', map_location=torch.device('cpu'), weights_only=True)
956
# Load all tensors onto the CPU, using a function
957
>>> torch.load('tensors.pt', map_location=lambda storage, loc: storage, weights_only=True)
958
# Load all tensors onto GPU 1
959
>>> torch.load('tensors.pt', map_location=lambda storage, loc: storage.cuda(1), weights_only=True)
960
# Map tensors from GPU 1 to GPU 0
961
>>> torch.load('tensors.pt', map_location={'cuda:1': 'cuda:0'}, weights_only=True)
962
# Load tensor from io.BytesIO object
963
# Loading from a buffer setting weights_only=False, warning this can be unsafe
964
>>> with open('tensor.pt', 'rb') as f:
965
... buffer = io.BytesIO(f.read())
966
>>> torch.load(buffer, weights_only=False)
967
# Load a module with 'ascii' encoding for unpickling
968
# Loading from a module setting weights_only=False, warning this can be unsafe
969
>>> torch.load('module.pt', encoding='ascii', weights_only=False)
971
torch._C._log_api_usage_once("torch.load")
973
"Weights only load failed. Re-running `torch.load` with `weights_only` set to `False`"
974
" will likely succeed, but it can result in arbitrary code execution."
975
"Do it only if you get the file from a trusted source. WeightsUnpickler error: "
977
# Add ability to force safe only weight loads via environment variable
978
if os.getenv("TORCH_FORCE_WEIGHTS_ONLY_LOAD", "0").lower() in ['1', 'y', 'yes', 'true']:
982
if pickle_module is not None:
983
raise RuntimeError("Can not safely load weights when explicit pickle_module is specified")
985
if pickle_module is None:
986
pickle_module = pickle
988
# make flipping default BC-compatible
992
_check_dill_version(pickle_module)
994
if 'encoding' not in pickle_load_args.keys():
995
pickle_load_args['encoding'] = 'utf-8'
997
with _open_file_like(f, 'rb') as opened_file:
998
if _is_zipfile(opened_file):
999
# The zipfile reader is going to advance the current file position.
1000
# If we want to actually tail call to torch.jit.load, we need to
1001
# reset back to the original position.
1002
orig_position = opened_file.tell()
1003
overall_storage = None
1004
with _open_zipfile_reader(opened_file) as opened_zipfile:
1005
if _is_torchscript_zip(opened_zipfile):
1006
warnings.warn("'torch.load' received a zip file that looks like a TorchScript archive"
1007
" dispatching to 'torch.jit.load' (call 'torch.jit.load' directly to"
1008
" silence this warning)", UserWarning)
1009
opened_file.seek(orig_position)
1010
return torch.jit.load(opened_file, map_location=map_location)
1013
raise ValueError("f must be a file path in order to use the mmap argument")
1014
size = os.path.getsize(f)
1015
overall_storage = torch.UntypedStorage.from_file(os.fspath(f), False, size)
1018
return _load(opened_zipfile,
1020
_weights_only_unpickler,
1021
overall_storage=overall_storage,
1023
except RuntimeError as e:
1024
raise pickle.UnpicklingError(UNSAFE_MESSAGE + str(e)) from None
1025
return _load(opened_zipfile,
1028
overall_storage=overall_storage,
1031
raise RuntimeError("mmap can only be used with files saved with "
1032
"`torch.save(_use_new_zipfile_serialization=True), "
1033
"please torch.save your checkpoint with this option in order to use mmap.")
1036
return _legacy_load(opened_file, map_location, _weights_only_unpickler, **pickle_load_args)
1037
except RuntimeError as e:
1038
raise pickle.UnpicklingError(UNSAFE_MESSAGE + str(e)) from None
1039
return _legacy_load(opened_file, map_location, pickle_module, **pickle_load_args)
1042
# Register pickling support for layout instances such as
1043
# torch.sparse_coo, etc
1044
def _get_layout(name):
1045
"""Get layout extension object from its string representation.
1047
cache = _get_layout.cache # type: ignore[attr-defined]
1049
for v in torch.__dict__.values():
1050
if isinstance(v, torch.layout):
1054
# There are yet not good way to type annotate function attributes https://github.com/python/mypy/issues/2087
1055
_get_layout.cache = {} # type: ignore[attr-defined]
1056
copyreg.pickle(torch.layout, lambda obj: (_get_layout, (str(obj),)))
1059
def _legacy_load(f, map_location, pickle_module, **pickle_load_args):
1060
deserialized_objects: Dict[int, Any] = {}
1062
restore_location = _get_restore_location(map_location)
1064
class UnpicklerWrapper(pickle_module.Unpickler): # type: ignore[name-defined]
1066
def find_class(self, mod_name, name):
1067
if type(name) is str and 'Storage' in name:
1069
return StorageType(name)
1072
return super().find_class(mod_name, name)
1074
def _check_container_source(container_type, source_file, original_source):
1076
current_source = ''.join(get_source_lines_and_file(container_type)[0])
1077
except Exception: # saving the source is optional, so we can ignore any errors
1078
warnings.warn("Couldn't retrieve source code for container of "
1079
"type " + container_type.__name__ + ". It won't be checked "
1080
"for correctness upon loading.")
1082
if original_source != current_source:
1083
if container_type.dump_patches:
1084
file_name = container_type.__name__ + '.patch'
1085
diff = difflib.unified_diff(current_source.split('\n'),
1086
original_source.split('\n'),
1088
source_file, lineterm="")
1089
lines = '\n'.join(diff)
1091
with open(file_name, 'a+') as f:
1092
file_size = f.seek(0, 2)
1096
elif file_size != len(lines) or f.read() != lines:
1098
msg = ("Saved a reverse patch to " + file_name + ". "
1099
"Run `patch -p0 < " + file_name + "` to revert your "
1102
msg = ("Tried to save a patch, but couldn't create a "
1103
"writable file " + file_name + ". Make sure it "
1104
"doesn't exist and your working directory is "
1107
msg = ("you can retrieve the original source code by "
1108
"accessing the object's source attribute or set "
1109
"`torch.nn.Module.dump_patches = True` and use the "
1110
"patch tool to revert the changes.")
1111
msg = f"source code of class '{torch.typename(container_type)}' has changed. {msg}"
1112
warnings.warn(msg, SourceChangeWarning)
1115
deserialized_objects: Dict[int, Any] = {}
1117
def persistent_load(saved_id):
1118
if isinstance(saved_id, tuple):
1119
# Ignore containers that don't have any sources saved
1120
if all(saved_id[1:]):
1121
_check_container_source(*saved_id)
1123
return deserialized_objects[int(saved_id)]
1125
with closing(tarfile.open(fileobj=f, mode='r:', format=tarfile.PAX_FORMAT)) as tar, \
1126
mkdtemp() as tmpdir:
1128
tar.extract('storages', path=tmpdir)
1129
with open(os.path.join(tmpdir, 'storages'), 'rb', 0) as f:
1130
num_storages = pickle_module.load(f, **pickle_load_args)
1131
for i in range(num_storages):
1132
args = pickle_module.load(f, **pickle_load_args)
1133
key, location, storage_type = args
1134
dtype = storage_type._dtype
1135
obj = cast(Storage, torch.UntypedStorage)._new_with_file(f, torch._utils._element_size(dtype))
1136
obj = restore_location(obj, location)
1137
# TODO: Once we decide to break serialization FC, we can
1138
# stop wrapping with TypedStorage
1139
deserialized_objects[key] = torch.storage.TypedStorage(
1144
storage_views = pickle_module.load(f, **pickle_load_args)
1145
for target_cdata, root_cdata, offset, numel in storage_views:
1146
root = deserialized_objects[root_cdata]
1147
element_size = torch._utils._element_size(root.dtype)
1148
offset_bytes = offset * element_size
1149
# TODO: Once we decide to break serialization FC, we can
1150
# stop wrapping with TypedStorage
1151
deserialized_objects[target_cdata] = torch.storage.TypedStorage(
1152
wrap_storage=root._untyped_storage[offset_bytes:offset_bytes + numel * element_size],
1156
tar.extract('tensors', path=tmpdir)
1157
with open(os.path.join(tmpdir, 'tensors'), 'rb', 0) as f:
1158
num_tensors = pickle_module.load(f, **pickle_load_args)
1159
for _ in range(num_tensors):
1160
args = pickle_module.load(f, **pickle_load_args)
1161
key, storage_id, original_tensor_type = args
1162
storage = deserialized_objects[storage_id]
1163
ndim, = struct.unpack('<i', f.read(4))
1164
# skip next 4 bytes; legacy encoding treated ndim as 8 bytes
1166
numel = struct.unpack(f'<{ndim}q', f.read(8 * ndim))
1167
stride = struct.unpack(f'<{ndim}q', f.read(8 * ndim))
1168
storage_offset, = struct.unpack('<q', f.read(8))
1169
tensor = torch.empty((0,), dtype=storage.dtype).set_(
1170
storage._untyped_storage, storage_offset, numel, stride)
1171
deserialized_objects[key] = tensor
1173
pickle_file = tar.extractfile('pickle')
1174
unpickler = UnpicklerWrapper(pickle_file, **pickle_load_args)
1175
unpickler.persistent_load = persistent_load
1176
result = unpickler.load()
1179
deserialized_objects = {}
1181
def persistent_load(saved_id):
1182
assert isinstance(saved_id, tuple)
1183
typename = _maybe_decode_ascii(saved_id[0])
1186
if typename == 'module':
1187
# Ignore containers that don't have any sources saved
1189
_check_container_source(*data)
1191
elif typename == 'storage':
1192
storage_type, root_key, location, numel, view_metadata = data
1193
location = _maybe_decode_ascii(location)
1194
dtype = storage_type.dtype
1196
nbytes = numel * torch._utils._element_size(dtype)
1198
if root_key not in deserialized_objects:
1199
if torch._guards.active_fake_mode() is not None:
1200
obj = cast(Storage, torch.UntypedStorage(nbytes, device='meta'))
1202
obj = cast(Storage, torch.UntypedStorage(nbytes))
1203
obj._torch_load_uninitialized = True
1204
obj = restore_location(obj, location)
1205
# TODO: Once we decide to break serialization FC, we can
1206
# stop wrapping with TypedStorage
1207
typed_storage = torch.storage.TypedStorage(
1211
deserialized_objects[root_key] = typed_storage
1213
typed_storage = deserialized_objects[root_key]
1214
if typed_storage._data_ptr() == 0:
1215
typed_storage = torch.storage.TypedStorage(
1216
device=typed_storage._untyped_storage.device,
1220
if view_metadata is not None:
1221
view_key, offset, view_size = view_metadata
1222
offset_bytes = offset * torch._utils._element_size(dtype)
1223
view_size_bytes = view_size * torch._utils._element_size(dtype)
1224
if view_key not in deserialized_objects:
1225
# TODO: Once we decide to break serialization FC, we can
1226
# stop wrapping with TypedStorage
1227
deserialized_objects[view_key] = torch.storage.TypedStorage(
1228
wrap_storage=typed_storage._untyped_storage[offset_bytes:offset_bytes + view_size_bytes],
1231
res = deserialized_objects[view_key]
1237
raise RuntimeError(f"Unknown saved id type: {saved_id[0]}")
1240
f_should_read_directly = _should_read_directly(f)
1242
if f_should_read_directly and f.tell() == 0:
1243
# legacy_load requires that f has fileno()
1244
# only if offset is zero we can attempt the legacy tar file loader
1246
return legacy_load(f)
1247
except tarfile.TarError:
1249
# .zip is used for torch.jit.save and will throw an un-pickling error here
1251
f"{f.name} is a zip archive (did you mean to use torch.jit.load()?)") from None
1252
# if not a tarfile, reset file offset and proceed
1255
if not hasattr(f, 'readinto') and (3, 8, 0) <= sys.version_info < (3, 8, 2):
1257
"torch.load does not work with file-like objects that do not implement readinto on Python 3.8.0 and 3.8.1. "
1258
f"Received object of type \"{type(f)}\". Please update to Python 3.8.2 or newer to restore this "
1261
magic_number = pickle_module.load(f, **pickle_load_args)
1262
if magic_number != MAGIC_NUMBER:
1263
raise RuntimeError("Invalid magic number; corrupt file?")
1264
protocol_version = pickle_module.load(f, **pickle_load_args)
1265
if protocol_version != PROTOCOL_VERSION:
1266
raise RuntimeError(f"Invalid protocol version: {protocol_version}")
1268
_sys_info = pickle_module.load(f, **pickle_load_args)
1269
unpickler = UnpicklerWrapper(f, **pickle_load_args)
1270
unpickler.persistent_load = persistent_load
1271
result = unpickler.load()
1273
deserialized_storage_keys = pickle_module.load(f, **pickle_load_args)
1275
if torch._guards.active_fake_mode() is None:
1276
offset = f.tell() if f_should_read_directly else None
1277
for key in deserialized_storage_keys:
1278
assert key in deserialized_objects
1279
typed_storage = deserialized_objects[key]
1280
typed_storage._untyped_storage._set_from_file(
1281
f, offset, f_should_read_directly,
1282
torch._utils._element_size(typed_storage.dtype))
1283
if offset is not None:
1286
torch._utils._validate_loaded_sparse_tensors()
1291
def _maybe_decode_ascii(bytes_str: Union[bytes, str]) -> str:
1292
# When using encoding='bytes' in Py3, some **internal** keys stored as
1293
# strings in Py2 are loaded as bytes. This function decodes them with
1294
# ascii encoding, one that Py3 uses by default.
1296
# NOTE: This should only be used on internal keys (e.g., `typename` and
1297
# `location` in `persistent_load` below!
1298
if isinstance(bytes_str, bytes):
1299
return bytes_str.decode('ascii')
1303
def _get_restore_location(map_location):
1304
if map_location is None:
1305
restore_location = default_restore_location
1306
elif isinstance(map_location, dict):
1307
def restore_location(storage, location):
1308
location = map_location.get(location, location)
1309
return default_restore_location(storage, location)
1310
elif isinstance(map_location, (str, bytes)):
1311
def restore_location(storage, location):
1312
return default_restore_location(storage, map_location)
1313
elif isinstance(map_location, torch.device):
1314
def restore_location(storage, location):
1315
return default_restore_location(storage, str(map_location))
1317
def restore_location(storage, location):
1318
result = map_location(storage, location)
1320
result = default_restore_location(storage, location)
1322
return restore_location
1326
def __init__(self, name):
1327
self._dtype = _get_dtype_from_pickle_storage_type(name)
1334
return f'StorageType(dtype={self.dtype})'
1337
def _load(zip_file, map_location, pickle_module, pickle_file='data.pkl', overall_storage=None, **pickle_load_args):
1338
restore_location = _get_restore_location(map_location)
1340
loaded_storages = {}
1342
# check if byteswapping is needed
1343
byteordername = 'byteorder'
1344
byteorderdata = None
1345
if zip_file.has_record(byteordername):
1346
byteorderdata = zip_file.get_record(byteordername)
1347
if byteorderdata not in [b'little', b'big']:
1348
raise ValueError('Unknown endianness type: ' + byteorderdata.decode())
1349
elif get_default_load_endianness() == LoadEndianness.LITTLE or \
1350
get_default_load_endianness() is None:
1351
byteorderdata = b'little'
1352
elif get_default_load_endianness() == LoadEndianness.BIG:
1353
byteorderdata = b'big'
1354
elif get_default_load_endianness() == LoadEndianness.NATIVE:
1357
raise ValueError('Invalid load endianness type')
1359
if not zip_file.has_record(byteordername) and \
1360
get_default_load_endianness() is None and \
1361
sys.byteorder == 'big':
1362
# Default behaviour was changed
1363
# See https://github.com/pytorch/pytorch/issues/101688
1364
warnings.warn("The default load endianness for checkpoints without a byteorder mark "
1365
"on big endian machines was changed from 'native' to 'little' endian, "
1366
"to avoid this behavior please use "
1367
"torch.serialization.set_default_load_endianness to set "
1368
"the desired default load endianness",
1371
def load_tensor(dtype, numel, key, location):
1372
name = f'data/{key}'
1373
if torch._guards.detect_fake_mode(None) is not None:
1374
nbytes = numel * torch._utils._element_size(dtype)
1375
storage = torch.UntypedStorage(nbytes, device='meta')
1376
elif overall_storage is not None:
1377
storage_offset = zip_file.get_record_offset(name)
1378
storage = overall_storage[storage_offset:storage_offset + numel]
1380
storage = zip_file.get_storage_from_record(name, numel, torch.UntypedStorage)._typed_storage()._untyped_storage
1381
# swap here if byteswapping is needed
1382
if byteorderdata is not None:
1383
if byteorderdata.decode() != sys.byteorder:
1384
storage.byteswap(dtype)
1386
# TODO: Once we decide to break serialization FC, we can
1387
# stop wrapping with TypedStorage
1388
typed_storage = torch.storage.TypedStorage(
1389
wrap_storage=restore_location(storage, location),
1393
if typed_storage._data_ptr() != 0:
1394
loaded_storages[key] = typed_storage
1396
return typed_storage
1398
def persistent_load(saved_id):
1399
assert isinstance(saved_id, tuple)
1400
typename = _maybe_decode_ascii(saved_id[0])
1403
assert typename == 'storage', \
1404
f"Unknown typename for persistent_load, expected 'storage' but got '{typename}'"
1405
storage_type, key, location, numel = data
1406
if storage_type is torch.UntypedStorage:
1409
dtype = storage_type.dtype
1411
if key in loaded_storages:
1412
typed_storage = loaded_storages[key]
1414
nbytes = numel * torch._utils._element_size(dtype)
1415
typed_storage = load_tensor(dtype, nbytes, key, _maybe_decode_ascii(location))
1417
return typed_storage
1419
load_module_mapping: Dict[str, str] = {
1420
# See https://github.com/pytorch/pytorch/pull/51633
1421
'torch.tensor': 'torch._tensor'
1424
# Need to subclass Unpickler instead of directly monkey-patching the find_class method
1425
# because it's marked readonly in pickle.
1426
# The type: ignore is because mypy can't statically determine the type of this class.
1427
class UnpicklerWrapper(pickle_module.Unpickler): # type: ignore[name-defined]
1428
# from https://stackoverflow.com/questions/13398462/unpickling-python-objects-with-a-changed-module-path/13405732
1429
# Lets us override the imports that pickle uses when unpickling an object.
1430
# This is useful for maintaining BC if we change a module path that tensor instantiation relies on.
1431
def find_class(self, mod_name, name):
1432
if type(name) is str and 'Storage' in name:
1434
return StorageType(name)
1437
mod_name = load_module_mapping.get(mod_name, mod_name)
1438
return super().find_class(mod_name, name)
1440
# Load the data (which may in turn use `persistent_load` to load tensors)
1441
data_file = io.BytesIO(zip_file.get_record(pickle_file))
1443
unpickler = UnpicklerWrapper(data_file, **pickle_load_args)
1444
unpickler.persistent_load = persistent_load
1445
result = unpickler.load()
1447
torch._utils._validate_loaded_sparse_tensors()
1448
torch._C._log_api_usage_metadata(
1449
"torch.load.metadata", {"serialization_id": zip_file.serialization_id()}
1454
def _is_torchscript_zip(zip_file):
1455
return 'constants.pkl' in zip_file.get_all_records()