tensor-sensor

Форк
0
/
analysis.py 
538 строк · 26.1 Кб
1
"""
2
MIT License
3

4
Copyright (c) 2021 Terence Parr
5

6
Permission is hereby granted, free of charge, to any person obtaining a copy
7
of this software and associated documentation files (the "Software"), to deal
8
in the Software without restriction, including without limitation the rights
9
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10
copies of the Software, and to permit persons to whom the Software is
11
furnished to do so, subject to the following conditions:
12

13
The above copyright notice and this permission notice shall be included in all
14
copies or substantial portions of the Software.
15

16
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22
SOFTWARE.
23
"""
24
import os
25
import sys
26
import traceback
27
import inspect
28
import hashlib
29
from pathlib import Path
30

31
import matplotlib.pyplot as plt
32

33
import tsensor
34

35

36
class clarify:
37
    # Prevent nested clarify() calls from processing exceptions.
38
    # See https://github.com/parrt/tensor-sensor/issues/18
39
    # Probably will fail with Python `threading` package due to this class var
40
    # but only if multiple threads call clarify().
41
    # Multiprocessing forks new processes so not a problem. Each vm has it's own class var.
42
    # Bump in __enter__, drop in __exit__
43
    nesting = 0
44

45
    def __init__(self,
46
                 fontname=('Consolas', 'DejaVu Sans Mono'), fontsize=13,
47
                 dimfontname='Arial', dimfontsize=9, char_sep_scale=1.8, fontcolor='#444443',
48
                 underline_color='#C2C2C2', ignored_color='#B4B4B4', error_op_color='#A40227',
49
                 show:(None,'viz')='viz',
50
                 hush_errors=True,
51
                 dtype_colors=None, dtype_precisions=None, dtype_alpha_range=None):
52
        """
53
        Augment tensor-related exceptions generated from numpy, pytorch, and tensorflow.
54
        Also display a visual representation of the offending Python line that
55
        shows the shape of tensors referenced by the code. All you have to do is wrap
56
        the outermost level of your code and clarify() will activate upon exception.
57

58
        Visualizations pop up in a separate window unless running from a notebook,
59
        in which case the visualization appears as part of the cell execution output.
60

61
        There is no runtime overhead associated with clarify() unless an exception occurs.
62

63
        The offending code is executed a second time, to identify which sub expressions
64
        are to blame. This implies that code with side effects could conceivably cause
65
        a problem, but since an exception has been generated, results are suspicious
66
        anyway.
67

68
        Example:
69

70
        import numpy as np
71
        import tsensor
72

73
        b = np.array([9, 10]).reshape(2, 1)
74
        with tsensor.clarify():
75
            np.dot(b,b) # tensor code or call to a function with tensor code
76

77
        See examples.ipynb for more examples.
78

79
        :param fontname: The name of the font used to display Python code
80
        :param fontsize: The font size used to display Python code; default is 13.
81
                         Also use this to increase the size of the generated figure;
82
                         larger font size means larger image.
83
        :param dimfontname:  The name of the font used to display the dimensions on the matrix and vector boxes
84
        :param dimfontsize: The  size of the font used to display the dimensions on the matrix and vector boxes
85
        :param char_sep_scale: It is notoriously difficult to discover how wide and tall
86
                               text is when plotted in matplotlib. In fact there's probably,
87
                               no hope to discover this information accurately in all cases.
88
                               Certainly, I gave up after spending huge effort. We have a
89
                               situation here where the font should be constant width, so
90
                               we can just use a simple scaler times the font size  to get
91
                               a reasonable approximation to the width and height of a
92
                               character box; the default of 1.8 seems to work reasonably
93
                               well for a wide range of fonts, but you might have to tweak it
94
                               when you change the font size.
95
        :param fontcolor:  The color of the Python code.
96
        :param underline_color:  The color of the lines that underscore tensor subexpressions; default is grey
97
        :param ignored_color: The de-highlighted color for deemphasizing code not involved in an erroneous sub expression
98
        :param error_op_color: The color to use for characters associated with the erroneous operator
99
        :param ax: If not none, this is the matplotlib drawing region in which to draw the visualization
100
        :param dpi: This library tries to generate SVG files, which are vector graphics not
101
                    2D arrays of pixels like PNG files. However, it needs to know how to
102
                    compute the exact figure size to remove padding around the visualization.
103
                    Matplotlib uses inches for its figure size and so we must convert
104
                    from pixels or data units to inches, which means we have to know what the
105
                    dots per inch, dpi, is for the image.
106
        :param hush_errors: Normally, error messages from true syntax errors but also
107
                            unhandled code caught by my parser are ignored. Turn this off
108
                            to see what the error messages are coming from my parser.
109
        :param show: Show visualization upon tensor error if show='viz'.
110
        :param dtype_colors: map from dtype w/o precision like 'int' to color
111
        :param dtype_precisions: list of bit precisions to colorize, such as [32,64,128]
112
        :param dtype_alpha_range: all tensors of the same type are drawn to the same color,
113
                                  and the alpha channel is used to show precision; the
114
                                  smaller the bit size, the lower the alpha channel. You
115
                                  can play with the range to get better visual dynamic range
116
                                  depending on how many precisions you want to display.
117
        """
