pytorch

Форк
0
/
serialization.py 
1455 строк · 61.1 Кб
1
import difflib
2
import os
3
import io
4
import shutil
5
import struct
6
import sys
7
import torch
8
import tarfile
9
import tempfile
10
import warnings
11
from contextlib import closing, contextmanager
12
from enum import Enum
13
from ._utils import _import_dotted_name
14
from torch._sources import get_source_lines_and_file
15
from torch.types import Storage
16
from torch.storage import _get_dtype_from_pickle_storage_type
17
from typing import Any, BinaryIO, Callable, cast, Dict, Optional, Type, Tuple, Union, IO, List
18
from typing_extensions import TypeAlias, TypeGuard  # Python 3.10+
19
import copyreg
20
import pickle
21
import torch._weights_only_unpickler as _weights_only_unpickler
22

23
DEFAULT_PROTOCOL = 2
24

25
LONG_SIZE = struct.Struct('=l').size
26
INT_SIZE = struct.Struct('=i').size
27
SHORT_SIZE = struct.Struct('=h').size
28

29
MAGIC_NUMBER = 0x1950a86a20f9469cfc6c
30
PROTOCOL_VERSION = 1001
31
STORAGE_KEY_SEPARATOR = ','
32

33
FILE_LIKE: TypeAlias = Union[str, os.PathLike, BinaryIO, IO[bytes]]
34
MAP_LOCATION: TypeAlias = Optional[Union[Callable[[torch.Tensor, str], torch.Tensor], torch.device, str, Dict[str, str]]]
35
STORAGE: TypeAlias = Union[Storage, torch.storage.TypedStorage, torch.UntypedStorage]
36

37
__all__ = [
38
    'SourceChangeWarning',
39
    'mkdtemp',
40
    'register_package',
41
    'check_module_version_greater_or_equal',
42
    'validate_cuda_device',
43
    'validate_hpu_device',
44
    'location_tag',
45
    'default_restore_location',
46
    'normalize_storage_type',
47
    'storage_to_tensor_type',
48
    'save',
49
    'load',
50
    'StorageType',
51
    'LoadEndianness',
52
    'get_default_load_endianness',
53
    'set_default_load_endianness',
54
]
55

56

57
class SourceChangeWarning(Warning):
58
    pass
59

60

61
@contextmanager
62
def mkdtemp():
63
    path = tempfile.mkdtemp()
64
    try:
65
        yield path
66
    finally:
67
        shutil.rmtree(path)
68

69

70
_package_registry: List[Tuple[int, Callable[[STORAGE], Optional[str]], Callable[[STORAGE, str], Optional[STORAGE]]]] = []
71

72
class LoadEndianness(Enum):
73
    NATIVE = 1
74
    LITTLE = 2
75
    BIG = 3
76

77
_default_load_endian: Optional[LoadEndianness] = None
78

79
def get_default_load_endianness() -> Optional[LoadEndianness]:
80
    '''
81
    Get fallback byte order for loading files
82

83
    If byteorder mark is not present in saved checkpoint,
84
    this byte order is used as fallback.
85
    By default, it's "native" byte order.
86

87
    Returns:
88
        default_load_endian: Optional[LoadEndianness]
89
    '''
90
    return _default_load_endian
91

92
def set_default_load_endianness(endianness):
93
    '''
94
    Set fallback byte order for loading files
95

96
    If byteorder mark is not present in saved checkpoint,
97
    this byte order is used as fallback.
98
    By default, it's "native" byte order.
99

100
    Args:
101
        endianness: the new fallback byte order
102
    '''
103
    global _default_load_endian
104
    if not isinstance(endianness, LoadEndianness) and endianness is not None:
105
        raise TypeError("Invalid argument type in function set_default_load_endianness")
106
    _default_load_endian = endianness
107

108
def _is_zipfile(f) -> bool:
109
    # This is a stricter implementation than zipfile.is_zipfile().
110
    # zipfile.is_zipfile() is True if the magic number appears anywhere in the
111
    # binary. Since we expect the files here to be generated by torch.save or
112
    # torch.jit.save, it's safe to only check the start bytes and avoid
113
    # collisions and assume the zip has only 1 file.
114
    # See bugs.python.org/issue28494.
115

116
    start = f.tell()
117
    # Read the first few bytes and match against the ZIP file signature
118
    local_header_magic_number = b'PK\x03\x04'
119
    read_bytes = f.read(len(local_header_magic_number))
120
    f.seek(start)
121
    return read_bytes == local_header_magic_number
122

123

124
def register_package(
125
    priority: int,
126
    tagger: Callable[[STORAGE], Optional[str]],
127
    deserializer: Callable[[STORAGE, str], Optional[STORAGE]]
128
):
129
    '''
130
    Registers callables for tagging and deserializing storage objects with an associated priority.
131
    Tagging associates a device with a storage object at save time while deserializing moves a
132
    storage object to an appropriate device at load time. :attr:`tagger` and :attr:`deserializer`
133
    are run in the order given by their :attr:`priority` until a tagger/deserializer returns a
134
    value that is not `None`.
135

136
    To override the deserialization behavior for a device in the global registry, one can register a
137
    tagger with a higher priority than the existing tagger.
138

139
    This function can also be used to register a tagger and deserializer for new devices.
140

141
    Args:
142
        priority: Indicates the priority associated with the tagger and deserializer, where a lower
143
            value indicates higher priority.
144
        tagger: Callable that takes in a storage object and returns its tagged device as a string
145
            or None.
146
        deserializer: Callable that takes in storage object and a device string and returns a storage
147
            object on the appropriate device or None.
148

149
    Returns:
150
        `None`
151

152
    Example:
153
        >>> def ipu_tag(obj):
154
        >>>     if obj.device.type == 'ipu':
155
        >>>         return 'ipu'
156
        >>> def ipu_deserialize(obj, location):
157
        >>>     if location.startswith('ipu'):
158
        >>>         ipu = getattr(torch, "ipu", None)
159
        >>>         assert ipu is not None, "IPU device module is not loaded"
160
        >>>         assert torch.ipu.is_available(), "ipu is not available"
161
        >>>         return obj.ipu(location)
162
        >>> torch.serialization.register_package(11, ipu_tag, ipu_deserialize)
163
    '''
164
    queue_elem = (priority, tagger, deserializer)
165
    _package_registry.append(queue_elem)
166
    _package_registry.sort()
167

168

169
def check_module_version_greater_or_equal(module, req_version_tuple, error_if_malformed=True):
170
    '''
171
    Check if a module's version satisfies requirements
172

173
    Usually, a module's version string will be like 'x.y.z', which would be represented
174
    as a tuple (x, y, z), but sometimes it could be an unexpected format. If the version
175
    string does not match the given tuple's format up to the length of the tuple, then
176
    error and exit or emit a warning.
177

178
    Args:
179
        module: the module to check the version of
180
        req_version_tuple: tuple (usually of ints) representing the required version
181
        error_if_malformed: whether we should exit if module version string is malformed
182

183
    Returns:
184
        requirement_is_met: bool
185
    '''
186
    try:
187
        version_strs = module.__version__.split('.')
188
        # Cast module version fields to match the types of the required version
189
        module_version = tuple(
190
            type(req_field)(version_strs[idx]) for idx, req_field in enumerate(req_version_tuple)
191
        )
192
        requirement_is_met = module_version >= req_version_tuple
193

194
    except Exception as e:
195
        message = (
196
            f"'{module.__name__}' module version string is malformed '{module.__version__}' and cannot be compared"
197
            f" with tuple {str(req_version_tuple)}"
198
        )
199
        if error_if_malformed:
200
            raise RuntimeError(message) from e
201
        else:
202
            warnings.warn(message + ', but continuing assuming that requirement is met')
203
            requirement_is_met = True
204

205
    return requirement_is_met
206

207

208
def _cpu_tag(obj):
209
    if obj.device.type == 'cpu':
