cython

Форк
0
/
TestUtils.py 
417 строк · 15.0 Кб
1
import os
2
import re
3
import unittest
4
import shlex
5
import sys
6
import tempfile
7
import textwrap
8
from functools import partial
9

10
from .Compiler import Errors
11
from .CodeWriter import CodeWriter
12
from .Compiler.TreeFragment import TreeFragment, strip_common_indent, StringParseContext
13
from .Compiler.Visitor import TreeVisitor, VisitorTransform
14
from .Compiler import TreePath
15
from .Compiler.ParseTreeTransforms import PostParse
16

17

18
class NodeTypeWriter(TreeVisitor):
19
    def __init__(self):
20
        super().__init__()
21
        self._indents = 0
22
        self.result = []
23

24
    def visit_Node(self, node):
25
        if not self.access_path:
26
            name = "(root)"
27
        else:
28
            tip = self.access_path[-1]
29
            if tip[2] is not None:
30
                name = "%s[%d]" % tip[1:3]
31
            else:
32
                name = tip[1]
33

34
        self.result.append("  " * self._indents +
35
                           "%s: %s" % (name, node.__class__.__name__))
36
        self._indents += 1
37
        self.visitchildren(node)
38
        self._indents -= 1
39

40

41
def treetypes(root):
42
    """Returns a string representing the tree by class names.
43
    There's a leading and trailing whitespace so that it can be
44
    compared by simple string comparison while still making test
45
    cases look ok."""
46
    w = NodeTypeWriter()
47
    w.visit(root)
48
    return "\n".join([""] + w.result + [""])
49

50

51
class CythonTest(unittest.TestCase):
52

53
    def setUp(self):
54
        Errors.init_thread()
55

56
    def tearDown(self):
57
        Errors.init_thread()
58

59
    def assertLines(self, expected, result):
60
        "Checks that the given strings or lists of strings are equal line by line"
61
        if not isinstance(expected, list):
62
            expected = expected.split("\n")
63
        if not isinstance(result, list):
64
            result = result.split("\n")
65
        for idx, (expected_line, result_line) in enumerate(zip(expected, result)):
66
            self.assertEqual(expected_line, result_line,
67
                             "Line %d:\nExp: %s\nGot: %s" % (idx, expected_line, result_line))
68
        self.assertEqual(len(expected), len(result),
69
                         "Unmatched lines. Got:\n%s\nExpected:\n%s" % ("\n".join(expected), "\n".join(result)))
70

71
    def codeToLines(self, tree):
72
        writer = CodeWriter()
73
        writer.write(tree)
74
        return writer.result.lines
75

76
    def codeToString(self, tree):
77
        return "\n".join(self.codeToLines(tree))
78

79
    def assertCode(self, expected, result_tree):
80
        result_lines = self.codeToLines(result_tree)
81

82
        expected_lines = strip_common_indent(expected.split("\n"))
83

84
        for idx, (line, expected_line) in enumerate(zip(result_lines, expected_lines)):
85
            self.assertEqual(expected_line, line,
86
                             "Line %d:\nGot: %s\nExp: %s" % (idx, line, expected_line))
87
        self.assertEqual(len(result_lines), len(expected_lines),
88
                         "Unmatched lines. Got:\n%s\nExpected:\n%s" % ("\n".join(result_lines), expected))
89

90
    def assertNodeExists(self, path, result_tree):
91
        self.assertNotEqual(TreePath.find_first(result_tree, path), None,
92
                            "Path '%s' not found in result tree" % path)
93

94
    def fragment(self, code, pxds=None, pipeline=None):
95
        "Simply create a tree fragment using the name of the test-case in parse errors."
96
        if pxds is None:
97
            pxds = {}
98
        if pipeline is None:
99
            pipeline = []
100
        name = self.id()
101
        if name.startswith("__main__."):
102
            name = name[len("__main__."):]
103
        name = name.replace(".", "_")
