tensor-sensor

Форк
0
652 строки · 28.8 Кб
1
"""
2
MIT License
3

4
Copyright (c) 2021 Terence Parr
5

6
Permission is hereby granted, free of charge, to any person obtaining a copy
7
of this software and associated documentation files (the "Software"), to deal
8
in the Software without restriction, including without limitation the rights
9
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10
copies of the Software, and to permit persons to whom the Software is
11
furnished to do so, subject to the following conditions:
12

13
The above copyright notice and this permission notice shall be included in all
14
copies or substantial portions of the Software.
15

16
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22
SOFTWARE.
23
"""
24
import sys
25
import os
26
from pathlib import Path
27
import tempfile
28
import graphviz
29
import graphviz.backend
30
import token
31
import matplotlib.patches as patches
32
import matplotlib.pyplot as plt
33
import matplotlib.colors as mc
34
from IPython.display import display, SVG
35
from IPython import get_ipython
36

37
import numpy as np
38
import tsensor
39
import tsensor.ast
40
import tsensor.analysis
41
import tsensor.parsing
42

43

44
class DTypeColorInfo:
45
    """
46
    Track the colors for various types, the transparency range, and bit precisions.
47
    By default, green indicates floating-point, blue indicates integer, and orange
48
    indicates complex numbers. The more saturated the color (lower transparency),
49
    the higher the precision.
50
    """
51
    orangeish = '#FDD66C'
52
    limeish = '#A8E1B0'
53
    blueish = '#7FA4D3'
54
    grey = '#EFEFF0'
55
    default_dtype_colors = {'float': limeish, 'int': blueish, 'complex': orangeish, 'other': grey}
56
    default_dtype_precisions = [32, 64, 128]  # hard to see diff if we use [4, 8, 16, 32, 64, 128]
57
    default_dtype_alpha_range = (0.5, 1.0)    # use (0.1, 1.0) if more precision values
58

59
    def __init__(self, dtype_colors=None, dtype_precisions=None, dtype_alpha_range=None):
60
        if dtype_colors is None:
61
            dtype_colors = DTypeColorInfo.default_dtype_colors
62
        if dtype_precisions is None:
63
            dtype_precisions = DTypeColorInfo.default_dtype_precisions
64
        if dtype_alpha_range is None:
65
            dtype_alpha_range = DTypeColorInfo.default_dtype_alpha_range
66

67
        if not isinstance(dtype_colors, dict) or (len(dtype_colors) > 0 and \
68
           not isinstance(list(dtype_colors.values())[0], str)):
69
            raise TypeError(
70
                "dtype_colors should be a dict mapping type name to color name or color hex RGB values."
71
            )
72

73
        self.dtype_colors, self.dtype_precisions, self.dtype_alpha_range = \
74
            dtype_colors, dtype_precisions, dtype_alpha_range
75

76
    def color(self, dtype):
77
        """Get color based on type and precision. Return list of RGB and alpha"""
78
        dtype_name, dtype_precision = PyVizView._split_dtype_precision(dtype)
79
        if dtype_name not in self.dtype_colors:
80
            return self.dtype_colors['other']
81
        color = self.dtype_colors[dtype_name]
82
        dtype_precision = int(dtype_precision)
83
        if dtype_precision not in self.dtype_precisions:
84
            return self.dtype_colors['other']
85

86
        color = mc.hex2color(color) if color[0] == '#' else mc.cnames[color]
87
        precision_idx = self.dtype_precisions.index(dtype_precision)
88
        nshades = len(self.dtype_precisions)
89
        alphas = np.linspace(*self.dtype_alpha_range, nshades)
90
        alpha = alphas[precision_idx]
91
        return list(color) + [alpha]
92

93

94
class PyVizView:
95
    """
96
    An object that collects relevant information about viewing Python code
97
    with visual annotations.
98
    """
99
    def __init__(self, statement, fontname, fontsize, dimfontname, dimfontsize,
100
                 char_sep_scale, dpi,
101
                 dtype_colors=None, dtype_precisions=None, dtype_alpha_range=None):
102
        self.statement = statement
103
        self.fontsize = fontsize
104
        self.fontname = fontname
105
        self.dimfontsize = dimfontsize
106
        self.dimfontname = dimfontname
107
        self.char_sep_scale = char_sep_scale
108
        self.dpi = dpi
