pytorch

Форк
0
/
utils.py 
519 строк · 16.1 Кб
1
from __future__ import annotations
2

3
import contextlib
4
import functools
5
import hashlib
6
import os
7
import re
8
import sys
9
import textwrap
10
from dataclasses import fields, is_dataclass
11
from enum import auto, Enum
12
from pathlib import Path
13
from typing import (
14
    Any,
15
    Callable,
16
    Generic,
17
    Iterable,
18
    Iterator,
19
    Literal,
20
    NoReturn,
21
    Sequence,
22
    TYPE_CHECKING,
23
    TypeVar,
24
)
25
from typing_extensions import Self
26

27
from torchgen.code_template import CodeTemplate
28

29

30
if TYPE_CHECKING:
31
    from argparse import Namespace
32

33

34
REPO_ROOT = Path(__file__).absolute().parent.parent
35

36

37
# Many of these functions share logic for defining both the definition
38
# and declaration (for example, the function signature is the same), so
39
# we organize them into one function that takes a Target to say which
40
# code we want.
41
#
42
# This is an OPEN enum (we may add more cases to it in the future), so be sure
43
# to explicitly specify with Literal[Target.XXX] or Literal[Target.XXX, Target.YYY]
44
# what targets are valid for your use.
45
class Target(Enum):
46
    # top level namespace (not including at)
47
    DEFINITION = auto()
48
    DECLARATION = auto()
49
    # TORCH_LIBRARY(...) { ... }
50
    REGISTRATION = auto()
51
    # namespace { ... }
52
    ANONYMOUS_DEFINITION = auto()
53
    # namespace cpu { ... }
54
    NAMESPACED_DEFINITION = auto()
55
    NAMESPACED_DECLARATION = auto()
56

57

58
# Matches "foo" in "foo, bar" but not "foobar". Used to search for the
59
# occurrence of a parameter in the derivative formula
60
IDENT_REGEX = r"(^|\W){}($|\W)"
61

62

63
# TODO: Use a real parser here; this will get bamboozled
64
def split_name_params(schema: str) -> tuple[str, list[str]]:
65
    m = re.match(r"(\w+)(\.\w+)?\((.*)\)", schema)
66
    if m is None:
67
        raise RuntimeError(f"Unsupported function schema: {schema}")
68
    name, _, params = m.groups()
69
    return name, params.split(", ")
70

71

72
T = TypeVar("T")
73
S = TypeVar("S")
74

75
# These two functions purposely return generators in analogy to map()
76
# so that you don't mix up when you need to list() them
77

78

79
# Map over function that may return None; omit Nones from output sequence
80
def mapMaybe(func: Callable[[T], S | None], xs: Iterable[T]) -> Iterator[S]:
81
    for x in xs:
82
        r = func(x)
83
        if r is not None:
84
            yield r
85

86

87
# Map over function that returns sequences and cat them all together
88
def concatMap(func: Callable[[T], Sequence[S]], xs: Iterable[T]) -> Iterator[S]:
89
    for x in xs:
90
        yield from func(x)
91

92

93
# Conveniently add error context to exceptions raised.  Lets us
94
# easily say that an error occurred while processing a specific
95
# context.
96
@contextlib.contextmanager
97
def context(msg_fn: Callable[[], str]) -> Iterator[None]:
98
    try:
99
        yield
100
    except Exception as e:
101
        # TODO: this does the wrong thing with KeyError
102
        msg = msg_fn()
103
        msg = textwrap.indent(msg, "  ")
104
        msg = f"{e.args[0]}\n{msg}" if e.args else msg
105
        e.args = (msg,) + e.args[1:]
106
        raise
107

108

109
# A little trick from https://github.com/python/mypy/issues/6366
110
# for getting mypy to do exhaustiveness checking
111
# TODO: put this somewhere else, maybe
112
def assert_never(x: NoReturn) -> NoReturn:
113
    raise AssertionError(f"Unhandled type: {type(x).__name__}")
114

115

116
@functools.lru_cache(maxsize=None)
117
def _read_template(template_fn: str) -> CodeTemplate:
118
    return CodeTemplate.from_file(template_fn)
119

120

121
# String hash that's stable across different executions, unlike builtin hash
122
def string_stable_hash(s: str) -> int:
123
    sha1 = hashlib.sha1(s.encode("latin1")).digest()
124
    return int.from_bytes(sha1, byteorder="little")
125

126

127
# A small abstraction for writing out generated files and keeping track
128
# of what files have been written (so you can write out a list of output
129
# files)
130
class FileManager:
131
    install_dir: str
