matplotlib

Форк
0
/
boilerplate.py 
464 строки · 13.3 Кб
1
"""
2
Script to autogenerate pyplot wrappers.
3

4
When this script is run, the current contents of pyplot are
5
split into generatable and non-generatable content (via the magic header
6
:attr:`PYPLOT_MAGIC_HEADER`) and the generatable content is overwritten.
7
Hence, the non-generatable content should be edited in the pyplot.py file
8
itself, whereas the generatable content must be edited via templates in
9
this file.
10
"""
11

12
# Although it is possible to dynamically generate the pyplot functions at
13
# runtime with the proper signatures, a static pyplot.py is simpler for static
14
# analysis tools to parse.
15

16
import ast
17
from enum import Enum
18
import functools
19
import inspect
20
from inspect import Parameter
21
from pathlib import Path
22
import sys
23
import subprocess
24

25

26
# This line imports the installed copy of matplotlib, and not the local copy.
27
import numpy as np
28
from matplotlib import _api, mlab
29
from matplotlib.axes import Axes
30
from matplotlib.figure import Figure
31

32

33
# This is the magic line that must exist in pyplot, after which the boilerplate
34
# content will be appended.
35
PYPLOT_MAGIC_HEADER = (
36
    "################# REMAINING CONTENT GENERATED BY boilerplate.py "
37
    "##############\n")
38

39
AUTOGEN_MSG = """
40

41
# Autogenerated by boilerplate.py.  Do not edit as changes will be lost."""
42

43
AXES_CMAPPABLE_METHOD_TEMPLATE = AUTOGEN_MSG + """
44
@_copy_docstring_and_deprecators(Axes.{called_name})
45
def {name}{signature}:
46
    __ret = gca().{called_name}{call}
47
    {sci_command}
48
    return __ret
49
"""
50

51
AXES_METHOD_TEMPLATE = AUTOGEN_MSG + """
52
@_copy_docstring_and_deprecators(Axes.{called_name})
53
def {name}{signature}:
54
    {return_statement}gca().{called_name}{call}
55
"""
56

57
FIGURE_METHOD_TEMPLATE = AUTOGEN_MSG + """
58
@_copy_docstring_and_deprecators(Figure.{called_name})
59
def {name}{signature}:
60
    {return_statement}gcf().{called_name}{call}
61
"""
62

63
CMAP_TEMPLATE = '''
64
def {name}() -> None:
65
    """
66
    Set the colormap to {name!r}.
67

68
    This changes the default colormap as well as the colormap of the current
69
    image if there is one. See ``help(colormaps)`` for more information.
70
    """
71
    set_cmap({name!r})
72
'''  # Colormap functions.
73

74

75
class value_formatter:
76
    """
77
    Format function default values as needed for inspect.formatargspec.
78
    The interesting part is a hard-coded list of functions used
79
    as defaults in pyplot methods.
80
    """
81

82
    def __init__(self, value):
83
        if value is mlab.detrend_none:
84
            self._repr = "mlab.detrend_none"
85
        elif value is mlab.window_hanning:
86
            self._repr = "mlab.window_hanning"
87
        elif value is np.mean:
88
            self._repr = "np.mean"
89
        elif value is _api.deprecation._deprecated_parameter:
90
            self._repr = "_api.deprecation._deprecated_parameter"
91
        elif isinstance(value, Enum):
92
            # Enum str is Class.Name whereas their repr is <Class.Name: value>.
93
            self._repr = f'{type(value).__name__}.{value.name}'
94
        else:
95
            self._repr = repr(value)
96

97
    def __repr__(self):
98
        return self._repr
99

100

101
class direct_repr:
102
    """
103
    A placeholder class to destringify annotations from ast
104
    """
105
    def __init__(self, value):
106
        self._repr = value
107

108
    def __repr__(self):
109
        return self._repr
110

111