210
        return 'cpu'
211

212

213
def _cuda_tag(obj):
214
    if obj.device.type == 'cuda':
215
        return 'cuda:' + str(obj.device.index)
216

217
def _hpu_tag(obj):
218
    if obj.device.type == 'hpu':
219
        return 'hpu:' + str(obj.device.index)
220

221
def _mps_tag(obj):
222
    if obj.device.type == 'mps':
223
        return 'mps'
224

225

226
def _meta_tag(obj):
227
    if obj.device.type == 'meta':
228
        return 'meta'
229

230

231
def _privateuse1_tag(obj):
232
    backend_name = torch._C._get_privateuse1_backend_name()
233
    if obj.device.type == backend_name:
234
        if obj.device.index is None:
235
            return backend_name
236
        else:
237
            return backend_name + ':' + str(obj.device.index)
238

239

240
def _cpu_deserialize(obj, location):
241
    if location == 'cpu':
242
        return obj
243

244

245
def validate_cuda_device(location):
246
    device = torch.cuda._utils._get_device_index(location, True)
247

248
    if not torch.cuda.is_available():
249
        raise RuntimeError('Attempting to deserialize object on a CUDA '
250
                           'device but torch.cuda.is_available() is False. '
251
                           'If you are running on a CPU-only machine, '
252
                           'please use torch.load with map_location=torch.device(\'cpu\') '
253
                           'to map your storages to the CPU.')
254
    device_count = torch.cuda.device_count()
255
    if device >= device_count:
256
        raise RuntimeError('Attempting to deserialize object on CUDA device '
257
                           f'{device} but torch.cuda.device_count() is {device_count}. Please use '
258
                           'torch.load with map_location to map your storages '
259
                           'to an existing device.')
260
    return device
261

262

263
def _cuda_deserialize(obj, location):
264
    if location.startswith('cuda'):
265
        device = validate_cuda_device(location)
266
        if getattr(obj, "_torch_load_uninitialized", False):
267
            with torch.cuda.device(device):
268
                return torch.UntypedStorage(obj.nbytes(), device=torch.device(location))
269
        else:
270
            return obj.cuda(device)
271

272

273
def validate_hpu_device(location):
274
    hpu = getattr(torch, "hpu", None)
275
    assert hpu is not None, "HPU device module is not loaded"
276
    device = hpu._utils._get_device_index(location, optional=True)
277

278
    if not hpu.is_available():
279
        raise RuntimeError('Attempting to deserialize object on a HPU '
280
                           'device but torch.hpu.is_available() is False. '
281
                           'If you are running on a CPU-only machine, '
282
                           'please use torch.load with map_location=torch.device(\'cpu\') '
283
                           'to map your storages to the CPU.')
284
    device_count = hpu.device_count()
285
    if device >= device_count:
286
        raise RuntimeError('Attempting to deserialize object on HPU device '
287
                           f'{device} but torch.hpu.device_count() is {device_count}. Please use '
288
                           'torch.load with map_location to map your storages '
289
                           'to an existing device.')
290
    return device
291

292

293
def _hpu_deserialize(obj, location):
294
    if location.startswith('hpu'):
295
        hpu = getattr(torch, "hpu", None)
296
        assert hpu is not None, "HPU device module is not loaded"
297
        device = validate_hpu_device(location)
298
        if getattr(obj, "_torch_load_uninitialized", False):
299
            with hpu.device(device):
300
                return torch.UntypedStorage(obj.nbytes(), device=torch.device(location))
301
        else:
302
            return obj.hpu(device)
303

304

305
def _mps_deserialize(obj, location):
306
    if location.startswith('mps'):
307
        return obj.mps()
308

309

310
def _meta_deserialize(obj, location):
311
    if location == 'meta':
312
        return torch.UntypedStorage(obj.nbytes(), device='meta')
313

314

315
def _validate_privateuse1_device(location, backend_name):
316
    '''
317
    Check whether the device index of privateuse1 is valid
318

319
    Register a device_module of privateuse1 by torch._register_device_module.
320
    Implement the following methods in device_module like cuda:
321
    device_module._utils._get_device_index(location, True),
322
    device_module.device_count().
323

324
    Args:
325
        location: string of device
326
        backend_name: the name of privateuse1, which can be renamed
327

328
    Returns:
329
        device_index: int
330
    '''
331
    if not hasattr(torch, backend_name):
332
        raise RuntimeError(f'The {backend_name.upper()} device module is not registered. '
333
                           'If you are running on a CPU-only machine, '
334
                           'please use torch.load with map_location=torch.device(\'cpu\') '
335
                           'to map your storages to the CPU.')
336
    device_module = getattr(torch, backend_name)
337
    if hasattr(device_module, '_utils') and hasattr(device_module._utils, '_get_device_index'):
338
        device_index = device_module._utils._get_device_index(location, True)
339
    else:
340
        device = torch.device(location)
341
        device_index = device.index if device.index else 0
342
    if hasattr(device_module, 'is_available') and not device_module.is_available():
343
        raise RuntimeError(f'Attempting to deserialize object on a {backend_name.upper()} '
344
                           f'device but torch.{backend_name}.is_available() is False. '
345
                           'If you are running on a CPU-only machine, '
346
                           'please use torch.load with map_location=torch.device(\'cpu\') '
347
                           'to map your storages to the CPU.')
348
    if hasattr(device_module, 'device_count'):
349
        device_count = device_module.device_count()
350
        if device_index >= device_count:
351
            raise RuntimeError(f'Attempting to deserialize object on {backend_name.upper()} device '
352
                               f'{device_index} but torch.{backend_name}.device_count() is {device_count}. '
353
                               'Please use torch.load with map_location to map your storages '
354
                               'to an existing device.')
355
    return device_index
356

357

358
def _privateuse1_deserialize(obj, location):
359
    backend_name = torch._C._get_privateuse1_backend_name()
360
    if location.startswith(backend_name):
361
        if not hasattr(obj, backend_name):
362
            raise RuntimeError(f'Attempting to load the storages to the {backend_name.upper()} device '
363
                               f'but torch.storage._StorageBase.{backend_name}() or '
364
                               f'torch.storage.TypedStorage.{backend_name}() is not generated. '
365
                               'Please use torch.utils.generate_methods_for_privateuse1_backend '
366
                               f'to generate storage.{backend_name}() method first.')
367
        device_index = _validate_privateuse1_device(location, backend_name)
368
        return getattr(obj, backend_name)(device_index)
369

370

371
register_package(10, _cpu_tag, _cpu_deserialize)
372
register_package(20, _cuda_tag, _cuda_deserialize)
373
register_package(21, _mps_tag, _mps_deserialize)
374
register_package(22, _meta_tag, _meta_deserialize)
375
register_package(23, _privateuse1_tag, _privateuse1_deserialize)
376
register_package(24, _hpu_tag, _hpu_deserialize)
377

378

379
def location_tag(storage: Union[Storage, torch.storage.TypedStorage, torch.UntypedStorage]):
380
    for _, tagger, _ in _package_registry:
381
        location = tagger(storage)
382
        if location:
383
            return location
384
    raise RuntimeError("don't know how to determine data location of "
385
                       + torch.typename(storage))
386

387

388
def default_restore_location(storage, location):
389
    for _, _, fn in _package_registry:
390
        result = fn(storage, location)
391
        if result is not None:
392
            return result
393
    raise RuntimeError("don't know how to restore data location of "
394
                       + torch.typename(storage) + " (tagged with "
395
                       + location + ")")
396

397

398
def normalize_storage_type(storage_type):
399
    return getattr(torch, storage_type.__name__)
400

401

402
def storage_to_tensor_type(storage):
403
    storage_type = type(storage)
404
    module = _import_dotted_name(storage_type.__module__)
405
    return getattr(module, storage_type.__name__.replace('Storage', 'Tensor'))
406

407

