pytorch

Форк
0
/
_traceback.py 
254 строки · 10.0 Кб
1
from types import TracebackType
2
from typing import List, Optional
3
import tempfile
4
import traceback
5
import contextlib
6
import inspect
7
import os.path
8

9
# This file contains utilities for ensuring dynamically compile()'d
10
# code fragments display their line numbers in backtraces.
11
#
12
# The constraints:
13
#
14
# - We don't have control over the user exception printer (in particular,
15
#   we cannot assume the linecache trick will work, c.f.
16
#   https://stackoverflow.com/q/50515651/23845 )
17
#
18
# - We don't want to create temporary files every time we compile()
19
#   some code; file creation should happen lazily only at exception
20
#   time.  Arguably, you *should* be willing to write out your
21
#   generated Python code to file system, but in some situations
22
#   (esp. library code) it would violate user expectation to write
23
#   to the file system, so we try to avoid it.  In particular, we'd
24
#   like to keep the files around, so users can open up the files
25
#   mentioned in the trace; if the file is invisible, we want to
26
#   avoid clogging up the filesystem.
27
#
28
#   If this is not a constraint for you, there is a substantially simpler
29
#   way to implement the functionality in this PR: instead of using
30
#   eval/exec directly, just always write a Python file to filesystem
31
#   and compile that.
32
#
33
# - You have control over a context where the compiled code will get
34
#   executed, so that we can interpose while the stack is unwinding
35
#   (otherwise, we have no way to interpose on the exception printing
36
#   process.)
37
#
38
# There are two things you have to do to make use of the utilities here:
39
#
40
# - When you compile your source code, you must save its string source
41
#   in its f_globals under the magic name "__compile_source__"
42
#
43
# - Before running the compiled code, enter the
44
#   report_compile_source_on_error() context manager.
45

46
@contextlib.contextmanager
47
def report_compile_source_on_error():
48
    try:
49
        yield
50
    except Exception as exc:
51
        tb = exc.__traceback__
52

53
        # Walk the traceback, looking for frames that have
54
        # source attached
55
        stack = []
56
        while tb is not None:
57
            filename = tb.tb_frame.f_code.co_filename
58
            source = tb.tb_frame.f_globals.get("__compile_source__")
59

60
            if filename == "<string>" and source is not None:
61
                # What black magic are we doing here?  Intuitively, what
62
                # we would like to do is overwrite the co_filename on any
63
                # frames that were generated from exec/eval so that they
64
                # point to a temporary file that has the actual line
65
                # information, so Python's default error printer can print
66
                # useful line information on it.
67
                #
68
                # Writing out the temporary file is easy.  But overwriting
69
                # co_filename is not!  You can't modify the code object
70
                # associated with a frame.  You can, however, reconstruct
71
                # a traceback with entirely new frames from scratch, so that's
72
                # what we do.  But there's another problem, which is how to
73
                # make the frame?
74
                #
75
                # The black magic is we make a frankenstein frame and code
76
                # object which resembles the original frame/code enough so
77
                # that it will print properly under traceback and the default
78
                # error printer, but IT IS NOT THE ORIGINAL FRAME (you
79
                # couldn't, e.g., execute its code with different variables
80
                # and expect it to work.)
81

82
                # Don't delete the temporary file so the user can inspect it
83
                # TODO: This creates a temporary file for every frame, but we
84
                # technically only need one per distinct __compile_source__
85
                with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix=".py") as f:
86
                    f.write(source)
87
                # Create a frame.  Python doesn't let you construct
88
                # FrameType directly, so just make one with compile
89
                frame = tb.tb_frame
90
                code = compile('__inspect_currentframe()', f.name, 'eval')
91
                code = code.replace(co_name=frame.f_code.co_name)
92
                # Python 3.11 only
93
                if hasattr(frame.f_code, 'co_linetable'):
94
                    # We can't copy ALL of the metadata over, because you
95
                    # can cause Python to segfault this way.  What exactly
96
                    # do we need?  We need enough information for
97
                    # traceback to be able to print the exception
98
                    # correctly.  Code reading Lib/traceback.py reveals
99
                    # that traceback calls code.co_positions() in order to
100
                    # get the augmented line/col numbers.  Objects/codeobject.c,
101
                    # specifically _PyCode_InitAddressRange, reveals that
102
                    # this iterator is initialized from co_linetable and
103
                    # co_firstfileno.  So copy these we must!
104
                    code = code.replace(  # type: ignore[call-arg]
105
                        co_linetable=frame.f_code.co_linetable,  # type: ignore[attr-defined]
106
                        co_firstlineno=frame.f_code.co_firstlineno,  # type: ignore[attr-defined]
107
                    )
108
                fake_frame = eval(
109
                    code,
110
                    frame.f_globals,
111
                    {
112
                        **frame.f_locals,
113
                        '__inspect_currentframe': inspect.currentframe
114
                    }
115
                )
116
                fake_tb = TracebackType(
117
                    None, fake_frame, tb.tb_lasti, tb.tb_lineno
118
                )