104
        return TreeFragment(code, name, pxds, pipeline=pipeline)
105

106
    def treetypes(self, root):
107
        return treetypes(root)
108

109
    def should_fail(self, func, exc_type=Exception):
110
        """Calls "func" and fails if it doesn't raise the right exception
111
        (any exception by default). Also returns the exception in question.
112
        """
113
        try:
114
            func()
115
            self.fail("Expected an exception of type %r" % exc_type)
116
        except exc_type as e:
117
            self.assertTrue(isinstance(e, exc_type))
118
            return e
119

120
    def should_not_fail(self, func):
121
        """Calls func and succeeds if and only if no exception is raised
122
        (i.e. converts exception raising into a failed testcase). Returns
123
        the return value of func."""
124
        try:
125
            return func()
126
        except Exception as exc:
127
            self.fail(str(exc))
128

129

130
class TransformTest(CythonTest):
131
    """
132
    Utility base class for transform unit tests. It is based around constructing
133
    test trees (either explicitly or by parsing a Cython code string); running
134
    the transform, serialize it using a customized Cython serializer (with
135
    special markup for nodes that cannot be represented in Cython),
136
    and do a string-comparison line-by-line of the result.
137

138
    To create a test case:
139
     - Call run_pipeline. The pipeline should at least contain the transform you
140
       are testing; pyx should be either a string (passed to the parser to
141
       create a post-parse tree) or a node representing input to pipeline.
142
       The result will be a transformed result.
143

144
     - Check that the tree is correct. If wanted, assertCode can be used, which
145
       takes a code string as expected, and a ModuleNode in result_tree
146
       (it serializes the ModuleNode to a string and compares line-by-line).
147

148
    All code strings are first stripped for whitespace lines and then common
149
    indentation.
150

151
    Plans: One could have a pxd dictionary parameter to run_pipeline.
152
    """
153

154
    def run_pipeline(self, pipeline, pyx, pxds=None):
155
        if pxds is None:
156
            pxds = {}
157
        tree = self.fragment(pyx, pxds).root
158
        # Run pipeline
159
        for T in pipeline:
160
            tree = T(tree)
161
        return tree
162

163

164
# For the test C code validation, we have to take care that the test directives (and thus
165
# the match strings) do not just appear in (multiline) C code comments containing the original
166
# Cython source code.  Thus, we discard the comments before matching.
167
# This seems a prime case for re.VERBOSE, but it seems to match some of the whitespace.
168
_strip_c_comments = partial(re.compile(
169
    re.sub(r'\s+', '', r'''
170
        /[*] (
171
            (?: [^*\n] | [*][^/] )*
172
            [\n]
173
            (?: [^*] | [*][^/] )*
174
        ) [*]/
175
    ''')
176
).sub, '')
177

178
_strip_cython_code_from_html = partial(re.compile(
179
    re.sub(r'\s\s+', '', r'''
180
    (?:
181
        <pre class=["'][^"']*cython\s+line[^"']*["']\s*>
182
        (?:[^<]|<(?!/pre))+
183
        </pre>
184
    )|(?:
185
        <style[^>]*>
186
        (?:[^<]|<(?!/style))+
187
        </style>
188
    )
189
    ''')
190
).sub, '')
191

192

193
def _parse_pattern(pattern):
194
    start = end = None
195
    if pattern.startswith('/'):
196
        start, pattern = re.split(r"(?<!\\)/", pattern[1:], maxsplit=1)
197
        pattern = pattern.strip()
198
    if pattern.startswith(':'):
199
        pattern = pattern[1:].strip()
200
        if pattern.startswith("/"):
201
            end, pattern = re.split(r"(?<!\\)/", pattern[1:], maxsplit=1)
202
            pattern = pattern.strip()
203
    return start, end, pattern
204

205

206
class TreeAssertVisitor(VisitorTransform):
207
    # actually, a TreeVisitor would be enough, but this needs to run
