pytorch

Форк
0
/
package_importer.py 
759 строк · 30.1 Кб
1
import builtins
2
import importlib
3
import importlib.machinery
4
import inspect
5
import io
6
import linecache
7
import os
8
import types
9
from contextlib import contextmanager
10
from typing import Any, BinaryIO, Callable, cast, Dict, Iterable, List, Optional, Union
11
from weakref import WeakValueDictionary
12

13
import torch
14
from torch.serialization import _get_restore_location, _maybe_decode_ascii
15

16
from ._directory_reader import DirectoryReader
17
from ._importlib import (
18
    _calc___package__,
19
    _normalize_line_endings,
20
    _normalize_path,
21
    _resolve_name,
22
    _sanity_check,
23
)
24
from ._mangling import demangle, PackageMangler
25
from ._package_unpickler import PackageUnpickler
26
from .file_structure_representation import _create_directory_from_file_list, Directory
27
from .glob_group import GlobPattern
28
from .importer import Importer
29

30
__all__ = ["PackageImporter"]
31

32

33
# This is a list of imports that are implicitly allowed even if they haven't
34
# been marked as extern. This is to work around the fact that Torch implicitly
35
# depends on numpy and package can't track it.
36
# https://github.com/pytorch/MultiPy/issues/46
37
IMPLICIT_IMPORT_ALLOWLIST: Iterable[str] = [
38
    "numpy",
39
    "numpy.core",
40
    "numpy.core._multiarray_umath",
41
    # FX GraphModule might depend on builtins module and users usually
42
    # don't extern builtins. Here we import it here by default.
43
    "builtins",
44
]
45

46

47
class PackageImporter(Importer):
48
    """Importers allow you to load code written to packages by :class:`PackageExporter`.
49
    Code is loaded in a hermetic way, using files from the package
50
    rather than the normal python import system. This allows
51
    for the packaging of PyTorch model code and data so that it can be run
52
    on a server or used in the future for transfer learning.
53

54
    The importer for packages ensures that code in the module can only be loaded from
55
    within the package, except for modules explicitly listed as external during export.
56
    The file ``extern_modules`` in the zip archive lists all the modules that a package externally depends on.
57
    This prevents "implicit" dependencies where the package runs locally because it is importing
58
    a locally-installed package, but then fails when the package is copied to another machine.
59
    """
60

61
    """The dictionary of already loaded modules from this package, equivalent to ``sys.modules`` but
62
    local to this importer.
63
    """
64

65
    modules: Dict[str, types.ModuleType]
66

67
    def __init__(
68
        self,
69
        file_or_buffer: Union[str, torch._C.PyTorchFileReader, os.PathLike, BinaryIO],
70
        module_allowed: Callable[[str], bool] = lambda module_name: True,
71
    ):
72
        """Open ``file_or_buffer`` for importing. This checks that the imported package only requires modules
73
        allowed by ``module_allowed``
74

75
        Args:
76
            file_or_buffer: a file-like object (has to implement :meth:`read`, :meth:`readline`, :meth:`tell`, and :meth:`seek`),
77
                a string, or an ``os.PathLike`` object containing a filename.
78
            module_allowed (Callable[[str], bool], optional): A method to determine if a externally provided module
79
                should be allowed. Can be used to ensure packages loaded do not depend on modules that the server
80
                does not support. Defaults to allowing anything.
81

82
        Raises:
83
            ImportError: If the package will use a disallowed module.
84
        """
85
        torch._C._log_api_usage_once("torch.package.PackageImporter")
86

87
        self.zip_reader: Any
88
        if isinstance(file_or_buffer, torch._C.PyTorchFileReader):
89
            self.filename = "<pytorch_file_reader>"
90
            self.zip_reader = file_or_buffer
91
        elif isinstance(file_or_buffer, (os.PathLike, str)):
92
            self.filename = os.fspath(file_or_buffer)
93
            if not os.path.isdir(self.filename):
94
                self.zip_reader = torch._C.PyTorchFileReader(self.filename)
95
            else:
96
                self.zip_reader = DirectoryReader(self.filename)
97
        else:
98
            self.filename = "<binary>"