132
    template_dir: str
133
    dry_run: bool
134
    filenames: set[str]
135

136
    def __init__(self, install_dir: str, template_dir: str, dry_run: bool) -> None:
137
        self.install_dir = install_dir
138
        self.template_dir = template_dir
139
        self.filenames = set()
140
        self.dry_run = dry_run
141

142
    def _write_if_changed(self, filename: str, contents: str) -> None:
143
        old_contents: str | None
144
        try:
145
            with open(filename) as f:
146
                old_contents = f.read()
147
        except OSError:
148
            old_contents = None
149
        if contents != old_contents:
150
            # Create output directory if it doesn't exist
151
            os.makedirs(os.path.dirname(filename), exist_ok=True)
152
            with open(filename, "w") as f:
153
                f.write(contents)
154

155
    # Read from template file and replace pattern with callable (type could be dict or str).
156
    def substitute_with_template(
157
        self, template_fn: str, env_callable: Callable[[], str | dict[str, Any]]
158
    ) -> str:
159
        template_path = os.path.join(self.template_dir, template_fn)
160
        env = env_callable()
161
        if isinstance(env, dict):
162
            if "generated_comment" not in env:
163
                generator_default = REPO_ROOT / "torchgen" / "gen.py"
164
                try:
165
                    generator = Path(
166
                        sys.modules["__main__"].__file__ or generator_default
167
                    ).absolute()
168
                except (KeyError, AttributeError):
169
                    generator = generator_default.absolute()
170

171
                try:
172
                    generator_path = generator.relative_to(REPO_ROOT).as_posix()
173
                except ValueError:
174
                    generator_path = generator.name
175

176
                env = {
177
                    **env,  # copy the original dict instead of mutating it
178
                    "generated_comment": (
179
                        "@" + f"generated by {generator_path} from {template_fn}"
180
                    ),
181
                }
182
            template = _read_template(template_path)
183
            return template.substitute(env)
184
        elif isinstance(env, str):
185
            return env
186
        else:
187
            assert_never(env)
188

189
    def write_with_template(
190
        self,
191
        filename: str,
192
        template_fn: str,
193
        env_callable: Callable[[], str | dict[str, Any]],
194
    ) -> None:
195
        filename = f"{self.install_dir}/{filename}"
196
        assert filename not in self.filenames, "duplicate file write {filename}"
197
        self.filenames.add(filename)
198
        if not self.dry_run:
199
            substitute_out = self.substitute_with_template(
200
                template_fn=template_fn,
201
                env_callable=env_callable,
202
            )
203
            self._write_if_changed(filename=filename, contents=substitute_out)
204

205
    def write(
206
        self,
207
        filename: str,
208
        env_callable: Callable[[], str | dict[str, Any]],
209
    ) -> None:
210
        self.write_with_template(filename, filename, env_callable)
211

212
    def write_sharded(
213
        self,
214
        filename: str,
215
        items: Iterable[T],
216
        *,
217
        key_fn: Callable[[T], str],
218
        env_callable: Callable[[T], dict[str, list[str]]],
219
        num_shards: int,
220
        base_env: dict[str, Any] | None = None,
221
        sharded_keys: set[str],
222
    ) -> None:
223
        everything: dict[str, Any] = {"shard_id": "Everything"}
224
        shards: list[dict[str, Any]] = [
225
            {"shard_id": f"_{i}"} for i in range(num_shards)
226
        ]
227
        all_shards = [everything] + shards
228

229
        if base_env is not None:
230
            for shard in all_shards:
231
                shard.update(base_env)
232

233
        for key in sharded_keys:
234
            for shard in all_shards:
235
                if key in shard:
236
                    assert isinstance(
237
                        shard[key], list
238
                    ), "sharded keys in base_env must be a list"
239
                    shard[key] = shard[key].copy()
240
                else:
241
                    shard[key] = []
242

243
        def merge_env(into: dict[str, list[str]], from_: dict[str, list[str]]) -> None:
244
            for k, v in from_.items():
245
                assert k in sharded_keys, f"undeclared sharded key {k}"
246
                into[k] += v
247

248
        if self.dry_run:
249
            # Dry runs don't write any templates, so incomplete environments are fine
250
            items = ()
251

252
        for item in items:
253
            key = key_fn(item)
254
            sid = string_stable_hash(key) % num_shards
255
            env = env_callable(item)
256

257
            merge_env(shards[sid], env)
258
            merge_env(everything, env)
259

260
        dot_pos = filename.rfind(".")
261
        if dot_pos == -1:
262
            dot_pos = len(filename)
