pytorch
1import _compat_pickle
2import pickle
3
4from .importer import Importer
5
6
7class PackageUnpickler(pickle._Unpickler): # type: ignore[name-defined]
8"""Package-aware unpickler.
9
10This behaves the same as a normal unpickler, except it uses `importer` to
11find any global names that it encounters while unpickling.
12"""
13
14def __init__(self, importer: Importer, *args, **kwargs):
15super().__init__(*args, **kwargs)
16self._importer = importer
17
18def find_class(self, module, name):
19# Subclasses may override this.
20if self.proto < 3 and self.fix_imports: # type: ignore[attr-defined]
21if (module, name) in _compat_pickle.NAME_MAPPING:
22module, name = _compat_pickle.NAME_MAPPING[(module, name)]
23elif module in _compat_pickle.IMPORT_MAPPING:
24module = _compat_pickle.IMPORT_MAPPING[module]
25mod = self._importer.import_module(module)
26return getattr(mod, name)
27