99
            self.zip_reader = torch._C.PyTorchFileReader(file_or_buffer)
100

101
        torch._C._log_api_usage_metadata(
102
            "torch.package.PackageImporter.metadata",
103
            {
104
                "serialization_id": self.zip_reader.serialization_id(),
105
                "file_name": self.filename,
106
            },
107
        )
108

109
        self.root = _PackageNode(None)
110
        self.modules = {}
111
        self.extern_modules = self._read_extern()
112

113
        for extern_module in self.extern_modules:
114
            if not module_allowed(extern_module):
115
                raise ImportError(
116
                    f"package '{file_or_buffer}' needs the external module '{extern_module}' "
117
                    f"but that module has been disallowed"
118
                )
119
            self._add_extern(extern_module)
120

121
        for fname in self.zip_reader.get_all_records():
122
            self._add_file(fname)
123

124
        self.patched_builtins = builtins.__dict__.copy()
125
        self.patched_builtins["__import__"] = self.__import__
126
        # Allow packaged modules to reference their PackageImporter
127
        self.modules["torch_package_importer"] = self  # type: ignore[assignment]
128

129
        self._mangler = PackageMangler()
130

131
        # used for reduce deserializaiton
132
        self.storage_context: Any = None
133
        self.last_map_location = None
134

135
        # used for torch.serialization._load
136
        self.Unpickler = lambda *args, **kwargs: PackageUnpickler(self, *args, **kwargs)
137

138
    def import_module(self, name: str, package=None):
139
        """Load a module from the package if it hasn't already been loaded, and then return
140
        the module. Modules are loaded locally
141
        to the importer and will appear in ``self.modules`` rather than ``sys.modules``.
142

143
        Args:
144
            name (str): Fully qualified name of the module to load.
145
            package ([type], optional): Unused, but present to match the signature of importlib.import_module. Defaults to ``None``.
146

147
        Returns:
148
            types.ModuleType: The (possibly already) loaded module.
149
        """
150
        # We should always be able to support importing modules from this package.
151
        # This is to support something like:
152
        #   obj = importer.load_pickle(...)
153
        #   importer.import_module(obj.__module__)  <- this string will be mangled
154
        #
155
        # Note that _mangler.demangle will not demangle any module names
156
        # produced by a different PackageImporter instance.
157
        name = self._mangler.demangle(name)
158

159
        return self._gcd_import(name)
160

161
    def load_binary(self, package: str, resource: str) -> bytes:
162
        """Load raw bytes.
163

164
        Args:
165
            package (str): The name of module package (e.g. ``"my_package.my_subpackage"``).
166
            resource (str): The unique name for the resource.
167

168
        Returns:
169
            bytes: The loaded data.
170
        """
171

172
        path = self._zipfile_path(package, resource)
173
        return self.zip_reader.get_record(path)
174

175
    def load_text(
176
        self,
177
        package: str,
178
        resource: str,
179
        encoding: str = "utf-8",
180
        errors: str = "strict",
181
    ) -> str:
182
        """Load a string.
183

184
        Args:
185
            package (str): The name of module package (e.g. ``"my_package.my_subpackage"``).
186
            resource (str): The unique name for the resource.
187
            encoding (str, optional): Passed to ``decode``. Defaults to ``'utf-8'``.
188
            errors (str, optional): Passed to ``decode``. Defaults to ``'strict'``.
189

190
        Returns:
191
            str: The loaded text.
192
        """
193
        data = self.load_binary(package, resource)
194
        return data.decode(encoding, errors)
195

196
    def load_pickle(self, package: str, resource: str, map_location=None) -> Any:
197
        """Unpickles the resource from the package, loading any modules that are needed to construct the objects
198
        using :meth:`import_module`.
199

200
        Args:
201
            package (str): The name of module package (e.g. ``"my_package.my_subpackage"``).
202
            resource (str): The unique name for the resource.
203
            map_location: Passed to `torch.load` to determine how tensors are mapped to devices. Defaults to ``None``.
204

205
        Returns:
206
            Any: The unpickled object.
207
        """
208
        pickle_file = self._zipfile_path(package, resource)
209
        restore_location = _get_restore_location(map_location)