112
def generate_function(name, called_fullname, template, **kwargs):
113
    """
114
    Create a wrapper function *pyplot_name* calling *call_name*.
115

116
    Parameters
117
    ----------
118
    name : str
119
        The function to be created.
120
    called_fullname : str
121
        The method to be wrapped in the format ``"Class.method"``.
122
    template : str
123
        The template to be used. The template must contain {}-style format
124
        placeholders. The following placeholders are filled in:
125

126
        - name: The function name.
127
        - signature: The function signature (including parentheses).
128
        - called_name: The name of the called function.
129
        - call: Parameters passed to *called_name* (including parentheses).
130

131
    **kwargs
132
        Additional parameters are passed to ``template.format()``.
133
    """
134
    # Get signature of wrapped function.
135
    class_name, called_name = called_fullname.split('.')
136
    class_ = {'Axes': Axes, 'Figure': Figure}[class_name]
137

138
    meth = getattr(class_, called_name)
139
    decorator = _api.deprecation.DECORATORS.get(meth)
140
    # Generate the wrapper with the non-kwonly signature, as it will get
141
    # redecorated with make_keyword_only by _copy_docstring_and_deprecators.
142
    if decorator and decorator.func is _api.make_keyword_only:
143
        meth = meth.__wrapped__
144

145
    annotated_trees = get_ast_mro_trees(class_)
146
    signature = get_matching_signature(meth, annotated_trees)
147

148
    # Replace self argument.
149
    params = list(signature.parameters.values())[1:]
150
    has_return_value = str(signature.return_annotation) != 'None'
151
    signature = str(signature.replace(parameters=[
152
        param.replace(default=value_formatter(param.default))
153
        if param.default is not param.empty else param
154
        for param in params]))
155
    # How to call the wrapped function.
156
    call = '(' + ', '.join((
157
           # Pass "intended-as-positional" parameters positionally to avoid
158
           # forcing third-party subclasses to reproduce the parameter names.
159
           '{0}'
160
           if param.kind in [
161
               Parameter.POSITIONAL_OR_KEYWORD]
162
              and param.default is Parameter.empty else
163
           # Only pass the data kwarg if it is actually set, to avoid forcing
164
           # third-party subclasses to support it.
165
           '**({{"data": data}} if data is not None else {{}})'
166
           if param.name == "data" else
167
           '{0}={0}'
168
           if param.kind in [
169
               Parameter.POSITIONAL_OR_KEYWORD,
170
               Parameter.KEYWORD_ONLY] else
171
           '{0}'
172
           if param.kind is Parameter.POSITIONAL_ONLY else
173
           '*{0}'
174
           if param.kind is Parameter.VAR_POSITIONAL else
175
           '**{0}'
176
           if param.kind is Parameter.VAR_KEYWORD else
177
           None).format(param.name)
178
       for param in params) + ')'
179
    return_statement = 'return ' if has_return_value else ''
180
    # Bail out in case of name collision.
181
    for reserved in ('gca', 'gci', 'gcf', '__ret'):
182
        if reserved in params:
183
            raise ValueError(
184
                f'Method {called_fullname} has kwarg named {reserved}')
185

186
    return template.format(
187
        name=name,
188
        called_name=called_name,
189
        signature=signature,
190
        call=call,
191
        return_statement=return_statement,
192
        **kwargs)
193

194

195
def boilerplate_gen():
196
    """Generator of lines for the automated part of pyplot."""
197

198
    _figure_commands = (
199
        'figimage',
200
        'figtext:text',
201
        'gca',
202
        'gci:_gci',
203
        'ginput',
204
        'subplots_adjust',
205
        'suptitle',
206
        'tight_layout',
207
        'waitforbuttonpress',
208
    )
209

210
    # These methods are all simple wrappers of Axes methods by the same name.