109
        self.dtype_color_info = DTypeColorInfo(dtype_colors, dtype_precisions, dtype_alpha_range)
110
        self._dtype_encountered = set() # which types, like 'int64', did we find in one plot?
111
        self.wchar = self.char_sep_scale * self.fontsize
112
        self.wchar_small = self.char_sep_scale * (self.fontsize - 2)  # for <int32> typenames
113
        self.hchar = self.char_sep_scale * self.fontsize
114
        self.dim_ypadding = 5
115
        self.dim_xpadding = 0
116
        self.linewidth = .7
117
        self.leftedge = 25
118
        self.bottomedge = 3
119
        self.filename = None
120
        self.matrix_size_scaler = 3.5      # How wide or tall as scaled fontsize is matrix?
121
        self.vector_size_scaler = 3.2 / 4  # How wide or tall as scaled fontsize is vector for skinny part?
122
        self.shift3D = 6
123
        self.cause = None # Did an exception occurred during evaluation?
124
        self.offending_expr = None
125
        self.fignumber = None
126

127
    @staticmethod
128
    def _split_dtype_precision(s):
129
        """Split the final integer part from a string"""
130
        head = s.rstrip('0123456789')
131
        tail = s[len(head):]
132
        return head, tail
133

134
    def set_locations(self, maxh):
135
        """
136
        This function finishes setting up necessary parameters about text
137
        and graphics locations for the plot. We don't know how to set these
138
        values until we know what the max height of the drawing will be. We don't know
139
        what that height is until after we've parsed and so on, which requires that
140
        we collect and store information in this view object before computing maxh.
141
        That is why this is a separate function not part of the constructor.
142
        """
143
        line2text = self.hchar / 1.7
144
        box2line = line2text*2.6
145
        self.texty = self.bottomedge + maxh + box2line + line2text
146
        self.liney = self.bottomedge + maxh + box2line
147
        self.box_topy = self.bottomedge + maxh
148
        self.maxy = self.texty + 1.4 * self.fontsize
149

150
    def _repr_svg_(self):
151
        "Show an SVG rendition in a notebook"
152
        return self.svg()
153

154
    def svg(self):
155
        """
156
        Render as svg and return svg text. Save file and store name in field svgfilename.
157
        """
158
        if self.filename is None: # have we saved before? (i.e., is it cached?)
159
            self.savefig(tempfile.mktemp(suffix='.svg'))
160
        elif not self.filename.endswith(".svg"):
161
            return None
162
        with open(self.filename, encoding='UTF-8') as f:
163
            svg = f.read()
164
        return svg
165

166
    def savefig(self, filename):
167
        "Save viz in format according to file extension."
168
        if plt.fignum_exists(self.fignumber):
169
            # If the matplotlib figure is still active, save it
170
            self.filename = filename # Remember the file so we can pull it back
171
            plt.savefig(filename, dpi=self.dpi, bbox_inches='tight', pad_inches=0)
172
        else: # we have already closed it so try to copy to new filename from the previous
173
            if filename!=self.filename:
174
                f,ext = os.path.splitext(filename)
175
                prev_f,prev_ext = os.path.splitext(self.filename)
176
                if ext != prev_ext:
177
                    print(f"File extension {ext} differs from previous {prev_ext}; uses previous.")
178
                    ext = prev_ext
179
                filename = f+ext # make sure that we don't copy raw bits and change the file extension to be inconsistent
180
                with open(self.filename, 'rb') as f:
181
                    img = f.read()
182
                with open(filename, 'wb') as f:
183
                    f.write(img)
184
                self.filename = filename  # overwrite the filename with new name
185

186
    def show(self):
187
        "Display an SVG in a notebook or pop up a window if not in notebook"
188
        if get_ipython() is None:
189
            svgfilename = tempfile.mktemp(suffix='.svg')
190
            self.savefig(svgfilename)
191
            self.filename = svgfilename
192
            plt.show()
193
        else:
194
            svg = self.svg()
195
            display(SVG(svg))
196
        plt.close()
197

198
    def boxsize(self, v):
199
        """
200
        How wide and tall should we draw the box representing a vector or matrix.
201
        """
202
        sh = tsensor.analysis._shape(v)
203
        ty = tsensor.analysis._dtype(v)
204
        if sh is None: return None
205
        if len(sh)==1: return self.vector_size(sh, ty)
206
        return self.matrix_size(sh, ty)
207

208
    def matrix_size(self, sh, ty):