210
        loaded_storages = {}
211
        loaded_reduces = {}
212
        storage_context = torch._C.DeserializationStorageContext()
213

214
        def load_tensor(dtype, size, key, location, restore_location):
215
            name = f"{key}.storage"
216

217
            if storage_context.has_storage(name):
218
                storage = storage_context.get_storage(name, dtype)._typed_storage()
219
            else:
220
                tensor = self.zip_reader.get_storage_from_record(
221
                    ".data/" + name, size, dtype
222
                )
223
                if isinstance(self.zip_reader, torch._C.PyTorchFileReader):
224
                    storage_context.add_storage(name, tensor)
225
                storage = tensor._typed_storage()
226
            loaded_storages[key] = restore_location(storage, location)
227

228
        def persistent_load(saved_id):
229
            assert isinstance(saved_id, tuple)
230
            typename = _maybe_decode_ascii(saved_id[0])
231
            data = saved_id[1:]
232

233
            if typename == "storage":
234
                storage_type, key, location, size = data
235
                dtype = storage_type.dtype
236

237
                if key not in loaded_storages:
238
                    load_tensor(
239
                        dtype,
240
                        size,
241
                        key,
242
                        _maybe_decode_ascii(location),
243
                        restore_location,
244
                    )
245
                storage = loaded_storages[key]
246
                # TODO: Once we decide to break serialization FC, we can
247
                # stop wrapping with TypedStorage
248
                return torch.storage.TypedStorage(
249
                    wrap_storage=storage._untyped_storage, dtype=dtype, _internal=True
250
                )
251
            elif typename == "reduce_package":
252
                # to fix BC breaking change, objects on this load path
253
                # will be loaded multiple times erroneously
254
                if len(data) == 2:
255
                    func, args = data
256
                    return func(self, *args)
257
                reduce_id, func, args = data
258
                if reduce_id not in loaded_reduces:
259
                    loaded_reduces[reduce_id] = func(self, *args)
260
                return loaded_reduces[reduce_id]
261
            else:
262
                f"Unknown typename for persistent_load, expected 'storage' or 'reduce_package' but got '{typename}'"
263

264
        # Load the data (which may in turn use `persistent_load` to load tensors)
265
        data_file = io.BytesIO(self.zip_reader.get_record(pickle_file))
266
        unpickler = self.Unpickler(data_file)
267
        unpickler.persistent_load = persistent_load  # type: ignore[assignment]
268

269
        @contextmanager
270
        def set_deserialization_context():
271
            # to let reduce_package access deserializaiton context
272
            self.storage_context = storage_context
273
            self.last_map_location = map_location
274
            try:
275
                yield
276
            finally:
277
                self.storage_context = None
278
                self.last_map_location = None
279

280
        with set_deserialization_context():
281
            result = unpickler.load()
282

283
        # TODO from zdevito:
284
        #   This stateful weird function will need to be removed in our efforts
285
        #   to unify the format. It has a race condition if multiple python
286
        #   threads try to read independent files
287
        torch._utils._validate_loaded_sparse_tensors()
288

289
        return result
290

291
    def id(self):
292
        """
293
        Returns internal identifier that torch.package uses to distinguish :class:`PackageImporter` instances.
294
        Looks like::
295

296
            <torch_package_0>
297
        """
298
        return self._mangler.parent_name()
299

300
    def file_structure(
301
        self, *, include: "GlobPattern" = "**", exclude: "GlobPattern" = ()
302
    ) -> Directory:
303
        """Returns a file structure representation of package's zipfile.
304

305
        Args:
306
            include (Union[List[str], str]): An optional string e.g. ``"my_package.my_subpackage"``, or optional list of strings
307
                for the names of the files to be included in the zipfile representation. This can also be
308
                a glob-style pattern, as described in :meth:`PackageExporter.mock`
309

310
            exclude (Union[List[str], str]): An optional pattern that excludes files whose name match the pattern.
311

312
        Returns:
313
            :class:`Directory`
314
        """
315
        return _create_directory_from_file_list(
316
            self.filename, self.zip_reader.get_all_records(), include, exclude
317
        )
318

319
    def python_version(self):