118
        self.show, self.fontname, self.fontsize, self.dimfontname, self.dimfontsize, \
119
        self.char_sep_scale, self.fontcolor, self.underline_color, self.ignored_color, \
120
        self.error_op_color, self.hush_errors, \
121
        self.dtype_colors, self.dtype_precisions, self.dtype_alpha_range = \
122
            show, fontname, fontsize, dimfontname, dimfontsize, \
123
            char_sep_scale, fontcolor, underline_color, ignored_color, \
124
            error_op_color, hush_errors, \
125
            dtype_colors, dtype_precisions, dtype_alpha_range
126

127
    def __enter__(self):
128
        self.frame = sys._getframe().f_back # where do we start tracking? Hmm...not sure we use this
129
        # print("ENTER", clarify.nesting, self.frame, id(self.frame))
130
        clarify.nesting += 1
131
        return self
132

133
    def __exit__(self, exc_type, exc_value, exc_traceback):
134
        # print("EXIT", clarify.nesting, self.frame, id(self.frame))
135
        clarify.nesting -= 1
136
        if clarify.nesting>0:
137
            return
138
        if exc_type is None:
139
            return
140
        exc_frame, lib_entry_frame = tensor_lib_entry_frame(exc_traceback)
141
        if lib_entry_frame is not None or is_interesting_exception(exc_value):
142
            # print("exception:", exc_value, exc_traceback)
143
            # traceback.print_tb(exc_traceback, limit=5, file=sys.stdout)
144
            module, name, filename, line, code = info(exc_frame)
145
            # print('info', module, name, filename, line, code)
146
            # print("exc id", id(exc_value))
147
            if code is not None:
148
                self.view = tsensor.viz.pyviz(code, exc_frame,
149
                                              self.fontname, self.fontsize, self.dimfontname,
150
                                              self.dimfontsize,
151
                                              self.char_sep_scale, self.fontcolor,
152
                                              self.underline_color, self.ignored_color,
153
                                              self.error_op_color,
154
                                              hush_errors=self.hush_errors,
155
                                              dtype_colors=self.dtype_colors,
156
                                              dtype_precisions=self.dtype_precisions,
157
                                              dtype_alpha_range=self.dtype_alpha_range)
158
                if self.view is not None: # Ignore if we can't process code causing exception (I use a subparser)
159
                    if self.show=='viz':
160
                        self.view.show()
161
                    augment_exception(exc_value, self.view.offending_expr)
162

163

164
class explain:
165
    def __init__(self,
166
                 fontname=('Consolas', 'DejaVu Sans Mono'), fontsize=13,
167
                 dimfontname='Arial', dimfontsize=9, char_sep_scale=1.8, fontcolor='#444443',
168
                 underline_color='#C2C2C2', ignored_color='#B4B4B4', error_op_color='#A40227',
169
                 savefig=None, hush_errors=True,
170
                 dtype_colors=None, dtype_precisions=None, dtype_alpha_range=None):
