pytorch

Форк
0
/
_config_module.py 
369 строк · 11.9 Кб
1
import contextlib
2

3
import copy
4
import hashlib
5
import inspect
6
import io
7
import pickle
8
import tokenize
9
import unittest
10
import warnings
11
from types import FunctionType, ModuleType
12
from typing import Any, Dict, Optional, Set, Union
13
from unittest import mock
14

15
# Types saved/loaded in configs
16
CONFIG_TYPES = (int, float, bool, type(None), str, list, set, tuple, dict)
17

18

19
def install_config_module(module):
20
    """
21
    Converts a module-level config into a `ConfigModule()`.
22

23
    See _config_typing.pyi for instructions on how to get the converted module to typecheck.
24
    """
25

26
    class ConfigModuleInstance(ConfigModule):
27
        _bypass_keys = set({"_is_dirty", "_hash_digest"})
28

29
    def visit(source, dest, prefix):
30
        """Walk the module structure and move everything to module._config"""
31
        for key, value in list(source.__dict__.items()):
32
            if (
33
                key.startswith("__")
34
                or isinstance(value, (ModuleType, FunctionType))
35
                or (hasattr(value, "__module__") and value.__module__ == "typing")
36
            ):
37
                continue
38

39
            name = f"{prefix}{key}"
40
            if isinstance(value, CONFIG_TYPES):
41
                config[name] = value
42
                default[name] = value
43
                if dest is module:
44
                    delattr(module, key)
45
            elif isinstance(value, type):
46
                assert value.__module__ == module.__name__
47
                # a subconfig with `class Blah:` syntax
48
                proxy = SubConfigProxy(module, f"{name}.")
49
                visit(value, proxy, f"{name}.")
50
                setattr(dest, key, proxy)
51
            else:
52
                raise AssertionError(f"Unhandled config {key}={value} ({type(value)})")
53

54
    config: Dict[str, Any] = dict()
55
    default: Dict[str, Any] = dict()
56

57
    compile_ignored_keys = get_assignments_with_compile_ignored_comments(module)
58

59
    visit(module, module, "")
60
    module._config = config
61
    module._default = default
62
    module._allowed_keys = set(config.keys())
63
    module._compile_ignored_keys = compile_ignored_keys
64
    module.__class__ = ConfigModuleInstance
65
    module._is_dirty = True
66
    module._hash_digest = None
67

68

69
COMPILE_IGNORED_MARKER = "@compile_ignored"
70

71

72
# Gets all the keys (i.e. assignments) with a @compile_ignored comment
73
def get_assignments_with_compile_ignored_comments(module):
74
    source_code = inspect.getsource(module)
75
    assignments = set()
76

77
    # Tokenize the source code to retrieve comments
78
    tokens = tokenize.tokenize(io.BytesIO(source_code.encode("utf-8")).readline)
79
    current_comment = "", -1
80
    prev_name = ""
81

82
    for token in tokens:
83
        if token.type == tokenize.COMMENT:
84
            prev_name = ""
85
            maybe_current = token.string.strip()
86
            if COMPILE_IGNORED_MARKER in maybe_current:
87
                assert current_comment == (
88
                    "",
89
                    -1,
90
                ), f"unconsumed {COMPILE_IGNORED_MARKER}"
91
                current_comment = maybe_current, token.start[0]
92
        elif token.type == tokenize.NAME:
93
            # Only accept the first name token, to handle if you have
94
            # something like foo: Bar = ...
95
            if not prev_name:
96
                prev_name = token.string
97
        elif token.type == tokenize.OP and token.string == "=":
98
            # Check if the current assignment follows a comment
99
            # with COMPILE_IGNORED_MARKER
100
            if (
101
                COMPILE_IGNORED_MARKER in current_comment[0]
102
                and current_comment[1] == token.start[0] - 1
103
            ):
104
                assignments.add(prev_name)
105
                current_comment = "", -1  # reset
106
            prev_name = ""
107
    assert current_comment == ("", -1), f"unconsumed {COMPILE_IGNORED_MARKER}"
108
    return assignments
109

110

111
class ConfigModule(ModuleType):
112
    # NOTE: This should be kept in sync with _config_typing.pyi.
113