320
        """Returns the version of python that was used to create this package.
321

322
        Note: this function is experimental and not Forward Compatible. The plan is to move this into a lock
323
        file later on.
324

325
        Returns:
326
            :class:`Optional[str]` a python version e.g. 3.8.9 or None if no version was stored with this package
327
        """
328
        python_version_path = ".data/python_version"
329
        return (
330
            self.zip_reader.get_record(python_version_path).decode("utf-8").strip()
331
            if self.zip_reader.has_record(python_version_path)
332
            else None
333
        )
334

335
    def _read_extern(self):
336
        return (
337
            self.zip_reader.get_record(".data/extern_modules")
338
            .decode("utf-8")
339
            .splitlines(keepends=False)
340
        )
341

342
    def _make_module(
343
        self, name: str, filename: Optional[str], is_package: bool, parent: str
344
    ):
345
        mangled_filename = self._mangler.mangle(filename) if filename else None
346
        spec = importlib.machinery.ModuleSpec(
347
            name,
348
            self,  # type: ignore[arg-type]
349
            origin="<package_importer>",
350
            is_package=is_package,
351
        )
352
        module = importlib.util.module_from_spec(spec)
353
        self.modules[name] = module
354
        module.__name__ = self._mangler.mangle(name)
355
        ns = module.__dict__
356
        ns["__spec__"] = spec
357
        ns["__loader__"] = self
358
        ns["__file__"] = mangled_filename
359
        ns["__cached__"] = None
360
        ns["__builtins__"] = self.patched_builtins
361
        ns["__torch_package__"] = True
362

363
        # Add this module to our private global registry. It should be unique due to mangling.
364
        assert module.__name__ not in _package_imported_modules
365
        _package_imported_modules[module.__name__] = module
366

367
        # pre-emptively install on the parent to prevent IMPORT_FROM from trying to
368
        # access sys.modules
369
        self._install_on_parent(parent, name, module)
370

371
        if filename is not None:
372
            assert mangled_filename is not None
373
            # pre-emptively install the source in `linecache` so that stack traces,
374
            # `inspect`, etc. work.
375
            assert filename not in linecache.cache  # type: ignore[attr-defined]
376
            linecache.lazycache(mangled_filename, ns)
377

378
            code = self._compile_source(filename, mangled_filename)
379
            exec(code, ns)
380

381
        return module
382

383
    def _load_module(self, name: str, parent: str):
384
        cur: _PathNode = self.root
385
        for atom in name.split("."):
386
            if not isinstance(cur, _PackageNode) or atom not in cur.children:
387
                if name in IMPLICIT_IMPORT_ALLOWLIST:
388
                    module = self.modules[name] = importlib.import_module(name)
389
                    return module
390
                raise ModuleNotFoundError(
391
                    f'No module named "{name}" in self-contained archive "{self.filename}"'
392
                    f" and the module is also not in the list of allowed external modules: {self.extern_modules}",
393
                    name=name,
394
                )
395
            cur = cur.children[atom]
396
            if isinstance(cur, _ExternNode):
397
                module = self.modules[name] = importlib.import_module(name)
398
                return module
399
        return self._make_module(name, cur.source_file, isinstance(cur, _PackageNode), parent)  # type: ignore[attr-defined]
400

401
    def _compile_source(self, fullpath: str, mangled_filename: str):
402
        source = self.zip_reader.get_record(fullpath)
403
        source = _normalize_line_endings(source)
404
        return compile(source, mangled_filename, "exec", dont_inherit=True)
405

406
    # note: named `get_source` so that linecache can find the source
407
    # when this is the __loader__ of a module.
408
    def get_source(self, module_name) -> str:
409
        # linecache calls `get_source` with the `module.__name__` as the argument, so we must demangle it here.
410
        module = self.import_module(demangle(module_name))
411
        return self.zip_reader.get_record(demangle(module.__file__)).decode("utf-8")
412

413
    # note: named `get_resource_reader` so that importlib.resources can find it.
414
    # This is otherwise considered an internal method.
415
    def get_resource_reader(self, fullname):
416
        try:
417
            package = self._get_package(fullname)
418
        except ImportError:
419
            return None
420
        if package.__loader__ is not self:
421
            return None
422
        return _PackageResourceReader(self, fullname)
423

424
    def _install_on_parent(self, parent: str, name: str, module: types.ModuleType):
425
        if not parent:
426
            return
427
        # Set the module as an attribute on its parent.
428
        parent_module = self.modules[parent]
429
        if parent_module.__loader__ is self:
430
            setattr(parent_module, name.rpartition(".")[2], module)
431

432
    # note: copied from cpython's import code, with call to create module replaced with _make_module
433
    def _do_find_and_load(self, name):
434
        path = None
435
        parent = name.rpartition(".")[0]
436
        module_name_no_parent = name.rpartition(".")[-1]
437
        if parent:
438
            if parent not in self.modules:
439
                self._gcd_import(parent)
440
            # Crazy side-effects!
441
            if name in self.modules:
442
                return self.modules[name]
443
            parent_module = self.modules[parent]
444

445
            try:
446
                path = parent_module.__path__  # type: ignore[attr-defined]
447

448
            except AttributeError:
449
                # when we attempt to import a package only containing pybinded files,
450
                # the parent directory isn't always a package as defined by python,
451
                # so we search if the package is actually there or not before calling the error.
452
                if isinstance(
453
                    parent_module.__loader__,
454
                    importlib.machinery.ExtensionFileLoader,
455
                ):
456
                    if name not in self.extern_modules:
457
                        msg = (
458
                            _ERR_MSG
459
                            + "; {!r} is a c extension module which was not externed. C extension modules \
460
                            need to be externed by the PackageExporter in order to be used as we do not support interning them.}."
461
                        ).format(name, name)
462
                        raise ModuleNotFoundError(msg, name=name) from None
463
                    if not isinstance(
464
                        parent_module.__dict__.get(module_name_no_parent),
465
                        types.ModuleType,
466
                    ):
467
                        msg = (
468
                            _ERR_MSG
469
                            + "; {!r} is a c extension package which does not contain {!r}."
470
                        ).format(name, parent, name)
471
                        raise ModuleNotFoundError(msg, name=name) from None
472
                else:
473
                    msg = (_ERR_MSG + "; {!r} is not a package").format(name, parent)
474
                    raise ModuleNotFoundError(msg, name=name) from None
475

476
        module = self._load_module(name, parent)
477

478
        self._install_on_parent(parent, name, module)
479

480
        return module
481

482
    # note: copied from cpython's import code
483
    def _find_and_load(self, name):
484
        module = self.modules.get(name, _NEEDS_LOADING)
485
        if module is _NEEDS_LOADING:
486
            return self._do_find_and_load(name)
487

488
        if module is None:
489
            message = f"import of {name} halted; None in sys.modules"
490
            raise ModuleNotFoundError(message, name=name)
491

492
        # To handle https://github.com/pytorch/pytorch/issues/57490, where std's
493
        # creation of fake submodules via the hacking of sys.modules is not import
494
        # friendly
495
        if name == "os":
496
            self.modules["os.path"] = cast(Any, module).path
497
        elif name == "typing":
498
            self.modules["typing.io"] = cast(Any, module).io
499
            self.modules["typing.re"] = cast(Any, module).re
500

501
        return module
502

503
    def _gcd_import(self, name, package=None, level=0):
504
        """Import and return the module based on its name, the package the call is
505
        being made from, and the level adjustment.
506

507
        This function represents the greatest common denominator of functionality
508
        between import_module and __import__. This includes setting __package__ if
509
        the loader did not.
510

511
        """
512
        _sanity_check(name, package, level)
513
        if level > 0:
514
            name = _resolve_name(name, package, level)
515

516
        return self._find_and_load(name)
517

518
    # note: copied from cpython's import code
519
    def _handle_fromlist(self, module, fromlist, *, recursive=False):
520
        """Figure out what __import__ should return.
521

522
        The import_ parameter is a callable which takes the name of module to
523
        import. It is required to decouple the function from assuming importlib's
524
        import implementation is desired.
525

526
        """
527
        module_name = demangle(module.__name__)
528
        # The hell that is fromlist ...