408
def _is_path(name_or_buffer) -> TypeGuard[Union[str, os.PathLike]]:
409
    return isinstance(name_or_buffer, (str, os.PathLike))
410

411

412
class _opener:
413
    def __init__(self, file_like):
414
        self.file_like = file_like
415

416
    def __enter__(self):
417
        return self.file_like
418

419
    def __exit__(self, *args):
420
        pass
421

422

423
class _open_file(_opener):
424
    def __init__(self, name, mode):
425
        super().__init__(open(name, mode))
426

427
    def __exit__(self, *args):
428
        self.file_like.close()
429

430

431
class _open_buffer_reader(_opener):
432
    def __init__(self, buffer):
433
        super().__init__(buffer)
434
        _check_seekable(buffer)
435

436

437
class _open_buffer_writer(_opener):
438
    def __exit__(self, *args):
439
        self.file_like.flush()
440

441

442
def _open_file_like(name_or_buffer, mode):
443
    if _is_path(name_or_buffer):
444
        return _open_file(name_or_buffer, mode)
445
    else:
446
        if 'w' in mode:
447
            return _open_buffer_writer(name_or_buffer)
448
        elif 'r' in mode:
449
            return _open_buffer_reader(name_or_buffer)
450
        else:
451
            raise RuntimeError(f"Expected 'r' or 'w' in mode but got {mode}")
452

453

454
class _open_zipfile_reader(_opener):
455
    def __init__(self, name_or_buffer) -> None:
456
        super().__init__(torch._C.PyTorchFileReader(name_or_buffer))
457

458

459
class _open_zipfile_writer_file(_opener):
460
    def __init__(self, name) -> None:
461
        self.file_stream = None
462
        self.name = str(name)
463
        try:
464
            self.name.encode('ascii')
465
        except UnicodeEncodeError:
466
            # PyTorchFileWriter only supports ascii filename.
467
            # For filenames with non-ascii characters, we rely on Python
468
            # for writing out the file.
469
            self.file_stream = io.FileIO(self.name, mode='w')
470
            super().__init__(torch._C.PyTorchFileWriter(self.file_stream))
471
        else:
472
            super().__init__(torch._C.PyTorchFileWriter(self.name))
473

474
    def __exit__(self, *args) -> None:
475
        self.file_like.write_end_of_file()
476
        if self.file_stream is not None:
477
            self.file_stream.close()
478

479

480
class _open_zipfile_writer_buffer(_opener):
481
    def __init__(self, buffer) -> None:
482
        if not callable(getattr(buffer, "write", None)):
483
            msg = f"Buffer of {str(type(buffer)).strip('<>')} has no callable attribute 'write'"
484
            if not hasattr(buffer, "write"):
485
                raise AttributeError(msg)
486
            raise TypeError(msg)
487
        self.buffer = buffer
488
        super().__init__(torch._C.PyTorchFileWriter(buffer))
489

490
    def __exit__(self, *args) -> None:
491
        self.file_like.write_end_of_file()
492
        self.buffer.flush()
493

494

495
def _open_zipfile_writer(name_or_buffer):
496
    container: Type[_opener]
497
    if _is_path(name_or_buffer):
498
        container = _open_zipfile_writer_file
499
    else:
500
        container = _open_zipfile_writer_buffer
501
    return container(name_or_buffer)
502

503

504
def _is_compressed_file(f) -> bool:
505
    compress_modules = ['gzip']
506
    try:
507
        return f.__module__ in compress_modules
508
    except AttributeError:
509
        return False
510

511

512
def _should_read_directly(f):
513
    """
514
    Checks if f is a file that should be read directly. It should be read
515
    directly if it is backed by a real file (has a fileno) and is not a
516
    a compressed file (e.g. gzip)
517
    """
518
    if _is_compressed_file(f):
519
        return False
520
    try:
521
        return f.fileno() >= 0
522
    except io.UnsupportedOperation:
523
        return False
524
    except AttributeError:
525
        return False
526

527

528
def _check_seekable(f) -> bool:
529

530
    def raise_err_msg(patterns, e):
531
        for p in patterns:
532
            if p in str(e):
533
                msg = (str(e) + ". You can only torch.load from a file that is seekable."
534
                                + " Please pre-load the data into a buffer like io.BytesIO and"
535
                                + " try to load from it instead.")
536
                raise type(e)(msg)
537
        raise e
538

539
    try:
540
        f.seek(f.tell())
541
        return True
542
    except (io.UnsupportedOperation, AttributeError) as e:
543
        raise_err_msg(["seek", "tell"], e)
544
    return False
545

546

547
def _check_dill_version(pickle_module) -> None:
548
    '''Checks if using dill as the pickle module, and if so, checks if it is the correct version.
549
    If dill version is lower than 0.3.1, a ValueError is raised.
550

551
    Args:
552
        pickle_module: module used for pickling metadata and objects
553

554
    '''
555
    if pickle_module is not None and pickle_module.__name__ == 'dill':
556
        required_dill_version = (0, 3, 1)
557
        if not check_module_version_greater_or_equal(pickle_module, required_dill_version, False):
558
            raise ValueError((
559
                "'torch' supports dill >= {}, but you have dill {}."
560
                " Please upgrade dill or switch to 'pickle'"
561
            ).format(
562
                '.'.join([str(num) for num in required_dill_version]),
563
                pickle_module.__version__
564
            ))
565

566

567
def _check_save_filelike(f):
568
    if not _is_path(f) and not hasattr(f, 'write'):
569
        raise AttributeError(
570
            "expected 'f' to be string, path, or a file-like object with "
571
            "a 'write' attribute")
572

573

574
def save(
575
    obj: object,
576
    f: FILE_LIKE,
577
    pickle_module: Any = pickle,
578
    pickle_protocol: int = DEFAULT_PROTOCOL,
579
    _use_new_zipfile_serialization: bool = True,
580
    _disable_byteorder_record: bool = False
581
) -> None:
582
    # Reference: https://github.com/pytorch/pytorch/issues/54354
583
    # The first line of this docstring overrides the one Sphinx generates for the
584
    # documentation. We need it so that Sphinx doesn't leak `pickle`s path from
585
    # the build environment (e.g. `<module 'pickle' from '/leaked/path').
586

587
    """save(obj, f, pickle_module=pickle, pickle_protocol=DEFAULT_PROTOCOL, _use_new_zipfile_serialization=True)
588

589
    Saves an object to a disk file.
590

591
    See also: :ref:`saving-loading-tensors`
592

593
    Args:
594
        obj: saved object
595
        f: a file-like object (has to implement write and flush) or a string or
596
           os.PathLike object containing a file name
597
        pickle_module: module used for pickling metadata and objects
598
        pickle_protocol: can be specified to override the default protocol
599

600
    .. note::
601
        A common PyTorch convention is to save tensors using .pt file extension.
602

603
    .. note::
604
        PyTorch preserves storage sharing across serialization. See
605
        :ref:`preserve-storage-sharing` for more details.
606

607
    .. note::
608
        The 1.6 release of PyTorch switched ``torch.save`` to use a new
609
        zipfile-based file format. ``torch.load`` still retains the ability to
610
        load files in the old format. If for any reason you want ``torch.save``
611
        to use the old format, pass the kwarg ``_use_new_zipfile_serialization=False``.
612

613
    Example:
614
        >>> # xdoctest: +SKIP("makes cwd dirty")
615
        >>> # Save to file
616
        >>> x = torch.tensor([0, 1, 2, 3, 4])
617
        >>> torch.save(x, 'tensor.pt')
618
        >>> # Save to io.BytesIO buffer
619
        >>> buffer = io.BytesIO()
620
        >>> torch.save(x, buffer)
621
    """
622
    torch._C._log_api_usage_once("torch.save")
623
    _check_dill_version(pickle_module)
624
    _check_save_filelike(f)
625

626
    if _use_new_zipfile_serialization:
627
        with _open_zipfile_writer(f) as opened_zipfile:
628
            _save(obj, opened_zipfile, pickle_module, pickle_protocol, _disable_byteorder_record)
629
            return
630
    else:
631
        with _open_file_like(f, 'wb') as opened_file:
632
            _legacy_save(obj, opened_file, pickle_module, pickle_protocol)
633

634

635
def _legacy_save(obj, f, pickle_module, pickle_protocol) -> None:
636
    import torch.nn as nn
637
    serialized_container_types = {}
638
    serialized_storages = {}
639

640
    # Since loading storages that view the same data with different dtypes is
641
    # not supported, we need to keep track of the dtype associated with each
642
    # storage data_ptr and throw an error if the dtype is ever different.
643
    # TODO: This feature could be added in the future
644
    storage_dtypes: Dict[int, torch.dtype] = {}
645

646
    def persistent_id(obj: Any) -> Optional[Tuple]:
647
        # FIXME: the docs say that persistent_id should only return a string
648
        # but torch store returns tuples. This works only in the binary protocol
649
        # see
650
        # https://docs.python.org/2/library/pickle.html#pickling-and-unpickling-external-objects
651
        # https://github.com/python/cpython/blob/master/Lib/pickle.py#L527-L537
652
        if isinstance(obj, type) and issubclass(obj, nn.Module):
653
            if obj in serialized_container_types:
654
                return None
655
            serialized_container_types[obj] = True
656
            source_file = source = None
657
            try:
658
                source_lines, _, source_file = get_source_lines_and_file(obj)
659
                source = ''.join(source_lines)
660
            except Exception:  # saving the source is optional, so we can ignore any errors
661
                warnings.warn("Couldn't retrieve source code for container of "
662
                              "type " + obj.__name__ + ". It won't be checked "
663
                              "for correctness upon loading.")
664
            return ('module', obj, source_file, source)
665

666
        if isinstance(obj, torch.storage.TypedStorage) or torch.is_storage(obj):
667
            storage: torch.UntypedStorage
668

669
            if isinstance(obj, torch.storage.TypedStorage):
670
                # TODO: Once we decide to break serialization FC, this case
671
                # can be deleted
672
                storage = obj._untyped_storage
673
                storage_dtype = obj.dtype
674
                storage_type_str = obj._pickle_storage_type()
675
                storage_type = getattr(torch, storage_type_str)
676
                dtype = obj.dtype
677
                storage_numel = obj._size()
678

679
            elif isinstance(obj, torch.UntypedStorage):
680
                storage = obj
681
                storage_dtype = torch.uint8
682
                storage_type = normalize_storage_type(type(obj))
683
                dtype = torch.uint8
684
                storage_numel = storage.nbytes()
685
            else:
686
                raise TypeError(f'type not recognized: {type(obj)}')
687

688
            # If storage is allocated, ensure that any other saved storages
689
            # pointing to the same data all have the same dtype. If storage is
690
            # not allocated, don't perform this check
691
            if storage.data_ptr() != 0:
692
                if storage.data_ptr() in storage_dtypes:
693
                    if storage_dtype != storage_dtypes[storage.data_ptr()]:
694
                        raise RuntimeError(
695
                            'Cannot save multiple tensors or storages that '
696
                            'view the same data as different types')
697
                else:
698
                    storage_dtypes[storage.data_ptr()] = storage_dtype
699

700
            view_metadata: Optional[Tuple[str, int, int]]
701

702
            # Offset is always 0, but we keep it for backwards compatibility
703
            # with the old serialization format (which supported storage views)
704
            offset = 0
705
            storage_key = str(storage._cdata)
706
            location = location_tag(storage)
707

708
            # TODO: There's an issue here with FC. It might be impossible to
709
            # solve, but it's worth noting. Imagine we save a list `[storage,
710
            # tensor]`, where `tensor.storage()` is the same as `storage`, and
711
            # `tensor.element_size() > 1`. Let's say that `tensor.dtype ==
712
            # torch.float`.  The storage will be serialized with element size
713
            # of 1, since we're choosing to serialize the first occurance of
714
            # a duplicate storage. Since this legacy serialization format saves
715
            # the numel of the storage, rather than nbytes directly, we'll be
716
            # effectively saving nbytes in this case.  We'll be able to load it
717
            # and the tensor back up with no problems in _this_ and future
718
            # versions of pytorch, but in older versions, here's the problem:
719
            # the storage will be loaded up as a UntypedStorage, and then the
720
            # FloatTensor will loaded and the UntypedStorage will be assigned to
721
            # it. Since the storage dtype does not match the tensor dtype, this
722
            # will cause an error.  If we reverse the list, like `[tensor,
723
            # storage]`, then we will save the `tensor.storage()` as a faked
724
            # `FloatStorage`, and the saved size will be the correct
725
            # dtype-specific numel count that old versions expect. `tensor`
726
            # will be able to load up properly in old versions, pointing to
727
            # a FloatStorage. However, `storage` is still being translated to
728
            # a UntypedStorage, and it will try to resolve to the same
729
            # FloatStorage that `tensor` contains. This will also cause an
730
            # error. It doesn't seem like there's any way around this.
731
            # Probably, we just cannot maintain FC for the legacy format if the
732
            # saved list contains both a tensor and a storage that point to the
733
            # same data.  We should still be able to maintain FC for lists of
734
            # just tensors, as long as all views share the same dtype as the
735
            # tensor they are viewing.
736

737
            if storage_key not in serialized_storages:
738
                serialized_storages[storage_key] = (storage, dtype)
739
            is_view = storage._cdata != storage._cdata
740
            if is_view:
741
                view_metadata = (str(storage._cdata), offset, storage.nbytes())
742
            else:
743
                view_metadata = None
744

745
            res = ('storage',
746
                   storage_type,
747
                   storage_key,
748
                   location,
749
                   storage_numel,
750
                   view_metadata)
751
            return res
752
        return None
753

754
    sys_info = dict(
755
        protocol_version=PROTOCOL_VERSION,
756
        little_endian=sys.byteorder == 'little',
757
        type_sizes=dict(
758
            short=SHORT_SIZE,
759
            int=INT_SIZE,
760
            long=LONG_SIZE,
761
        ),
762
    )
763

764
    pickle_module.dump(MAGIC_NUMBER, f, protocol=pickle_protocol)
765
    pickle_module.dump(PROTOCOL_VERSION, f, protocol=pickle_protocol)
766
    pickle_module.dump(sys_info, f, protocol=pickle_protocol)
767
    pickler = pickle_module.Pickler(f, protocol=pickle_protocol)
768
    pickler.persistent_id = persistent_id
769
    pickler.dump(obj)
770

771
    serialized_storage_keys = sorted(serialized_storages.keys())
772
    pickle_module.dump(serialized_storage_keys, f, protocol=pickle_protocol)
773
    f.flush()
774
    for key in serialized_storage_keys:
775
        storage, dtype = serialized_storages[key]
776
        storage._write_file(f, _should_read_directly(f), True, torch._utils._element_size(dtype))
777

778

779
def _save(obj, zip_file, pickle_module, pickle_protocol, _disable_byteorder_record):
780
    serialized_storages = {}
781
    id_map: Dict[int, str] = {}
782

783
    # Since loading storages that view the same data with different dtypes is
784
    # not supported, we need to keep track of the dtype associated with each
785
    # storage data_ptr and throw an error if the dtype is ever different.
786
    # TODO: This feature could be added in the future
787
    storage_dtypes: Dict[int, torch.dtype] = {}
788

789
    def persistent_id(obj):
790
        # FIXME: the docs say that persistent_id should only return a string
791
        # but torch store returns tuples. This works only in the binary protocol
792
        # see
793
        # https://docs.python.org/2/library/pickle.html#pickling-and-unpickling-external-objects
794
        # https://github.com/python/cpython/blob/master/Lib/pickle.py#L527-L537
795
        if isinstance(obj, torch.storage.TypedStorage) or torch.is_storage(obj):
796

797
            if isinstance(obj, torch.storage.TypedStorage):
