pytorch
1# mypy: allow-untyped-defs
2from types import TracebackType3from typing import List, Optional4import tempfile5import traceback6import contextlib7import inspect8import os.path9
10# This file contains utilities for ensuring dynamically compile()'d
11# code fragments display their line numbers in backtraces.
12#
13# The constraints:
14#
15# - We don't have control over the user exception printer (in particular,
16# we cannot assume the linecache trick will work, c.f.
17# https://stackoverflow.com/q/50515651/23845 )
18#
19# - We don't want to create temporary files every time we compile()
20# some code; file creation should happen lazily only at exception
21# time. Arguably, you *should* be willing to write out your
22# generated Python code to file system, but in some situations
23# (esp. library code) it would violate user expectation to write
24# to the file system, so we try to avoid it. In particular, we'd
25# like to keep the files around, so users can open up the files
26# mentioned in the trace; if the file is invisible, we want to
27# avoid clogging up the filesystem.
28#
29# If this is not a constraint for you, there is a substantially simpler
30# way to implement the functionality in this PR: instead of using
31# eval/exec directly, just always write a Python file to filesystem
32# and compile that.
33#
34# - You have control over a context where the compiled code will get
35# executed, so that we can interpose while the stack is unwinding
36# (otherwise, we have no way to interpose on the exception printing
37# process.)
38#
39# There are two things you have to do to make use of the utilities here:
40#
41# - When you compile your source code, you must save its string source
42# in its f_globals under the magic name "__compile_source__"
43#
44# - Before running the compiled code, enter the
45# report_compile_source_on_error() context manager.
46
47@contextlib.contextmanager48def report_compile_source_on_error():49try:50yield51except Exception as exc:52tb = exc.__traceback__53
54# Walk the traceback, looking for frames that have55# source attached56stack = []57while tb is not None:58filename = tb.tb_frame.f_code.co_filename59source = tb.tb_frame.f_globals.get("__compile_source__")60
61if filename == "<string>" and source is not None:62# What black magic are we doing here? Intuitively, what63# we would like to do is overwrite the co_filename on any64# frames that were generated from exec/eval so that they65# point to a temporary file that has the actual line66# information, so Python's default error printer can print67# useful line information on it.68#69# Writing out the temporary file is easy. But overwriting70# co_filename is not! You can't modify the code object71# associated with a frame. You can, however, reconstruct72# a traceback with entirely new frames from scratch, so that's73# what we do. But there's another problem, which is how to74# make the frame?75#76# The black magic is we make a frankenstein frame and code77# object which resembles the original frame/code enough so78# that it will print properly under traceback and the default79# error printer, but IT IS NOT THE ORIGINAL FRAME (you80# couldn't, e.g., execute its code with different variables81# and expect it to work.)82
83# Don't delete the temporary file so the user can inspect it84# TODO: This creates a temporary file for every frame, but we85# technically only need one per distinct __compile_source__86with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix=".py") as f:87f.write(source)88# Create a frame. Python doesn't let you construct89# FrameType directly, so just make one with compile90frame = tb.tb_frame91code = compile('__inspect_currentframe()', f.name, 'eval')92code = code.replace(co_name=frame.f_code.co_name)93# Python 3.11 only94if hasattr(frame.f_code, 'co_linetable'):95# We can't copy ALL of the metadata over, because you96# can cause Python to segfault this way. What exactly97# do we need? We need enough information for98# traceback to be able to print the exception99# correctly. Code reading Lib/traceback.py reveals100# that traceback calls code.co_positions() in order to101# get the augmented line/col numbers. Objects/codeobject.c,102# specifically _PyCode_InitAddressRange, reveals that103# this iterator is initialized from co_linetable and104# co_firstfileno. So copy these we must!105code = code.replace( # type: ignore[call-arg]106co_linetable=frame.f_code.co_linetable, # type: ignore[attr-defined]107co_firstlineno=frame.f_code.co_firstlineno, # type: ignore[attr-defined]108)109fake_frame = eval(110code,111frame.f_globals,112{113**frame.f_locals,114'__inspect_currentframe': inspect.currentframe115}116)117fake_tb = TracebackType(118None, fake_frame, tb.tb_lasti, tb.tb_lineno119)120stack.append(fake_tb)121else:122stack.append(tb)123
124tb = tb.tb_next125
126# Reconstruct the linked list127tb_next = None128for tb in reversed(stack):129tb.tb_next = tb_next130tb_next = tb131
132raise exc.with_traceback(tb_next) # noqa: B904133
134def shorten_filename(fn, *, base=None):135"""Shorten a source filepath, with the assumption that torch/ subdirectories don't need to be shown to user."""136if base is None:137base = os.path.dirname(os.path.dirname(__file__))138# Truncate torch/foo.py to foo.py139try:140prefix = os.path.commonpath([fn, base])141except ValueError:142return fn143else:144return fn[len(prefix) + 1:]145
146def format_frame(frame, *, base=None, line=False):147"""148Format a FrameSummary in a short way, without printing full absolute path or code.
149
150The idea is the result fits on a single line.
151"""
152extra_line = ""153if line:154extra_line = f"{frame.line} # "155return f"{extra_line}{shorten_filename(frame.filename, base=base)}:{frame.lineno} in {frame.name}"156
157def format_traceback_short(tb):158"""Format a TracebackType in a short way, printing only the inner-most frame."""159return format_frame(traceback.extract_tb(tb)[-1])160
161class CapturedTraceback:162__slots__ = ['tb', 'skip']163
164def __init__(self, tb, skip=0):165self.tb = tb166self.skip = skip167
168def cleanup(self):169self.tb = None170
171def summary(self):172import torch._C._profiler173
174if self.tb is None:175# TODO: Maybe indicate that the traceback was elided?176return traceback.StackSummary()177
178return _extract_symbolized_tb(179torch._C._profiler.symbolize_tracebacks([self.tb])[0],180self.skip181)182
183def __getstate__(self):184return (None, {185'tb': None, # TB is not pickleable186'skip': self.skip,187})188
189@staticmethod190def extract(*, script=False, cpp=False, skip=0):191"""192Like traceback.extract_stack(), but faster (approximately 20x faster); it
193is fast enough that you can unconditionally log stacks this way as part of
194normal execution. It returns a torch._C._profiler.CapturedTraceback
195object that must be formatted specially with format_captured_tb.
196
197By default, this only reports Python backtraces (like extract_stack). You
198can set the script/cpp kwargs to also turn on TorchScript/C++ trace
199reporting.
200"""
201import torch._C._profiler202
203if script or cpp:204assert skip == 0, "skip with script/cpp NYI"205
206return CapturedTraceback(207torch._C._profiler.gather_traceback(python=True, script=script, cpp=cpp),208# Elide extract() frame if we don't have script/cpp frames. If209# we do have those frames, it doesn't work so force zero.2100 if script or cpp else skip + 1211)212
213def format(self):214"""215Formats a single torch._C._profiler.CapturedTraceback into a list of
216strings equivalent to the output of traceback.format_list. Note that if
217pass it CapturedTraceback with C++ traces, it is better not to use this
218function and use the batch formatting API format_captured_tbs to amortize
219the cost of symbolization
220"""
221return traceback.format_list(self.summary())222
223@staticmethod224def format_all(tbs):225"""226Bulk version of CapturedTraceback.format. Returns a list of list of strings.
227"""
228import torch._C._profiler229
230# Directly populate tracebacks that already have cached summaries231rs: List[Optional[List[str]]] = []232delayed_idxs = []233for i, tb in enumerate(tbs):234if tb.tb is None:235rs.append([])236else:237rs.append(None)238delayed_idxs.append(i)239
240stbs = torch._C._profiler.symbolize_tracebacks([tbs[i].tb for i in delayed_idxs])241for i, stb in zip(delayed_idxs, stbs):242rs[i] = traceback.format_list(tbs[i].summary())243
244return rs245
246
247def _extract_symbolized_tb(tb, skip):248"""249Given a symbolized traceback from symbolize_tracebacks, return a StackSummary object of
250pre-processed stack trace entries.
251"""
252stack = traceback.StackSummary()253for f in reversed(tb[skip:]):254stack.append(traceback.FrameSummary(f['filename'], f['line'], f['name']))255return stack256