208
    # as part of the compiler pipeline
209

210
    def __init__(self):
211
        super().__init__()
212
        self._module_pos = None
213
        self._c_patterns = []
214
        self._c_antipatterns = []
215

216
    def create_c_file_validator(self):
217
        patterns, antipatterns = self._c_patterns, self._c_antipatterns
218

219
        def fail(pos, pattern, found, file_path):
220
            Errors.error(pos, "Pattern '%s' %s found in %s" %(
221
                pattern,
222
                'was' if found else 'was not',
223
                file_path,
224
            ))
225

226
        def extract_section(file_path, content, start, end):
227
            if start:
228
                split = re.search(start, content)
229
                if split:
230
                    content = content[split.end():]
231
                else:
232
                    fail(self._module_pos, start, found=False, file_path=file_path)
233
            if end:
234
                split = re.search(end, content)
235
                if split:
236
                    content = content[:split.start()]
237
                else:
238
                    fail(self._module_pos, end, found=False, file_path=file_path)
239
            return content
240

241
        def validate_file_content(file_path, content):
242
            for pattern in patterns:
243
                #print("Searching pattern '%s'" % pattern)
244
                start, end, pattern = _parse_pattern(pattern)
245
                section = extract_section(file_path, content, start, end)
246
                if not re.search(pattern, section):
247
                    fail(self._module_pos, pattern, found=False, file_path=file_path)
248

249
            for antipattern in antipatterns:
250
                #print("Searching antipattern '%s'" % antipattern)
251
                start, end, antipattern = _parse_pattern(antipattern)
252
                section = extract_section(file_path, content, start, end)
253
                if re.search(antipattern, section):
254
                    fail(self._module_pos, antipattern, found=True, file_path=file_path)
255

256
        def validate_c_file(result):
257
            c_file = result.c_file
258
            if not (patterns or antipatterns):
259
                #print("No patterns defined for %s" % c_file)
260
                return result
261

262
            with open(c_file, encoding='utf8') as f:
263
                content = f.read()
264
            content = _strip_c_comments(content)
265
            validate_file_content(c_file, content)
266

267
            html_file = os.path.splitext(c_file)[0] + ".html"
268
            if os.path.exists(html_file) and os.path.getmtime(c_file) <= os.path.getmtime(html_file):
269
                with open(html_file, encoding='utf8') as f:
270
                    content = f.read()
271
                content = _strip_cython_code_from_html(content)
272
                validate_file_content(html_file, content)
273

274
        return validate_c_file
275

276
    def _check_directives(self, node):
277
        directives = node.directives
278
        if 'test_assert_path_exists' in directives:
279
            for path in directives['test_assert_path_exists']:
280
                if TreePath.find_first(node, path) is None:
281
                    Errors.error(
282
                        node.pos,
283
                        "Expected path '%s' not found in result tree" % path)
284
        if 'test_fail_if_path_exists' in directives:
285
            for path in directives['test_fail_if_path_exists']:
286
                first_node = TreePath.find_first(node, path)
287
                if first_node is not None:
288
                    Errors.error(
289
                        first_node.pos,
290
                        "Unexpected path '%s' found in result tree" % path)
291
        if 'test_assert_c_code_has' in directives:
292
            self._c_patterns.extend(directives['test_assert_c_code_has'])
293
        if 'test_fail_if_c_code_has' in directives:
294
            self._c_antipatterns.extend(directives['test_fail_if_c_code_has'])
295

296
    def visit_ModuleNode(self, node):
297
        self._module_pos = node.pos
298
        self._check_directives(node)
299
        self.visitchildren(node)
300
        return node
301

302
    def visit_CompilerDirectivesNode(self, node):
303
        self._check_directives(node)
304
        self.visitchildren(node)
305
        return node
306

307
    visit_Node = VisitorTransform.recurse_to_children
308

309