114
    # The default values of the configuration settings.  This can be used to
115
    # determine if the config has been changed or not.
116
    _default: Dict[str, Any]
117
    # The actual configuration settings.  E.g., torch._dynamo.config.debug
118
    # would live as "debug" in the key, and torch._inductor.config.triton.cudagraphs
119
    # maps as "triton.cudagraphs"
120
    _config: Dict[str, Any]
121
    _allowed_keys: Set[str]
122
    _bypass_keys: Set[str]
123
    _compile_ignored_keys: Set[str]
124
    _is_dirty: bool
125
    _hash_digest: Optional[bytes]
126

127
    def __init__(self):
128
        raise NotImplementedError(
129
            f"use {__name__}.install_config_module(sys.modules[__name__])"
130
        )
131

132
    def __setattr__(self, name, value):
133
        if name in self._bypass_keys:
134
            super().__setattr__(name, value)
135
        elif name not in self._allowed_keys:
136
            raise AttributeError(f"{self.__name__}.{name} does not exist")
137
        else:
138
            self._config[name] = value
139

140
    def __getattr__(self, name):
141
        try:
142
            return self._config[name]
143
        except KeyError as e:
144
            # make hasattr() work properly
145
            raise AttributeError(f"{self.__name__}.{name} does not exist") from e
146

147
    def __delattr__(self, name):
148
        # must support delete because unittest.mock.patch deletes
149
        # then recreate things
150
        del self._config[name]
151

152
    def save_config(self) -> bytes:
153
        """Convert config to a pickled blob"""
154
        config = dict(self._config)
155
        for key in config.get("_save_config_ignore", ()):
156
            config.pop(key)
157
        return pickle.dumps(config, protocol=2)
158

159
    def codegen_config(self) -> str:
160
        """Convert config to Python statements that replicate current config.
161
        This does NOT include config settings that are at default values.
162
        """
163
        lines = []
164
        mod = self.__name__
165
        for k, v in self._config.items():
166
            if k in self._config.get("_save_config_ignore", ()):
167
                continue
168
            if v == self._default[k]:
169
                continue
170
            lines.append(f"{mod}.{k} = {v!r}")
171
        return "\n".join(lines)
172

173
    def get_hash(self) -> bytes:
174
        """Hashes the configs that are not compile_ignored"""
175
        if self._is_dirty or self._hash_digest is None:
176
            dict_to_hash = {
177
                k: v
178
                for k, v in self._config.items()
179
                if k not in self._compile_ignored_keys
180
            }
181
            string_to_hash = repr(sorted(dict_to_hash.items()))
182
            self._hash_digest = hashlib.md5(string_to_hash.encode("utf-8")).digest()
183
            self._is_dirty = False
184
        return self._hash_digest
185

186
    def to_dict(self) -> Dict[str, Any]:
187
        warnings.warn(
188
            "config.to_dict() has been deprecated. It may no longer change the underlying config."
189
            " use config.shallow_copy_dict() or config.get_config_copy() instead",
190
            DeprecationWarning,
191
        )
192
        return self.shallow_copy_dict()
193

194
    def shallow_copy_dict(self) -> Dict[str, Any]:
195
        return {**self._config}
196

197
    def load_config(self, maybe_pickled_config: Union[bytes, Dict[str, Any]]) -> None:
198
        """Restore from a prior call to save_config() or shallow_copy_dict()"""
199
        if not isinstance(maybe_pickled_config, dict):
200
            config = pickle.loads(maybe_pickled_config)
201
        else:
202
            config = maybe_pickled_config
203
        self._config.update(config)
204

205
    def get_config_copy(self) -> Dict[str, Any]:
206
        return copy.deepcopy(self._config)
207

208
    def patch(
209
        self,
210
        arg1: Optional[Union[str, Dict[str, Any]]] = None,
211
        arg2: Any = None,
212
        **kwargs,
213
    ):
214
        """
215
        Decorator and/or context manager to make temporary changes to a config.
216

217
        As a decorator:
218

219
            @config.patch("name", val)
220
            @config.patch(name1=val1, name2=val2)
221
            @config.patch({"name1": val1, "name2", val2})
222
            def foo(...):
223
                ...
224

225
        As a context manager:
226

227
            with config.patch("name", val):
228
                ...
229
        """