529
        # If a package was imported, try to import stuff from fromlist.
530
        if hasattr(module, "__path__"):
531
            for x in fromlist:
532
                if not isinstance(x, str):
533
                    if recursive:
534
                        where = module_name + ".__all__"
535
                    else:
536
                        where = "``from list''"
537
                    raise TypeError(
538
                        f"Item in {where} must be str, " f"not {type(x).__name__}"
539
                    )
540
                elif x == "*":
541
                    if not recursive and hasattr(module, "__all__"):
542
                        self._handle_fromlist(module, module.__all__, recursive=True)
543
                elif not hasattr(module, x):
544
                    from_name = f"{module_name}.{x}"
545
                    try:
546
                        self._gcd_import(from_name)
547
                    except ModuleNotFoundError as exc:
548
                        # Backwards-compatibility dictates we ignore failed
549
                        # imports triggered by fromlist for modules that don't
550
                        # exist.
551
                        if (
552
                            exc.name == from_name
553
                            and self.modules.get(from_name, _NEEDS_LOADING) is not None
554
                        ):
555
                            continue
556
                        raise
557
        return module
558

559
    def __import__(self, name, globals=None, locals=None, fromlist=(), level=0):
560
        if level == 0:
561
            module = self._gcd_import(name)
562
        else:
563
            globals_ = globals if globals is not None else {}
564
            package = _calc___package__(globals_)
565
            module = self._gcd_import(name, package, level)
566
        if not fromlist:
567
            # Return up to the first dot in 'name'. This is complicated by the fact
568
            # that 'name' may be relative.
569
            if level == 0:
570
                return self._gcd_import(name.partition(".")[0])
571
            elif not name:
572
                return module
573
            else:
574
                # Figure out where to slice the module's name up to the first dot
575
                # in 'name'.
576
                cut_off = len(name) - len(name.partition(".")[0])
577
                # Slice end needs to be positive to alleviate need to special-case
578
                # when ``'.' not in name``.
579
                module_name = demangle(module.__name__)
580
                return self.modules[module_name[: len(module_name) - cut_off]]
581
        else:
582
            return self._handle_fromlist(module, fromlist)
583

584
    def _get_package(self, package):
585
        """Take a package name or module object and return the module.
586

587
        If a name, the module is imported.  If the passed or imported module
588
        object is not a package, raise an exception.
589
        """
590
        if hasattr(package, "__spec__"):
591
            if package.__spec__.submodule_search_locations is None:
592
                raise TypeError(f"{package.__spec__.name!r} is not a package")
593
            else:
594
                return package
595
        else:
596
            module = self.import_module(package)
597
            if module.__spec__.submodule_search_locations is None:
598
                raise TypeError(f"{package!r} is not a package")
599
            else:
600
                return module
601

602
    def _zipfile_path(self, package, resource=None):
603
        package = self._get_package(package)
604
        assert package.__loader__ is self
605
        name = demangle(package.__name__)
606
        if resource is not None:
607
            resource = _normalize_path(resource)
608
            return f"{name.replace('.', '/')}/{resource}"
609
        else:
610
            return f"{name.replace('.', '/')}"
611

612
    def _get_or_create_package(
613
        self, atoms: List[str]
614
    ) -> "Union[_PackageNode, _ExternNode]":
615
        cur = self.root
616
        for i, atom in enumerate(atoms):
617
            node = cur.children.get(atom, None)
618
            if node is None:
619
                node = cur.children[atom] = _PackageNode(None)
620
            if isinstance(node, _ExternNode):
621
                return node
622
            if isinstance(node, _ModuleNode):
623
                name = ".".join(atoms[:i])
624
                raise ImportError(
625
                    f"inconsistent module structure. module {name} is not a package, but has submodules"
626
                )
627
            assert isinstance(node, _PackageNode)
628
            cur = node
629
        return cur
630

631
    def _add_file(self, filename: str):
632
        """Assembles a Python module out of the given file. Will ignore files in the .data directory.
633

634
        Args:
635
            filename (str): the name of the file inside of the package archive to be added
636
        """
637
        *prefix, last = filename.split("/")
638
        if len(prefix) > 1 and prefix[0] == ".data":