171
        """
172
        As the Python virtual machine executes lines of code, generate a
173
        visualization for tensor-related expressions using from numpy, pytorch,
174
        and tensorflow. The shape of tensors referenced by the code are displayed.
175

176
        Visualizations pop up in a separate window unless running from a notebook,
177
        in which case the visualization appears as part of the cell execution output.
178

179
        There is heavy runtime overhead associated with explain() as every line
180
        is executed twice: once by explain() and then another time by the interpreter
181
        as part of normal execution.
182

183
        Expressions with side effects can easily generate incorrect results. Due to
184
        this and the overhead, you should limit the use of this to code you're trying
185
        to debug.  Assignments are not evaluated by explain so code `x = ...` causes
186
        an assignment to x just once, during normal execution. This explainer
187
        knows the value of x and will display it but does not assign to it.
188

189
        Upon exception, execution will stop as usual but, like clarify(), explain()
190
        will augment the exception to indicate the offending sub expression. Further,
191
        the visualization will deemphasize code not associated with the offending
192
        sub expression. The sizes of relevant tensor values are still visualized.
193

194
        Example:
195

196
        import numpy as np
197
        import tsensor
198

199
        b = np.array([9, 10]).reshape(2, 1)
200
        with tsensor.explain():
201
            b + b # tensor code or call to a function with tensor code
202

203
        See examples.ipynb for more examples.
204

205
        :param fontname: The name of the font used to display Python code
206
        :param fontsize: The font size used to display Python code; default is 13.
207
                         Also use this to increase the size of the generated figure;
208
                         larger font size means larger image.
209
        :param dimfontname:  The name of the font used to display the dimensions on the matrix and vector boxes
210
        :param dimfontsize: The  size of the font used to display the dimensions on the matrix and vector boxes
211
        :param char_sep_scale: It is notoriously difficult to discover how wide and tall
212
                               text is when plotted in matplotlib. In fact there's probably,
213
                               no hope to discover this information accurately in all cases.
214
                               Certainly, I gave up after spending huge effort. We have a
215
                               situation here where the font should be constant width, so
216
                               we can just use a simple scaler times the font size  to get
217
                               a reasonable approximation to the width and height of a
218
                               character box; the default of 1.8 seems to work reasonably
219
                               well for a wide range of fonts, but you might have to tweak it
220
                               when you change the font size.
221
        :param fontcolor:  The color of the Python code.
222
        :param underline_color:  The color of the lines that underscore tensor subexpressions; default is grey
223
        :param ignored_color: The de-highlighted color for deemphasizing code not involved in an erroneous sub expression
224
        :param error_op_color: The color to use for characters associated with the erroneous operator
225
        :param ax: If not none, this is the matplotlib drawing region in which to draw the visualization
226
        :param dpi: This library tries to generate SVG files, which are vector graphics not
227
                    2D arrays of pixels like PNG files. However, it needs to know how to
228
                    compute the exact figure size to remove padding around the visualization.
229
                    Matplotlib uses inches for its figure size and so we must convert
230
                    from pixels or data units to inches, which means we have to know what the
231
                    dots per inch, dpi, is for the image.
232
        :param hush_errors: Normally, error messages from true syntax errors but also
233
                            unhandled code caught by my parser are ignored. Turn this off
234
                            to see what the error messages are coming from my parser.
235
        :param savefig: A string indicating where to save the visualization; don't save
236
                        a file if None.
237
        :param dtype_colors: map from dtype w/o precision like 'int' to color
238
        :param dtype_precisions: list of bit precisions to colorize, such as [32,64,128]
239
        :param dtype_alpha_range: all tensors of the same type are drawn to the same color,
240
                                  and the alpha channel is used to show precision; the
241
                                  smaller the bit size, the lower the alpha channel. You
242
                                  can play with the range to get better visual dynamic range
243
                                  depending on how many precisions you want to display.
244
        """
245
        self.savefig, self.fontname, self.fontsize, self.dimfontname, self.dimfontsize, \
