pytorch

Форк
0
/
_sources.py 
138 строк · 4.3 Кб
1
# mypy: allow-untyped-defs
2
import ast
3
import functools
4
import inspect
5
from textwrap import dedent
6
from typing import Any, List, NamedTuple, Optional, Tuple
7

8
from torch._C import ErrorReport
9
from torch._C._jit_tree_views import SourceRangeFactory
10

11

12
def get_source_lines_and_file(
13
    obj: Any,
14
    error_msg: Optional[str] = None,
15
) -> Tuple[List[str], int, Optional[str]]:
16
    """
17
    Wrapper around inspect.getsourcelines and inspect.getsourcefile.
18

19
    Returns: (sourcelines, file_lino, filename)
20
    """
21
    filename = None  # in case getsourcefile throws
22
    try:
23
        filename = inspect.getsourcefile(obj)
24
        sourcelines, file_lineno = inspect.getsourcelines(obj)
25
    except OSError as e:
26
        msg = (
27
            f"Can't get source for {obj}. TorchScript requires source access in "
28
            "order to carry out compilation, make sure original .py files are "
29
            "available."
30
        )
31
        if error_msg:
32
            msg += "\n" + error_msg
33
        raise OSError(msg) from e
34

35
    return sourcelines, file_lineno, filename
36

37

38
def normalize_source_lines(sourcelines: List[str]) -> List[str]:
39
    """
40
    This helper function accepts a list of source lines. It finds the
41
    indentation level of the function definition (`def`), then it indents
42
    all lines in the function body to a point at or greater than that
43
    level. This allows for comments and continued string literals that
44
    are at a lower indentation than the rest of the code.
45
    Args:
46
        sourcelines: function source code, separated into lines by
47
                        the '\n' character
48
    Returns:
49
        A list of source lines that have been correctly aligned
50
    """
51

52
    def remove_prefix(text, prefix):
53
        return text[text.startswith(prefix) and len(prefix) :]
54

55
    # Find the line and line number containing the function definition
56
    idx = None
57
    for i, l in enumerate(sourcelines):
58
        if l.lstrip().startswith("def"):
59
            idx = i
60
            break
61

62
    # This will happen when the function is a lambda- we won't find "def" anywhere in the source
63
    # lines in that case. Currently trying to JIT compile a lambda will throw an error up in
64
    # `parse_def()`, but we might want to handle this case in the future.
65
    if idx is None:
66
        return sourcelines
67

68
    # Get a string representing the amount of leading whitespace
69
    fn_def = sourcelines[idx]
70
    whitespace = fn_def.split("def")[0]
71

72
    # Add this leading whitespace to all lines before and after the `def`
73
    aligned_prefix = [
74
        whitespace + remove_prefix(s, whitespace) for s in sourcelines[:idx]
75
    ]
76
    aligned_suffix = [
77
        whitespace + remove_prefix(s, whitespace) for s in sourcelines[idx + 1 :]
78
    ]
79

80
    # Put it together again
81
    aligned_prefix.append(fn_def)
82
    return aligned_prefix + aligned_suffix
83

84

85
# Thin wrapper around SourceRangeFactory to store extra metadata
86
# about the function-to-be-compiled.
87
class SourceContext(SourceRangeFactory):
88
    def __init__(
89
        self,
90
        source,
91
        filename,
92
        file_lineno,
93
        leading_whitespace_len,
94
        uses_true_division=True,
95
        funcname=None,
96
    ):
97
        super().__init__(source, filename, file_lineno, leading_whitespace_len)
98
        self.uses_true_division = uses_true_division
99
        self.filename = filename
100
        self.funcname = funcname
101

102

103
@functools.lru_cache(maxsize=None)
104
def make_source_context(*args):
105
    return SourceContext(*args)
106

107

108
def fake_range():
109
    return SourceContext("", None, 0, 0).make_raw_range(0, 1)
110

111

112
class ParsedDef(NamedTuple):
113
    ast: ast.Module
114
    ctx: SourceContext
115
    source: str
116
    filename: Optional[str]
117
    file_lineno: int
118

119

120
def parse_def(fn):
121
    sourcelines, file_lineno, filename = get_source_lines_and_file(
122
        fn, ErrorReport.call_stack()
123
    )
124
    sourcelines = normalize_source_lines(sourcelines)
125
    source = "".join(sourcelines)
126
    dedent_src = dedent(source)
127
    py_ast = ast.parse(dedent_src)
128
    if len(py_ast.body) != 1 or not isinstance(py_ast.body[0], ast.FunctionDef):
129
        raise RuntimeError(
130
            f"Expected a single top-level function: {filename}:{file_lineno}"
131
        )
132
    leading_whitespace_len = len(source.split("\n", 1)[0]) - len(
133
        dedent_src.split("\n", 1)[0]
134
    )
135
    ctx = make_source_context(
136
        source, filename, file_lineno, leading_whitespace_len, True, fn.__name__
137
    )
138
    return ParsedDef(py_ast, ctx, source, filename, file_lineno)
139

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.