263
        base_filename = filename[:dot_pos]
264
        extension = filename[dot_pos:]
265

266
        for shard in all_shards:
267
            shard_id = shard["shard_id"]
268
            self.write_with_template(
269
                f"{base_filename}{shard_id}{extension}", filename, lambda: shard
270
            )
271

272
        # filenames is used to track compiled files, but FooEverything.cpp isn't meant to be compiled
273
        self.filenames.discard(
274
            f"{self.install_dir}/{base_filename}Everything{extension}"
275
        )
276

277
    def write_outputs(self, variable_name: str, filename: str) -> None:
278
        """Write a file containing the list of all outputs which are
279
        generated by this script."""
280
        content = "set({}\n    {})".format(
281
            variable_name,
282
            "\n    ".join('"' + name + '"' for name in sorted(self.filenames)),
283
        )
284
        self._write_if_changed(filename, content)
285

286
    def template_dir_for_comments(self) -> str:
287
        """
288
        This needs to be deterministic. The template dir is an absolute path
289
        that varies across builds. So, just use the path relative to this file,
290
        which will point to the codegen source but will be stable.
291
        """
292
        return os.path.relpath(self.template_dir, os.path.dirname(__file__))
293

294

295
# Helper function to generate file manager
296
def make_file_manager(
297
    options: Namespace, install_dir: str | None = None
298
) -> FileManager:
299
    template_dir = os.path.join(options.source_path, "templates")
300
    install_dir = install_dir if install_dir else options.install_dir
301
    return FileManager(
302
        install_dir=install_dir, template_dir=template_dir, dry_run=options.dry_run
303
    )
304

305

306
# Helper function to create a pretty representation for dataclasses
307
def dataclass_repr(
308
    obj: Any,
309
    indent: int = 0,
310
    width: int = 80,
311
) -> str:
312
    # built-in pprint module support dataclasses from python 3.10
313
    if sys.version_info >= (3, 10):
314
        from pprint import pformat
315

316
        return pformat(obj, indent, width)
317

318
    return _pformat(obj, indent=indent, width=width)
319

320

321
def _pformat(
322
    obj: Any,
323
    indent: int,
324
    width: int,
325
    curr_indent: int = 0,
326
) -> str:
327
    assert is_dataclass(obj), f"obj should be a dataclass, received: {type(obj)}"
328

329
    class_name = obj.__class__.__name__
330
    # update current indentation level with class name
331
    curr_indent += len(class_name) + 1
332

333
    fields_list = [(f.name, getattr(obj, f.name)) for f in fields(obj) if f.repr]
334

335
    fields_str = []
336
    for name, attr in fields_list:
337
        # update the current indent level with the field name
338
        # dict, list, set and tuple also add indent as done in pprint
339
        _curr_indent = curr_indent + len(name) + 1
340
        if is_dataclass(attr):
341
            str_repr = _pformat(attr, indent, width, _curr_indent)
342
        elif isinstance(attr, dict):
343
            str_repr = _format_dict(attr, indent, width, _curr_indent)
344
        elif isinstance(attr, (list, set, tuple)):
345
            str_repr = _format_list(attr, indent, width, _curr_indent)
346
        else:
347
            str_repr = repr(attr)
348

349
        fields_str.append(f"{name}={str_repr}")
350

351
    indent_str = curr_indent * " "
352
    body = f",\n{indent_str}".join(fields_str)
353
    return f"{class_name}({body})"
354

355

356
def _format_dict(
357
    attr: dict[Any, Any],
358
    indent: int,
359
    width: int,
360
    curr_indent: int,
361
) -> str:
362
    curr_indent += indent + 3
363
    dict_repr = []
364
    for k, v in attr.items():
365
        k_repr = repr(k)
366
        v_str = (
367
            _pformat(v, indent, width, curr_indent + len(k_repr))
368
            if is_dataclass(v)
369
            else repr(v)
370
        )
371
        dict_repr.append(f"{k_repr}: {v_str}")
372

373
    return _format(dict_repr, indent, width, curr_indent, "{", "}")
374

375

376
def _format_list(
377
    attr: list[Any] | set[Any] | tuple[Any, ...],
378
    indent: int,
379
    width: int,
380
    curr_indent: int,
381
) -> str:
382
    curr_indent += indent + 1
383
    list_repr = [
384
        _pformat(l, indent, width, curr_indent) if is_dataclass(l) else repr(l)
385
        for l in attr
386
    ]
387
    start, end = ("[", "]") if isinstance(attr, list) else ("(", ")")