798
                # TODO: Once we decide to break serialization FC, this case
799
                # can be deleted
800
                storage = obj._untyped_storage
801
                storage_dtype = obj.dtype
802
                storage_type_str = obj._pickle_storage_type()
803
                storage_type = getattr(torch, storage_type_str)
804
                storage_numel = obj._size()
805

806
            else:
807
                storage = obj
808
                storage_dtype = torch.uint8
809
                storage_type = normalize_storage_type(type(obj))
810
                storage_numel = storage.nbytes()
811

812
            # If storage is allocated, ensure that any other saved storages
813
            # pointing to the same data all have the same dtype. If storage is
814
            # not allocated, don't perform this check
815
            if storage.data_ptr() != 0:
816
                if storage.data_ptr() in storage_dtypes:
817
                    if storage_dtype != storage_dtypes[storage.data_ptr()]:
818
                        raise RuntimeError(
819
                            'Cannot save multiple tensors or storages that '
820
                            'view the same data as different types')
821
                else:
822
                    storage_dtypes[storage.data_ptr()] = storage_dtype
823

824
            storage_key = id_map.setdefault(storage._cdata, str(len(id_map)))
825
            location = location_tag(storage)
826
            serialized_storages[storage_key] = storage
827

828
            return ('storage',
829
                    storage_type,
830
                    storage_key,
831
                    location,
832
                    storage_numel)
833

834
        return None
835

836
    # Write the pickle data for `obj`
837
    data_buf = io.BytesIO()
838
    pickler = pickle_module.Pickler(data_buf, protocol=pickle_protocol)
839
    pickler.persistent_id = persistent_id
840
    pickler.dump(obj)
841
    data_value = data_buf.getvalue()
842
    zip_file.write_record('data.pkl', data_value, len(data_value))
843

844
    # Write byte order marker
845
    if not _disable_byteorder_record:
846
        if sys.byteorder not in ['little', 'big']:
847
            raise ValueError('Unknown endianness type: ' + sys.byteorder)
848

849
        zip_file.write_record('byteorder', sys.byteorder, len(sys.byteorder))
850

851
    # Write each tensor to a file named tensor/the_tensor_key in the zip archive
852
    for key in sorted(serialized_storages.keys()):
853
        name = f'data/{key}'
854
        storage = serialized_storages[key]
855
        # given that we copy things around anyway, we might use storage.cpu()
856
        # this means to that to get tensors serialized, you need to implement
857
        # .cpu() on the underlying Storage
858
        if storage.device.type != 'cpu':
859
            storage = storage.cpu()
860
        # Now that it is on the CPU we can directly copy it into the zip file
861
        num_bytes = storage.nbytes()
862
        zip_file.write_record(name, storage, num_bytes)
863

864

865
def load(
866
    f: FILE_LIKE,
867
    map_location: MAP_LOCATION = None,
868
    pickle_module: Any = None,
869
    *,
870
    weights_only: bool = False,
871
    mmap: Optional[bool] = None,
872
    **pickle_load_args: Any
873
) -> Any:
874
    # Reference: https://github.com/pytorch/pytorch/issues/54354
875
    # The first line of this docstring overrides the one Sphinx generates for the
876
    # documentation. We need it so that Sphinx doesn't leak `pickle`s path from
877
    # the build environment (e.g. `<module 'pickle' from '/leaked/path').
878

879
    """load(f, map_location=None, pickle_module=pickle, *, weights_only=False, mmap=None, **pickle_load_args)
880

881
    Loads an object saved with :func:`torch.save` from a file.
882

883
    :func:`torch.load` uses Python's unpickling facilities but treats storages,
884
    which underlie tensors, specially. They are first deserialized on the
885
    CPU and are then moved to the device they were saved from. If this fails
886
    (e.g. because the run time system doesn't have certain devices), an exception
887
    is raised. However, storages can be dynamically remapped to an alternative
888
    set of devices using the :attr:`map_location` argument.
889

890
    If :attr:`map_location` is a callable, it will be called once for each serialized
891
    storage with two arguments: storage and location. The storage argument
892
    will be the initial deserialization of the storage, residing on the CPU.
893
    Each serialized storage has a location tag associated with it which
894
    identifies the device it was saved from, and this tag is the second
895
    argument passed to :attr:`map_location`. The builtin location tags are ``'cpu'``
896
    for CPU tensors and ``'cuda:device_id'`` (e.g. ``'cuda:2'``) for CUDA tensors.
897
    :attr:`map_location` should return either ``None`` or a storage. If
898
    :attr:`map_location` returns a storage, it will be used as the final deserialized
899
    object, already moved to the right device. Otherwise, :func:`torch.load` will
900
    fall back to the default behavior, as if :attr:`map_location` wasn't specified.
901

902
    If :attr:`map_location` is a :class:`torch.device` object or a string containing
903
    a device tag, it indicates the location where all tensors should be loaded.
904

905
    Otherwise, if :attr:`map_location` is a dict, it will be used to remap location tags
906
    appearing in the file (keys), to ones that specify where to put the
907
    storages (values).
908

909
    User extensions can register their own location tags and tagging and
910
    deserialization methods using :func:`torch.serialization.register_package`.
911

912
    Args:
913
        f: a file-like object (has to implement :meth:`read`, :meth:`readline`, :meth:`tell`, and :meth:`seek`),
914
            or a string or os.PathLike object containing a file name
915
        map_location: a function, :class:`torch.device`, string or a dict specifying how to remap storage
916
            locations
917
        pickle_module: module used for unpickling metadata and objects (has to
918
            match the :attr:`pickle_module` used to serialize file)
919
        weights_only: Indicates whether unpickler should be restricted to
920
            loading only tensors, primitive types and dictionaries
921
        mmap: Indicates whether the file should be mmaped rather than loading all the storages into memory.
922
            Typically, tensor storages in the file will first be moved from disk to CPU memory, after which they
923
            are moved to the location that they were tagged with when saving, or specified by ``map_location``. This
924
            second step is a no-op if the final location is CPU. When the ``mmap`` flag is set, instead of copying the
925
            tensor storages from disk to CPU memory in the first step, ``f`` is mmaped.
926
        pickle_load_args: (Python 3 only) optional keyword arguments passed over to
927
            :func:`pickle_module.load` and :func:`pickle_module.Unpickler`, e.g.,
928
            :attr:`errors=...`.
929

930
    .. warning::
931
        :func:`torch.load()` unless `weights_only` parameter is set to `True`,
932
        uses ``pickle`` module implicitly, which is known to be insecure.
933
        It is possible to construct malicious pickle data which will execute arbitrary code
934
        during unpickling. Never load data that could have come from an untrusted
935
        source in an unsafe mode, or that could have been tampered with. **Only load data you trust**.
936

937
    .. note::
938
        When you call :func:`torch.load()` on a file which contains GPU tensors, those tensors
939
        will be loaded to GPU by default. You can call ``torch.load(.., map_location='cpu')``
940
        and then :meth:`load_state_dict` to avoid GPU RAM surge when loading a model checkpoint.
941

942
    .. note::
943
        By default, we decode byte strings as ``utf-8``.  This is to avoid a common error
944
        case ``UnicodeDecodeError: 'ascii' codec can't decode byte 0x...``
945
        when loading files saved by Python 2 in Python 3.  If this default
946
        is incorrect, you may use an extra :attr:`encoding` keyword argument to specify how
947
        these objects should be loaded, e.g., :attr:`encoding='latin1'` decodes them
948
        to strings using ``latin1`` encoding, and :attr:`encoding='bytes'` keeps them
949
        as byte arrays which can be decoded later with ``byte_array.decode(...)``.
950

951
    Example:
952
        >>> # xdoctest: +SKIP("undefined filepaths")
953
        >>> torch.load('tensors.pt', weights_only=True)
954
        # Load all tensors onto the CPU
955
        >>> torch.load('tensors.pt', map_location=torch.device('cpu'), weights_only=True)
956
        # Load all tensors onto the CPU, using a function
957
        >>> torch.load('tensors.pt', map_location=lambda storage, loc: storage, weights_only=True)
958
        # Load all tensors onto GPU 1
959
        >>> torch.load('tensors.pt', map_location=lambda storage, loc: storage.cuda(1), weights_only=True)
960
        # Map tensors from GPU 1 to GPU 0
961
        >>> torch.load('tensors.pt', map_location={'cuda:1': 'cuda:0'}, weights_only=True)
962
        # Load tensor from io.BytesIO object
963
        # Loading from a buffer setting weights_only=False, warning this can be unsafe
964
        >>> with open('tensor.pt', 'rb') as f:
965
        ...     buffer = io.BytesIO(f.read())
966
        >>> torch.load(buffer, weights_only=False)
967
        # Load a module with 'ascii' encoding for unpickling
968
        # Loading from a module setting weights_only=False, warning this can be unsafe
969
        >>> torch.load('module.pt', encoding='ascii', weights_only=False)
970
    """