246
        self.char_sep_scale, self.fontcolor, self.underline_color, self.ignored_color, \
247
        self.error_op_color, self.hush_errors, \
248
        self.dtype_colors, self.dtype_precisions, self.dtype_alpha_range = \
249
            savefig, fontname, fontsize, dimfontname, dimfontsize, \
250
            char_sep_scale, fontcolor, underline_color, ignored_color, \
251
            error_op_color, hush_errors, \
252
            dtype_colors, dtype_precisions, dtype_alpha_range
253

254
    def __enter__(self):
255
        # print("ON trace", sys._getframe())
256
        self.tracer = ExplainTensorTracer(self)
257
        sys.settrace(self.tracer.listener)
258
        frame = sys._getframe()
259
        prev = frame.f_back # get block wrapped in "with"
260
        prev.f_trace = self.tracer.listener
261
        return self.tracer
262

263
    def __exit__(self, exc_type, exc_value, exc_traceback):
264
        # print("OFF trace")
265
        sys.settrace(None)
266
        # At this point we have already tried to visualize the statement
267
        # If there was no error, the visualization will look normal
268
        # but a matrix operation error will show the erroneous operator highlighted.
269
        # That was artificial execution of the code. Now the VM has executed
270
        # the statement for real and has found the same exception. Make sure to
271
        # augment the message with causal information.
272
        if exc_type is None:
273
            return
274
        exc_frame, lib_entry_frame = tensor_lib_entry_frame(exc_traceback)
275
        if lib_entry_frame is not None or is_interesting_exception(exc_value):
276
            # print("exception:", exc_value, exc_traceback)
277
            # traceback.print_tb(exc_traceback, limit=5, file=sys.stdout)
278
            module, name, filename, line, code = info(exc_frame)
279
            # print('info', module, name, filename, line, code)
280
            if code is not None:
281
                # We've already displayed picture so just augment message
282
                root, tokens = tsensor.parsing.parse(code)
283
                if root is not None: # Could be syntax error in statement or code I can't handle
284
                    offending_expr = None
285
                    try:
286
                        root.eval(exc_frame)
287
                    except tsensor.ast.IncrEvalTrap as e:
288
                        offending_expr = e.offending_expr
289
                    augment_exception(exc_value, offending_expr)
290

291

292
class ExplainTensorTracer:
293
    def __init__(self, explainer):
294
        self.explainer = explainer
295
        self.exceptions = set()
296
        self.linecount = 0
297
        self.views = []
298
        # set of hashes for statements already visualized;
299
        # generate each combination of statement and shapes once
300
        self.done = set()
301

302
    def listener(self, frame, event, arg):
303
        # print("listener", event, ":", frame)
304
        if event!='line':
305
            # It seems that we are getting CALL events even for calls in foo() from:
306
            #   with tsensor.explain(): foo()
307
            # Must be that we already have a listener and, though we returned None here,
308
            # somehow the original listener is still getting events. Strange but oh well.
309
            # We must ignore these.
310
            return None
311
        module = frame.f_globals['__name__']
312
        info = inspect.getframeinfo(frame)
313
        filename, line = info.filename, info.lineno
314
        name = info.function
315

316
        # Note: always true since L292 above...
317
        if event=='line':
318
            self.line_listener(module, name, filename, line, info, frame)
319

320
        # By returning none, we prevent explain()'ing from descending into
321
        # invoked functions. In principle, we could allow a certain amount
322
        # of tracing but I'm not sure that would be super useful.
323
        return None
324

325
    def line_listener(self, module, name, filename, line, info, frame):
326
        code = info.code_context[0].strip()
327
        if code.startswith("sys.settrace(None)"):
328
            return
329

330
        # Don't generate a statement visualization more than once
331
        h = hash(code)
332
        if h in self.done:
333
            return
334
        self.done.add(h)
335

336
        p = tsensor.parsing.PyExprParser(code)
337
        t = p.parse()
338
        if t is not None:
339
            # print(f"A line encountered in {module}.{name}() at {filename}:{line}")
340
            # print("\t", code)
341
            # print("\t", repr(t))