209
        """
210
        How wide and tall should we draw the box representing a matrix.
211
        """
212
        if len(sh)==1 and sh[0]==1:
213
            return self.vector_size(sh, ty)
214

215
        if len(sh) > 1 and sh[0] == 1 and sh[1] == 1:
216
            # A special case where we have a 1x1 matrix extending into the screen.
217
            # Make the 1x1 part a little bit wider than a vector so it's more readable
218
            w, h = 2 * self.vector_size_scaler * self.wchar, 2 * self.vector_size_scaler * self.wchar
219
        elif len(sh) > 1 and sh[1] == 1:
220
            w, h = self.vector_size_scaler * self.wchar, self.matrix_size_scaler * self.wchar
221
        elif len(sh)>1 and sh[0]==1:
222
            w, h = self.matrix_size_scaler * self.wchar, self.vector_size_scaler * self.wchar
223
        else:
224
            w, h = self.matrix_size_scaler * self.wchar, self.matrix_size_scaler * self.wchar
225
        return w, h
226

227
    def vector_size(self, sh, ty):
228
        """
229
        How wide and tall is a vector?  It's not a function of vector length; instead
230
        we make a row vector with same width as a matrix but height of just one char.
231
        For consistency with matrix_size(), I pass in shape, though it's ignored.
232
        """
233
        return self.matrix_size_scaler * self.wchar, self.vector_size_scaler * self.wchar
234

235
    def draw(self, ax, sub):
236
        sh = tsensor.analysis._shape(sub.value)
237
        ty = tsensor.analysis._dtype(sub.value)
238
        self._dtype_encountered.add(ty)
239
        if len(sh) == 1:
240
            self.draw_vector(ax, sub, sh, ty)
241
        else:
242
            self.draw_matrix(ax, sub, sh, ty)
243

244
    def draw_vector(self,ax,sub, sh, ty: str):
245
        mid = (sub.leftx + sub.rightx) / 2
246
        w,h = self.vector_size(sh, ty)
247
        color = self.dtype_color_info.color(ty)
248
        rect1 = patches.Rectangle(xy=(mid - w/2, self.box_topy-h),
249
                                  width=w,
250
                                  height=h,
251
                                  linewidth=self.linewidth,
252
                                  facecolor=color,
253
                                  edgecolor='grey',
254
                                  fill=True)
255
        ax.add_patch(rect1)
256

257
        # Text above vector rectangle
258
        ax.text(mid, self.box_topy + self.dim_ypadding, self.nabbrev(sh[0]),
259
                horizontalalignment='center',
260
                fontname=self.dimfontname, fontsize=self.dimfontsize)
261
        # Type info at the bottom of everything
262
        ax.text(mid, self.box_topy - self.hchar, '<${\mathit{'+ty+'}}$>',
263
                verticalalignment='top', horizontalalignment='center',
264
                fontname=self.dimfontname, fontsize=self.dimfontsize-2)
265

266
    def draw_matrix(self,ax,sub, sh, ty):
267
        mid = (sub.leftx + sub.rightx) / 2
268
        w,h = self.matrix_size(sh, ty)
269
        box_left = mid - w / 2
270
        color = self.dtype_color_info.color(ty)
271

272
        if len(sh) > 2:
273
            back_rect = patches.Rectangle(xy=(box_left + self.shift3D, self.box_topy - h + self.shift3D),
274
                                          width=w,
275
                                          height=h,
276
                                          linewidth=self.linewidth,
277
                                          facecolor=color,
278
                                          edgecolor='grey',
279
                                          fill=True)
280
            ax.add_patch(back_rect)
281
        rect = patches.Rectangle(xy=(box_left, self.box_topy - h),
282
                                  width=w,
283
                                  height=h,
284
                                  linewidth=self.linewidth,
285
                                  facecolor=color,
286
                                  edgecolor='grey',
287
                                  fill=True)
288
        ax.add_patch(rect)
289

290
        # Text above matrix rectangle
291
        ax.text(box_left, self.box_topy - h/2, self.nabbrev(sh[0]),
292
                verticalalignment='center', horizontalalignment='right',
293
                fontname=self.dimfontname, fontsize=self.dimfontsize, rotation=90)
294

295
        # Note: this was always true since matrix...
296
        textx = mid
297
        texty = self.box_topy + self.dim_ypadding
298
        if len(sh) > 2:
299
            texty += self.dim_ypadding
300
            textx += self.shift3D
301

302
        # Text to the left
303
        ax.text(textx, texty, self.nabbrev(sh[1]), horizontalalignment='center',
304
                fontname=self.dimfontname, fontsize=self.dimfontsize)
305

306
        if len(sh) > 2:
307
            # Text to the right
308
            ax.text(box_left+w, self.box_topy - h/2, self.nabbrev(sh[2]),
309
                    verticalalignment='center', horizontalalignment='center',
310
                    fontname=self.dimfontname, fontsize=self.dimfontsize,
311
                    rotation=45)
312

313
        bottom_text_line = self.box_topy - h - self.dim_ypadding
314
        if len(sh) > 3:
315
            # Text below
316
            remaining = r"$\cdots\mathsf{x}$"+r"$\mathsf{x}$".join([self.nabbrev(sh[i]) for i in range(3,len(sh))])
317
            bottom_text_line = self.box_topy - h - self.dim_ypadding
318
            ax.text(mid, bottom_text_line, remaining,
319
                    verticalalignment='top', horizontalalignment='center',
320
                    fontname=self.dimfontname, fontsize=self.dimfontsize)
321
            bottom_text_line -= self.hchar + self.dim_ypadding
322

323
        # Type info at the bottom of everything
324
        ax.text(mid, bottom_text_line, '<${\mathit{'+ty+'}}$>',
325
                verticalalignment='top', horizontalalignment='center',
326
                fontname=self.dimfontname, fontsize=self.dimfontsize-2)
327

328
    @staticmethod
329
    def nabbrev(n: int) -> str:
330
        if n % 1_000_000 == 0:
331
            return str(n // 1_000_000)+'m'
332
        if n % 1_000 == 0:
333
            return str(n // 1_000)+'k'
334
        return str(n)
335

336

337
def pyviz(statement: str, frame=None,
338
          fontname='Consolas', fontsize=13,
339
          dimfontname='Arial', dimfontsize=9, char_sep_scale=1.8, fontcolor='#444443',
340
          underline_color='#C2C2C2', ignored_color='#B4B4B4', error_op_color='#A40227',
341
          ax=None, dpi=200, hush_errors=True,
342
          dtype_colors=None, dtype_precisions=None, dtype_alpha_range=None) -> PyVizView:
343
    """
344
    Parse and evaluate the Python code in the statement string passed in using
345
    the indicated execution frame. The execution frame of the invoking function
346
    is used if frame is None.
347

348
    The visualization finds the smallest subexpressions that evaluate to
349
    tensors then underlies them and shows a box or rectangle representing
350
    the tensor dimensions. Boxes in blue (default) have two or more dimensions
351
    but rectangles in yellow (default) have one dimension with shape (n,).
352

353
    Upon tensor-related execution error, the offending self-expression is
354
    highlighted (by de-highlighting the other code) and the operator is shown
355
    using error_op_color.
356

357
    To adjust the size of the generated visualization to be smaller or bigger,
358
    decrease or increase the font size.
359

360
    :param statement: A string representing the line of Python code to visualize within an execution frame.
361
    :param frame: The execution frame in which to evaluate the statement. If None,
362
                  use the execution frame of the invoking function
363
    :param fontname: The name of the font used to display Python code
364
    :param fontsize: The font size used to display Python code; default is 13.
365
                     Also use this to increase the size of the generated figure;
366
                     larger font size means larger image.
367
    :param dimfontname:  The name of the font used to display the dimensions on the matrix and vector boxes
368
    :param dimfontsize: The  size of the font used to display the dimensions on the matrix and vector boxes
369
    :param char_sep_scale: It is notoriously difficult to discover how wide and tall
370
                           text is when plotted in matplotlib. In fact there's probably,
371
                           no hope to discover this information accurately in all cases.
372
                           Certainly, I gave up after spending huge effort. We have a
373
                           situation here where the font should be constant width, so
374
                           we can just use a simple scalar times the font size to get
375
                           a reasonable approximation of the width and height of a
376
                           character box; the default of 1.8 seems to work reasonably
377
                           well for a wide range of fonts, but you might have to tweak it
378
                           when you change the font size.
379
    :param fontcolor:  The color of the Python code.
380
    :param underline_color:  The color of the lines that underscore tensor subexpressions; default is grey
381
    :param ignored_color: The de-highlighted color for de-emphasizing code not involved in an erroneous sub expression
382
    :param error_op_color: The color to use for characters associated with the erroneous operator
383
    :param ax: If not none, this is the matplotlib drawing region in which to draw the visualization
384
    :param dpi: This library tries to generate SVG files, which are vector graphics not
385
                2D arrays of pixels like PNG files. However, it needs to know how to
386
                compute the exact figure size to remove padding around the visualization.
387
                Matplotlib uses inches for its figure size and so we must convert
388
                from pixels or data units to inches, which means we have to know what the
389
                dots per inch, dpi, is for the image.
390
    :param hush_errors: Normally, error messages from true syntax errors but also
391
                        unhandled code caught by my parser are ignored. Turn this off
392
                        to see what the error messages are coming from my parser.
393
    :param dtype_colors: map from dtype w/o precision like 'int' to color
394
    :param dtype_precisions: list of bit precisions to colorize, such as [32,64,128]
395
    :param dtype_alpha_range: all tensors of the same type are drawn to the same color,
396
                              and the alpha channel is used to show precision; the
397
                              smaller the bit size, the lower the alpha channel. You
398
                              can play with the range to get better visual dynamic range
399
                              depending on how many precisions you want to display.
400
    :return: Returns a PyVizView holding info about the visualization; from a notebook
401
             an SVG image will appear. Return none upon parsing error in statement.
402
    """
403
    view = PyVizView(statement, fontname, fontsize, dimfontname, dimfontsize, char_sep_scale, dpi,
404
                     dtype_colors, dtype_precisions, dtype_alpha_range)
405

406
    if frame is None: # use frame of caller if not passed in
407
        frame = sys._getframe().f_back
408
    root, tokens = tsensor.parsing.parse(statement, hush_errors=hush_errors)
409
    if root is None:
410
        print(f"Can't parse {statement}; root is None")
411
        # likely syntax error in statement or code I can't handle
412
        return None
413
    root_to_viz = root
414
    try:
415
        root.eval(frame)
416
    except tsensor.ast.IncrEvalTrap as e:
417
        root_to_viz = e.offending_expr
418
        view.offending_expr = e.offending_expr
419
        view.cause = e.__cause__
420
        # Don't raise the exception; keep going to visualize code and erroneous
421
        # subexpressions. If this function is invoked from clarify() or explain(),
422
        # the statement will be executed and will fail again during normal execution;
423
        # an exception will be thrown at that time. Then explain/clarify
424
        # will update the error message
425
    subexprs = tsensor.analysis.smallest_matrix_subexpr(root_to_viz)
426
    if ax is None:
427
        fig, ax = plt.subplots(1, 1, dpi=dpi)
428
    else:
429
        fig = ax.figure
430
    view.fignumber = fig.number # track this so that we can determine if the figure has been closed
431

432
    ax.axis("off")
433

434
    # First, we need to figure out how wide the visualization components are
435
    # for each sub expression. If these are wider than the sub expression text,
436
    # than we need to leave space around the sub expression text
437
    lpad = np.zeros((len(statement),)) # pad for characters
438
    rpad = np.zeros((len(statement),))
439
    maxh = 0
440
    for sub in subexprs:
441
        w, h = view.boxsize(sub.value)
442
        # update width to include horizontal room for type text like int32
443
        ty = tsensor.analysis._dtype(sub.value)
444
        w_typename = len(ty) * view.wchar_small
445
        w = max(w, w_typename)
446
        maxh = max(h, maxh)
447
        nexpr = sub.stop.cstop_idx - sub.start.cstart_idx
448
        if (sub.start.cstart_idx-1)>0 and statement[sub.start.cstart_idx - 1]== ' ':  # if char to left is space
449
            nexpr += 1
450
        if sub.stop.cstop_idx<len(statement) and statement[sub.stop.cstop_idx]== ' ': # if char to right is space
451
            nexpr += 1
452
        if w > view.wchar * nexpr:
453
            lpad[sub.start.cstart_idx] += (w - view.wchar) / 2
454
            rpad[sub.stop.cstop_idx - 1] += (w - view.wchar) / 2
455

456
    # Now we know how to place all the elements, since we know what the maximum height is
457
    view.set_locations(maxh)
458

459
    # Find each character's position based upon width of a character and any padding
460
    charx = np.empty((len(statement),))
461
    x = view.leftedge
462
    for i,c in enumerate(statement):
463
        x += lpad[i]
464
        charx[i] = x
465
        x += view.wchar
466
        x += rpad[i]
467

468
    # Draw text for statement or expression
469
    if view.offending_expr is not None: # highlight erroneous subexpr
470
        highlight = np.full(shape=(len(statement),), fill_value=False, dtype=bool)
471
        for tok in tokens[root_to_viz.start.index:root_to_viz.stop.index+1]:
472
            highlight[tok.cstart_idx:tok.cstop_idx] = True
473
        errors = np.full(shape=(len(statement),), fill_value=False, dtype=bool)
474
        for tok in root_to_viz.optokens:
475
            errors[tok.cstart_idx:tok.cstop_idx] = True
476
        for i, c in enumerate(statement):
477
            color = ignored_color
478
            if highlight[i]:
479
                color = fontcolor
480
            if errors[i]: # override color if operator token
481
                color = error_op_color
482
            ax.text(charx[i], view.texty, c, color=color, fontname=fontname, fontsize=fontsize)
483
    else:
484
        for i, c in enumerate(statement):
485
            ax.text(charx[i], view.texty, c, color=fontcolor, fontname=fontname, fontsize=fontsize)
486

487
    # Compute the left and right edges of subexpressions (alter nodes with info)
488
    for i,sub in enumerate(subexprs):
489
        a = charx[sub.start.cstart_idx]
490
        b = charx[sub.stop.cstop_idx - 1] + view.wchar
491
        sub.leftx = a
492
        sub.rightx = b
493

494
    # Draw grey underlines and draw matrices
495
    for i,sub in enumerate(subexprs):
496
        a,b = sub.leftx, sub.rightx
497
        pad = view.wchar*0.1
498
        ax.plot([a-pad, b+pad], [view.liney,view.liney], '-', linewidth=.5, c=underline_color)
499
        view.draw(ax, sub)
500

501
    fig_width = charx[-1] + view.wchar + rpad[-1]
502
    fig_width_inches = fig_width / dpi
503
    fig_height_inches = view.maxy / dpi
504
    fig.set_size_inches(fig_width_inches, fig_height_inches)
505

506
    ax.set_xlim(0, fig_width)
507
    ax.set_ylim(0, view.maxy)
508

509
    return view
510

511

512
# ---------------- SHOW AST STUFF ---------------------------
513

514
class QuietGraphvizWrapper(graphviz.Source):
515
    def __init__(self, dotsrc):
516
        super().__init__(source=dotsrc)
517

518
    def _repr_svg_(self):
519
        return self.pipe(format='svg', quiet=True).decode(self._encoding)
520

521
    def savefig(self, filename):
522
        path = Path(filename)
523
        path.parent.mkdir(exist_ok=True)
524

525
        dotfilename = self.save(directory=path.parent.as_posix(), filename=path.stem)
526
        format = path.suffix[1:]  # ".svg" -> "svg" etc...
527
        cmd = ["dot", f"-T{format}", "-o", filename, dotfilename]
528
        # print(' '.join(cmd))
529
        if graphviz.__version__ <= '0.17':
530
            graphviz.backend.run(cmd, capture_output=True, check=True, quiet=False)
531
        else:
532
            graphviz.backend.execute.run_check(cmd, capture_output=True, check=True, quiet=False)
533

534

535
def astviz(statement:str, frame='current',
536
           dtype_colors=None, dtype_precisions=None, dtype_alpha_range=None) -> graphviz.Source:
537
    """
538
    Display the abstract syntax tree (AST) for the indicated Python code
539
    in statement. Evaluate that code in the context of frame. If the frame
540
    is not specified, the default is to execute the code within the context of
541
    the invoking code. Pass in frame=None to avoid evaluation and just display
542
    the AST.
543

544
    Returns a QuietGraphvizWrapper that renders as SVG in a notebook but
545
    you can also call `savefig()` to save the file and in a variety of formats,
546
    according to the file extension.
547
    """
548
    return QuietGraphvizWrapper(
549
        astviz_dot(statement, frame,
550
                   dtype_colors, dtype_precisions, dtype_alpha_range)
551
    )
552

553

554
def astviz_dot(statement:str, frame='current',
555
               dtype_colors=None, dtype_precisions=None, dtype_alpha_range=None) -> str:
556
    def internal_label(node):
557
        sh = tsensor.analysis._shape(node.value)
558
        ty = tsensor.analysis._dtype(node.value)
559
        text = ''.join(str(t) for t in node.optokens)
560
        if sh is None:
561
            return f'<font face="{fontname}" point-size="{fontsize}">{text}</font>'
562

563
        sz = 'x'.join([PyVizView.nabbrev(sh[i]) for i in range(len(sh))])
564
        return f"""<font face="Consolas" color="#444443" point-size="{fontsize}">{text}</font><br/><font face="Arial" color="#444443" point-size="{dimfontsize}">{sz}</font><br/><font face="Arial" color="#444443" point-size="{dimfontsize}">&lt;{ty}&gt;</font>"""
565

566
    dtype_color_info = DTypeColorInfo(dtype_colors, dtype_precisions, dtype_alpha_range)
567

568
    root, tokens = tsensor.parsing.parse(statement)
569

570
    if frame=='current': # use frame of caller if nothing passed in
571
        frame = sys._getframe().f_back
572
        if frame.f_code.co_name=='astviz':
573
            frame = frame.f_back
574

575
    if frame is not None: # if the passed in None, then don't do the evaluation
576
        root.eval(frame)
577

578
    nodes = tsensor.ast.postorder(root)
579
    atoms = tsensor.ast.leaves(root)
580
    atomsS = set(atoms)
581
    ops = [nd for nd in nodes if nd not in atomsS] # keep order
582

583
    gr = """digraph G {
584
        margin=0;
585
        nodesep=.01;
586
        ranksep=.3;
587
        rankdir=BT;
588
        ordering=out; # keep order of leaves
589
    """
590

591
    fontname="Consolas"
592
    fontsize=12
593
    dimfontsize = 9
594
    spread = 0
595

596
    # Gen leaf nodes
597
    for i in range(len(tokens)):
598
        t = tokens[i]
599
        if t.type!=token.ENDMARKER:
600
            nodetext = t.value
601
            # if ']' in nodetext:
602
            if nodetext==']':
603
                nodetext = nodetext.replace(']','&zwnj;]') # &zwnj; is 0-width nonjoiner. ']' by itself is bad for DOT
604
            label = f'<font face="{fontname}" color="#444443" point-size="{fontsize}">{nodetext}</font>'
605
            _spread = spread
606
            if t.type==token.DOT:
607
                _spread=.1
608
            elif t.type==token.EQUAL:
609
                _spread=.25
610
            elif t.type in tsensor.parsing.ADDOP:
611
                _spread=.4
612
            elif t.type in tsensor.parsing.MULOP:
613
                _spread=.2
614
            gr += f'leaf{id(t)} [shape=box penwidth=0 margin=.001 width={_spread} label=<{label}>]\n'
615

616
    # Make sure leaves are on same level
617
    gr += f'{{ rank=same; '
618
    for t in tokens:
619
        if t.type!=token.ENDMARKER:
620
            gr += f' leaf{id(t)}'
621
    gr += '\n}\n'
622

623
    # Make sure leaves are left to right by linking
624
    for i in range(len(tokens) - 2):
625
        t = tokens[i]
626
        t2 = tokens[i + 1]
627
        gr += f'leaf{id(t)} -> leaf{id(t2)} [style=invis];\n'
628

629
    # Draw internal ops nodes
630
    for nd in ops:
631
        label = internal_label(nd)
632
        sh = tsensor.analysis._shape(nd.value)
633
        if sh is None:
634
            color = ""
635
        else:
636
            ty = tsensor.analysis._dtype(nd.value)
637
            color = dtype_color_info.color(ty)
638
            color = mc.rgb2hex(color, keep_alpha=True)
639
            color = f'fillcolor="{color}" style=filled'
640
        gr += f'node{id(nd)} [shape=box {color} penwidth=0 margin=0 width=.25 height=.2 label=<{label}>]\n'
641

642
    # Link internal nodes to other nodes or leaves
643
    for nd in nodes:
644
        kids = nd.kids
645
        for sub in kids:
646
            if sub in atomsS:
647
                gr += f'node{id(nd)} -> leaf{id(sub.token)} [dir=back, penwidth="0.5", color="#6B6B6B", arrowsize=.3];\n'
648
            else:
649
                gr += f'node{id(nd)} -> node{id(sub)} [dir=back, penwidth="0.5", color="#6B6B6B", arrowsize=.3];\n'
650

651
    gr += "}\n"
652
    return gr
653

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.