pytorch
1from types import TracebackType2from typing import List, Optional3import tempfile4import traceback5import contextlib6import inspect7import os.path8
9# This file contains utilities for ensuring dynamically compile()'d
10# code fragments display their line numbers in backtraces.
11#
12# The constraints:
13#
14# - We don't have control over the user exception printer (in particular,
15# we cannot assume the linecache trick will work, c.f.
16# https://stackoverflow.com/q/50515651/23845 )
17#
18# - We don't want to create temporary files every time we compile()
19# some code; file creation should happen lazily only at exception
20# time. Arguably, you *should* be willing to write out your
21# generated Python code to file system, but in some situations
22# (esp. library code) it would violate user expectation to write
23# to the file system, so we try to avoid it. In particular, we'd
24# like to keep the files around, so users can open up the files
25# mentioned in the trace; if the file is invisible, we want to
26# avoid clogging up the filesystem.
27#
28# If this is not a constraint for you, there is a substantially simpler
29# way to implement the functionality in this PR: instead of using
30# eval/exec directly, just always write a Python file to filesystem
31# and compile that.
32#
33# - You have control over a context where the compiled code will get
34# executed, so that we can interpose while the stack is unwinding
35# (otherwise, we have no way to interpose on the exception printing
36# process.)
37#
38# There are two things you have to do to make use of the utilities here:
39#
40# - When you compile your source code, you must save its string source
41# in its f_globals under the magic name "__compile_source__"
42#
43# - Before running the compiled code, enter the
44# report_compile_source_on_error() context manager.
45
46@contextlib.contextmanager47def report_compile_source_on_error():48try:49yield50except Exception as exc:51tb = exc.__traceback__52
53# Walk the traceback, looking for frames that have54# source attached55stack = []56while tb is not None:57filename = tb.tb_frame.f_code.co_filename58source = tb.tb_frame.f_globals.get("__compile_source__")59
60if filename == "<string>" and source is not None:61# What black magic are we doing here? Intuitively, what62# we would like to do is overwrite the co_filename on any63# frames that were generated from exec/eval so that they64# point to a temporary file that has the actual line65# information, so Python's default error printer can print66# useful line information on it.67#68# Writing out the temporary file is easy. But overwriting69# co_filename is not! You can't modify the code object70# associated with a frame. You can, however, reconstruct71# a traceback with entirely new frames from scratch, so that's72# what we do. But there's another problem, which is how to73# make the frame?74#75# The black magic is we make a frankenstein frame and code76# object which resembles the original frame/code enough so77# that it will print properly under traceback and the default78# error printer, but IT IS NOT THE ORIGINAL FRAME (you79# couldn't, e.g., execute its code with different variables80# and expect it to work.)81
82# Don't delete the temporary file so the user can inspect it83# TODO: This creates a temporary file for every frame, but we84# technically only need one per distinct __compile_source__85with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix=".py") as f:86f.write(source)87# Create a frame. Python doesn't let you construct88# FrameType directly, so just make one with compile89frame = tb.tb_frame90code = compile('__inspect_currentframe()', f.name, 'eval')91code = code.replace(co_name=frame.f_code.co_name)92# Python 3.11 only93if hasattr(frame.f_code, 'co_linetable'):94# We can't copy ALL of the metadata over, because you95# can cause Python to segfault this way. What exactly96# do we need? We need enough information for97# traceback to be able to print the exception98# correctly. Code reading Lib/traceback.py reveals99# that traceback calls code.co_positions() in order to100# get the augmented line/col numbers. Objects/codeobject.c,101# specifically _PyCode_InitAddressRange, reveals that102# this iterator is initialized from co_linetable and103# co_firstfileno. So copy these we must!104code = code.replace( # type: ignore[call-arg]105co_linetable=frame.f_code.co_linetable, # type: ignore[attr-defined]106co_firstlineno=frame.f_code.co_firstlineno, # type: ignore[attr-defined]107)108fake_frame = eval(109code,110frame.f_globals,111{112**frame.f_locals,113'__inspect_currentframe': inspect.currentframe114}115)116fake_tb = TracebackType(117None, fake_frame, tb.tb_lasti, tb.tb_lineno118)119stack.append(fake_tb)120else:121stack.append(tb)122
123tb = tb.tb_next124
125# Reconstruct the linked list126tb_next = None127for tb in reversed(stack):128tb.tb_next = tb_next129tb_next = tb130
131raise exc.with_traceback(tb_next) # noqa: TRY200132
133def shorten_filename(fn, *, base=None):134"""Shorten a source filepath, with the assumption that torch/ subdirectories don't need to be shown to user."""135if base is None:136base = os.path.dirname(os.path.dirname(__file__))137# Truncate torch/foo.py to foo.py138try:139prefix = os.path.commonpath([fn, base])140except ValueError:141return fn142else:143return fn[len(prefix) + 1:]144
145def format_frame(frame, *, base=None, line=False):146"""147Format a FrameSummary in a short way, without printing full absolute path or code.
148
149The idea is the result fits on a single line.
150"""
151extra_line = ""152if line:153extra_line = f"{frame.line} # "154return f"{extra_line}{shorten_filename(frame.filename, base=base)}:{frame.lineno} in {frame.name}"155
156def format_traceback_short(tb):157"""Format a TracebackType in a short way, printing only the inner-most frame."""158return format_frame(traceback.extract_tb(tb)[-1])159
160class CapturedTraceback:161__slots__ = ['tb', 'skip']162
163def __init__(self, tb, skip=0):164self.tb = tb165self.skip = skip166
167def cleanup(self):168self.tb = None169
170def summary(self):171import torch._C._profiler172
173if self.tb is None:174# TODO: Maybe indicate that the traceback was elided?175return traceback.StackSummary()176
177return _extract_symbolized_tb(178torch._C._profiler.symbolize_tracebacks([self.tb])[0],179self.skip180)181
182def __getstate__(self):183return (None, {184'tb': None, # TB is not pickleable185'skip': self.skip,186})187
188@staticmethod189def extract(*, script=False, cpp=False, skip=0):190"""191Like traceback.extract_stack(), but faster (approximately 20x faster); it
192is fast enough that you can unconditionally log stacks this way as part of
193normal execution. It returns a torch._C._profiler.CapturedTraceback
194object that must be formatted specially with format_captured_tb.
195
196By default, this only reports Python backtraces (like extract_stack). You
197can set the script/cpp kwargs to also turn on TorchScript/C++ trace
198reporting.
199"""
200import torch._C._profiler201
202if script or cpp:203assert skip == 0, "skip with script/cpp NYI"204
205return CapturedTraceback(206torch._C._profiler.gather_traceback(python=True, script=script, cpp=cpp),207# Elide extract() frame if we don't have script/cpp frames. If208# we do have those frames, it doesn't work so force zero.2090 if script or cpp else skip + 1210)211
212def format(self):213"""214Formats a single torch._C._profiler.CapturedTraceback into a list of
215strings equivalent to the output of traceback.format_list. Note that if
216pass it CapturedTraceback with C++ traces, it is better not to use this
217function and use the batch formatting API format_captured_tbs to amortize
218the cost of symbolization
219"""
220return traceback.format_list(self.summary())221
222@staticmethod223def format_all(tbs):224"""225Bulk version of CapturedTraceback.format. Returns a list of list of strings.
226"""
227import torch._C._profiler228
229# Directly populate tracebacks that already have cached summaries230rs: List[Optional[List[str]]] = []231delayed_idxs = []232for i, tb in enumerate(tbs):233if tb.tb is None:234rs.append([])235else:236rs.append(None)237delayed_idxs.append(i)238
239stbs = torch._C._profiler.symbolize_tracebacks([tbs[i].tb for i in delayed_idxs])240for i, stb in zip(delayed_idxs, stbs):241rs[i] = traceback.format_list(tbs[i].summary())242
243return rs244
245
246def _extract_symbolized_tb(tb, skip):247"""248Given a symbolized traceback from symbolize_tracebacks, return a StackSummary object of
249pre-processed stack trace entries.
250"""
251stack = traceback.StackSummary()252for f in reversed(tb[skip:]):253stack.append(traceback.FrameSummary(f['filename'], f['line'], f['name']))254return stack255