tensor-sensor
652 строки · 28.8 Кб
1"""
2MIT License
3
4Copyright (c) 2021 Terence Parr
5
6Permission is hereby granted, free of charge, to any person obtaining a copy
7of this software and associated documentation files (the "Software"), to deal
8in the Software without restriction, including without limitation the rights
9to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10copies of the Software, and to permit persons to whom the Software is
11furnished to do so, subject to the following conditions:
12
13The above copyright notice and this permission notice shall be included in all
14copies or substantial portions of the Software.
15
16THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22SOFTWARE.
23"""
24import sys
25import os
26from pathlib import Path
27import tempfile
28import graphviz
29import graphviz.backend
30import token
31import matplotlib.patches as patches
32import matplotlib.pyplot as plt
33import matplotlib.colors as mc
34from IPython.display import display, SVG
35from IPython import get_ipython
36
37import numpy as np
38import tsensor
39import tsensor.ast
40import tsensor.analysis
41import tsensor.parsing
42
43
44class DTypeColorInfo:
45"""
46Track the colors for various types, the transparency range, and bit precisions.
47By default, green indicates floating-point, blue indicates integer, and orange
48indicates complex numbers. The more saturated the color (lower transparency),
49the higher the precision.
50"""
51orangeish = '#FDD66C'
52limeish = '#A8E1B0'
53blueish = '#7FA4D3'
54grey = '#EFEFF0'
55default_dtype_colors = {'float': limeish, 'int': blueish, 'complex': orangeish, 'other': grey}
56default_dtype_precisions = [32, 64, 128] # hard to see diff if we use [4, 8, 16, 32, 64, 128]
57default_dtype_alpha_range = (0.5, 1.0) # use (0.1, 1.0) if more precision values
58
59def __init__(self, dtype_colors=None, dtype_precisions=None, dtype_alpha_range=None):
60if dtype_colors is None:
61dtype_colors = DTypeColorInfo.default_dtype_colors
62if dtype_precisions is None:
63dtype_precisions = DTypeColorInfo.default_dtype_precisions
64if dtype_alpha_range is None:
65dtype_alpha_range = DTypeColorInfo.default_dtype_alpha_range
66
67if not isinstance(dtype_colors, dict) or (len(dtype_colors) > 0 and \
68not isinstance(list(dtype_colors.values())[0], str)):
69raise TypeError(
70"dtype_colors should be a dict mapping type name to color name or color hex RGB values."
71)
72
73self.dtype_colors, self.dtype_precisions, self.dtype_alpha_range = \
74dtype_colors, dtype_precisions, dtype_alpha_range
75
76def color(self, dtype):
77"""Get color based on type and precision. Return list of RGB and alpha"""
78dtype_name, dtype_precision = PyVizView._split_dtype_precision(dtype)
79if dtype_name not in self.dtype_colors:
80return self.dtype_colors['other']
81color = self.dtype_colors[dtype_name]
82dtype_precision = int(dtype_precision)
83if dtype_precision not in self.dtype_precisions:
84return self.dtype_colors['other']
85
86color = mc.hex2color(color) if color[0] == '#' else mc.cnames[color]
87precision_idx = self.dtype_precisions.index(dtype_precision)
88nshades = len(self.dtype_precisions)
89alphas = np.linspace(*self.dtype_alpha_range, nshades)
90alpha = alphas[precision_idx]
91return list(color) + [alpha]
92
93
94class PyVizView:
95"""
96An object that collects relevant information about viewing Python code
97with visual annotations.
98"""
99def __init__(self, statement, fontname, fontsize, dimfontname, dimfontsize,
100char_sep_scale, dpi,
101dtype_colors=None, dtype_precisions=None, dtype_alpha_range=None):
102self.statement = statement
103self.fontsize = fontsize
104self.fontname = fontname
105self.dimfontsize = dimfontsize
106self.dimfontname = dimfontname
107self.char_sep_scale = char_sep_scale
108self.dpi = dpi
109self.dtype_color_info = DTypeColorInfo(dtype_colors, dtype_precisions, dtype_alpha_range)
110self._dtype_encountered = set() # which types, like 'int64', did we find in one plot?
111self.wchar = self.char_sep_scale * self.fontsize
112self.wchar_small = self.char_sep_scale * (self.fontsize - 2) # for <int32> typenames
113self.hchar = self.char_sep_scale * self.fontsize
114self.dim_ypadding = 5
115self.dim_xpadding = 0
116self.linewidth = .7
117self.leftedge = 25
118self.bottomedge = 3
119self.filename = None
120self.matrix_size_scaler = 3.5 # How wide or tall as scaled fontsize is matrix?
121self.vector_size_scaler = 3.2 / 4 # How wide or tall as scaled fontsize is vector for skinny part?
122self.shift3D = 6
123self.cause = None # Did an exception occurred during evaluation?
124self.offending_expr = None
125self.fignumber = None
126
127@staticmethod
128def _split_dtype_precision(s):
129"""Split the final integer part from a string"""
130head = s.rstrip('0123456789')
131tail = s[len(head):]
132return head, tail
133
134def set_locations(self, maxh):
135"""
136This function finishes setting up necessary parameters about text
137and graphics locations for the plot. We don't know how to set these
138values until we know what the max height of the drawing will be. We don't know
139what that height is until after we've parsed and so on, which requires that
140we collect and store information in this view object before computing maxh.
141That is why this is a separate function not part of the constructor.
142"""
143line2text = self.hchar / 1.7
144box2line = line2text*2.6
145self.texty = self.bottomedge + maxh + box2line + line2text
146self.liney = self.bottomedge + maxh + box2line
147self.box_topy = self.bottomedge + maxh
148self.maxy = self.texty + 1.4 * self.fontsize
149
150def _repr_svg_(self):
151"Show an SVG rendition in a notebook"
152return self.svg()
153
154def svg(self):
155"""
156Render as svg and return svg text. Save file and store name in field svgfilename.
157"""
158if self.filename is None: # have we saved before? (i.e., is it cached?)
159self.savefig(tempfile.mktemp(suffix='.svg'))
160elif not self.filename.endswith(".svg"):
161return None
162with open(self.filename, encoding='UTF-8') as f:
163svg = f.read()
164return svg
165
166def savefig(self, filename):
167"Save viz in format according to file extension."
168if plt.fignum_exists(self.fignumber):
169# If the matplotlib figure is still active, save it
170self.filename = filename # Remember the file so we can pull it back
171plt.savefig(filename, dpi=self.dpi, bbox_inches='tight', pad_inches=0)
172else: # we have already closed it so try to copy to new filename from the previous
173if filename!=self.filename:
174f,ext = os.path.splitext(filename)
175prev_f,prev_ext = os.path.splitext(self.filename)
176if ext != prev_ext:
177print(f"File extension {ext} differs from previous {prev_ext}; uses previous.")
178ext = prev_ext
179filename = f+ext # make sure that we don't copy raw bits and change the file extension to be inconsistent
180with open(self.filename, 'rb') as f:
181img = f.read()
182with open(filename, 'wb') as f:
183f.write(img)
184self.filename = filename # overwrite the filename with new name
185
186def show(self):
187"Display an SVG in a notebook or pop up a window if not in notebook"
188if get_ipython() is None:
189svgfilename = tempfile.mktemp(suffix='.svg')
190self.savefig(svgfilename)
191self.filename = svgfilename
192plt.show()
193else:
194svg = self.svg()
195display(SVG(svg))
196plt.close()
197
198def boxsize(self, v):
199"""
200How wide and tall should we draw the box representing a vector or matrix.
201"""
202sh = tsensor.analysis._shape(v)
203ty = tsensor.analysis._dtype(v)
204if sh is None: return None
205if len(sh)==1: return self.vector_size(sh, ty)
206return self.matrix_size(sh, ty)
207
208def matrix_size(self, sh, ty):
209"""
210How wide and tall should we draw the box representing a matrix.
211"""
212if len(sh)==1 and sh[0]==1:
213return self.vector_size(sh, ty)
214
215if len(sh) > 1 and sh[0] == 1 and sh[1] == 1:
216# A special case where we have a 1x1 matrix extending into the screen.
217# Make the 1x1 part a little bit wider than a vector so it's more readable
218w, h = 2 * self.vector_size_scaler * self.wchar, 2 * self.vector_size_scaler * self.wchar
219elif len(sh) > 1 and sh[1] == 1:
220w, h = self.vector_size_scaler * self.wchar, self.matrix_size_scaler * self.wchar
221elif len(sh)>1 and sh[0]==1:
222w, h = self.matrix_size_scaler * self.wchar, self.vector_size_scaler * self.wchar
223else:
224w, h = self.matrix_size_scaler * self.wchar, self.matrix_size_scaler * self.wchar
225return w, h
226
227def vector_size(self, sh, ty):
228"""
229How wide and tall is a vector? It's not a function of vector length; instead
230we make a row vector with same width as a matrix but height of just one char.
231For consistency with matrix_size(), I pass in shape, though it's ignored.
232"""
233return self.matrix_size_scaler * self.wchar, self.vector_size_scaler * self.wchar
234
235def draw(self, ax, sub):
236sh = tsensor.analysis._shape(sub.value)
237ty = tsensor.analysis._dtype(sub.value)
238self._dtype_encountered.add(ty)
239if len(sh) == 1:
240self.draw_vector(ax, sub, sh, ty)
241else:
242self.draw_matrix(ax, sub, sh, ty)
243
244def draw_vector(self,ax,sub, sh, ty: str):
245mid = (sub.leftx + sub.rightx) / 2
246w,h = self.vector_size(sh, ty)
247color = self.dtype_color_info.color(ty)
248rect1 = patches.Rectangle(xy=(mid - w/2, self.box_topy-h),
249width=w,
250height=h,
251linewidth=self.linewidth,
252facecolor=color,
253edgecolor='grey',
254fill=True)
255ax.add_patch(rect1)
256
257# Text above vector rectangle
258ax.text(mid, self.box_topy + self.dim_ypadding, self.nabbrev(sh[0]),
259horizontalalignment='center',
260fontname=self.dimfontname, fontsize=self.dimfontsize)
261# Type info at the bottom of everything
262ax.text(mid, self.box_topy - self.hchar, '<${\mathit{'+ty+'}}$>',
263verticalalignment='top', horizontalalignment='center',
264fontname=self.dimfontname, fontsize=self.dimfontsize-2)
265
266def draw_matrix(self,ax,sub, sh, ty):
267mid = (sub.leftx + sub.rightx) / 2
268w,h = self.matrix_size(sh, ty)
269box_left = mid - w / 2
270color = self.dtype_color_info.color(ty)
271
272if len(sh) > 2:
273back_rect = patches.Rectangle(xy=(box_left + self.shift3D, self.box_topy - h + self.shift3D),
274width=w,
275height=h,
276linewidth=self.linewidth,
277facecolor=color,
278edgecolor='grey',
279fill=True)
280ax.add_patch(back_rect)
281rect = patches.Rectangle(xy=(box_left, self.box_topy - h),
282width=w,
283height=h,
284linewidth=self.linewidth,
285facecolor=color,
286edgecolor='grey',
287fill=True)
288ax.add_patch(rect)
289
290# Text above matrix rectangle
291ax.text(box_left, self.box_topy - h/2, self.nabbrev(sh[0]),
292verticalalignment='center', horizontalalignment='right',
293fontname=self.dimfontname, fontsize=self.dimfontsize, rotation=90)
294
295# Note: this was always true since matrix...
296textx = mid
297texty = self.box_topy + self.dim_ypadding
298if len(sh) > 2:
299texty += self.dim_ypadding
300textx += self.shift3D
301
302# Text to the left
303ax.text(textx, texty, self.nabbrev(sh[1]), horizontalalignment='center',
304fontname=self.dimfontname, fontsize=self.dimfontsize)
305
306if len(sh) > 2:
307# Text to the right
308ax.text(box_left+w, self.box_topy - h/2, self.nabbrev(sh[2]),
309verticalalignment='center', horizontalalignment='center',
310fontname=self.dimfontname, fontsize=self.dimfontsize,
311rotation=45)
312
313bottom_text_line = self.box_topy - h - self.dim_ypadding
314if len(sh) > 3:
315# Text below
316remaining = r"$\cdots\mathsf{x}$"+r"$\mathsf{x}$".join([self.nabbrev(sh[i]) for i in range(3,len(sh))])
317bottom_text_line = self.box_topy - h - self.dim_ypadding
318ax.text(mid, bottom_text_line, remaining,
319verticalalignment='top', horizontalalignment='center',
320fontname=self.dimfontname, fontsize=self.dimfontsize)
321bottom_text_line -= self.hchar + self.dim_ypadding
322
323# Type info at the bottom of everything
324ax.text(mid, bottom_text_line, '<${\mathit{'+ty+'}}$>',
325verticalalignment='top', horizontalalignment='center',
326fontname=self.dimfontname, fontsize=self.dimfontsize-2)
327
328@staticmethod
329def nabbrev(n: int) -> str:
330if n % 1_000_000 == 0:
331return str(n // 1_000_000)+'m'
332if n % 1_000 == 0:
333return str(n // 1_000)+'k'
334return str(n)
335
336
337def pyviz(statement: str, frame=None,
338fontname='Consolas', fontsize=13,
339dimfontname='Arial', dimfontsize=9, char_sep_scale=1.8, fontcolor='#444443',
340underline_color='#C2C2C2', ignored_color='#B4B4B4', error_op_color='#A40227',
341ax=None, dpi=200, hush_errors=True,
342dtype_colors=None, dtype_precisions=None, dtype_alpha_range=None) -> PyVizView:
343"""
344Parse and evaluate the Python code in the statement string passed in using
345the indicated execution frame. The execution frame of the invoking function
346is used if frame is None.
347
348The visualization finds the smallest subexpressions that evaluate to
349tensors then underlies them and shows a box or rectangle representing
350the tensor dimensions. Boxes in blue (default) have two or more dimensions
351but rectangles in yellow (default) have one dimension with shape (n,).
352
353Upon tensor-related execution error, the offending self-expression is
354highlighted (by de-highlighting the other code) and the operator is shown
355using error_op_color.
356
357To adjust the size of the generated visualization to be smaller or bigger,
358decrease or increase the font size.
359
360:param statement: A string representing the line of Python code to visualize within an execution frame.
361:param frame: The execution frame in which to evaluate the statement. If None,
362use the execution frame of the invoking function
363:param fontname: The name of the font used to display Python code
364:param fontsize: The font size used to display Python code; default is 13.
365Also use this to increase the size of the generated figure;
366larger font size means larger image.
367:param dimfontname: The name of the font used to display the dimensions on the matrix and vector boxes
368:param dimfontsize: The size of the font used to display the dimensions on the matrix and vector boxes
369:param char_sep_scale: It is notoriously difficult to discover how wide and tall
370text is when plotted in matplotlib. In fact there's probably,
371no hope to discover this information accurately in all cases.
372Certainly, I gave up after spending huge effort. We have a
373situation here where the font should be constant width, so
374we can just use a simple scalar times the font size to get
375a reasonable approximation of the width and height of a
376character box; the default of 1.8 seems to work reasonably
377well for a wide range of fonts, but you might have to tweak it
378when you change the font size.
379:param fontcolor: The color of the Python code.
380:param underline_color: The color of the lines that underscore tensor subexpressions; default is grey
381:param ignored_color: The de-highlighted color for de-emphasizing code not involved in an erroneous sub expression
382:param error_op_color: The color to use for characters associated with the erroneous operator
383:param ax: If not none, this is the matplotlib drawing region in which to draw the visualization
384:param dpi: This library tries to generate SVG files, which are vector graphics not
3852D arrays of pixels like PNG files. However, it needs to know how to
386compute the exact figure size to remove padding around the visualization.
387Matplotlib uses inches for its figure size and so we must convert
388from pixels or data units to inches, which means we have to know what the
389dots per inch, dpi, is for the image.
390:param hush_errors: Normally, error messages from true syntax errors but also
391unhandled code caught by my parser are ignored. Turn this off
392to see what the error messages are coming from my parser.
393:param dtype_colors: map from dtype w/o precision like 'int' to color
394:param dtype_precisions: list of bit precisions to colorize, such as [32,64,128]
395:param dtype_alpha_range: all tensors of the same type are drawn to the same color,
396and the alpha channel is used to show precision; the
397smaller the bit size, the lower the alpha channel. You
398can play with the range to get better visual dynamic range
399depending on how many precisions you want to display.
400:return: Returns a PyVizView holding info about the visualization; from a notebook
401an SVG image will appear. Return none upon parsing error in statement.
402"""
403view = PyVizView(statement, fontname, fontsize, dimfontname, dimfontsize, char_sep_scale, dpi,
404dtype_colors, dtype_precisions, dtype_alpha_range)
405
406if frame is None: # use frame of caller if not passed in
407frame = sys._getframe().f_back
408root, tokens = tsensor.parsing.parse(statement, hush_errors=hush_errors)
409if root is None:
410print(f"Can't parse {statement}; root is None")
411# likely syntax error in statement or code I can't handle
412return None
413root_to_viz = root
414try:
415root.eval(frame)
416except tsensor.ast.IncrEvalTrap as e:
417root_to_viz = e.offending_expr
418view.offending_expr = e.offending_expr
419view.cause = e.__cause__
420# Don't raise the exception; keep going to visualize code and erroneous
421# subexpressions. If this function is invoked from clarify() or explain(),
422# the statement will be executed and will fail again during normal execution;
423# an exception will be thrown at that time. Then explain/clarify
424# will update the error message
425subexprs = tsensor.analysis.smallest_matrix_subexpr(root_to_viz)
426if ax is None:
427fig, ax = plt.subplots(1, 1, dpi=dpi)
428else:
429fig = ax.figure
430view.fignumber = fig.number # track this so that we can determine if the figure has been closed
431
432ax.axis("off")
433
434# First, we need to figure out how wide the visualization components are
435# for each sub expression. If these are wider than the sub expression text,
436# than we need to leave space around the sub expression text
437lpad = np.zeros((len(statement),)) # pad for characters
438rpad = np.zeros((len(statement),))
439maxh = 0
440for sub in subexprs:
441w, h = view.boxsize(sub.value)
442# update width to include horizontal room for type text like int32
443ty = tsensor.analysis._dtype(sub.value)
444w_typename = len(ty) * view.wchar_small
445w = max(w, w_typename)
446maxh = max(h, maxh)
447nexpr = sub.stop.cstop_idx - sub.start.cstart_idx
448if (sub.start.cstart_idx-1)>0 and statement[sub.start.cstart_idx - 1]== ' ': # if char to left is space
449nexpr += 1
450if sub.stop.cstop_idx<len(statement) and statement[sub.stop.cstop_idx]== ' ': # if char to right is space
451nexpr += 1
452if w > view.wchar * nexpr:
453lpad[sub.start.cstart_idx] += (w - view.wchar) / 2
454rpad[sub.stop.cstop_idx - 1] += (w - view.wchar) / 2
455
456# Now we know how to place all the elements, since we know what the maximum height is
457view.set_locations(maxh)
458
459# Find each character's position based upon width of a character and any padding
460charx = np.empty((len(statement),))
461x = view.leftedge
462for i,c in enumerate(statement):
463x += lpad[i]
464charx[i] = x
465x += view.wchar
466x += rpad[i]
467
468# Draw text for statement or expression
469if view.offending_expr is not None: # highlight erroneous subexpr
470highlight = np.full(shape=(len(statement),), fill_value=False, dtype=bool)
471for tok in tokens[root_to_viz.start.index:root_to_viz.stop.index+1]:
472highlight[tok.cstart_idx:tok.cstop_idx] = True
473errors = np.full(shape=(len(statement),), fill_value=False, dtype=bool)
474for tok in root_to_viz.optokens:
475errors[tok.cstart_idx:tok.cstop_idx] = True
476for i, c in enumerate(statement):
477color = ignored_color
478if highlight[i]:
479color = fontcolor
480if errors[i]: # override color if operator token
481color = error_op_color
482ax.text(charx[i], view.texty, c, color=color, fontname=fontname, fontsize=fontsize)
483else:
484for i, c in enumerate(statement):
485ax.text(charx[i], view.texty, c, color=fontcolor, fontname=fontname, fontsize=fontsize)
486
487# Compute the left and right edges of subexpressions (alter nodes with info)
488for i,sub in enumerate(subexprs):
489a = charx[sub.start.cstart_idx]
490b = charx[sub.stop.cstop_idx - 1] + view.wchar
491sub.leftx = a
492sub.rightx = b
493
494# Draw grey underlines and draw matrices
495for i,sub in enumerate(subexprs):
496a,b = sub.leftx, sub.rightx
497pad = view.wchar*0.1
498ax.plot([a-pad, b+pad], [view.liney,view.liney], '-', linewidth=.5, c=underline_color)
499view.draw(ax, sub)
500
501fig_width = charx[-1] + view.wchar + rpad[-1]
502fig_width_inches = fig_width / dpi
503fig_height_inches = view.maxy / dpi
504fig.set_size_inches(fig_width_inches, fig_height_inches)
505
506ax.set_xlim(0, fig_width)
507ax.set_ylim(0, view.maxy)
508
509return view
510
511
512# ---------------- SHOW AST STUFF ---------------------------
513
514class QuietGraphvizWrapper(graphviz.Source):
515def __init__(self, dotsrc):
516super().__init__(source=dotsrc)
517
518def _repr_svg_(self):
519return self.pipe(format='svg', quiet=True).decode(self._encoding)
520
521def savefig(self, filename):
522path = Path(filename)
523path.parent.mkdir(exist_ok=True)
524
525dotfilename = self.save(directory=path.parent.as_posix(), filename=path.stem)
526format = path.suffix[1:] # ".svg" -> "svg" etc...
527cmd = ["dot", f"-T{format}", "-o", filename, dotfilename]
528# print(' '.join(cmd))
529if graphviz.__version__ <= '0.17':
530graphviz.backend.run(cmd, capture_output=True, check=True, quiet=False)
531else:
532graphviz.backend.execute.run_check(cmd, capture_output=True, check=True, quiet=False)
533
534
535def astviz(statement:str, frame='current',
536dtype_colors=None, dtype_precisions=None, dtype_alpha_range=None) -> graphviz.Source:
537"""
538Display the abstract syntax tree (AST) for the indicated Python code
539in statement. Evaluate that code in the context of frame. If the frame
540is not specified, the default is to execute the code within the context of
541the invoking code. Pass in frame=None to avoid evaluation and just display
542the AST.
543
544Returns a QuietGraphvizWrapper that renders as SVG in a notebook but
545you can also call `savefig()` to save the file and in a variety of formats,
546according to the file extension.
547"""
548return QuietGraphvizWrapper(
549astviz_dot(statement, frame,
550dtype_colors, dtype_precisions, dtype_alpha_range)
551)
552
553
554def astviz_dot(statement:str, frame='current',
555dtype_colors=None, dtype_precisions=None, dtype_alpha_range=None) -> str:
556def internal_label(node):
557sh = tsensor.analysis._shape(node.value)
558ty = tsensor.analysis._dtype(node.value)
559text = ''.join(str(t) for t in node.optokens)
560if sh is None:
561return f'<font face="{fontname}" point-size="{fontsize}">{text}</font>'
562
563sz = 'x'.join([PyVizView.nabbrev(sh[i]) for i in range(len(sh))])
564return f"""<font face="Consolas" color="#444443" point-size="{fontsize}">{text}</font><br/><font face="Arial" color="#444443" point-size="{dimfontsize}">{sz}</font><br/><font face="Arial" color="#444443" point-size="{dimfontsize}"><{ty}></font>"""
565
566dtype_color_info = DTypeColorInfo(dtype_colors, dtype_precisions, dtype_alpha_range)
567
568root, tokens = tsensor.parsing.parse(statement)
569
570if frame=='current': # use frame of caller if nothing passed in
571frame = sys._getframe().f_back
572if frame.f_code.co_name=='astviz':
573frame = frame.f_back
574
575if frame is not None: # if the passed in None, then don't do the evaluation
576root.eval(frame)
577
578nodes = tsensor.ast.postorder(root)
579atoms = tsensor.ast.leaves(root)
580atomsS = set(atoms)
581ops = [nd for nd in nodes if nd not in atomsS] # keep order
582
583gr = """digraph G {
584margin=0;
585nodesep=.01;
586ranksep=.3;
587rankdir=BT;
588ordering=out; # keep order of leaves
589"""
590
591fontname="Consolas"
592fontsize=12
593dimfontsize = 9
594spread = 0
595
596# Gen leaf nodes
597for i in range(len(tokens)):
598t = tokens[i]
599if t.type!=token.ENDMARKER:
600nodetext = t.value
601# if ']' in nodetext:
602if nodetext==']':
603nodetext = nodetext.replace(']','‌]') # ‌ is 0-width nonjoiner. ']' by itself is bad for DOT
604label = f'<font face="{fontname}" color="#444443" point-size="{fontsize}">{nodetext}</font>'
605_spread = spread
606if t.type==token.DOT:
607_spread=.1
608elif t.type==token.EQUAL:
609_spread=.25
610elif t.type in tsensor.parsing.ADDOP:
611_spread=.4
612elif t.type in tsensor.parsing.MULOP:
613_spread=.2
614gr += f'leaf{id(t)} [shape=box penwidth=0 margin=.001 width={_spread} label=<{label}>]\n'
615
616# Make sure leaves are on same level
617gr += f'{{ rank=same; '
618for t in tokens:
619if t.type!=token.ENDMARKER:
620gr += f' leaf{id(t)}'
621gr += '\n}\n'
622
623# Make sure leaves are left to right by linking
624for i in range(len(tokens) - 2):
625t = tokens[i]
626t2 = tokens[i + 1]
627gr += f'leaf{id(t)} -> leaf{id(t2)} [style=invis];\n'
628
629# Draw internal ops nodes
630for nd in ops:
631label = internal_label(nd)
632sh = tsensor.analysis._shape(nd.value)
633if sh is None:
634color = ""
635else:
636ty = tsensor.analysis._dtype(nd.value)
637color = dtype_color_info.color(ty)
638color = mc.rgb2hex(color, keep_alpha=True)
639color = f'fillcolor="{color}" style=filled'
640gr += f'node{id(nd)} [shape=box {color} penwidth=0 margin=0 width=.25 height=.2 label=<{label}>]\n'
641
642# Link internal nodes to other nodes or leaves
643for nd in nodes:
644kids = nd.kids
645for sub in kids:
646if sub in atomsS:
647gr += f'node{id(nd)} -> leaf{id(sub.token)} [dir=back, penwidth="0.5", color="#6B6B6B", arrowsize=.3];\n'
648else:
649gr += f'node{id(nd)} -> node{id(sub)} [dir=back, penwidth="0.5", color="#6B6B6B", arrowsize=.3];\n'
650
651gr += "}\n"
652return gr
653