388
    return _format(list_repr, indent, width, curr_indent, start, end)
389

390

391
def _format(
392
    fields_str: list[str],
393
    indent: int,
394
    width: int,
395
    curr_indent: int,
396
    start: str,
397
    end: str,
398
) -> str:
399
    delimiter, curr_indent_str = "", ""
400
    # if it exceed the max width then we place one element per line
401
    if len(repr(fields_str)) >= width:
402
        delimiter = "\n"
403
        curr_indent_str = " " * curr_indent
404

405
    indent_str = " " * indent
406
    body = f", {delimiter}{curr_indent_str}".join(fields_str)
407
    return f"{start}{indent_str}{body}{end}"
408

409

410
class NamespaceHelper:
411
    """A helper for constructing the namespace open and close strings for a nested set of namespaces.
412

413
    e.g. for namespace_str torch::lazy,
414

415
    prologue:
416
    namespace torch {
417
    namespace lazy {
418

419
    epilogue:
420
    } // namespace lazy
421
    } // namespace torch
422
    """
423

424
    def __init__(
425
        self, namespace_str: str, entity_name: str = "", max_level: int = 2
426
    ) -> None:
427
        # cpp_namespace can be a colon joined string such as torch::lazy
428
        cpp_namespaces = namespace_str.split("::")
429
        assert (
430
            len(cpp_namespaces) <= max_level
431
        ), f"Codegen doesn't support more than {max_level} level(s) of custom namespace. Got {namespace_str}."
432
        self.cpp_namespace_ = namespace_str
433
        self.prologue_ = "\n".join([f"namespace {n} {{" for n in cpp_namespaces])
434
        self.epilogue_ = "\n".join(
435
            [f"}} // namespace {n}" for n in reversed(cpp_namespaces)]
436
        )
437
        self.namespaces_ = cpp_namespaces
438
        self.entity_name_ = entity_name
439

440
    @staticmethod
441
    def from_namespaced_entity(
442
        namespaced_entity: str, max_level: int = 2
443
    ) -> NamespaceHelper:
444
        """
445
        Generate helper from nested namespaces as long as class/function name. E.g.: "torch::lazy::add"
446
        """
447
        names = namespaced_entity.split("::")
448
        entity_name = names[-1]
449
        namespace_str = "::".join(names[:-1])
450
        return NamespaceHelper(
451
            namespace_str=namespace_str, entity_name=entity_name, max_level=max_level
452
        )
453

454
    @property
455
    def prologue(self) -> str:
456
        return self.prologue_
457

458
    @property
459
    def epilogue(self) -> str:
460
        return self.epilogue_
461

462
    @property
463
    def entity_name(self) -> str:
464
        return self.entity_name_
465

466
    # Only allow certain level of namespaces
467
    def get_cpp_namespace(self, default: str = "") -> str:
468
        """
469
        Return the namespace string from joining all the namespaces by "::" (hence no leading "::").
470
        Return default if namespace string is empty.
471
        """
472
        return self.cpp_namespace_ if self.cpp_namespace_ else default
473

474

475
class OrderedSet(Generic[T]):
476
    storage: dict[T, Literal[None]]
477

478
    def __init__(self, iterable: Iterable[T] | None = None) -> None:
479
        if iterable is None:
480
            self.storage = {}
481
        else:
482
            self.storage = dict.fromkeys(iterable)
483

484
    def __contains__(self, item: T) -> bool:
485
        return item in self.storage
486

487
    def __iter__(self) -> Iterator[T]:
488
        return iter(self.storage.keys())
489

490
    def update(self, items: OrderedSet[T]) -> None:
491
        self.storage.update(items.storage)
492

493
    def add(self, item: T) -> None:
494
        self.storage[item] = None
495

496
    def copy(self) -> OrderedSet[T]:
497
        ret: OrderedSet[T] = OrderedSet()
498
        ret.storage = self.storage.copy()
499
        return ret
500

501
    @staticmethod
502
    def union(*args: OrderedSet[T]) -> OrderedSet[T]:
503
        ret = args[0].copy()
504
        for s in args[1:]:
505
            ret.update(s)
506
        return ret
507

508
    def __or__(self, other: OrderedSet[T]) -> OrderedSet[T]:
509
        return OrderedSet.union(self, other)
510

511
    def __ior__(self, other: OrderedSet[T]) -> Self:
512
        self.update(other)
513
        return self
514

515
    def __eq__(self, other: object) -> bool:
516
        if isinstance(other, OrderedSet):
517
            return self.storage == other.storage
518
        else:
519
            return set(self.storage.keys()) == other
520

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.