pytorch

Форк
0
/
_traceback.py 
255 строк · 10.1 Кб
1
# mypy: allow-untyped-defs
2
from types import TracebackType
3
from typing import List, Optional
4
import tempfile
5
import traceback
6
import contextlib
7
import inspect
8
import os.path
9

10
# This file contains utilities for ensuring dynamically compile()'d
11
# code fragments display their line numbers in backtraces.
12
#
13
# The constraints:
14
#
15
# - We don't have control over the user exception printer (in particular,
16
#   we cannot assume the linecache trick will work, c.f.
17
#   https://stackoverflow.com/q/50515651/23845 )
18
#
19
# - We don't want to create temporary files every time we compile()
20
#   some code; file creation should happen lazily only at exception
21
#   time.  Arguably, you *should* be willing to write out your
22
#   generated Python code to file system, but in some situations
23
#   (esp. library code) it would violate user expectation to write
24
#   to the file system, so we try to avoid it.  In particular, we'd
25
#   like to keep the files around, so users can open up the files
26
#   mentioned in the trace; if the file is invisible, we want to
27
#   avoid clogging up the filesystem.
28
#
29
#   If this is not a constraint for you, there is a substantially simpler
30
#   way to implement the functionality in this PR: instead of using
31
#   eval/exec directly, just always write a Python file to filesystem
32
#   and compile that.
33
#
34
# - You have control over a context where the compiled code will get
35
#   executed, so that we can interpose while the stack is unwinding
36
#   (otherwise, we have no way to interpose on the exception printing
37
#   process.)
38
#
39
# There are two things you have to do to make use of the utilities here:
40
#
41
# - When you compile your source code, you must save its string source
42
#   in its f_globals under the magic name "__compile_source__"
43
#
44
# - Before running the compiled code, enter the
45
#   report_compile_source_on_error() context manager.
46

47
@contextlib.contextmanager
48
def report_compile_source_on_error():
49
    try:
50
        yield
51
    except Exception as exc:
52
        tb = exc.__traceback__
53

54
        # Walk the traceback, looking for frames that have
55
        # source attached
56
        stack = []
57
        while tb is not None:
58
            filename = tb.tb_frame.f_code.co_filename
59
            source = tb.tb_frame.f_globals.get("__compile_source__")
60

61
            if filename == "<string>" and source is not None:
62
                # What black magic are we doing here?  Intuitively, what
63
                # we would like to do is overwrite the co_filename on any
64
                # frames that were generated from exec/eval so that they
65
                # point to a temporary file that has the actual line
66
                # information, so Python's default error printer can print
67
                # useful line information on it.
68
                #
69
                # Writing out the temporary file is easy.  But overwriting
70
                # co_filename is not!  You can't modify the code object
71
                # associated with a frame.  You can, however, reconstruct
72
                # a traceback with entirely new frames from scratch, so that's
73
                # what we do.  But there's another problem, which is how to
74
                # make the frame?
75
                #
76
                # The black magic is we make a frankenstein frame and code
77
                # object which resembles the original frame/code enough so
78
                # that it will print properly under traceback and the default
79
                # error printer, but IT IS NOT THE ORIGINAL FRAME (you
80
                # couldn't, e.g., execute its code with different variables
81
                # and expect it to work.)
82

83
                # Don't delete the temporary file so the user can inspect it
84
                # TODO: This creates a temporary file for every frame, but we
85
                # technically only need one per distinct __compile_source__
86
                with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix=".py") as f:
87
                    f.write(source)
88
                # Create a frame.  Python doesn't let you construct
89
                # FrameType directly, so just make one with compile
90
                frame = tb.tb_frame
91
                code = compile('__inspect_currentframe()', f.name, 'eval')
92
                code = code.replace(co_name=frame.f_code.co_name)
93
                # Python 3.11 only
94
                if hasattr(frame.f_code, 'co_linetable'):
95
                    # We can't copy ALL of the metadata over, because you
96
                    # can cause Python to segfault this way.  What exactly
97
                    # do we need?  We need enough information for
98
                    # traceback to be able to print the exception
99
                    # correctly.  Code reading Lib/traceback.py reveals
100
                    # that traceback calls code.co_positions() in order to
101
                    # get the augmented line/col numbers.  Objects/codeobject.c,
102
                    # specifically _PyCode_InitAddressRange, reveals that
103
                    # this iterator is initialized from co_linetable and
104
                    # co_firstfileno.  So copy these we must!
105
                    code = code.replace(  # type: ignore[call-arg]
106
                        co_linetable=frame.f_code.co_linetable,  # type: ignore[attr-defined]
107
                        co_firstlineno=frame.f_code.co_firstlineno,  # type: ignore[attr-defined]
108
                    )
109
                fake_frame = eval(
110
                    code,
111
                    frame.f_globals,
112
                    {
113
                        **frame.f_locals,
114
                        '__inspect_currentframe': inspect.currentframe
115
                    }
116
                )
