5
from textwrap import dedent
6
from typing import Any, List, NamedTuple, Optional, Tuple
8
from torch._C import ErrorReport
9
from torch._C._jit_tree_views import SourceRangeFactory
12
def get_source_lines_and_file(
14
error_msg: Optional[str] = None,
15
) -> Tuple[List[str], int, Optional[str]]:
17
Wrapper around inspect.getsourcelines and inspect.getsourcefile.
19
Returns: (sourcelines, file_lino, filename)
23
filename = inspect.getsourcefile(obj)
24
sourcelines, file_lineno = inspect.getsourcelines(obj)
27
f"Can't get source for {obj}. TorchScript requires source access in "
28
"order to carry out compilation, make sure original .py files are "
32
msg += "\n" + error_msg
33
raise OSError(msg) from e
35
return sourcelines, file_lineno, filename
38
def normalize_source_lines(sourcelines: List[str]) -> List[str]:
40
This helper function accepts a list of source lines. It finds the
41
indentation level of the function definition (`def`), then it indents
42
all lines in the function body to a point at or greater than that
43
level. This allows for comments and continued string literals that
44
are at a lower indentation than the rest of the code.
46
sourcelines: function source code, separated into lines by
49
A list of source lines that have been correctly aligned
52
def remove_prefix(text, prefix):
53
return text[text.startswith(prefix) and len(prefix) :]
57
for i, l in enumerate(sourcelines):
58
if l.lstrip().startswith("def"):
69
fn_def = sourcelines[idx]
70
whitespace = fn_def.split("def")[0]
74
whitespace + remove_prefix(s, whitespace) for s in sourcelines[:idx]
77
whitespace + remove_prefix(s, whitespace) for s in sourcelines[idx + 1 :]
81
aligned_prefix.append(fn_def)
82
return aligned_prefix + aligned_suffix
87
class SourceContext(SourceRangeFactory):
93
leading_whitespace_len,
94
uses_true_division=True,
97
super().__init__(source, filename, file_lineno, leading_whitespace_len)
98
self.uses_true_division = uses_true_division
99
self.filename = filename
100
self.funcname = funcname
103
@functools.lru_cache(maxsize=None)
104
def make_source_context(*args):
105
return SourceContext(*args)
109
return SourceContext("", None, 0, 0).make_raw_range(0, 1)
112
class ParsedDef(NamedTuple):
116
filename: Optional[str]
121
sourcelines, file_lineno, filename = get_source_lines_and_file(
122
fn, ErrorReport.call_stack()
124
sourcelines = normalize_source_lines(sourcelines)
125
source = "".join(sourcelines)
126
dedent_src = dedent(source)
127
py_ast = ast.parse(dedent_src)
128
if len(py_ast.body) != 1 or not isinstance(py_ast.body[0], ast.FunctionDef):
130
f"Expected a single top-level function: {filename}:{file_lineno}"
132
leading_whitespace_len = len(source.split("\n", 1)[0]) - len(
133
dedent_src.split("\n", 1)[0]
135
ctx = make_source_context(
136
source, filename, file_lineno, leading_whitespace_len, True, fn.__name__
138
return ParsedDef(py_ast, ctx, source, filename, file_lineno)