1
from __future__ import annotations
10
from dataclasses import fields, is_dataclass
11
from enum import auto, Enum
12
from pathlib import Path
25
from typing_extensions import Self
27
from torchgen.code_template import CodeTemplate
31
from argparse import Namespace
34
REPO_ROOT = Path(__file__).absolute().parent.parent
52
ANONYMOUS_DEFINITION = auto()
54
NAMESPACED_DEFINITION = auto()
55
NAMESPACED_DECLARATION = auto()
60
IDENT_REGEX = r"(^|\W){}($|\W)"
64
def split_name_params(schema: str) -> tuple[str, list[str]]:
65
m = re.match(r"(\w+)(\.\w+)?\((.*)\)", schema)
67
raise RuntimeError(f"Unsupported function schema: {schema}")
68
name, _, params = m.groups()
69
return name, params.split(", ")
80
def mapMaybe(func: Callable[[T], S | None], xs: Iterable[T]) -> Iterator[S]:
88
def concatMap(func: Callable[[T], Sequence[S]], xs: Iterable[T]) -> Iterator[S]:
96
@contextlib.contextmanager
97
def context(msg_fn: Callable[[], str]) -> Iterator[None]:
100
except Exception as e:
103
msg = textwrap.indent(msg, " ")
104
msg = f"{e.args[0]}\n{msg}" if e.args else msg
105
e.args = (msg,) + e.args[1:]
112
def assert_never(x: NoReturn) -> NoReturn:
113
raise AssertionError(f"Unhandled type: {type(x).__name__}")
116
@functools.lru_cache(maxsize=None)
117
def _read_template(template_fn: str) -> CodeTemplate:
118
return CodeTemplate.from_file(template_fn)
122
def string_stable_hash(s: str) -> int:
123
sha1 = hashlib.sha1(s.encode("latin1")).digest()
124
return int.from_bytes(sha1, byteorder="little")
136
def __init__(self, install_dir: str, template_dir: str, dry_run: bool) -> None:
137
self.install_dir = install_dir
138
self.template_dir = template_dir
139
self.filenames = set()
140
self.dry_run = dry_run
142
def _write_if_changed(self, filename: str, contents: str) -> None:
143
old_contents: str | None
145
with open(filename) as f:
146
old_contents = f.read()
149
if contents != old_contents:
151
os.makedirs(os.path.dirname(filename), exist_ok=True)
152
with open(filename, "w") as f:
156
def substitute_with_template(
157
self, template_fn: str, env_callable: Callable[[], str | dict[str, Any]]
159
template_path = os.path.join(self.template_dir, template_fn)
161
if isinstance(env, dict):
162
if "generated_comment" not in env:
163
generator_default = REPO_ROOT / "torchgen" / "gen.py"
166
sys.modules["__main__"].__file__ or generator_default
168
except (KeyError, AttributeError):
169
generator = generator_default.absolute()
172
generator_path = generator.relative_to(REPO_ROOT).as_posix()
174
generator_path = generator.name
178
"generated_comment": (
179
"@" + f"generated by {generator_path} from {template_fn}"
182
template = _read_template(template_path)
183
return template.substitute(env)
184
elif isinstance(env, str):
189
def write_with_template(
193
env_callable: Callable[[], str | dict[str, Any]],
195
filename = f"{self.install_dir}/{filename}"
196
assert filename not in self.filenames, "duplicate file write {filename}"
197
self.filenames.add(filename)
199
substitute_out = self.substitute_with_template(
200
template_fn=template_fn,
201
env_callable=env_callable,
203
self._write_if_changed(filename=filename, contents=substitute_out)
208
env_callable: Callable[[], str | dict[str, Any]],
210
self.write_with_template(filename, filename, env_callable)
217
key_fn: Callable[[T], str],
218
env_callable: Callable[[T], dict[str, list[str]]],
220
base_env: dict[str, Any] | None = None,
221
sharded_keys: set[str],
223
everything: dict[str, Any] = {"shard_id": "Everything"}
224
shards: list[dict[str, Any]] = [
225
{"shard_id": f"_{i}"} for i in range(num_shards)
227
all_shards = [everything] + shards
229
if base_env is not None:
230
for shard in all_shards:
231
shard.update(base_env)
233
for key in sharded_keys:
234
for shard in all_shards:
238
), "sharded keys in base_env must be a list"
239
shard[key] = shard[key].copy()
243
def merge_env(into: dict[str, list[str]], from_: dict[str, list[str]]) -> None:
244
for k, v in from_.items():
245
assert k in sharded_keys, f"undeclared sharded key {k}"
254
sid = string_stable_hash(key) % num_shards
255
env = env_callable(item)
257
merge_env(shards[sid], env)
258
merge_env(everything, env)
260
dot_pos = filename.rfind(".")
262
dot_pos = len(filename)
263
base_filename = filename[:dot_pos]
264
extension = filename[dot_pos:]
266
for shard in all_shards:
267
shard_id = shard["shard_id"]
268
self.write_with_template(
269
f"{base_filename}{shard_id}{extension}", filename, lambda: shard
273
self.filenames.discard(
274
f"{self.install_dir}/{base_filename}Everything{extension}"
277
def write_outputs(self, variable_name: str, filename: str) -> None:
278
"""Write a file containing the list of all outputs which are
279
generated by this script."""
280
content = "set({}\n {})".format(
282
"\n ".join('"' + name + '"' for name in sorted(self.filenames)),
284
self._write_if_changed(filename, content)
286
def template_dir_for_comments(self) -> str:
288
This needs to be deterministic. The template dir is an absolute path
289
that varies across builds. So, just use the path relative to this file,
290
which will point to the codegen source but will be stable.
292
return os.path.relpath(self.template_dir, os.path.dirname(__file__))
296
def make_file_manager(
297
options: Namespace, install_dir: str | None = None
299
template_dir = os.path.join(options.source_path, "templates")
300
install_dir = install_dir if install_dir else options.install_dir
302
install_dir=install_dir, template_dir=template_dir, dry_run=options.dry_run
313
if sys.version_info >= (3, 10):
314
from pprint import pformat
316
return pformat(obj, indent, width)
318
return _pformat(obj, indent=indent, width=width)
325
curr_indent: int = 0,
327
assert is_dataclass(obj), f"obj should be a dataclass, received: {type(obj)}"
329
class_name = obj.__class__.__name__
331
curr_indent += len(class_name) + 1
333
fields_list = [(f.name, getattr(obj, f.name)) for f in fields(obj) if f.repr]
336
for name, attr in fields_list:
339
_curr_indent = curr_indent + len(name) + 1
340
if is_dataclass(attr):
341
str_repr = _pformat(attr, indent, width, _curr_indent)
342
elif isinstance(attr, dict):
343
str_repr = _format_dict(attr, indent, width, _curr_indent)
344
elif isinstance(attr, (list, set, tuple)):
345
str_repr = _format_list(attr, indent, width, _curr_indent)
347
str_repr = repr(attr)
349
fields_str.append(f"{name}={str_repr}")
351
indent_str = curr_indent * " "
352
body = f",\n{indent_str}".join(fields_str)
353
return f"{class_name}({body})"
357
attr: dict[Any, Any],
362
curr_indent += indent + 3
364
for k, v in attr.items():
367
_pformat(v, indent, width, curr_indent + len(k_repr))
371
dict_repr.append(f"{k_repr}: {v_str}")
373
return _format(dict_repr, indent, width, curr_indent, "{", "}")
377
attr: list[Any] | set[Any] | tuple[Any, ...],
382
curr_indent += indent + 1
384
_pformat(l, indent, width, curr_indent) if is_dataclass(l) else repr(l)
387
start, end = ("[", "]") if isinstance(attr, list) else ("(", ")")
388
return _format(list_repr, indent, width, curr_indent, start, end)
392
fields_str: list[str],
399
delimiter, curr_indent_str = "", ""
401
if len(repr(fields_str)) >= width:
403
curr_indent_str = " " * curr_indent
405
indent_str = " " * indent
406
body = f", {delimiter}{curr_indent_str}".join(fields_str)
407
return f"{start}{indent_str}{body}{end}"
410
class NamespaceHelper:
411
"""A helper for constructing the namespace open and close strings for a nested set of namespaces.
413
e.g. for namespace_str torch::lazy,
425
self, namespace_str: str, entity_name: str = "", max_level: int = 2
428
cpp_namespaces = namespace_str.split("::")
430
len(cpp_namespaces) <= max_level
431
), f"Codegen doesn't support more than {max_level} level(s) of custom namespace. Got {namespace_str}."
432
self.cpp_namespace_ = namespace_str
433
self.prologue_ = "\n".join([f"namespace {n} {{" for n in cpp_namespaces])
434
self.epilogue_ = "\n".join(
435
[f"}} // namespace {n}" for n in reversed(cpp_namespaces)]
437
self.namespaces_ = cpp_namespaces
438
self.entity_name_ = entity_name
441
def from_namespaced_entity(
442
namespaced_entity: str, max_level: int = 2
443
) -> NamespaceHelper:
445
Generate helper from nested namespaces as long as class/function name. E.g.: "torch::lazy::add"
447
names = namespaced_entity.split("::")
448
entity_name = names[-1]
449
namespace_str = "::".join(names[:-1])
450
return NamespaceHelper(
451
namespace_str=namespace_str, entity_name=entity_name, max_level=max_level
455
def prologue(self) -> str:
456
return self.prologue_
459
def epilogue(self) -> str:
460
return self.epilogue_
463
def entity_name(self) -> str:
464
return self.entity_name_
467
def get_cpp_namespace(self, default: str = "") -> str:
469
Return the namespace string from joining all the namespaces by "::" (hence no leading "::").
470
Return default if namespace string is empty.
472
return self.cpp_namespace_ if self.cpp_namespace_ else default
475
class OrderedSet(Generic[T]):
476
storage: dict[T, Literal[None]]
478
def __init__(self, iterable: Iterable[T] | None = None) -> None:
482
self.storage = dict.fromkeys(iterable)
484
def __contains__(self, item: T) -> bool:
485
return item in self.storage
487
def __iter__(self) -> Iterator[T]:
488
return iter(self.storage.keys())
490
def update(self, items: OrderedSet[T]) -> None:
491
self.storage.update(items.storage)
493
def add(self, item: T) -> None:
494
self.storage[item] = None
496
def copy(self) -> OrderedSet[T]:
497
ret: OrderedSet[T] = OrderedSet()
498
ret.storage = self.storage.copy()
502
def union(*args: OrderedSet[T]) -> OrderedSet[T]:
508
def __or__(self, other: OrderedSet[T]) -> OrderedSet[T]:
509
return OrderedSet.union(self, other)
511
def __ior__(self, other: OrderedSet[T]) -> Self:
515
def __eq__(self, other: object) -> bool:
516
if isinstance(other, OrderedSet):
517
return self.storage == other.storage
519
return set(self.storage.keys()) == other