3
import importlib.machinery
9
from contextlib import contextmanager
10
from typing import Any, BinaryIO, Callable, cast, Dict, Iterable, List, Optional, Union
11
from weakref import WeakValueDictionary
14
from torch.serialization import _get_restore_location, _maybe_decode_ascii
16
from ._directory_reader import DirectoryReader
17
from ._importlib import (
19
_normalize_line_endings,
24
from ._mangling import demangle, PackageMangler
25
from ._package_unpickler import PackageUnpickler
26
from .file_structure_representation import _create_directory_from_file_list, Directory
27
from .glob_group import GlobPattern
28
from .importer import Importer
30
__all__ = ["PackageImporter"]
33
# This is a list of imports that are implicitly allowed even if they haven't
34
# been marked as extern. This is to work around the fact that Torch implicitly
35
# depends on numpy and package can't track it.
36
# https://github.com/pytorch/MultiPy/issues/46
37
IMPLICIT_IMPORT_ALLOWLIST: Iterable[str] = [
40
"numpy.core._multiarray_umath",
41
# FX GraphModule might depend on builtins module and users usually
42
# don't extern builtins. Here we import it here by default.
47
class PackageImporter(Importer):
48
"""Importers allow you to load code written to packages by :class:`PackageExporter`.
49
Code is loaded in a hermetic way, using files from the package
50
rather than the normal python import system. This allows
51
for the packaging of PyTorch model code and data so that it can be run
52
on a server or used in the future for transfer learning.
54
The importer for packages ensures that code in the module can only be loaded from
55
within the package, except for modules explicitly listed as external during export.
56
The file ``extern_modules`` in the zip archive lists all the modules that a package externally depends on.
57
This prevents "implicit" dependencies where the package runs locally because it is importing
58
a locally-installed package, but then fails when the package is copied to another machine.
61
"""The dictionary of already loaded modules from this package, equivalent to ``sys.modules`` but
62
local to this importer.
65
modules: Dict[str, types.ModuleType]
69
file_or_buffer: Union[str, torch._C.PyTorchFileReader, os.PathLike, BinaryIO],
70
module_allowed: Callable[[str], bool] = lambda module_name: True,
72
"""Open ``file_or_buffer`` for importing. This checks that the imported package only requires modules
73
allowed by ``module_allowed``
76
file_or_buffer: a file-like object (has to implement :meth:`read`, :meth:`readline`, :meth:`tell`, and :meth:`seek`),
77
a string, or an ``os.PathLike`` object containing a filename.
78
module_allowed (Callable[[str], bool], optional): A method to determine if a externally provided module
79
should be allowed. Can be used to ensure packages loaded do not depend on modules that the server
80
does not support. Defaults to allowing anything.
83
ImportError: If the package will use a disallowed module.
85
torch._C._log_api_usage_once("torch.package.PackageImporter")
88
if isinstance(file_or_buffer, torch._C.PyTorchFileReader):
89
self.filename = "<pytorch_file_reader>"
90
self.zip_reader = file_or_buffer
91
elif isinstance(file_or_buffer, (os.PathLike, str)):
92
self.filename = os.fspath(file_or_buffer)
93
if not os.path.isdir(self.filename):
94
self.zip_reader = torch._C.PyTorchFileReader(self.filename)
96
self.zip_reader = DirectoryReader(self.filename)
98
self.filename = "<binary>"
99
self.zip_reader = torch._C.PyTorchFileReader(file_or_buffer)
101
torch._C._log_api_usage_metadata(
102
"torch.package.PackageImporter.metadata",
104
"serialization_id": self.zip_reader.serialization_id(),
105
"file_name": self.filename,
109
self.root = _PackageNode(None)
111
self.extern_modules = self._read_extern()
113
for extern_module in self.extern_modules:
114
if not module_allowed(extern_module):
116
f"package '{file_or_buffer}' needs the external module '{extern_module}' "
117
f"but that module has been disallowed"
119
self._add_extern(extern_module)
121
for fname in self.zip_reader.get_all_records():
122
self._add_file(fname)
124
self.patched_builtins = builtins.__dict__.copy()
125
self.patched_builtins["__import__"] = self.__import__
126
# Allow packaged modules to reference their PackageImporter
127
self.modules["torch_package_importer"] = self # type: ignore[assignment]
129
self._mangler = PackageMangler()
131
# used for reduce deserializaiton
132
self.storage_context: Any = None
133
self.last_map_location = None
135
# used for torch.serialization._load
136
self.Unpickler = lambda *args, **kwargs: PackageUnpickler(self, *args, **kwargs)
138
def import_module(self, name: str, package=None):
139
"""Load a module from the package if it hasn't already been loaded, and then return
140
the module. Modules are loaded locally
141
to the importer and will appear in ``self.modules`` rather than ``sys.modules``.
144
name (str): Fully qualified name of the module to load.
145
package ([type], optional): Unused, but present to match the signature of importlib.import_module. Defaults to ``None``.
148
types.ModuleType: The (possibly already) loaded module.
150
# We should always be able to support importing modules from this package.
151
# This is to support something like:
152
# obj = importer.load_pickle(...)
153
# importer.import_module(obj.__module__) <- this string will be mangled
155
# Note that _mangler.demangle will not demangle any module names
156
# produced by a different PackageImporter instance.
157
name = self._mangler.demangle(name)
159
return self._gcd_import(name)
161
def load_binary(self, package: str, resource: str) -> bytes:
165
package (str): The name of module package (e.g. ``"my_package.my_subpackage"``).
166
resource (str): The unique name for the resource.
169
bytes: The loaded data.
172
path = self._zipfile_path(package, resource)
173
return self.zip_reader.get_record(path)
179
encoding: str = "utf-8",
180
errors: str = "strict",
185
package (str): The name of module package (e.g. ``"my_package.my_subpackage"``).
186
resource (str): The unique name for the resource.
187
encoding (str, optional): Passed to ``decode``. Defaults to ``'utf-8'``.
188
errors (str, optional): Passed to ``decode``. Defaults to ``'strict'``.
191
str: The loaded text.
193
data = self.load_binary(package, resource)
194
return data.decode(encoding, errors)
196
def load_pickle(self, package: str, resource: str, map_location=None) -> Any:
197
"""Unpickles the resource from the package, loading any modules that are needed to construct the objects
198
using :meth:`import_module`.
201
package (str): The name of module package (e.g. ``"my_package.my_subpackage"``).
202
resource (str): The unique name for the resource.
203
map_location: Passed to `torch.load` to determine how tensors are mapped to devices. Defaults to ``None``.
206
Any: The unpickled object.
208
pickle_file = self._zipfile_path(package, resource)
209
restore_location = _get_restore_location(map_location)
212
storage_context = torch._C.DeserializationStorageContext()
214
def load_tensor(dtype, size, key, location, restore_location):
215
name = f"{key}.storage"
217
if storage_context.has_storage(name):
218
storage = storage_context.get_storage(name, dtype)._typed_storage()
220
tensor = self.zip_reader.get_storage_from_record(
221
".data/" + name, size, dtype
223
if isinstance(self.zip_reader, torch._C.PyTorchFileReader):
224
storage_context.add_storage(name, tensor)
225
storage = tensor._typed_storage()
226
loaded_storages[key] = restore_location(storage, location)
228
def persistent_load(saved_id):
229
assert isinstance(saved_id, tuple)
230
typename = _maybe_decode_ascii(saved_id[0])
233
if typename == "storage":
234
storage_type, key, location, size = data
235
dtype = storage_type.dtype
237
if key not in loaded_storages:
242
_maybe_decode_ascii(location),
245
storage = loaded_storages[key]
246
# TODO: Once we decide to break serialization FC, we can
247
# stop wrapping with TypedStorage
248
return torch.storage.TypedStorage(
249
wrap_storage=storage._untyped_storage, dtype=dtype, _internal=True
251
elif typename == "reduce_package":
252
# to fix BC breaking change, objects on this load path
253
# will be loaded multiple times erroneously
256
return func(self, *args)
257
reduce_id, func, args = data
258
if reduce_id not in loaded_reduces:
259
loaded_reduces[reduce_id] = func(self, *args)
260
return loaded_reduces[reduce_id]
262
f"Unknown typename for persistent_load, expected 'storage' or 'reduce_package' but got '{typename}'"
264
# Load the data (which may in turn use `persistent_load` to load tensors)
265
data_file = io.BytesIO(self.zip_reader.get_record(pickle_file))
266
unpickler = self.Unpickler(data_file)
267
unpickler.persistent_load = persistent_load # type: ignore[assignment]
270
def set_deserialization_context():
271
# to let reduce_package access deserializaiton context
272
self.storage_context = storage_context
273
self.last_map_location = map_location
277
self.storage_context = None
278
self.last_map_location = None
280
with set_deserialization_context():
281
result = unpickler.load()
284
# This stateful weird function will need to be removed in our efforts
285
# to unify the format. It has a race condition if multiple python
286
# threads try to read independent files
287
torch._utils._validate_loaded_sparse_tensors()
293
Returns internal identifier that torch.package uses to distinguish :class:`PackageImporter` instances.
298
return self._mangler.parent_name()
301
self, *, include: "GlobPattern" = "**", exclude: "GlobPattern" = ()
303
"""Returns a file structure representation of package's zipfile.
306
include (Union[List[str], str]): An optional string e.g. ``"my_package.my_subpackage"``, or optional list of strings
307
for the names of the files to be included in the zipfile representation. This can also be
308
a glob-style pattern, as described in :meth:`PackageExporter.mock`
310
exclude (Union[List[str], str]): An optional pattern that excludes files whose name match the pattern.
315
return _create_directory_from_file_list(
316
self.filename, self.zip_reader.get_all_records(), include, exclude
319
def python_version(self):
320
"""Returns the version of python that was used to create this package.
322
Note: this function is experimental and not Forward Compatible. The plan is to move this into a lock
326
:class:`Optional[str]` a python version e.g. 3.8.9 or None if no version was stored with this package
328
python_version_path = ".data/python_version"
330
self.zip_reader.get_record(python_version_path).decode("utf-8").strip()
331
if self.zip_reader.has_record(python_version_path)
335
def _read_extern(self):
337
self.zip_reader.get_record(".data/extern_modules")
339
.splitlines(keepends=False)
343
self, name: str, filename: Optional[str], is_package: bool, parent: str
345
mangled_filename = self._mangler.mangle(filename) if filename else None
346
spec = importlib.machinery.ModuleSpec(
348
self, # type: ignore[arg-type]
349
origin="<package_importer>",
350
is_package=is_package,
352
module = importlib.util.module_from_spec(spec)
353
self.modules[name] = module
354
module.__name__ = self._mangler.mangle(name)
356
ns["__spec__"] = spec
357
ns["__loader__"] = self
358
ns["__file__"] = mangled_filename
359
ns["__cached__"] = None
360
ns["__builtins__"] = self.patched_builtins
361
ns["__torch_package__"] = True
363
# Add this module to our private global registry. It should be unique due to mangling.
364
assert module.__name__ not in _package_imported_modules
365
_package_imported_modules[module.__name__] = module
367
# pre-emptively install on the parent to prevent IMPORT_FROM from trying to
369
self._install_on_parent(parent, name, module)
371
if filename is not None:
372
assert mangled_filename is not None
373
# pre-emptively install the source in `linecache` so that stack traces,
374
# `inspect`, etc. work.
375
assert filename not in linecache.cache # type: ignore[attr-defined]
376
linecache.lazycache(mangled_filename, ns)
378
code = self._compile_source(filename, mangled_filename)
383
def _load_module(self, name: str, parent: str):
384
cur: _PathNode = self.root
385
for atom in name.split("."):
386
if not isinstance(cur, _PackageNode) or atom not in cur.children:
387
if name in IMPLICIT_IMPORT_ALLOWLIST:
388
module = self.modules[name] = importlib.import_module(name)
390
raise ModuleNotFoundError(
391
f'No module named "{name}" in self-contained archive "{self.filename}"'
392
f" and the module is also not in the list of allowed external modules: {self.extern_modules}",
395
cur = cur.children[atom]
396
if isinstance(cur, _ExternNode):
397
module = self.modules[name] = importlib.import_module(name)
399
return self._make_module(name, cur.source_file, isinstance(cur, _PackageNode), parent) # type: ignore[attr-defined]
401
def _compile_source(self, fullpath: str, mangled_filename: str):
402
source = self.zip_reader.get_record(fullpath)
403
source = _normalize_line_endings(source)
404
return compile(source, mangled_filename, "exec", dont_inherit=True)
406
# note: named `get_source` so that linecache can find the source
407
# when this is the __loader__ of a module.
408
def get_source(self, module_name) -> str:
409
# linecache calls `get_source` with the `module.__name__` as the argument, so we must demangle it here.
410
module = self.import_module(demangle(module_name))
411
return self.zip_reader.get_record(demangle(module.__file__)).decode("utf-8")
413
# note: named `get_resource_reader` so that importlib.resources can find it.
414
# This is otherwise considered an internal method.
415
def get_resource_reader(self, fullname):
417
package = self._get_package(fullname)
420
if package.__loader__ is not self:
422
return _PackageResourceReader(self, fullname)
424
def _install_on_parent(self, parent: str, name: str, module: types.ModuleType):
427
# Set the module as an attribute on its parent.
428
parent_module = self.modules[parent]
429
if parent_module.__loader__ is self:
430
setattr(parent_module, name.rpartition(".")[2], module)
432
# note: copied from cpython's import code, with call to create module replaced with _make_module
433
def _do_find_and_load(self, name):
435
parent = name.rpartition(".")[0]
436
module_name_no_parent = name.rpartition(".")[-1]
438
if parent not in self.modules:
439
self._gcd_import(parent)
440
# Crazy side-effects!
441
if name in self.modules:
442
return self.modules[name]
443
parent_module = self.modules[parent]
446
path = parent_module.__path__ # type: ignore[attr-defined]
448
except AttributeError:
449
# when we attempt to import a package only containing pybinded files,
450
# the parent directory isn't always a package as defined by python,
451
# so we search if the package is actually there or not before calling the error.
453
parent_module.__loader__,
454
importlib.machinery.ExtensionFileLoader,
456
if name not in self.extern_modules:
459
+ "; {!r} is a c extension module which was not externed. C extension modules \
460
need to be externed by the PackageExporter in order to be used as we do not support interning them.}."
462
raise ModuleNotFoundError(msg, name=name) from None
464
parent_module.__dict__.get(module_name_no_parent),
469
+ "; {!r} is a c extension package which does not contain {!r}."
470
).format(name, parent, name)
471
raise ModuleNotFoundError(msg, name=name) from None
473
msg = (_ERR_MSG + "; {!r} is not a package").format(name, parent)
474
raise ModuleNotFoundError(msg, name=name) from None
476
module = self._load_module(name, parent)
478
self._install_on_parent(parent, name, module)
482
# note: copied from cpython's import code
483
def _find_and_load(self, name):
484
module = self.modules.get(name, _NEEDS_LOADING)
485
if module is _NEEDS_LOADING:
486
return self._do_find_and_load(name)
489
message = f"import of {name} halted; None in sys.modules"
490
raise ModuleNotFoundError(message, name=name)
492
# To handle https://github.com/pytorch/pytorch/issues/57490, where std's
493
# creation of fake submodules via the hacking of sys.modules is not import
496
self.modules["os.path"] = cast(Any, module).path
497
elif name == "typing":
498
self.modules["typing.io"] = cast(Any, module).io
499
self.modules["typing.re"] = cast(Any, module).re
503
def _gcd_import(self, name, package=None, level=0):
504
"""Import and return the module based on its name, the package the call is
505
being made from, and the level adjustment.
507
This function represents the greatest common denominator of functionality
508
between import_module and __import__. This includes setting __package__ if
512
_sanity_check(name, package, level)
514
name = _resolve_name(name, package, level)
516
return self._find_and_load(name)
518
# note: copied from cpython's import code
519
def _handle_fromlist(self, module, fromlist, *, recursive=False):
520
"""Figure out what __import__ should return.
522
The import_ parameter is a callable which takes the name of module to
523
import. It is required to decouple the function from assuming importlib's
524
import implementation is desired.
527
module_name = demangle(module.__name__)
528
# The hell that is fromlist ...
529
# If a package was imported, try to import stuff from fromlist.
530
if hasattr(module, "__path__"):
532
if not isinstance(x, str):
534
where = module_name + ".__all__"
536
where = "``from list''"
538
f"Item in {where} must be str, " f"not {type(x).__name__}"
541
if not recursive and hasattr(module, "__all__"):
542
self._handle_fromlist(module, module.__all__, recursive=True)
543
elif not hasattr(module, x):
544
from_name = f"{module_name}.{x}"
546
self._gcd_import(from_name)
547
except ModuleNotFoundError as exc:
548
# Backwards-compatibility dictates we ignore failed
549
# imports triggered by fromlist for modules that don't
552
exc.name == from_name
553
and self.modules.get(from_name, _NEEDS_LOADING) is not None
559
def __import__(self, name, globals=None, locals=None, fromlist=(), level=0):
561
module = self._gcd_import(name)
563
globals_ = globals if globals is not None else {}
564
package = _calc___package__(globals_)
565
module = self._gcd_import(name, package, level)
567
# Return up to the first dot in 'name'. This is complicated by the fact
568
# that 'name' may be relative.
570
return self._gcd_import(name.partition(".")[0])
574
# Figure out where to slice the module's name up to the first dot
576
cut_off = len(name) - len(name.partition(".")[0])
577
# Slice end needs to be positive to alleviate need to special-case
578
# when ``'.' not in name``.
579
module_name = demangle(module.__name__)
580
return self.modules[module_name[: len(module_name) - cut_off]]
582
return self._handle_fromlist(module, fromlist)
584
def _get_package(self, package):
585
"""Take a package name or module object and return the module.
587
If a name, the module is imported. If the passed or imported module
588
object is not a package, raise an exception.
590
if hasattr(package, "__spec__"):
591
if package.__spec__.submodule_search_locations is None:
592
raise TypeError(f"{package.__spec__.name!r} is not a package")
596
module = self.import_module(package)
597
if module.__spec__.submodule_search_locations is None:
598
raise TypeError(f"{package!r} is not a package")
602
def _zipfile_path(self, package, resource=None):
603
package = self._get_package(package)
604
assert package.__loader__ is self
605
name = demangle(package.__name__)
606
if resource is not None:
607
resource = _normalize_path(resource)
608
return f"{name.replace('.', '/')}/{resource}"
610
return f"{name.replace('.', '/')}"
612
def _get_or_create_package(
613
self, atoms: List[str]
614
) -> "Union[_PackageNode, _ExternNode]":
616
for i, atom in enumerate(atoms):
617
node = cur.children.get(atom, None)
619
node = cur.children[atom] = _PackageNode(None)
620
if isinstance(node, _ExternNode):
622
if isinstance(node, _ModuleNode):
623
name = ".".join(atoms[:i])
625
f"inconsistent module structure. module {name} is not a package, but has submodules"
627
assert isinstance(node, _PackageNode)
631
def _add_file(self, filename: str):
632
"""Assembles a Python module out of the given file. Will ignore files in the .data directory.
635
filename (str): the name of the file inside of the package archive to be added
637
*prefix, last = filename.split("/")
638
if len(prefix) > 1 and prefix[0] == ".data":
640
package = self._get_or_create_package(prefix)
641
if isinstance(package, _ExternNode):
643
f"inconsistent module structure. package contains a module file {filename}"
644
f" that is a subpackage of a module marked external."
646
if last == "__init__.py":
647
package.source_file = filename
648
elif last.endswith(".py"):
649
package_name = last[: -len(".py")]
650
package.children[package_name] = _ModuleNode(filename)
652
def _add_extern(self, extern_name: str):
653
*prefix, last = extern_name.split(".")
654
package = self._get_or_create_package(prefix)
655
if isinstance(package, _ExternNode):
656
return # the shorter extern covers this extern case
657
package.children[last] = _ExternNode()
660
_NEEDS_LOADING = object()
661
_ERR_MSG_PREFIX = "No module named "
662
_ERR_MSG = _ERR_MSG_PREFIX + "{!r}"
669
class _PackageNode(_PathNode):
670
def __init__(self, source_file: Optional[str]):
671
self.source_file = source_file
672
self.children: Dict[str, _PathNode] = {}
675
class _ModuleNode(_PathNode):
676
__slots__ = ["source_file"]
678
def __init__(self, source_file: str):
679
self.source_file = source_file
682
class _ExternNode(_PathNode):
686
# A private global registry of all modules that have been package-imported.
687
_package_imported_modules: WeakValueDictionary = WeakValueDictionary()
689
# `inspect` by default only looks in `sys.modules` to find source files for classes.
690
# Patch it to check our private registry of package-imported modules as well.
691
_orig_getfile = inspect.getfile
694
def _patched_getfile(object):
695
if inspect.isclass(object):
696
if object.__module__ in _package_imported_modules:
697
return _package_imported_modules[object.__module__].__file__
698
return _orig_getfile(object)
701
inspect.getfile = _patched_getfile
704
class _PackageResourceReader:
705
"""Private class used to support PackageImporter.get_resource_reader().
707
Confirms to the importlib.abc.ResourceReader interface. Allowed to access
708
the innards of PackageImporter.
711
def __init__(self, importer, fullname):
712
self.importer = importer
713
self.fullname = fullname
715
def open_resource(self, resource):
716
from io import BytesIO
718
return BytesIO(self.importer.load_binary(self.fullname, resource))
720
def resource_path(self, resource):
721
# The contract for resource_path is that it either returns a concrete
722
# file system path or raises FileNotFoundError.
724
self.importer.zip_reader, DirectoryReader
725
) and self.importer.zip_reader.has_record(
726
os.path.join(self.fullname, resource)
729
self.importer.zip_reader.directory, self.fullname, resource
731
raise FileNotFoundError
733
def is_resource(self, name):
734
path = self.importer._zipfile_path(self.fullname, name)
735
return self.importer.zip_reader.has_record(path)
738
from pathlib import Path
740
filename = self.fullname.replace(".", "/")
742
fullname_path = Path(self.importer._zipfile_path(self.fullname))
743
files = self.importer.zip_reader.get_all_records()
745
for filename in files:
747
relative = Path(filename).relative_to(fullname_path)
750
# If the path of the file (which is relative to the top of the zip
751
# namespace), relative to the package given when the resource
752
# reader was created, has a parent, then it's a name in a
753
# subdirectory and thus we skip it.
754
parent_name = relative.parent.name
755
if len(parent_name) == 0:
757
elif parent_name not in subdirs_seen:
758
subdirs_seen.add(parent_name)