310
def unpack_source_tree(tree_file, workdir, cython_root):
311
    programs = {
312
        'PYTHON': [sys.executable],
313
        'CYTHON': [sys.executable, os.path.join(cython_root, 'cython.py')],
314
        'CYTHONIZE': [sys.executable, os.path.join(cython_root, 'cythonize.py')]
315
    }
316

317
    if workdir is None:
318
        workdir = tempfile.mkdtemp()
319
    header, cur_file = [], None
320
    with open(tree_file, 'rb') as f:
321
        try:
322
            for line in f:
323
                if line[:5] == b'#####':
324
                    filename = line.strip().strip(b'#').strip().decode('utf8').replace('/', os.path.sep)
325
                    path = os.path.join(workdir, filename)
326
                    if not os.path.exists(os.path.dirname(path)):
327
                        os.makedirs(os.path.dirname(path))
328
                    if cur_file is not None:
329
                        to_close, cur_file = cur_file, None
330
                        to_close.close()
331
                    cur_file = open(path, 'wb')
332
                elif cur_file is not None:
333
                    cur_file.write(line)
334
                elif line.strip() and not line.lstrip().startswith(b'#'):
335
                    if line.strip() not in (b'"""', b"'''"):
336
                        command = shlex.split(line.decode('utf8'))
337
                        if not command: continue
338
                        # In Python 3: prog, *args = command
339
                        prog, args = command[0], command[1:]
340
                        try:
341
                            header.append(programs[prog]+args)
342
                        except KeyError:
343
                            header.append(command)
344
        finally:
345
            if cur_file is not None:
346
                cur_file.close()
347
    return workdir, header
348

349

350
def write_file(file_path, content, dedent=False, encoding=None):
351
    r"""Write some content (text or bytes) to the file
352
    at `file_path` without translating `'\n'` into `os.linesep`.
353

354
    The default encoding is `'utf-8'`.
355
    """
356
    if isinstance(content, bytes):
357
        mode = "wb"
358

359
        # binary mode doesn't take an encoding and newline arguments
360
        newline = None
361
        default_encoding = None
362
    else:
363
        mode = "w"
364

365
        # any "\n" characters written are not translated
366
        # to the system default line separator, os.linesep
367
        newline = "\n"
368
        default_encoding = "utf-8"
369

370
    if encoding is None:
371
        encoding = default_encoding
372

373
    if dedent:
374
        content = textwrap.dedent(content)
375

376
    with open(file_path, mode=mode, encoding=encoding, newline=newline) as f:
377
        f.write(content)
378

379

380
def write_newer_file(file_path, newer_than, content, dedent=False, encoding=None):
381
    r"""
382
    Write `content` to the file `file_path` without translating `'\n'`
383
    into `os.linesep` and make sure it is newer than the file `newer_than`.
384

385
    The default encoding is `'utf-8'` (same as for `write_file`).
386
    """
387
    write_file(file_path, content, dedent=dedent, encoding=encoding)
388

389
    try:
390
        other_time = os.path.getmtime(newer_than)
391
    except OSError:
392
        # Support writing a fresh file (which is always newer than a non-existent one)
393
        other_time = None
394

395
    while other_time is None or other_time >= os.path.getmtime(file_path):
396
        write_file(file_path, content, dedent=dedent, encoding=encoding)
397

398

399
def py_parse_code(code):
400
    """
401
    Compiles code far enough to get errors from the parser and post-parse stage.
402

403
    Is useful for checking for syntax errors, however it doesn't generate runable
404
    code.
405
    """
406
    context = StringParseContext("test")
407
    # all the errors we care about are in the parsing or postparse stage
408
    try:
409
        with Errors.local_errors() as errors:
410
            result = TreeFragment(code, pipeline=[PostParse(context)])
411
            result = result.substitute()
412
        if errors:
413
            raise errors[0]  # compile error, which should get caught below
414
        else:
415
            return result
416
    except Errors.CompileError as e:
417
        raise SyntaxError(e.message_only)
418

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.