971
    torch._C._log_api_usage_once("torch.load")
972
    UNSAFE_MESSAGE = (
973
        "Weights only load failed. Re-running `torch.load` with `weights_only` set to `False`"
974
        " will likely succeed, but it can result in arbitrary code execution."
975
        "Do it only if you get the file from a trusted source. WeightsUnpickler error: "
976
    )
977
    # Add ability to force safe only weight loads via environment variable
978
    if os.getenv("TORCH_FORCE_WEIGHTS_ONLY_LOAD", "0").lower() in ['1', 'y', 'yes', 'true']:
979
        weights_only = True
980

981
    if weights_only:
982
        if pickle_module is not None:
983
            raise RuntimeError("Can not safely load weights when explicit pickle_module is specified")
984
    else:
985
        if pickle_module is None:
986
            pickle_module = pickle
987

988
    # make flipping default BC-compatible
989
    if mmap is None:
990
        mmap = False
991

992
    _check_dill_version(pickle_module)
993

994
    if 'encoding' not in pickle_load_args.keys():
995
        pickle_load_args['encoding'] = 'utf-8'
996

997
    with _open_file_like(f, 'rb') as opened_file:
998
        if _is_zipfile(opened_file):
999
            # The zipfile reader is going to advance the current file position.
1000
            # If we want to actually tail call to torch.jit.load, we need to
1001
            # reset back to the original position.
1002
            orig_position = opened_file.tell()
1003
            overall_storage = None
1004
            with _open_zipfile_reader(opened_file) as opened_zipfile:
1005
                if _is_torchscript_zip(opened_zipfile):
1006
                    warnings.warn("'torch.load' received a zip file that looks like a TorchScript archive"
1007
                                  " dispatching to 'torch.jit.load' (call 'torch.jit.load' directly to"
1008
                                  " silence this warning)", UserWarning)
1009
                    opened_file.seek(orig_position)
1010
                    return torch.jit.load(opened_file, map_location=map_location)
1011
                if mmap:
1012
                    if not _is_path(f):
1013
                        raise ValueError("f must be a file path in order to use the mmap argument")
1014
                    size = os.path.getsize(f)
1015
                    overall_storage = torch.UntypedStorage.from_file(os.fspath(f), False, size)
1016
                if weights_only:
1017
                    try:
1018
                        return _load(opened_zipfile,
1019
                                     map_location,
1020
                                     _weights_only_unpickler,
1021
                                     overall_storage=overall_storage,
1022
                                     **pickle_load_args)
1023
                    except RuntimeError as e:
1024
                        raise pickle.UnpicklingError(UNSAFE_MESSAGE + str(e)) from None
1025
                return _load(opened_zipfile,
1026
                             map_location,
1027
                             pickle_module,
1028
                             overall_storage=overall_storage,
1029
                             **pickle_load_args)
1030
        if mmap:
1031
            raise RuntimeError("mmap can only be used with files saved with "
1032
                               "`torch.save(_use_new_zipfile_serialization=True), "
1033
                               "please torch.save your checkpoint with this option in order to use mmap.")
1034
        if weights_only:
1035
            try:
1036
                return _legacy_load(opened_file, map_location, _weights_only_unpickler, **pickle_load_args)
1037
            except RuntimeError as e:
1038
                raise pickle.UnpicklingError(UNSAFE_MESSAGE + str(e)) from None
1039
        return _legacy_load(opened_file, map_location, pickle_module, **pickle_load_args)
1040

1041

1042
# Register pickling support for layout instances such as
1043
# torch.sparse_coo, etc
1044
def _get_layout(name):
1045
    """Get layout extension object from its string representation.
1046
    """
1047
    cache = _get_layout.cache   # type: ignore[attr-defined]
1048
    if not cache:
1049
        for v in torch.__dict__.values():
1050
            if isinstance(v, torch.layout):
1051
                cache[str(v)] = v
1052
    return cache[name]
1053

1054
# There are yet not good way to type annotate function attributes https://github.com/python/mypy/issues/2087
1055
_get_layout.cache = {}   # type: ignore[attr-defined]
1056
copyreg.pickle(torch.layout, lambda obj: (_get_layout, (str(obj),)))
1057

1058

1059
def _legacy_load(f, map_location, pickle_module, **pickle_load_args):
1060
    deserialized_objects: Dict[int, Any] = {}
1061

1062
    restore_location = _get_restore_location(map_location)
1063

1064
    class UnpicklerWrapper(pickle_module.Unpickler):  # type: ignore[name-defined]
1065

1066
        def find_class(self, mod_name, name):
1067
            if type(name) is str and 'Storage' in name:
1068
                try:
1069
                    return StorageType(name)
1070
                except KeyError:
1071
                    pass
1072
            return super().find_class(mod_name, name)
1073

1074
    def _check_container_source(container_type, source_file, original_source):
1075
        try:
1076
            current_source = ''.join(get_source_lines_and_file(container_type)[0])
1077
        except Exception:  # saving the source is optional, so we can ignore any errors
1078
            warnings.warn("Couldn't retrieve source code for container of "
1079
                          "type " + container_type.__name__ + ". It won't be checked "
1080
                          "for correctness upon loading.")
1081
            return
1082
        if original_source != current_source:
1083
            if container_type.dump_patches:
1084
                file_name = container_type.__name__ + '.patch'
1085
                diff = difflib.unified_diff(current_source.split('\n'),
1086
                                            original_source.split('\n'),
1087
                                            source_file,
1088
                                            source_file, lineterm="")
1089
                lines = '\n'.join(diff)
1090
                try:
1091
                    with open(file_name, 'a+') as f:
1092
                        file_size = f.seek(0, 2)
1093
                        f.seek(0)
1094
                        if file_size == 0:
1095
                            f.write(lines)
1096
                        elif file_size != len(lines) or f.read() != lines:
1097
                            raise OSError
1098
                    msg = ("Saved a reverse patch to " + file_name + ". "
1099
                           "Run `patch -p0 < " + file_name + "` to revert your "
1100
                           "changes.")
1101
                except OSError:
1102
                    msg = ("Tried to save a patch, but couldn't create a "
1103
                           "writable file " + file_name + ". Make sure it "
1104
                           "doesn't exist and your working directory is "
1105
                           "writable.")
1106
            else:
1107
                msg = ("you can retrieve the original source code by "
1108
                       "accessing the object's source attribute or set "
1109
                       "`torch.nn.Module.dump_patches = True` and use the "
1110
                       "patch tool to revert the changes.")
1111
            msg = f"source code of class '{torch.typename(container_type)}' has changed. {msg}"
1112
            warnings.warn(msg, SourceChangeWarning)
1113

1114
    def legacy_load(f):
1115
        deserialized_objects: Dict[int, Any] = {}
1116

1117
        def persistent_load(saved_id):
1118
            if isinstance(saved_id, tuple):