211
    _axes_commands = (
212
        'acorr',
213
        'angle_spectrum',
214
        'annotate',
215
        'arrow',
216
        'autoscale',
217
        'axhline',
218
        'axhspan',
219
        'axis',
220
        'axline',
221
        'axvline',
222
        'axvspan',
223
        'bar',
224
        'barbs',
225
        'barh',
226
        'bar_label',
227
        'boxplot',
228
        'broken_barh',
229
        'clabel',
230
        'cohere',
231
        'contour',
232
        'contourf',
233
        'csd',
234
        'ecdf',
235
        'errorbar',
236
        'eventplot',
237
        'fill',
238
        'fill_between',
239
        'fill_betweenx',
240
        'grid',
241
        'hexbin',
242
        'hist',
243
        'stairs',
244
        'hist2d',
245
        'hlines',
246
        'imshow',
247
        'legend',
248
        'locator_params',
249
        'loglog',
250
        'magnitude_spectrum',
251
        'margins',
252
        'minorticks_off',
253
        'minorticks_on',
254
        'pcolor',
255
        'pcolormesh',
256
        'phase_spectrum',
257
        'pie',
258
        'plot',
259
        'plot_date',
260
        'psd',
261
        'quiver',
262
        'quiverkey',
263
        'scatter',
264
        'semilogx',
265
        'semilogy',
266
        'specgram',
267
        'spy',
268
        'stackplot',
269
        'stem',
270
        'step',
271
        'streamplot',
272
        'table',
273
        'text',
274
        'tick_params',
275
        'ticklabel_format',
276
        'tricontour',
277
        'tricontourf',
278
        'tripcolor',
279
        'triplot',
280
        'violinplot',
281
        'vlines',
282
        'xcorr',
283
        # pyplot name : real name
284
        'sci:_sci',
285
        'title:set_title',
286
        'xlabel:set_xlabel',
287
        'ylabel:set_ylabel',
288
        'xscale:set_xscale',
289
        'yscale:set_yscale',
290
    )
291

292
    cmappable = {
293
        'contour': (
294
            'if __ret._A is not None:  # type: ignore[attr-defined]\n'
295
            '        sci(__ret)'
296
        ),
297
        'contourf': (
298
            'if __ret._A is not None:  # type: ignore[attr-defined]\n'
299
            '        sci(__ret)'
300
        ),
301
        'hexbin': 'sci(__ret)',
302
        'scatter': 'sci(__ret)',
303
        'pcolor': 'sci(__ret)',
304
        'pcolormesh': 'sci(__ret)',
305
        'hist2d': 'sci(__ret[-1])',
306
        'imshow': 'sci(__ret)',
307
        'spy': (
308
            'if isinstance(__ret, cm.ScalarMappable):\n'
309
            '        sci(__ret)'
310
        ),
311
        'quiver': 'sci(__ret)',
312
        'specgram': 'sci(__ret[-1])',
313
        'streamplot': 'sci(__ret.lines)',
314
        'tricontour': (
315
            'if __ret._A is not None:  # type: ignore[attr-defined]\n'
316
            '        sci(__ret)'
317
        ),
318
        'tricontourf': (
319
            'if __ret._A is not None:  # type: ignore[attr-defined]\n'
320
            '        sci(__ret)'
321
        ),
322
        'tripcolor': 'sci(__ret)',
323
    }
324

325
    for spec in _figure_commands:
326
        if ':' in spec:
327
            name, called_name = spec.split(':')
328
        else:
329
            name = called_name = spec
330
        yield generate_function(name, f'Figure.{called_name}',
331
                                FIGURE_METHOD_TEMPLATE)
332

333
    for spec in _axes_commands:
334
        if ':' in spec:
335
            name, called_name = spec.split(':')
336
        else:
337
            name = called_name = spec
338

339
        template = (AXES_CMAPPABLE_METHOD_TEMPLATE if name in cmappable else
340
                    AXES_METHOD_TEMPLATE)
341
        yield generate_function(name, f'Axes.{called_name}', template,
342
                                sci_command=cmappable.get(name))
343