342
            self.linecount += 1
343
            self.viz_statement(code, frame)
344

345
    def viz_statement(self, code, frame):
346
        view = tsensor.viz.pyviz(code, frame,
347
                                 self.explainer.fontname, self.explainer.fontsize,
348
                                 self.explainer.dimfontname,
349
                                 self.explainer.dimfontsize,
350
                                 self.explainer.char_sep_scale, self.explainer.fontcolor,
351
                                 self.explainer.underline_color, self.explainer.ignored_color,
352
                                 self.explainer.error_op_color,
353
                                 hush_errors=self.explainer.hush_errors,
354
                                 dtype_colors=self.explainer.dtype_colors,
355
                                 dtype_precisions=self.explainer.dtype_precisions,
356
                                 dtype_alpha_range=self.explainer.dtype_alpha_range)
357
        self.views.append(view)
358
        if self.explainer.savefig is not None:
359
            file_path = Path(self.explainer.savefig)
360
            file_path = file_path.parent / f"{file_path.stem}-{self.linecount}{file_path.suffix}"
361
            view.savefig(file_path)
362
            view.filename = file_path
363
            plt.close()
364
        else:
365
            view.show()
366
        return view
367

368
    @staticmethod
369
    def hash(statement):
370
        """
371
        We want to avoid generating a visualization more than once.
372
        For now, assume that the code for a statement is the unique identifier.
373
        """
374
        return hashlib.md5(statement.encode('utf-8')).hexdigest()
375

376

377
def eval(statement:str, frame=None) -> (tsensor.ast.ParseTreeNode, object):
378
    """
379
    Parse statement and return an ast in the context of execution frame or, if None,
380
    the invoking function's frame. Set the value field of all ast nodes.
381
    Overall result is in root.value.
382
    :param statement: A string representing the line of Python code to visualize within an execution frame.
383
    :param frame: The execution frame in which to evaluate the statement. If None,
384
                  use the execution frame of the invoking function
385
    :return An abstract parse tree representing the statement; nodes are
386
            ParseTreeNode subclasses.
387
    """
388
    p = tsensor.parsing.PyExprParser(statement)
389
    root = p.parse()
390
    if frame is None: # use frame of caller
391
        frame = sys._getframe().f_back
392
    root.eval(frame)
393
    return root, root.value
394

395

396
def augment_exception(exc_value, subexpr):
397
    explanation = subexpr.clarify()
398
    augment = ""
399
    if explanation is not None:
400
        augment = explanation
401
    # Reuse exception but overwrite the message
402
    if hasattr(exc_value, "_message"):
403
        exc_value._message = exc_value.message + "\n" + augment
404
    else:
405
        exc_value.args = [exc_value.args[0] + "\n" + augment]
406

407

408
def is_interesting_exception(e):
409
    # print(f"is_interesting_exception: type is {type(e)}")
410
    if e.__class__.__module__.startswith("tensorflow"):
411
        return True
412
    sentinels = {'matmul', 'THTensorMath', 'tensor', 'tensors', 'dimension',
413
                 'not aligned', 'size mismatch', 'shape', 'shapes', 'matrix',
414
                 'call to _th_addmm'}
415
    if len(e.args)==0:
416
        msg = e.message
417
    else:
418
        msg = e.args[0]
419
    return any([s in msg for s in sentinels])
420

421

422
def tensor_lib_entry_frame(exc_traceback):
423
    """
424
    Don't trace into internals of numpy/torch/tensorflow/jax; we want to reset frame
425
    to where in the user's python code it asked the tensor lib to perform an
426
    invalid operation.
427

428
    To detect libraries, look for code whose filename has "site-packages/{package}"
429
    or "dist-packages/{package}".
430

431
    Return last-user-frame, first-tensor-lib-frame if lib found else last-user-frame, None
432

433
    Note: Sometimes operators yield exceptions and no tensor lib entry frame. E.g.,
434
    np.ones(1) @ np.ones(2).
435
    """
436
    tb = exc_traceback
437
    # import traceback
438
    # for t in traceback.extract_tb(exc_traceback):