1119
                # Ignore containers that don't have any sources saved
1120
                if all(saved_id[1:]):
1121
                    _check_container_source(*saved_id)
1122
                return saved_id[0]
1123
            return deserialized_objects[int(saved_id)]
1124

1125
        with closing(tarfile.open(fileobj=f, mode='r:', format=tarfile.PAX_FORMAT)) as tar, \
1126
                mkdtemp() as tmpdir:
1127

1128
            tar.extract('storages', path=tmpdir)
1129
            with open(os.path.join(tmpdir, 'storages'), 'rb', 0) as f:
1130
                num_storages = pickle_module.load(f, **pickle_load_args)
1131
                for i in range(num_storages):
1132
                    args = pickle_module.load(f, **pickle_load_args)
1133
                    key, location, storage_type = args
1134
                    dtype = storage_type._dtype
1135
                    obj = cast(Storage, torch.UntypedStorage)._new_with_file(f, torch._utils._element_size(dtype))
1136
                    obj = restore_location(obj, location)
1137
                    # TODO: Once we decide to break serialization FC, we can
1138
                    # stop wrapping with TypedStorage
1139
                    deserialized_objects[key] = torch.storage.TypedStorage(
1140
                        wrap_storage=obj,
1141
                        dtype=dtype,
1142
                        _internal=True)
1143

1144
                storage_views = pickle_module.load(f, **pickle_load_args)
1145
                for target_cdata, root_cdata, offset, numel in storage_views:
1146
                    root = deserialized_objects[root_cdata]
1147
                    element_size = torch._utils._element_size(root.dtype)
1148
                    offset_bytes = offset * element_size
1149
                    # TODO: Once we decide to break serialization FC, we can
1150
                    # stop wrapping with TypedStorage
1151
                    deserialized_objects[target_cdata] = torch.storage.TypedStorage(
1152
                        wrap_storage=root._untyped_storage[offset_bytes:offset_bytes + numel * element_size],
1153
                        dtype=root.dtype,
1154
                        _internal=True)
1155

1156
            tar.extract('tensors', path=tmpdir)
1157
            with open(os.path.join(tmpdir, 'tensors'), 'rb', 0) as f:
1158
                num_tensors = pickle_module.load(f, **pickle_load_args)
1159
                for _ in range(num_tensors):
1160
                    args = pickle_module.load(f, **pickle_load_args)
1161
                    key, storage_id, original_tensor_type = args
1162
                    storage = deserialized_objects[storage_id]
1163
                    ndim, = struct.unpack('<i', f.read(4))
1164
                    # skip next 4 bytes; legacy encoding treated ndim as 8 bytes
1165
                    f.read(4)
1166
                    numel = struct.unpack(f'<{ndim}q', f.read(8 * ndim))
1167
                    stride = struct.unpack(f'<{ndim}q', f.read(8 * ndim))
1168
                    storage_offset, = struct.unpack('<q', f.read(8))
1169
                    tensor = torch.empty((0,), dtype=storage.dtype).set_(
1170
                        storage._untyped_storage, storage_offset, numel, stride)
1171
                    deserialized_objects[key] = tensor
1172

1173
            pickle_file = tar.extractfile('pickle')
1174
            unpickler = UnpicklerWrapper(pickle_file, **pickle_load_args)
1175
            unpickler.persistent_load = persistent_load
1176
            result = unpickler.load()
1177
            return result
1178

1179
    deserialized_objects = {}
1180

1181
    def persistent_load(saved_id):
1182
        assert isinstance(saved_id, tuple)
1183
        typename = _maybe_decode_ascii(saved_id[0])
1184
        data = saved_id[1:]
1185

1186
        if typename == 'module':
1187
            # Ignore containers that don't have any sources saved
1188
            if all(data[1:]):
1189
                _check_container_source(*data)
1190
            return data[0]
1191
        elif typename == 'storage':
1192
            storage_type, root_key, location, numel, view_metadata = data
1193
            location = _maybe_decode_ascii(location)
1194
            dtype = storage_type.dtype
1195

1196
            nbytes = numel * torch._utils._element_size(dtype)
1197

1198
            if root_key not in deserialized_objects:
1199
                if torch._guards.active_fake_mode() is not None:
1200
                    obj = cast(Storage, torch.UntypedStorage(nbytes, device='meta'))
1201
                else:
1202
                    obj = cast(Storage, torch.UntypedStorage(nbytes))
1203
                    obj._torch_load_uninitialized = True
1204
                    obj = restore_location(obj, location)
1205
                # TODO: Once we decide to break serialization FC, we can
1206
                # stop wrapping with TypedStorage
1207
                typed_storage = torch.storage.TypedStorage(
1208
                    wrap_storage=obj,
1209
                    dtype=dtype,
1210
                    _internal=True)
1211
                deserialized_objects[root_key] = typed_storage
1212
            else:
1213
                typed_storage = deserialized_objects[root_key]
1214
                if typed_storage._data_ptr() == 0:
1215
                    typed_storage = torch.storage.TypedStorage(
1216
                        device=typed_storage._untyped_storage.device,
1217
                        dtype=dtype,
1218
                        _internal=True)
1219

1220
            if view_metadata is not None:
1221
                view_key, offset, view_size = view_metadata
1222
                offset_bytes = offset * torch._utils._element_size(dtype)
1223
                view_size_bytes = view_size * torch._utils._element_size(dtype)
1224
                if view_key not in deserialized_objects:
1225
                    # TODO: Once we decide to break serialization FC, we can
1226
                    # stop wrapping with TypedStorage
1227
                    deserialized_objects[view_key] = torch.storage.TypedStorage(
1228
                        wrap_storage=typed_storage._untyped_storage[offset_bytes:offset_bytes + view_size_bytes],
1229
                        dtype=dtype,
1230
                        _internal=True)
1231
                res = deserialized_objects[view_key]
1232

1233
            else:
1234
                res = typed_storage
1235
            return res
1236
        else:
1237
            raise RuntimeError(f"Unknown saved id type: {saved_id[0]}")
1238

1239
    _check_seekable(f)
1240
    f_should_read_directly = _should_read_directly(f)
1241

1242
    if f_should_read_directly and f.tell() == 0:
1243
        # legacy_load requires that f has fileno()
1244
        # only if offset is zero we can attempt the legacy tar file loader
1245
        try:
1246
            return legacy_load(f)
1247
        except tarfile.TarError:
1248
            if _is_zipfile(f):
1249
                # .zip is used for torch.jit.save and will throw an un-pickling error here
1250
                raise RuntimeError(
1251
                    f"{f.name} is a zip archive (did you mean to use torch.jit.load()?)") from None
1252
            # if not a tarfile, reset file offset and proceed
1253
            f.seek(0)
1254

1255
    if not hasattr(f, 'readinto') and (3, 8, 0) <= sys.version_info < (3, 8, 2):
1256
        raise RuntimeError(
1257
            "torch.load does not work with file-like objects that do not implement readinto on Python 3.8.0 and 3.8.1. "
1258
            f"Received object of type \"{type(f)}\". Please update to Python 3.8.2 or newer to restore this "
1259
            "functionality.")
1260

1261
    magic_number = pickle_module.load(f, **pickle_load_args)
1262
    if magic_number != MAGIC_NUMBER:
1263
        raise RuntimeError("Invalid magic number; corrupt file?")
1264
    protocol_version = pickle_module.load(f, **pickle_load_args)
1265
    if protocol_version != PROTOCOL_VERSION:
1266
        raise RuntimeError(f"Invalid protocol version: {protocol_version}")
1267

1268
    _sys_info = pickle_module.load(f, **pickle_load_args)
1269
    unpickler = UnpicklerWrapper(f, **pickle_load_args)
1270
    unpickler.persistent_load = persistent_load
1271
    result = unpickler.load()
1272

1273
    deserialized_storage_keys = pickle_module.load(f, **pickle_load_args)
1274

