8
from functools import partial
10
from .Compiler import Errors
11
from .CodeWriter import CodeWriter
12
from .Compiler.TreeFragment import TreeFragment, strip_common_indent, StringParseContext
13
from .Compiler.Visitor import TreeVisitor, VisitorTransform
14
from .Compiler import TreePath
15
from .Compiler.ParseTreeTransforms import PostParse
18
class NodeTypeWriter(TreeVisitor):
24
def visit_Node(self, node):
25
if not self.access_path:
28
tip = self.access_path[-1]
29
if tip[2] is not None:
30
name = "%s[%d]" % tip[1:3]
34
self.result.append(" " * self._indents +
35
"%s: %s" % (name, node.__class__.__name__))
37
self.visitchildren(node)
42
"""Returns a string representing the tree by class names.
43
There's a leading and trailing whitespace so that it can be
44
compared by simple string comparison while still making test
48
return "\n".join([""] + w.result + [""])
51
class CythonTest(unittest.TestCase):
59
def assertLines(self, expected, result):
60
"Checks that the given strings or lists of strings are equal line by line"
61
if not isinstance(expected, list):
62
expected = expected.split("\n")
63
if not isinstance(result, list):
64
result = result.split("\n")
65
for idx, (expected_line, result_line) in enumerate(zip(expected, result)):
66
self.assertEqual(expected_line, result_line,
67
"Line %d:\nExp: %s\nGot: %s" % (idx, expected_line, result_line))
68
self.assertEqual(len(expected), len(result),
69
"Unmatched lines. Got:\n%s\nExpected:\n%s" % ("\n".join(expected), "\n".join(result)))
71
def codeToLines(self, tree):
74
return writer.result.lines
76
def codeToString(self, tree):
77
return "\n".join(self.codeToLines(tree))
79
def assertCode(self, expected, result_tree):
80
result_lines = self.codeToLines(result_tree)
82
expected_lines = strip_common_indent(expected.split("\n"))
84
for idx, (line, expected_line) in enumerate(zip(result_lines, expected_lines)):
85
self.assertEqual(expected_line, line,
86
"Line %d:\nGot: %s\nExp: %s" % (idx, line, expected_line))
87
self.assertEqual(len(result_lines), len(expected_lines),
88
"Unmatched lines. Got:\n%s\nExpected:\n%s" % ("\n".join(result_lines), expected))
90
def assertNodeExists(self, path, result_tree):
91
self.assertNotEqual(TreePath.find_first(result_tree, path), None,
92
"Path '%s' not found in result tree" % path)
94
def fragment(self, code, pxds=None, pipeline=None):
95
"Simply create a tree fragment using the name of the test-case in parse errors."
101
if name.startswith("__main__."):
102
name = name[len("__main__."):]
103
name = name.replace(".", "_")
104
return TreeFragment(code, name, pxds, pipeline=pipeline)
106
def treetypes(self, root):
107
return treetypes(root)
109
def should_fail(self, func, exc_type=Exception):
110
"""Calls "func" and fails if it doesn't raise the right exception
111
(any exception by default). Also returns the exception in question.
115
self.fail("Expected an exception of type %r" % exc_type)
116
except exc_type as e:
117
self.assertTrue(isinstance(e, exc_type))
120
def should_not_fail(self, func):
121
"""Calls func and succeeds if and only if no exception is raised
122
(i.e. converts exception raising into a failed testcase). Returns
123
the return value of func."""
126
except Exception as exc:
130
class TransformTest(CythonTest):
132
Utility base class for transform unit tests. It is based around constructing
133
test trees (either explicitly or by parsing a Cython code string); running
134
the transform, serialize it using a customized Cython serializer (with
135
special markup for nodes that cannot be represented in Cython),
136
and do a string-comparison line-by-line of the result.
138
To create a test case:
139
- Call run_pipeline. The pipeline should at least contain the transform you
140
are testing; pyx should be either a string (passed to the parser to
141
create a post-parse tree) or a node representing input to pipeline.
142
The result will be a transformed result.
144
- Check that the tree is correct. If wanted, assertCode can be used, which
145
takes a code string as expected, and a ModuleNode in result_tree
146
(it serializes the ModuleNode to a string and compares line-by-line).
148
All code strings are first stripped for whitespace lines and then common
151
Plans: One could have a pxd dictionary parameter to run_pipeline.
154
def run_pipeline(self, pipeline, pyx, pxds=None):
157
tree = self.fragment(pyx, pxds).root
168
_strip_c_comments = partial(re.compile(
169
re.sub(r'\s+', '', r'''
171
(?: [^*\n] | [*][^/] )*
173
(?: [^*] | [*][^/] )*
178
_strip_cython_code_from_html = partial(re.compile(
179
re.sub(r'\s\s+', '', r'''
181
<pre class=["'][^"']*cython\s+line[^"']*["']\s*>
186
(?:[^<]|<(?!/style))+
193
def _parse_pattern(pattern):
195
if pattern.startswith('/'):
196
start, pattern = re.split(r"(?<!\\)/", pattern[1:], maxsplit=1)
197
pattern = pattern.strip()
198
if pattern.startswith(':'):
199
pattern = pattern[1:].strip()
200
if pattern.startswith("/"):
201
end, pattern = re.split(r"(?<!\\)/", pattern[1:], maxsplit=1)
202
pattern = pattern.strip()
203
return start, end, pattern
206
class TreeAssertVisitor(VisitorTransform):
212
self._module_pos = None
213
self._c_patterns = []
214
self._c_antipatterns = []
216
def create_c_file_validator(self):
217
patterns, antipatterns = self._c_patterns, self._c_antipatterns
219
def fail(pos, pattern, found, file_path):
220
Errors.error(pos, "Pattern '%s' %s found in %s" %(
222
'was' if found else 'was not',
226
def extract_section(file_path, content, start, end):
228
split = re.search(start, content)
230
content = content[split.end():]
232
fail(self._module_pos, start, found=False, file_path=file_path)
234
split = re.search(end, content)
236
content = content[:split.start()]
238
fail(self._module_pos, end, found=False, file_path=file_path)
241
def validate_file_content(file_path, content):
242
for pattern in patterns:
244
start, end, pattern = _parse_pattern(pattern)
245
section = extract_section(file_path, content, start, end)
246
if not re.search(pattern, section):
247
fail(self._module_pos, pattern, found=False, file_path=file_path)
249
for antipattern in antipatterns:
251
start, end, antipattern = _parse_pattern(antipattern)
252
section = extract_section(file_path, content, start, end)
253
if re.search(antipattern, section):
254
fail(self._module_pos, antipattern, found=True, file_path=file_path)
256
def validate_c_file(result):
257
c_file = result.c_file
258
if not (patterns or antipatterns):
262
with open(c_file, encoding='utf8') as f:
264
content = _strip_c_comments(content)
265
validate_file_content(c_file, content)
267
html_file = os.path.splitext(c_file)[0] + ".html"
268
if os.path.exists(html_file) and os.path.getmtime(c_file) <= os.path.getmtime(html_file):
269
with open(html_file, encoding='utf8') as f:
271
content = _strip_cython_code_from_html(content)
272
validate_file_content(html_file, content)
274
return validate_c_file
276
def _check_directives(self, node):
277
directives = node.directives
278
if 'test_assert_path_exists' in directives:
279
for path in directives['test_assert_path_exists']:
280
if TreePath.find_first(node, path) is None:
283
"Expected path '%s' not found in result tree" % path)
284
if 'test_fail_if_path_exists' in directives:
285
for path in directives['test_fail_if_path_exists']:
286
first_node = TreePath.find_first(node, path)
287
if first_node is not None:
290
"Unexpected path '%s' found in result tree" % path)
291
if 'test_assert_c_code_has' in directives:
292
self._c_patterns.extend(directives['test_assert_c_code_has'])
293
if 'test_fail_if_c_code_has' in directives:
294
self._c_antipatterns.extend(directives['test_fail_if_c_code_has'])
296
def visit_ModuleNode(self, node):
297
self._module_pos = node.pos
298
self._check_directives(node)
299
self.visitchildren(node)
302
def visit_CompilerDirectivesNode(self, node):
303
self._check_directives(node)
304
self.visitchildren(node)
307
visit_Node = VisitorTransform.recurse_to_children
310
def unpack_source_tree(tree_file, workdir, cython_root):
312
'PYTHON': [sys.executable],
313
'CYTHON': [sys.executable, os.path.join(cython_root, 'cython.py')],
314
'CYTHONIZE': [sys.executable, os.path.join(cython_root, 'cythonize.py')]
318
workdir = tempfile.mkdtemp()
319
header, cur_file = [], None
320
with open(tree_file, 'rb') as f:
323
if line[:5] == b'#####':
324
filename = line.strip().strip(b'#').strip().decode('utf8').replace('/', os.path.sep)
325
path = os.path.join(workdir, filename)
326
if not os.path.exists(os.path.dirname(path)):
327
os.makedirs(os.path.dirname(path))
328
if cur_file is not None:
329
to_close, cur_file = cur_file, None
331
cur_file = open(path, 'wb')
332
elif cur_file is not None:
334
elif line.strip() and not line.lstrip().startswith(b'#'):
335
if line.strip() not in (b'"""', b"'''"):
336
command = shlex.split(line.decode('utf8'))
337
if not command: continue
339
prog, args = command[0], command[1:]
341
header.append(programs[prog]+args)
343
header.append(command)
345
if cur_file is not None:
347
return workdir, header
350
def write_file(file_path, content, dedent=False, encoding=None):
351
r"""Write some content (text or bytes) to the file
352
at `file_path` without translating `'\n'` into `os.linesep`.
354
The default encoding is `'utf-8'`.
356
if isinstance(content, bytes):
361
default_encoding = None
368
default_encoding = "utf-8"
371
encoding = default_encoding
374
content = textwrap.dedent(content)
376
with open(file_path, mode=mode, encoding=encoding, newline=newline) as f:
380
def write_newer_file(file_path, newer_than, content, dedent=False, encoding=None):
382
Write `content` to the file `file_path` without translating `'\n'`
383
into `os.linesep` and make sure it is newer than the file `newer_than`.
385
The default encoding is `'utf-8'` (same as for `write_file`).
387
write_file(file_path, content, dedent=dedent, encoding=encoding)
390
other_time = os.path.getmtime(newer_than)
395
while other_time is None or other_time >= os.path.getmtime(file_path):
396
write_file(file_path, content, dedent=dedent, encoding=encoding)
399
def py_parse_code(code):
401
Compiles code far enough to get errors from the parser and post-parse stage.
403
Is useful for checking for syntax errors, however it doesn't generate runable
406
context = StringParseContext("test")
409
with Errors.local_errors() as errors:
410
result = TreeFragment(code, pipeline=[PostParse(context)])
411
result = result.substitute()
416
except Errors.CompileError as e:
417
raise SyntaxError(e.message_only)