439
    #     print(t)
440
    packages = ['numpy','torch','tensorflow','jax']
441
    dirs = [os.path.join('site-packages',p) for p in packages]
442
    dirs += [os.path.join('dist-packages',p) for p in packages]
443
    dirs += ['<__array_function__'] # numpy seems to not have real filename
444
    prev = tb
445
    while tb is not None:
446
        filename = tb.tb_frame.f_code.co_filename
447
        reached_lib = [p in filename for p in dirs]
448
        if sum(reached_lib)>0:
449
            return prev.tb_frame, tb.tb_frame
450
        prev = tb
451
        tb = tb.tb_next
452
    return prev.tb_frame, None
453

454

455
def info(frame):
456
    if hasattr(frame, '__name__'):
457
        module = frame.f_globals['__name__']
458
    else:
459
        module = None
460
    info = inspect.getframeinfo(frame)
461
    if info.code_context is not None:
462
        code = info.code_context[0].strip()
463
    else:
464
        code = None
465
    filename, line = info.filename, info.lineno
466
    name = info.function
467
    return module, name, filename, line, code
468

469

470
def smallest_matrix_subexpr(t):
471
    """
472
    During visualization, we need to find the smallest expression
473
    that evaluates to a non-scalar. That corresponds to the deepest subtree
474
    that evaluates to a non-scalar. Because we do not have parent pointers,
475
    we cannot start at the leaves and walk upwards. Instead, pass a Boolean
476
    back to indicate whether this node or one of the descendants
477
    evaluates to a non-scalar.  Nodes in the tree that have matrix values and
478
    no matrix below are the ones to visualize.
479
    """
480
    nodes = []
481
    _smallest_matrix_subexpr(t, nodes)
482
    return nodes
483

484

485
def _smallest_matrix_subexpr(t, nodes) -> bool:
486
    if t is None: return False  # prevent buggy code from causing us to fail
487
    if isinstance(t, tsensor.ast.Member) and \
488
       isinstance(t.obj, tsensor.ast.Atom) and \
489
       isinstance(t.member, tsensor.ast.Atom) and \
490
       str(t.member)=='T':
491
        nodes.append(t)
492
        return True
493
    if len(t.kids)==0: # leaf node
494
        if istensor(t.value):
495
            nodes.append(t)
496
        return istensor(t.value)
497
    n_matrix_below = 0 # once this latches true, it's passed all the way up to the root
498
    for sub in t.kids:
499
        matrix_below = _smallest_matrix_subexpr(sub, nodes)
500
        n_matrix_below += matrix_below # how many descendents evaluated two non-scalar?
501
    # If current node is matrix and no descendents are, then this is smallest
502
    # sub expression that evaluates to a matrix; keep track
503
    if istensor(t.value) and n_matrix_below==0:
504
        nodes.append(t)
505
    # Report to caller that this node or some descendent is a matrix
506
    return istensor(t.value) or n_matrix_below > 0
507

508

509
def istensor(x):
510
    return _shape(x) is not None
511

512

513
def _dtype(v) -> str:
514
    if hasattr(v, "dtype"):
515
        dtype = v.dtype
516
    elif "dtype" in v.__class__.__name__:
517
        dtype = v
518
    else:
519
        return None
520

521
    if dtype.__class__.__module__ == "torch":
522
        # ugly but works
523
        return str(dtype).replace("torch.", "")
524
    if hasattr(dtype, "names") and dtype.names is not None and hasattr(dtype, "fields"):
525
        # structured dtype: https://numpy.org/devdocs/user/basics.rec.html
526
        return ",".join([_dtype(val) for val, _ in dtype.fields.values()])
527
    return dtype.name
528

529

530
def _shape(v):
531
    # do we have a shape and it answers len()? Should get stuff right.
532
    if hasattr(v, "shape") and hasattr(v.shape, "__len__"):
533
        if v.shape.__class__.__module__ == "torch" and v.shape.__class__.__name__ == "Size":
534
            if len(v.shape)==0:
535
                return None
536
            return list(v.shape)
537
        return v.shape
538
    return None
539

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.