1275
    if torch._guards.active_fake_mode() is None:
1276
        offset = f.tell() if f_should_read_directly else None
1277
        for key in deserialized_storage_keys:
1278
            assert key in deserialized_objects
1279
            typed_storage = deserialized_objects[key]
1280
            typed_storage._untyped_storage._set_from_file(
1281
                f, offset, f_should_read_directly,
1282
                torch._utils._element_size(typed_storage.dtype))
1283
            if offset is not None:
1284
                offset = f.tell()
1285

1286
    torch._utils._validate_loaded_sparse_tensors()
1287

1288
    return result
1289

1290

1291
def _maybe_decode_ascii(bytes_str: Union[bytes, str]) -> str:
1292
    # When using encoding='bytes' in Py3, some **internal** keys stored as
1293
    # strings in Py2 are loaded as bytes. This function decodes them with
1294
    # ascii encoding, one that Py3 uses by default.
1295
    #
1296
    # NOTE: This should only be used on internal keys (e.g., `typename` and
1297
    #       `location` in `persistent_load` below!
1298
    if isinstance(bytes_str, bytes):
1299
        return bytes_str.decode('ascii')
1300
    return bytes_str
1301

1302

1303
def _get_restore_location(map_location):
1304
    if map_location is None:
1305
        restore_location = default_restore_location
1306
    elif isinstance(map_location, dict):
1307
        def restore_location(storage, location):
1308
            location = map_location.get(location, location)
1309
            return default_restore_location(storage, location)
1310
    elif isinstance(map_location, (str, bytes)):
1311
        def restore_location(storage, location):
1312
            return default_restore_location(storage, map_location)
1313
    elif isinstance(map_location, torch.device):
1314
        def restore_location(storage, location):
1315
            return default_restore_location(storage, str(map_location))
1316
    else:
1317
        def restore_location(storage, location):
1318
            result = map_location(storage, location)
1319
            if result is None:
1320
                result = default_restore_location(storage, location)
1321
            return result
1322
    return restore_location
1323

1324

1325
class StorageType:
1326
    def __init__(self, name):
1327
        self._dtype = _get_dtype_from_pickle_storage_type(name)
1328

1329
    @property
1330
    def dtype(self):
1331
        return self._dtype
1332

1333
    def __str__(self):
1334
        return f'StorageType(dtype={self.dtype})'
1335

1336

1337
def _load(zip_file, map_location, pickle_module, pickle_file='data.pkl', overall_storage=None, **pickle_load_args):
1338
    restore_location = _get_restore_location(map_location)
1339

1340
    loaded_storages = {}
1341

1342
    # check if byteswapping is needed
1343
    byteordername = 'byteorder'
1344
    byteorderdata = None
1345
    if zip_file.has_record(byteordername):
1346
        byteorderdata = zip_file.get_record(byteordername)
1347
        if byteorderdata not in [b'little', b'big']:
1348
            raise ValueError('Unknown endianness type: ' + byteorderdata.decode())
1349
    elif get_default_load_endianness() == LoadEndianness.LITTLE or \
1350
            get_default_load_endianness() is None:
1351
        byteorderdata = b'little'
1352
    elif get_default_load_endianness() == LoadEndianness.BIG:
1353
        byteorderdata = b'big'
1354
    elif get_default_load_endianness() == LoadEndianness.NATIVE:
1355
        pass
1356
    else:
1357
        raise ValueError('Invalid load endianness type')
1358

1359
    if not zip_file.has_record(byteordername) and \
1360
            get_default_load_endianness() is None and \
1361
            sys.byteorder == 'big':
1362
        # Default behaviour was changed
1363
        # See https://github.com/pytorch/pytorch/issues/101688
1364
        warnings.warn("The default load endianness for checkpoints without a byteorder mark "
1365
                      "on big endian machines was changed from 'native' to 'little' endian, "
1366
                      "to avoid this behavior please use "
1367
                      "torch.serialization.set_default_load_endianness to set "
1368
                      "the desired default load endianness",
1369
                      UserWarning)
1370

1371
    def load_tensor(dtype, numel, key, location):
1372
        name = f'data/{key}'
1373
        if torch._guards.detect_fake_mode(None) is not None:
1374
            nbytes = numel * torch._utils._element_size(dtype)
1375
            storage = torch.UntypedStorage(nbytes, device='meta')
1376
        elif overall_storage is not None:
1377
            storage_offset = zip_file.get_record_offset(name)
1378
            storage = overall_storage[storage_offset:storage_offset + numel]
1379
        else:
1380
            storage = zip_file.get_storage_from_record(name, numel, torch.UntypedStorage)._typed_storage()._untyped_storage
1381
        # swap here if byteswapping is needed
1382
        if byteorderdata is not None:
1383
            if byteorderdata.decode() != sys.byteorder:
1384
                storage.byteswap(dtype)
1385

1386
        # TODO: Once we decide to break serialization FC, we can
1387
        # stop wrapping with TypedStorage
1388
        typed_storage = torch.storage.TypedStorage(
1389
            wrap_storage=restore_location(storage, location),
1390
            dtype=dtype,
1391
            _internal=True)
1392

1393
        if typed_storage._data_ptr() != 0:
1394
            loaded_storages[key] = typed_storage
1395

1396
        return typed_storage
1397

1398
    def persistent_load(saved_id):
1399
        assert isinstance(saved_id, tuple)
1400
        typename = _maybe_decode_ascii(saved_id[0])
1401
        data = saved_id[1:]
1402

1403
        assert typename == 'storage', \
1404
            f"Unknown typename for persistent_load, expected 'storage' but got '{typename}'"
1405
        storage_type, key, location, numel = data
1406
        if storage_type is torch.UntypedStorage:
1407
            dtype = torch.uint8
1408
        else:
1409
            dtype = storage_type.dtype
1410

1411
        if key in loaded_storages:
1412
            typed_storage = loaded_storages[key]
1413
        else:
1414
            nbytes = numel * torch._utils._element_size(dtype)
1415
            typed_storage = load_tensor(dtype, nbytes, key, _maybe_decode_ascii(location))
1416

1417
        return typed_storage
1418

1419
    load_module_mapping: Dict[str, str] = {
1420
        # See https://github.com/pytorch/pytorch/pull/51633
1421
        'torch.tensor': 'torch._tensor'
1422
    }
1423

1424
    # Need to subclass Unpickler instead of directly monkey-patching the find_class method
1425
    # because it's marked readonly in pickle.
1426
    # The type: ignore is because mypy can't statically determine the type of this class.
1427
    class UnpicklerWrapper(pickle_module.Unpickler):  # type: ignore[name-defined]
1428
        # from https://stackoverflow.com/questions/13398462/unpickling-python-objects-with-a-changed-module-path/13405732
1429
        # Lets us override the imports that pickle uses when unpickling an object.
1430
        # This is useful for maintaining BC if we change a module path that tensor instantiation relies on.
1431
        def find_class(self, mod_name, name):
1432
            if type(name) is str and 'Storage' in name:
1433
                try:
1434
                    return StorageType(name)
1435
                except KeyError:
1436
                    pass
1437
            mod_name = load_module_mapping.get(mod_name, mod_name)
1438
            return super().find_class(mod_name, name)
1439

1440
    # Load the data (which may in turn use `persistent_load` to load tensors)
1441
    data_file = io.BytesIO(zip_file.get_record(pickle_file))
1442

1443
    unpickler = UnpicklerWrapper(data_file, **pickle_load_args)
1444
    unpickler.persistent_load = persistent_load
1445
    result = unpickler.load()
1446

1447
    torch._utils._validate_loaded_sparse_tensors()
1448
    torch._C._log_api_usage_metadata(
1449
        "torch.load.metadata", {"serialization_id": zip_file.serialization_id()}
1450
    )
1451
    return result
1452

1453

1454
def _is_torchscript_zip(zip_file):
1455
    return 'constants.pkl' in zip_file.get_all_records()
1456

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.