639
            return
640
        package = self._get_or_create_package(prefix)
641
        if isinstance(package, _ExternNode):
642
            raise ImportError(
643
                f"inconsistent module structure. package contains a module file {filename}"
644
                f" that is a subpackage of a module marked external."
645
            )
646
        if last == "__init__.py":
647
            package.source_file = filename
648
        elif last.endswith(".py"):
649
            package_name = last[: -len(".py")]
650
            package.children[package_name] = _ModuleNode(filename)
651

652
    def _add_extern(self, extern_name: str):
653
        *prefix, last = extern_name.split(".")
654
        package = self._get_or_create_package(prefix)
655
        if isinstance(package, _ExternNode):
656
            return  # the shorter extern covers this extern case
657
        package.children[last] = _ExternNode()
658

659

660
_NEEDS_LOADING = object()
661
_ERR_MSG_PREFIX = "No module named "
662
_ERR_MSG = _ERR_MSG_PREFIX + "{!r}"
663

664

665
class _PathNode:
666
    pass
667

668

669
class _PackageNode(_PathNode):
670
    def __init__(self, source_file: Optional[str]):
671
        self.source_file = source_file
672
        self.children: Dict[str, _PathNode] = {}
673

674

675
class _ModuleNode(_PathNode):
676
    __slots__ = ["source_file"]
677

678
    def __init__(self, source_file: str):
679
        self.source_file = source_file
680

681

682
class _ExternNode(_PathNode):
683
    pass
684

685

686
# A private global registry of all modules that have been package-imported.
687
_package_imported_modules: WeakValueDictionary = WeakValueDictionary()
688

689
# `inspect` by default only looks in `sys.modules` to find source files for classes.
690
# Patch it to check our private registry of package-imported modules as well.
691
_orig_getfile = inspect.getfile
692

693

694
def _patched_getfile(object):
695
    if inspect.isclass(object):
696
        if object.__module__ in _package_imported_modules:
697
            return _package_imported_modules[object.__module__].__file__
698
    return _orig_getfile(object)
699

700

701
inspect.getfile = _patched_getfile
702

703

704
class _PackageResourceReader:
705
    """Private class used to support PackageImporter.get_resource_reader().
706

707
    Confirms to the importlib.abc.ResourceReader interface. Allowed to access
708
    the innards of PackageImporter.
709
    """
710

711
    def __init__(self, importer, fullname):
712
        self.importer = importer
713
        self.fullname = fullname
714

715
    def open_resource(self, resource):
716
        from io import BytesIO
717

718
        return BytesIO(self.importer.load_binary(self.fullname, resource))
719

720
    def resource_path(self, resource):
721
        # The contract for resource_path is that it either returns a concrete
722
        # file system path or raises FileNotFoundError.
723
        if isinstance(
724
            self.importer.zip_reader, DirectoryReader
725
        ) and self.importer.zip_reader.has_record(
726
            os.path.join(self.fullname, resource)
727
        ):
728
            return os.path.join(
729
                self.importer.zip_reader.directory, self.fullname, resource
730
            )
731
        raise FileNotFoundError
732

733
    def is_resource(self, name):
734
        path = self.importer._zipfile_path(self.fullname, name)
735
        return self.importer.zip_reader.has_record(path)
736

737
    def contents(self):
738
        from pathlib import Path
739

740
        filename = self.fullname.replace(".", "/")
741

742
        fullname_path = Path(self.importer._zipfile_path(self.fullname))
743
        files = self.importer.zip_reader.get_all_records()
744
        subdirs_seen = set()
745
        for filename in files:
746
            try:
747
                relative = Path(filename).relative_to(fullname_path)
748
            except ValueError:
749
                continue
750
            # If the path of the file (which is relative to the top of the zip
751
            # namespace), relative to the package given when the resource
752
            # reader was created, has a parent, then it's a name in a
753
            # subdirectory and thus we skip it.
754
            parent_name = relative.parent.name
755
            if len(parent_name) == 0:
756
                yield relative.name
757
            elif parent_name not in subdirs_seen:
758
                subdirs_seen.add(parent_name)
759
                yield parent_name
760

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.