230
        changes: Dict[str, Any]
231
        if arg1 is not None:
232
            if arg2 is not None:
233
                assert isinstance(arg1, str)
234
                # patch("key", True) syntax
235
                changes = {arg1: arg2}
236
            else:
237
                assert isinstance(arg1, dict)
238
                # patch({"key": True}) syntax
239
                changes = arg1
240
            assert not kwargs
241
        else:
242
            # patch(key=True) syntax
243
            changes = kwargs
244
            assert arg2 is None
245
        assert isinstance(changes, dict), f"expected `dict` got {type(changes)}"
246
        prior: Dict[str, Any] = {}
247
        config = self
248
        dirty = False
249

250
        class ConfigPatch(ContextDecorator):
251
            def __enter__(self):
252
                assert not prior
253
                nonlocal dirty
254
                for key in changes.keys():
255
                    # KeyError on invalid entry
256
                    prior[key] = config._config[key]
257
                    dirty = key not in config._compile_ignored_keys
258
                config._config.update(changes)
259
                config._is_dirty = dirty
260

261
            def __exit__(self, exc_type, exc_val, exc_tb):
262
                nonlocal dirty
263
                config._config.update(prior)
264
                config._is_dirty = dirty
265
                prior.clear()
266

267
        return ConfigPatch()
268

269
    def _make_closure_patcher(self, **changes):
270
        """
271
        A lower-overhead version of patch() for things on the critical path.
272

273
        Usage:
274

275
            # do this off the critical path
276
            change_fn = config.make_closure_patcher(foo=True)
277

278
            ...
279

280
            revert = change_fn()
281
            try:
282
              ...
283
            finally:
284
                revert()
285

286
        """
287
        config = self._config
288

289
        def change():
290
            prior = {k: config[k] for k in changes}
291
            config.update(changes)
292

293
            def revert():
294
                config.update(prior)
295

296
            return revert
297

298
        return change
299

300

301
class ContextDecorator(contextlib.ContextDecorator):
302
    """
303
    Same as contextlib.ContextDecorator, but with support for
304
    `unittest.TestCase`
305
    """
306

307
    def __enter__(self):
308
        raise NotImplementedError("NYI")
309

310
    def __exit__(self, exc_type, exc_val, exc_tb):
311
        raise NotImplementedError("NYI")
312

313
    def __call__(self, func):
314
        if isinstance(func, type) and issubclass(func, unittest.TestCase):
315

316
            class _TestCase(func):  # type: ignore[valid-type, misc]
317
                @classmethod
318
                def setUpClass(cls):
319
                    self.__enter__()
320
                    try:
321
                        super().setUpClass()
322
                    except Exception:
323
                        self.__exit__(None, None, None)
324
                        raise
325

326
                @classmethod
327
                def tearDownClass(cls):
328
                    try:
329
                        super().tearDownClass()
330
                    finally:
331
                        self.__exit__(None, None, None)
332

333
            _TestCase.__name__ = func.__name__
334
            _TestCase.__qualname__ = func.__qualname__
335
            _TestCase.__module__ = func.__module__
336

337
            return _TestCase
338

339
        return super().__call__(func)
340

341

342
class SubConfigProxy:
343
    """
344
    Shim to redirect to main config.
345
    `config.triton.cudagraphs` maps to _config["triton.cudagraphs"]
346
    """
347

348
    def __init__(self, config, prefix):
349
        # `super().__setattr__` to bypass custom `__setattr__`
350
        super().__setattr__("_config", config)
351
        super().__setattr__("_prefix", prefix)
352

353
    def __setattr__(self, name, value):
354
        return self._config.__setattr__(self._prefix + name, value)
355

356
    def __getattr__(self, name):
357
        return self._config.__getattr__(self._prefix + name)
358

359
    def __delattr__(self, name):
360
        return self._config.__delattr__(self._prefix + name)
361

362

363
def patch_object(obj, name, value):
364
    """
365
    Workaround `mock.patch.object` issue with ConfigModule
366
    """
367
    if isinstance(obj, ConfigModule):
368
        return obj.patch(name, value)
369
    return mock.patch.object(obj, name, value)
370

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.