344
    cmaps = (
345
        'autumn',
346
        'bone',
347
        'cool',
348
        'copper',
349
        'flag',
350
        'gray',
351
        'hot',
352
        'hsv',
353
        'jet',
354
        'pink',
355
        'prism',
356
        'spring',
357
        'summer',
358
        'winter',
359
        'magma',
360
        'inferno',
361
        'plasma',
362
        'viridis',
363
        "nipy_spectral"
364
    )
365
    # add all the colormaps (autumn, hsv, ....)
366
    for name in cmaps:
367
        yield AUTOGEN_MSG
368
        yield CMAP_TEMPLATE.format(name=name)
369

370

371
def build_pyplot(pyplot_path):
372
    pyplot_orig = pyplot_path.read_text().splitlines(keepends=True)
373
    try:
374
        pyplot_orig = pyplot_orig[:pyplot_orig.index(PYPLOT_MAGIC_HEADER) + 1]
375
    except IndexError as err:
376
        raise ValueError('The pyplot.py file *must* have the exact line: %s'
377
                         % PYPLOT_MAGIC_HEADER) from err
378

379
    with pyplot_path.open('w') as pyplot:
380
        pyplot.writelines(pyplot_orig)
381
        pyplot.writelines(boilerplate_gen())
382

383
    # Run black to autoformat pyplot
384
    subprocess.run(
385
        [sys.executable, "-m", "black", "--line-length=88", pyplot_path],
386
        check=True
387
    )
388

389

390
### Methods for retrieving signatures from pyi stub files
391

392
def get_ast_tree(cls):
393
    path = Path(inspect.getfile(cls))
394
    stubpath = path.with_suffix(".pyi")
395
    path = stubpath if stubpath.exists() else path
396
    tree = ast.parse(path.read_text())
397
    for item in tree.body:
398
        if isinstance(item, ast.ClassDef) and item.name == cls.__name__:
399
            return item
400
    raise ValueError(f"Cannot find {cls.__name__} in ast")
401

402

403
@functools.lru_cache
404
def get_ast_mro_trees(cls):
405
    return [get_ast_tree(c) for c in cls.__mro__ if c.__module__ != "builtins"]
406

407

408
def get_matching_signature(method, trees):
409
    sig = inspect.signature(method)
410
    for tree in trees:
411
        for item in tree.body:
412
            if not isinstance(item, ast.FunctionDef):
413
                continue
414
            if item.name == method.__name__:
415
                return update_sig_from_node(item, sig)
416
    # The following methods are implemented outside of the mro of Axes
417
    # and thus do not get their annotated versions found with current code
418
    #     stackplot
419
    #     streamplot
420
    #     table
421
    #     tricontour
422
    #     tricontourf
423
    #     tripcolor
424
    #     triplot
425

426
    # import warnings
427
    # warnings.warn(f"'{method.__name__}' not found")
428
    return sig
429

430

431
def update_sig_from_node(node, sig):
432
    params = dict(sig.parameters)
433
    args = node.args
434
    allargs = (
435
        *args.posonlyargs,
436
        *args.args,
437
        args.vararg,
438
        *args.kwonlyargs,
439
        args.kwarg,
440
    )
441
    for param in allargs:
442
        if param is None:
443
            continue
444
        if param.annotation is None:
445
            continue
446
        annotation = direct_repr(ast.unparse(param.annotation))
447
        params[param.arg] = params[param.arg].replace(annotation=annotation)
448

449
    if node.returns is not None:
450
        return inspect.Signature(
451
            params.values(),
452
            return_annotation=direct_repr(ast.unparse(node.returns))
453
        )
454
    else:
455
        return inspect.Signature(params.values())
456

457

458
if __name__ == '__main__':
459
    # Write the matplotlib.pyplot file.
460
    if len(sys.argv) > 1:
461
        pyplot_path = Path(sys.argv[1])
462
    else:
463
        pyplot_path = Path(__file__).parent / "../lib/matplotlib/pyplot.py"
464
    build_pyplot(pyplot_path)
465

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.