119
                stack.append(fake_tb)
120
            else:
121
                stack.append(tb)
122

123
            tb = tb.tb_next
124

125
        # Reconstruct the linked list
126
        tb_next = None
127
        for tb in reversed(stack):
128
            tb.tb_next = tb_next
129
            tb_next = tb
130

131
        raise exc.with_traceback(tb_next)  # noqa: TRY200
132

133
def shorten_filename(fn, *, base=None):
134
    """Shorten a source filepath, with the assumption that torch/ subdirectories don't need to be shown to user."""
135
    if base is None:
136
        base = os.path.dirname(os.path.dirname(__file__))
137
    # Truncate torch/foo.py to foo.py
138
    try:
139
        prefix = os.path.commonpath([fn, base])
140
    except ValueError:
141
        return fn
142
    else:
143
        return fn[len(prefix) + 1:]
144

145
def format_frame(frame, *, base=None, line=False):
146
    """
147
    Format a FrameSummary in a short way, without printing full absolute path or code.
148

149
    The idea is the result fits on a single line.
150
    """
151
    extra_line = ""
152
    if line:
153
        extra_line = f"{frame.line}  # "
154
    return f"{extra_line}{shorten_filename(frame.filename, base=base)}:{frame.lineno} in {frame.name}"
155

156
def format_traceback_short(tb):
157
    """Format a TracebackType in a short way, printing only the inner-most frame."""
158
    return format_frame(traceback.extract_tb(tb)[-1])
159

160
class CapturedTraceback:
161
    __slots__ = ['tb', 'skip']
162

163
    def __init__(self, tb, skip=0):
164
        self.tb = tb
165
        self.skip = skip
166

167
    def cleanup(self):
168
        self.tb = None
169

170
    def summary(self):
171
        import torch._C._profiler
172

173
        if self.tb is None:
174
            # TODO: Maybe indicate that the traceback was elided?
175
            return traceback.StackSummary()
176

177
        return _extract_symbolized_tb(
178
            torch._C._profiler.symbolize_tracebacks([self.tb])[0],
179
            self.skip
180
        )
181

182
    def __getstate__(self):
183
        return (None, {
184
            'tb': None,  # TB is not pickleable
185
            'skip': self.skip,
186
        })
187

188
    @staticmethod
189
    def extract(*, script=False, cpp=False, skip=0):
190
        """
191
        Like traceback.extract_stack(), but faster (approximately 20x faster); it
192
        is fast enough that you can unconditionally log stacks this way as part of
193
        normal execution.  It returns a torch._C._profiler.CapturedTraceback
194
        object that must be formatted specially with format_captured_tb.
195

196
        By default, this only reports Python backtraces (like extract_stack).  You
197
        can set the script/cpp kwargs to also turn on TorchScript/C++ trace
198
        reporting.
199
        """
200
        import torch._C._profiler
201

202
        if script or cpp:
203
            assert skip == 0, "skip with script/cpp NYI"
204

205
        return CapturedTraceback(
206
            torch._C._profiler.gather_traceback(python=True, script=script, cpp=cpp),
207
            # Elide extract() frame if we don't have script/cpp frames.  If
208
            # we do have those frames, it doesn't work so force zero.
209
            0 if script or cpp else skip + 1
210
        )
211

212
    def format(self):
213
        """
214
        Formats a single torch._C._profiler.CapturedTraceback into a list of
215
        strings equivalent to the output of traceback.format_list.  Note that if
216
        pass it CapturedTraceback with C++ traces,  it is better not to use this
217
        function and use the batch formatting API format_captured_tbs to amortize
218
        the cost of symbolization
219
        """
220
        return traceback.format_list(self.summary())
221

222
    @staticmethod
223
    def format_all(tbs):
224
        """
225
        Bulk version of CapturedTraceback.format.  Returns a list of list of strings.
226
        """
227
        import torch._C._profiler
228

229
        # Directly populate tracebacks that already have cached summaries
230
        rs: List[Optional[List[str]]] = []
231
        delayed_idxs = []
232
        for i, tb in enumerate(tbs):
233
            if tb.tb is None:
234
                rs.append([])
235
            else:
236
                rs.append(None)
237
                delayed_idxs.append(i)
238

239
        stbs = torch._C._profiler.symbolize_tracebacks([tbs[i].tb for i in delayed_idxs])
240
        for i, stb in zip(delayed_idxs, stbs):
241
            rs[i] = traceback.format_list(tbs[i].summary())
242

243
        return rs
244

245

246
def _extract_symbolized_tb(tb, skip):
247
    """
248
    Given a symbolized traceback from symbolize_tracebacks, return a StackSummary object of
249
    pre-processed stack trace entries.
250
    """
251
    stack = traceback.StackSummary()
252
    for f in reversed(tb[skip:]):
253
        stack.append(traceback.FrameSummary(f['filename'], f['line'], f['name']))
254
    return stack
255

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.