117
                fake_tb = TracebackType(
118
                    None, fake_frame, tb.tb_lasti, tb.tb_lineno
119
                )
120
                stack.append(fake_tb)
121
            else:
122
                stack.append(tb)
123

124
            tb = tb.tb_next
125

126
        # Reconstruct the linked list
127
        tb_next = None
128
        for tb in reversed(stack):
129
            tb.tb_next = tb_next
130
            tb_next = tb
131

132
        raise exc.with_traceback(tb_next)  # noqa: B904
133

134
def shorten_filename(fn, *, base=None):
135
    """Shorten a source filepath, with the assumption that torch/ subdirectories don't need to be shown to user."""
136
    if base is None:
137
        base = os.path.dirname(os.path.dirname(__file__))
138
    # Truncate torch/foo.py to foo.py
139
    try:
140
        prefix = os.path.commonpath([fn, base])
141
    except ValueError:
142
        return fn
143
    else:
144
        return fn[len(prefix) + 1:]
145

146
def format_frame(frame, *, base=None, line=False):
147
    """
148
    Format a FrameSummary in a short way, without printing full absolute path or code.
149

150
    The idea is the result fits on a single line.
151
    """
152
    extra_line = ""
153
    if line:
154
        extra_line = f"{frame.line}  # "
155
    return f"{extra_line}{shorten_filename(frame.filename, base=base)}:{frame.lineno} in {frame.name}"
156

157
def format_traceback_short(tb):
158
    """Format a TracebackType in a short way, printing only the inner-most frame."""
159
    return format_frame(traceback.extract_tb(tb)[-1])
160

161
class CapturedTraceback:
162
    __slots__ = ['tb', 'skip']
163

164
    def __init__(self, tb, skip=0):
165
        self.tb = tb
166
        self.skip = skip
167

168
    def cleanup(self):
169
        self.tb = None
170

171
    def summary(self):
172
        import torch._C._profiler
173

174
        if self.tb is None:
175
            # TODO: Maybe indicate that the traceback was elided?
176
            return traceback.StackSummary()
177

178
        return _extract_symbolized_tb(
179
            torch._C._profiler.symbolize_tracebacks([self.tb])[0],
180
            self.skip
181
        )
182

183
    def __getstate__(self):
184
        return (None, {
185
            'tb': None,  # TB is not pickleable
186
            'skip': self.skip,
187
        })
188

189
    @staticmethod
190
    def extract(*, script=False, cpp=False, skip=0):
191
        """
192
        Like traceback.extract_stack(), but faster (approximately 20x faster); it
193
        is fast enough that you can unconditionally log stacks this way as part of
194
        normal execution.  It returns a torch._C._profiler.CapturedTraceback
195
        object that must be formatted specially with format_captured_tb.
196

197
        By default, this only reports Python backtraces (like extract_stack).  You
198
        can set the script/cpp kwargs to also turn on TorchScript/C++ trace
199
        reporting.
200
        """
201
        import torch._C._profiler
202

203
        if script or cpp:
204
            assert skip == 0, "skip with script/cpp NYI"
205

206
        return CapturedTraceback(
207
            torch._C._profiler.gather_traceback(python=True, script=script, cpp=cpp),
208
            # Elide extract() frame if we don't have script/cpp frames.  If
209
            # we do have those frames, it doesn't work so force zero.
210
            0 if script or cpp else skip + 1
211
        )
212

213
    def format(self):
214
        """
215
        Formats a single torch._C._profiler.CapturedTraceback into a list of
216
        strings equivalent to the output of traceback.format_list.  Note that if
217
        pass it CapturedTraceback with C++ traces,  it is better not to use this
218
        function and use the batch formatting API format_captured_tbs to amortize
219
        the cost of symbolization
220
        """
221
        return traceback.format_list(self.summary())
222

223
    @staticmethod
224
    def format_all(tbs):
225
        """
226
        Bulk version of CapturedTraceback.format.  Returns a list of list of strings.
227
        """
228
        import torch._C._profiler
229

230
        # Directly populate tracebacks that already have cached summaries
231
        rs: List[Optional[List[str]]] = []
232
        delayed_idxs = []
233
        for i, tb in enumerate(tbs):
234
            if tb.tb is None:
235
                rs.append([])
236
            else:
237
                rs.append(None)
238
                delayed_idxs.append(i)
239

240
        stbs = torch._C._profiler.symbolize_tracebacks([tbs[i].tb for i in delayed_idxs])
241
        for i, stb in zip(delayed_idxs, stbs):
242
            rs[i] = traceback.format_list(tbs[i].summary())
243

244
        return rs
245

246

247
def _extract_symbolized_tb(tb, skip):
248
    """
249
    Given a symbolized traceback from symbolize_tracebacks, return a StackSummary object of
250
    pre-processed stack trace entries.
251
    """
252
    stack = traceback.StackSummary()
253
    for f in reversed(tb[skip:]):
254
        stack.append(traceback.FrameSummary(f['filename'], f['line'], f['name']))
255
    return stack
256

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.