pytorch

Форк
0
/
importer.py 
237 строк · 8.7 Кб
1
import importlib
2
from abc import ABC, abstractmethod
3
from pickle import (  # type: ignore[attr-defined]  # type: ignore[attr-defined]
4
    _getattribute,
5
    _Pickler,
6
    whichmodule as _pickle_whichmodule,
7
)
8
from types import ModuleType
9
from typing import Any, Dict, List, Optional, Tuple
10

11
from ._mangling import demangle, get_mangle_prefix, is_mangled
12

13
__all__ = ["ObjNotFoundError", "ObjMismatchError", "Importer", "OrderedImporter"]
14

15

16
class ObjNotFoundError(Exception):
17
    """Raised when an importer cannot find an object by searching for its name."""
18

19
    pass
20

21

22
class ObjMismatchError(Exception):
23
    """Raised when an importer found a different object with the same name as the user-provided one."""
24

25
    pass
26

27

28
class Importer(ABC):
29
    """Represents an environment to import modules from.
30

31
    By default, you can figure out what module an object belongs by checking
32
    __module__ and importing the result using __import__ or importlib.import_module.
33

34
    torch.package introduces module importers other than the default one.
35
    Each PackageImporter introduces a new namespace. Potentially a single
36
    name (e.g. 'foo.bar') is present in multiple namespaces.
37

38
    It supports two main operations:
39
        import_module: module_name -> module object
40
        get_name: object -> (parent module name, name of obj within module)
41

42
    The guarantee is that following round-trip will succeed or throw an ObjNotFoundError/ObjMisMatchError.
43
        module_name, obj_name = env.get_name(obj)
44
        module = env.import_module(module_name)
45
        obj2 = getattr(module, obj_name)
46
        assert obj1 is obj2
47
    """
48

49
    modules: Dict[str, ModuleType]
50

51
    @abstractmethod
52
    def import_module(self, module_name: str) -> ModuleType:
53
        """Import `module_name` from this environment.
54

55
        The contract is the same as for importlib.import_module.
56
        """
57
        pass
58

59
    def get_name(self, obj: Any, name: Optional[str] = None) -> Tuple[str, str]:
60
        """Given an object, return a name that can be used to retrieve the
61
        object from this environment.
62

63
        Args:
64
            obj: An object to get the module-environment-relative name for.
65
            name: If set, use this name instead of looking up __name__ or __qualname__ on `obj`.
66
                This is only here to match how Pickler handles __reduce__ functions that return a string,
67
                don't use otherwise.
68
        Returns:
69
            A tuple (parent_module_name, attr_name) that can be used to retrieve `obj` from this environment.
70
            Use it like:
71
                mod = importer.import_module(parent_module_name)
72
                obj = getattr(mod, attr_name)
73

74
        Raises:
75
            ObjNotFoundError: we couldn't retrieve `obj by name.
76
            ObjMisMatchError: we found a different object with the same name as `obj`.
77
        """
78
        if name is None and obj and _Pickler.dispatch.get(type(obj)) is None:
79
            # Honor the string return variant of __reduce__, which will give us
80
            # a global name to search for in this environment.
81
            # TODO: I guess we should do copyreg too?
82
            reduce = getattr(obj, "__reduce__", None)
83
            if reduce is not None:
84
                try:
85
                    rv = reduce()
86
                    if isinstance(rv, str):
87
                        name = rv
88
                except Exception:
89
                    pass
90
        if name is None:
91
            name = getattr(obj, "__qualname__", None)
92
        if name is None:
93
            name = obj.__name__
94

95
        orig_module_name = self.whichmodule(obj, name)
96
        # Demangle the module name before importing. If this obj came out of a
97
        # PackageImporter, `__module__` will be mangled. See mangling.md for
98
        # details.
99
        module_name = demangle(orig_module_name)
100

101
        # Check that this name will indeed return the correct object
102
        try:
103
            module = self.import_module(module_name)
104
            obj2, _ = _getattribute(module, name)
105
        except (ImportError, KeyError, AttributeError):
106
            raise ObjNotFoundError(
107
                f"{obj} was not found as {module_name}.{name}"
108
            ) from None
109

110
        if obj is obj2:
111
            return module_name, name
112

113
        def get_obj_info(obj):
114
            assert name is not None
