tensor-sensor
538 строк · 26.1 Кб
1"""
2MIT License
3
4Copyright (c) 2021 Terence Parr
5
6Permission is hereby granted, free of charge, to any person obtaining a copy
7of this software and associated documentation files (the "Software"), to deal
8in the Software without restriction, including without limitation the rights
9to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10copies of the Software, and to permit persons to whom the Software is
11furnished to do so, subject to the following conditions:
12
13The above copyright notice and this permission notice shall be included in all
14copies or substantial portions of the Software.
15
16THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22SOFTWARE.
23"""
24import os25import sys26import traceback27import inspect28import hashlib29from pathlib import Path30
31import matplotlib.pyplot as plt32
33import tsensor34
35
36class clarify:37# Prevent nested clarify() calls from processing exceptions.38# See https://github.com/parrt/tensor-sensor/issues/1839# Probably will fail with Python `threading` package due to this class var40# but only if multiple threads call clarify().41# Multiprocessing forks new processes so not a problem. Each vm has it's own class var.42# Bump in __enter__, drop in __exit__43nesting = 044
45def __init__(self,46fontname=('Consolas', 'DejaVu Sans Mono'), fontsize=13,47dimfontname='Arial', dimfontsize=9, char_sep_scale=1.8, fontcolor='#444443',48underline_color='#C2C2C2', ignored_color='#B4B4B4', error_op_color='#A40227',49show:(None,'viz')='viz',50hush_errors=True,51dtype_colors=None, dtype_precisions=None, dtype_alpha_range=None):52"""53Augment tensor-related exceptions generated from numpy, pytorch, and tensorflow.
54Also display a visual representation of the offending Python line that
55shows the shape of tensors referenced by the code. All you have to do is wrap
56the outermost level of your code and clarify() will activate upon exception.
57
58Visualizations pop up in a separate window unless running from a notebook,
59in which case the visualization appears as part of the cell execution output.
60
61There is no runtime overhead associated with clarify() unless an exception occurs.
62
63The offending code is executed a second time, to identify which sub expressions
64are to blame. This implies that code with side effects could conceivably cause
65a problem, but since an exception has been generated, results are suspicious
66anyway.
67
68Example:
69
70import numpy as np
71import tsensor
72
73b = np.array([9, 10]).reshape(2, 1)
74with tsensor.clarify():
75np.dot(b,b) # tensor code or call to a function with tensor code
76
77See examples.ipynb for more examples.
78
79:param fontname: The name of the font used to display Python code
80:param fontsize: The font size used to display Python code; default is 13.
81Also use this to increase the size of the generated figure;
82larger font size means larger image.
83:param dimfontname: The name of the font used to display the dimensions on the matrix and vector boxes
84:param dimfontsize: The size of the font used to display the dimensions on the matrix and vector boxes
85:param char_sep_scale: It is notoriously difficult to discover how wide and tall
86text is when plotted in matplotlib. In fact there's probably,
87no hope to discover this information accurately in all cases.
88Certainly, I gave up after spending huge effort. We have a
89situation here where the font should be constant width, so
90we can just use a simple scaler times the font size to get
91a reasonable approximation to the width and height of a
92character box; the default of 1.8 seems to work reasonably
93well for a wide range of fonts, but you might have to tweak it
94when you change the font size.
95:param fontcolor: The color of the Python code.
96:param underline_color: The color of the lines that underscore tensor subexpressions; default is grey
97:param ignored_color: The de-highlighted color for deemphasizing code not involved in an erroneous sub expression
98:param error_op_color: The color to use for characters associated with the erroneous operator
99:param ax: If not none, this is the matplotlib drawing region in which to draw the visualization
100:param dpi: This library tries to generate SVG files, which are vector graphics not
1012D arrays of pixels like PNG files. However, it needs to know how to
102compute the exact figure size to remove padding around the visualization.
103Matplotlib uses inches for its figure size and so we must convert
104from pixels or data units to inches, which means we have to know what the
105dots per inch, dpi, is for the image.
106:param hush_errors: Normally, error messages from true syntax errors but also
107unhandled code caught by my parser are ignored. Turn this off
108to see what the error messages are coming from my parser.
109:param show: Show visualization upon tensor error if show='viz'.
110:param dtype_colors: map from dtype w/o precision like 'int' to color
111:param dtype_precisions: list of bit precisions to colorize, such as [32,64,128]
112:param dtype_alpha_range: all tensors of the same type are drawn to the same color,
113and the alpha channel is used to show precision; the
114smaller the bit size, the lower the alpha channel. You
115can play with the range to get better visual dynamic range
116depending on how many precisions you want to display.
117"""
118self.show, self.fontname, self.fontsize, self.dimfontname, self.dimfontsize, \119self.char_sep_scale, self.fontcolor, self.underline_color, self.ignored_color, \120self.error_op_color, self.hush_errors, \121self.dtype_colors, self.dtype_precisions, self.dtype_alpha_range = \122show, fontname, fontsize, dimfontname, dimfontsize, \123char_sep_scale, fontcolor, underline_color, ignored_color, \124error_op_color, hush_errors, \125dtype_colors, dtype_precisions, dtype_alpha_range126
127def __enter__(self):128self.frame = sys._getframe().f_back # where do we start tracking? Hmm...not sure we use this129# print("ENTER", clarify.nesting, self.frame, id(self.frame))130clarify.nesting += 1131return self132
133def __exit__(self, exc_type, exc_value, exc_traceback):134# print("EXIT", clarify.nesting, self.frame, id(self.frame))135clarify.nesting -= 1136if clarify.nesting>0:137return138if exc_type is None:139return140exc_frame, lib_entry_frame = tensor_lib_entry_frame(exc_traceback)141if lib_entry_frame is not None or is_interesting_exception(exc_value):142# print("exception:", exc_value, exc_traceback)143# traceback.print_tb(exc_traceback, limit=5, file=sys.stdout)144module, name, filename, line, code = info(exc_frame)145# print('info', module, name, filename, line, code)146# print("exc id", id(exc_value))147if code is not None:148self.view = tsensor.viz.pyviz(code, exc_frame,149self.fontname, self.fontsize, self.dimfontname,150self.dimfontsize,151self.char_sep_scale, self.fontcolor,152self.underline_color, self.ignored_color,153self.error_op_color,154hush_errors=self.hush_errors,155dtype_colors=self.dtype_colors,156dtype_precisions=self.dtype_precisions,157dtype_alpha_range=self.dtype_alpha_range)158if self.view is not None: # Ignore if we can't process code causing exception (I use a subparser)159if self.show=='viz':160self.view.show()161augment_exception(exc_value, self.view.offending_expr)162
163
164class explain:165def __init__(self,166fontname=('Consolas', 'DejaVu Sans Mono'), fontsize=13,167dimfontname='Arial', dimfontsize=9, char_sep_scale=1.8, fontcolor='#444443',168underline_color='#C2C2C2', ignored_color='#B4B4B4', error_op_color='#A40227',169savefig=None, hush_errors=True,170dtype_colors=None, dtype_precisions=None, dtype_alpha_range=None):171"""172As the Python virtual machine executes lines of code, generate a
173visualization for tensor-related expressions using from numpy, pytorch,
174and tensorflow. The shape of tensors referenced by the code are displayed.
175
176Visualizations pop up in a separate window unless running from a notebook,
177in which case the visualization appears as part of the cell execution output.
178
179There is heavy runtime overhead associated with explain() as every line
180is executed twice: once by explain() and then another time by the interpreter
181as part of normal execution.
182
183Expressions with side effects can easily generate incorrect results. Due to
184this and the overhead, you should limit the use of this to code you're trying
185to debug. Assignments are not evaluated by explain so code `x = ...` causes
186an assignment to x just once, during normal execution. This explainer
187knows the value of x and will display it but does not assign to it.
188
189Upon exception, execution will stop as usual but, like clarify(), explain()
190will augment the exception to indicate the offending sub expression. Further,
191the visualization will deemphasize code not associated with the offending
192sub expression. The sizes of relevant tensor values are still visualized.
193
194Example:
195
196import numpy as np
197import tsensor
198
199b = np.array([9, 10]).reshape(2, 1)
200with tsensor.explain():
201b + b # tensor code or call to a function with tensor code
202
203See examples.ipynb for more examples.
204
205:param fontname: The name of the font used to display Python code
206:param fontsize: The font size used to display Python code; default is 13.
207Also use this to increase the size of the generated figure;
208larger font size means larger image.
209:param dimfontname: The name of the font used to display the dimensions on the matrix and vector boxes
210:param dimfontsize: The size of the font used to display the dimensions on the matrix and vector boxes
211:param char_sep_scale: It is notoriously difficult to discover how wide and tall
212text is when plotted in matplotlib. In fact there's probably,
213no hope to discover this information accurately in all cases.
214Certainly, I gave up after spending huge effort. We have a
215situation here where the font should be constant width, so
216we can just use a simple scaler times the font size to get
217a reasonable approximation to the width and height of a
218character box; the default of 1.8 seems to work reasonably
219well for a wide range of fonts, but you might have to tweak it
220when you change the font size.
221:param fontcolor: The color of the Python code.
222:param underline_color: The color of the lines that underscore tensor subexpressions; default is grey
223:param ignored_color: The de-highlighted color for deemphasizing code not involved in an erroneous sub expression
224:param error_op_color: The color to use for characters associated with the erroneous operator
225:param ax: If not none, this is the matplotlib drawing region in which to draw the visualization
226:param dpi: This library tries to generate SVG files, which are vector graphics not
2272D arrays of pixels like PNG files. However, it needs to know how to
228compute the exact figure size to remove padding around the visualization.
229Matplotlib uses inches for its figure size and so we must convert
230from pixels or data units to inches, which means we have to know what the
231dots per inch, dpi, is for the image.
232:param hush_errors: Normally, error messages from true syntax errors but also
233unhandled code caught by my parser are ignored. Turn this off
234to see what the error messages are coming from my parser.
235:param savefig: A string indicating where to save the visualization; don't save
236a file if None.
237:param dtype_colors: map from dtype w/o precision like 'int' to color
238:param dtype_precisions: list of bit precisions to colorize, such as [32,64,128]
239:param dtype_alpha_range: all tensors of the same type are drawn to the same color,
240and the alpha channel is used to show precision; the
241smaller the bit size, the lower the alpha channel. You
242can play with the range to get better visual dynamic range
243depending on how many precisions you want to display.
244"""
245self.savefig, self.fontname, self.fontsize, self.dimfontname, self.dimfontsize, \246self.char_sep_scale, self.fontcolor, self.underline_color, self.ignored_color, \247self.error_op_color, self.hush_errors, \248self.dtype_colors, self.dtype_precisions, self.dtype_alpha_range = \249savefig, fontname, fontsize, dimfontname, dimfontsize, \250char_sep_scale, fontcolor, underline_color, ignored_color, \251error_op_color, hush_errors, \252dtype_colors, dtype_precisions, dtype_alpha_range253
254def __enter__(self):255# print("ON trace", sys._getframe())256self.tracer = ExplainTensorTracer(self)257sys.settrace(self.tracer.listener)258frame = sys._getframe()259prev = frame.f_back # get block wrapped in "with"260prev.f_trace = self.tracer.listener261return self.tracer262
263def __exit__(self, exc_type, exc_value, exc_traceback):264# print("OFF trace")265sys.settrace(None)266# At this point we have already tried to visualize the statement267# If there was no error, the visualization will look normal268# but a matrix operation error will show the erroneous operator highlighted.269# That was artificial execution of the code. Now the VM has executed270# the statement for real and has found the same exception. Make sure to271# augment the message with causal information.272if exc_type is None:273return274exc_frame, lib_entry_frame = tensor_lib_entry_frame(exc_traceback)275if lib_entry_frame is not None or is_interesting_exception(exc_value):276# print("exception:", exc_value, exc_traceback)277# traceback.print_tb(exc_traceback, limit=5, file=sys.stdout)278module, name, filename, line, code = info(exc_frame)279# print('info', module, name, filename, line, code)280if code is not None:281# We've already displayed picture so just augment message282root, tokens = tsensor.parsing.parse(code)283if root is not None: # Could be syntax error in statement or code I can't handle284offending_expr = None285try:286root.eval(exc_frame)287except tsensor.ast.IncrEvalTrap as e:288offending_expr = e.offending_expr289augment_exception(exc_value, offending_expr)290
291
292class ExplainTensorTracer:293def __init__(self, explainer):294self.explainer = explainer295self.exceptions = set()296self.linecount = 0297self.views = []298# set of hashes for statements already visualized;299# generate each combination of statement and shapes once300self.done = set()301
302def listener(self, frame, event, arg):303# print("listener", event, ":", frame)304if event!='line':305# It seems that we are getting CALL events even for calls in foo() from:306# with tsensor.explain(): foo()307# Must be that we already have a listener and, though we returned None here,308# somehow the original listener is still getting events. Strange but oh well.309# We must ignore these.310return None311module = frame.f_globals['__name__']312info = inspect.getframeinfo(frame)313filename, line = info.filename, info.lineno314name = info.function315
316# Note: always true since L292 above...317if event=='line':318self.line_listener(module, name, filename, line, info, frame)319
320# By returning none, we prevent explain()'ing from descending into321# invoked functions. In principle, we could allow a certain amount322# of tracing but I'm not sure that would be super useful.323return None324
325def line_listener(self, module, name, filename, line, info, frame):326code = info.code_context[0].strip()327if code.startswith("sys.settrace(None)"):328return329
330# Don't generate a statement visualization more than once331h = hash(code)332if h in self.done:333return334self.done.add(h)335
336p = tsensor.parsing.PyExprParser(code)337t = p.parse()338if t is not None:339# print(f"A line encountered in {module}.{name}() at {filename}:{line}")340# print("\t", code)341# print("\t", repr(t))342self.linecount += 1343self.viz_statement(code, frame)344
345def viz_statement(self, code, frame):346view = tsensor.viz.pyviz(code, frame,347self.explainer.fontname, self.explainer.fontsize,348self.explainer.dimfontname,349self.explainer.dimfontsize,350self.explainer.char_sep_scale, self.explainer.fontcolor,351self.explainer.underline_color, self.explainer.ignored_color,352self.explainer.error_op_color,353hush_errors=self.explainer.hush_errors,354dtype_colors=self.explainer.dtype_colors,355dtype_precisions=self.explainer.dtype_precisions,356dtype_alpha_range=self.explainer.dtype_alpha_range)357self.views.append(view)358if self.explainer.savefig is not None:359file_path = Path(self.explainer.savefig)360file_path = file_path.parent / f"{file_path.stem}-{self.linecount}{file_path.suffix}"361view.savefig(file_path)362view.filename = file_path363plt.close()364else:365view.show()366return view367
368@staticmethod369def hash(statement):370"""371We want to avoid generating a visualization more than once.
372For now, assume that the code for a statement is the unique identifier.
373"""
374return hashlib.md5(statement.encode('utf-8')).hexdigest()375
376
377def eval(statement:str, frame=None) -> (tsensor.ast.ParseTreeNode, object):378"""379Parse statement and return an ast in the context of execution frame or, if None,
380the invoking function's frame. Set the value field of all ast nodes.
381Overall result is in root.value.
382:param statement: A string representing the line of Python code to visualize within an execution frame.
383:param frame: The execution frame in which to evaluate the statement. If None,
384use the execution frame of the invoking function
385:return An abstract parse tree representing the statement; nodes are
386ParseTreeNode subclasses.
387"""
388p = tsensor.parsing.PyExprParser(statement)389root = p.parse()390if frame is None: # use frame of caller391frame = sys._getframe().f_back392root.eval(frame)393return root, root.value394
395
396def augment_exception(exc_value, subexpr):397explanation = subexpr.clarify()398augment = ""399if explanation is not None:400augment = explanation401# Reuse exception but overwrite the message402if hasattr(exc_value, "_message"):403exc_value._message = exc_value.message + "\n" + augment404else:405exc_value.args = [exc_value.args[0] + "\n" + augment]406
407
408def is_interesting_exception(e):409# print(f"is_interesting_exception: type is {type(e)}")410if e.__class__.__module__.startswith("tensorflow"):411return True412sentinels = {'matmul', 'THTensorMath', 'tensor', 'tensors', 'dimension',413'not aligned', 'size mismatch', 'shape', 'shapes', 'matrix',414'call to _th_addmm'}415if len(e.args)==0:416msg = e.message417else:418msg = e.args[0]419return any([s in msg for s in sentinels])420
421
422def tensor_lib_entry_frame(exc_traceback):423"""424Don't trace into internals of numpy/torch/tensorflow/jax; we want to reset frame
425to where in the user's python code it asked the tensor lib to perform an
426invalid operation.
427
428To detect libraries, look for code whose filename has "site-packages/{package}"
429or "dist-packages/{package}".
430
431Return last-user-frame, first-tensor-lib-frame if lib found else last-user-frame, None
432
433Note: Sometimes operators yield exceptions and no tensor lib entry frame. E.g.,
434np.ones(1) @ np.ones(2).
435"""
436tb = exc_traceback437# import traceback438# for t in traceback.extract_tb(exc_traceback):439# print(t)440packages = ['numpy','torch','tensorflow','jax']441dirs = [os.path.join('site-packages',p) for p in packages]442dirs += [os.path.join('dist-packages',p) for p in packages]443dirs += ['<__array_function__'] # numpy seems to not have real filename444prev = tb445while tb is not None:446filename = tb.tb_frame.f_code.co_filename447reached_lib = [p in filename for p in dirs]448if sum(reached_lib)>0:449return prev.tb_frame, tb.tb_frame450prev = tb451tb = tb.tb_next452return prev.tb_frame, None453
454
455def info(frame):456if hasattr(frame, '__name__'):457module = frame.f_globals['__name__']458else:459module = None460info = inspect.getframeinfo(frame)461if info.code_context is not None:462code = info.code_context[0].strip()463else:464code = None465filename, line = info.filename, info.lineno466name = info.function467return module, name, filename, line, code468
469
470def smallest_matrix_subexpr(t):471"""472During visualization, we need to find the smallest expression
473that evaluates to a non-scalar. That corresponds to the deepest subtree
474that evaluates to a non-scalar. Because we do not have parent pointers,
475we cannot start at the leaves and walk upwards. Instead, pass a Boolean
476back to indicate whether this node or one of the descendants
477evaluates to a non-scalar. Nodes in the tree that have matrix values and
478no matrix below are the ones to visualize.
479"""
480nodes = []481_smallest_matrix_subexpr(t, nodes)482return nodes483
484
485def _smallest_matrix_subexpr(t, nodes) -> bool:486if t is None: return False # prevent buggy code from causing us to fail487if isinstance(t, tsensor.ast.Member) and \488isinstance(t.obj, tsensor.ast.Atom) and \489isinstance(t.member, tsensor.ast.Atom) and \490str(t.member)=='T':491nodes.append(t)492return True493if len(t.kids)==0: # leaf node494if istensor(t.value):495nodes.append(t)496return istensor(t.value)497n_matrix_below = 0 # once this latches true, it's passed all the way up to the root498for sub in t.kids:499matrix_below = _smallest_matrix_subexpr(sub, nodes)500n_matrix_below += matrix_below # how many descendents evaluated two non-scalar?501# If current node is matrix and no descendents are, then this is smallest502# sub expression that evaluates to a matrix; keep track503if istensor(t.value) and n_matrix_below==0:504nodes.append(t)505# Report to caller that this node or some descendent is a matrix506return istensor(t.value) or n_matrix_below > 0507
508
509def istensor(x):510return _shape(x) is not None511
512
513def _dtype(v) -> str:514if hasattr(v, "dtype"):515dtype = v.dtype516elif "dtype" in v.__class__.__name__:517dtype = v518else:519return None520
521if dtype.__class__.__module__ == "torch":522# ugly but works523return str(dtype).replace("torch.", "")524if hasattr(dtype, "names") and dtype.names is not None and hasattr(dtype, "fields"):525# structured dtype: https://numpy.org/devdocs/user/basics.rec.html526return ",".join([_dtype(val) for val, _ in dtype.fields.values()])527return dtype.name528
529
530def _shape(v):531# do we have a shape and it answers len()? Should get stuff right.532if hasattr(v, "shape") and hasattr(v.shape, "__len__"):533if v.shape.__class__.__module__ == "torch" and v.shape.__class__.__name__ == "Size":534if len(v.shape)==0:535return None536return list(v.shape)537return v.shape538return None539