11
from types import FunctionType, ModuleType
12
from typing import Any, Dict, Optional, Set, Union
13
from unittest import mock
16
CONFIG_TYPES = (int, float, bool, type(None), str, list, set, tuple, dict)
19
def install_config_module(module):
21
Converts a module-level config into a `ConfigModule()`.
23
See _config_typing.pyi for instructions on how to get the converted module to typecheck.
26
class ConfigModuleInstance(ConfigModule):
27
_bypass_keys = set({"_is_dirty", "_hash_digest"})
29
def visit(source, dest, prefix):
30
"""Walk the module structure and move everything to module._config"""
31
for key, value in list(source.__dict__.items()):
34
or isinstance(value, (ModuleType, FunctionType))
35
or (hasattr(value, "__module__") and value.__module__ == "typing")
39
name = f"{prefix}{key}"
40
if isinstance(value, CONFIG_TYPES):
45
elif isinstance(value, type):
46
assert value.__module__ == module.__name__
48
proxy = SubConfigProxy(module, f"{name}.")
49
visit(value, proxy, f"{name}.")
50
setattr(dest, key, proxy)
52
raise AssertionError(f"Unhandled config {key}={value} ({type(value)})")
54
config: Dict[str, Any] = dict()
55
default: Dict[str, Any] = dict()
57
compile_ignored_keys = get_assignments_with_compile_ignored_comments(module)
59
visit(module, module, "")
60
module._config = config
61
module._default = default
62
module._allowed_keys = set(config.keys())
63
module._compile_ignored_keys = compile_ignored_keys
64
module.__class__ = ConfigModuleInstance
65
module._is_dirty = True
66
module._hash_digest = None
69
COMPILE_IGNORED_MARKER = "@compile_ignored"
73
def get_assignments_with_compile_ignored_comments(module):
74
source_code = inspect.getsource(module)
78
tokens = tokenize.tokenize(io.BytesIO(source_code.encode("utf-8")).readline)
79
current_comment = "", -1
83
if token.type == tokenize.COMMENT:
85
maybe_current = token.string.strip()
86
if COMPILE_IGNORED_MARKER in maybe_current:
87
assert current_comment == (
90
), f"unconsumed {COMPILE_IGNORED_MARKER}"
91
current_comment = maybe_current, token.start[0]
92
elif token.type == tokenize.NAME:
96
prev_name = token.string
97
elif token.type == tokenize.OP and token.string == "=":
101
COMPILE_IGNORED_MARKER in current_comment[0]
102
and current_comment[1] == token.start[0] - 1
104
assignments.add(prev_name)
105
current_comment = "", -1
107
assert current_comment == ("", -1), f"unconsumed {COMPILE_IGNORED_MARKER}"
111
class ConfigModule(ModuleType):
116
_default: Dict[str, Any]
120
_config: Dict[str, Any]
121
_allowed_keys: Set[str]
122
_bypass_keys: Set[str]
123
_compile_ignored_keys: Set[str]
125
_hash_digest: Optional[bytes]
128
raise NotImplementedError(
129
f"use {__name__}.install_config_module(sys.modules[__name__])"
132
def __setattr__(self, name, value):
133
if name in self._bypass_keys:
134
super().__setattr__(name, value)
135
elif name not in self._allowed_keys:
136
raise AttributeError(f"{self.__name__}.{name} does not exist")
138
self._config[name] = value
140
def __getattr__(self, name):
142
return self._config[name]
143
except KeyError as e:
145
raise AttributeError(f"{self.__name__}.{name} does not exist") from e
147
def __delattr__(self, name):
150
del self._config[name]
152
def save_config(self) -> bytes:
153
"""Convert config to a pickled blob"""
154
config = dict(self._config)
155
for key in config.get("_save_config_ignore", ()):
157
return pickle.dumps(config, protocol=2)
159
def codegen_config(self) -> str:
160
"""Convert config to Python statements that replicate current config.
161
This does NOT include config settings that are at default values.
165
for k, v in self._config.items():
166
if k in self._config.get("_save_config_ignore", ()):
168
if v == self._default[k]:
170
lines.append(f"{mod}.{k} = {v!r}")
171
return "\n".join(lines)
173
def get_hash(self) -> bytes:
174
"""Hashes the configs that are not compile_ignored"""
175
if self._is_dirty or self._hash_digest is None:
178
for k, v in self._config.items()
179
if k not in self._compile_ignored_keys
181
string_to_hash = repr(sorted(dict_to_hash.items()))
182
self._hash_digest = hashlib.md5(string_to_hash.encode("utf-8")).digest()
183
self._is_dirty = False
184
return self._hash_digest
186
def to_dict(self) -> Dict[str, Any]:
188
"config.to_dict() has been deprecated. It may no longer change the underlying config."
189
" use config.shallow_copy_dict() or config.get_config_copy() instead",
192
return self.shallow_copy_dict()
194
def shallow_copy_dict(self) -> Dict[str, Any]:
195
return {**self._config}
197
def load_config(self, maybe_pickled_config: Union[bytes, Dict[str, Any]]) -> None:
198
"""Restore from a prior call to save_config() or shallow_copy_dict()"""
199
if not isinstance(maybe_pickled_config, dict):
200
config = pickle.loads(maybe_pickled_config)
202
config = maybe_pickled_config
203
self._config.update(config)
205
def get_config_copy(self) -> Dict[str, Any]:
206
return copy.deepcopy(self._config)
210
arg1: Optional[Union[str, Dict[str, Any]]] = None,
215
Decorator and/or context manager to make temporary changes to a config.
219
@config.patch("name", val)
220
@config.patch(name1=val1, name2=val2)
221
@config.patch({"name1": val1, "name2", val2})
225
As a context manager:
227
with config.patch("name", val):
230
changes: Dict[str, Any]
233
assert isinstance(arg1, str)
235
changes = {arg1: arg2}
237
assert isinstance(arg1, dict)
245
assert isinstance(changes, dict), f"expected `dict` got {type(changes)}"
246
prior: Dict[str, Any] = {}
250
class ConfigPatch(ContextDecorator):
254
for key in changes.keys():
256
prior[key] = config._config[key]
257
dirty = key not in config._compile_ignored_keys
258
config._config.update(changes)
259
config._is_dirty = dirty
261
def __exit__(self, exc_type, exc_val, exc_tb):
263
config._config.update(prior)
264
config._is_dirty = dirty
269
def _make_closure_patcher(self, **changes):
271
A lower-overhead version of patch() for things on the critical path.
275
# do this off the critical path
276
change_fn = config.make_closure_patcher(foo=True)
287
config = self._config
290
prior = {k: config[k] for k in changes}
291
config.update(changes)
301
class ContextDecorator(contextlib.ContextDecorator):
303
Same as contextlib.ContextDecorator, but with support for
308
raise NotImplementedError("NYI")
310
def __exit__(self, exc_type, exc_val, exc_tb):
311
raise NotImplementedError("NYI")
313
def __call__(self, func):
314
if isinstance(func, type) and issubclass(func, unittest.TestCase):
316
class _TestCase(func):
323
self.__exit__(None, None, None)
327
def tearDownClass(cls):
329
super().tearDownClass()
331
self.__exit__(None, None, None)
333
_TestCase.__name__ = func.__name__
334
_TestCase.__qualname__ = func.__qualname__
335
_TestCase.__module__ = func.__module__
339
return super().__call__(func)
344
Shim to redirect to main config.
345
`config.triton.cudagraphs` maps to _config["triton.cudagraphs"]
348
def __init__(self, config, prefix):
350
super().__setattr__("_config", config)
351
super().__setattr__("_prefix", prefix)
353
def __setattr__(self, name, value):
354
return self._config.__setattr__(self._prefix + name, value)
356
def __getattr__(self, name):
357
return self._config.__getattr__(self._prefix + name)
359
def __delattr__(self, name):
360
return self._config.__delattr__(self._prefix + name)
363
def patch_object(obj, name, value):
365
Workaround `mock.patch.object` issue with ConfigModule
367
if isinstance(obj, ConfigModule):
368
return obj.patch(name, value)
369
return mock.patch.object(obj, name, value)