115
            module_name = self.whichmodule(obj, name)
116
            is_mangled_ = is_mangled(module_name)
117
            location = (
118
                get_mangle_prefix(module_name)
119
                if is_mangled_
120
                else "the current Python environment"
121
            )
122
            importer_name = (
123
                f"the importer for {get_mangle_prefix(module_name)}"
124
                if is_mangled_
125
                else "'sys_importer'"
126
            )
127
            return module_name, location, importer_name
128

129
        obj_module_name, obj_location, obj_importer_name = get_obj_info(obj)
130
        obj2_module_name, obj2_location, obj2_importer_name = get_obj_info(obj2)
131
        msg = (
132
            f"\n\nThe object provided is from '{obj_module_name}', "
133
            f"which is coming from {obj_location}."
134
            f"\nHowever, when we import '{obj2_module_name}', it's coming from {obj2_location}."
135
            "\nTo fix this, make sure this 'PackageExporter's importer lists "
136
            f"{obj_importer_name} before {obj2_importer_name}."
137
        )
138
        raise ObjMismatchError(msg)
139

140
    def whichmodule(self, obj: Any, name: str) -> str:
141
        """Find the module name an object belongs to.
142

143
        This should be considered internal for end-users, but developers of
144
        an importer can override it to customize the behavior.
145

146
        Taken from pickle.py, but modified to exclude the search into sys.modules
147
        """
148
        module_name = getattr(obj, "__module__", None)
149
        if module_name is not None:
150
            return module_name
151

152
        # Protect the iteration by using a list copy of self.modules against dynamic
153
        # modules that trigger imports of other modules upon calls to getattr.
154
        for module_name, module in self.modules.copy().items():
155
            if (
156
                module_name == "__main__"
157
                or module_name == "__mp_main__"  # bpo-42406
158
                or module is None
159
            ):
160
                continue
161
            try:
162
                if _getattribute(module, name)[0] is obj:
163
                    return module_name
164
            except AttributeError:
165
                pass
166

167
        return "__main__"
168

169

170
class _SysImporter(Importer):
171
    """An importer that implements the default behavior of Python."""
172

173
    def import_module(self, module_name: str):
174
        return importlib.import_module(module_name)
175

176
    def whichmodule(self, obj: Any, name: str) -> str:
177
        return _pickle_whichmodule(obj, name)
178

179

180
sys_importer = _SysImporter()
181

182

183
class OrderedImporter(Importer):
184
    """A compound importer that takes a list of importers and tries them one at a time.
185

186
    The first importer in the list that returns a result "wins".
187
    """
188

189
    def __init__(self, *args):
190
        self._importers: List[Importer] = list(args)
191

192
    def _is_torchpackage_dummy(self, module):
193
        """Returns true iff this module is an empty PackageNode in a torch.package.
194

195
        If you intern `a.b` but never use `a` in your code, then `a` will be an
196
        empty module with no source. This can break cases where we are trying to
197
        re-package an object after adding a real dependency on `a`, since
198
        OrderedImportere will resolve `a` to the dummy package and stop there.
199

200
        See: https://github.com/pytorch/pytorch/pull/71520#issuecomment-1029603769
201
        """
202
        if not getattr(module, "__torch_package__", False):
203
            return False
204
        if not hasattr(module, "__path__"):
205
            return False
206
        if not hasattr(module, "__file__"):
207
            return True
208
        return module.__file__ is None
209

210
    def import_module(self, module_name: str) -> ModuleType:
211
        last_err = None
212
        for importer in self._importers:
213
            if not isinstance(importer, Importer):
214
                raise TypeError(
215
                    f"{importer} is not a Importer. "
216
                    "All importers in OrderedImporter must inherit from Importer."
217
                )
218
            try:
219
                module = importer.import_module(module_name)
220
                if self._is_torchpackage_dummy(module):
221
                    continue
222
                return module
223
            except ModuleNotFoundError as err:
224
                last_err = err
225

226
        if last_err is not None:
227
            raise last_err
228
        else:
229
            raise ModuleNotFoundError(module_name)
230

231
    def whichmodule(self, obj: Any, name: str) -> str:
232
        for importer in self._importers:
233
            module_name = importer.whichmodule(obj, name)
234
            if module_name != "__main__":
235
                return module_name
236

237
        return "__main__"
238

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.