1
from .Errors import error, message
5
from . import PyrexTypes
7
from .PyrexTypes import py_object_type, unspecified_type
8
from .Visitor import CythonTransform, EnvTransform
10
from functools import reduce
13
class TypedExprNode(ExprNodes.ExprNode):
14
# Used for declaring assignments of a specified type without a known entry.
17
def __init__(self, type, pos=None):
18
super().__init__(pos, type=type)
20
object_expr = TypedExprNode(py_object_type)
23
class MarkParallelAssignments(EnvTransform):
24
# Collects assignments inside parallel blocks prange, with parallel.
25
# Perhaps it's better to move it to ControlFlowAnalysis.
27
# tells us whether we're in a normal loop
30
parallel_errors = False
32
def __init__(self, context):
33
# Track the parallel block scopes (with parallel, for i in prange())
34
self.parallel_block_stack = []
35
super().__init__(context)
37
def mark_assignment(self, lhs, rhs, inplace_op=None):
38
if isinstance(lhs, (ExprNodes.NameNode, Nodes.PyArgDeclNode)):
40
# TODO: This shouldn't happen...
43
if self.parallel_block_stack:
44
parallel_node = self.parallel_block_stack[-1]
45
previous_assignment = parallel_node.assignments.get(lhs.entry)
47
# If there was a previous assignment to the variable, keep the
48
# previous assignment position
49
if previous_assignment:
50
pos, previous_inplace_op = previous_assignment
52
if (inplace_op and previous_inplace_op and
53
inplace_op != previous_inplace_op):
55
t = (inplace_op, previous_inplace_op)
57
"Reduction operator '%s' is inconsistent "
58
"with previous reduction operator '%s'" % t)
62
parallel_node.assignments[lhs.entry] = (pos, inplace_op)
63
parallel_node.assigned_nodes.append(lhs)
65
elif isinstance(lhs, ExprNodes.SequenceNode):
66
for i, arg in enumerate(lhs.args):
67
if not rhs or arg.is_starred:
70
item_node = rhs.inferable_item_node(i)
71
self.mark_assignment(arg, item_node)
73
# Could use this info to infer cdef class attributes...
76
def visit_WithTargetAssignmentStatNode(self, node):
77
self.mark_assignment(node.lhs, node.with_node.enter_call)
78
self.visitchildren(node)
81
def visit_SingleAssignmentNode(self, node):
82
self.mark_assignment(node.lhs, node.rhs)
83
self.visitchildren(node)
86
def visit_CascadedAssignmentNode(self, node):
87
for lhs in node.lhs_list:
88
self.mark_assignment(lhs, node.rhs)
89
self.visitchildren(node)
92
def visit_InPlaceAssignmentNode(self, node):
93
self.mark_assignment(node.lhs, node.create_binop_node(), node.operator)
94
self.visitchildren(node)
97
def visit_ForInStatNode(self, node):
98
# TODO: Remove redundancy with range optimization...
100
sequence = node.iterator.sequence
102
iterator_scope = node.iterator.expr_scope or self.current_env()
103
if isinstance(sequence, ExprNodes.SimpleCallNode):
104
function = sequence.function
105
if sequence.self is None and function.is_name:
106
entry = iterator_scope.lookup(function.name)
107
if not entry or entry.is_builtin:
108
if function.name == 'reversed' and len(sequence.args) == 1:
109
sequence = sequence.args[0]
110
elif function.name == 'enumerate' and len(sequence.args) == 1:
111
if target.is_sequence_constructor and len(target.args) == 2:
112
iterator = sequence.args[0]
114
iterator_type = iterator.infer_type(iterator_scope)
115
if iterator_type.is_builtin_type:
116
# assume that builtin types have a length within Py_ssize_t
117
self.mark_assignment(
119
ExprNodes.IntNode(target.pos, value='PY_SSIZE_T_MAX',
120
type=PyrexTypes.c_py_ssize_t_type))
121
target = target.args[1]
122
sequence = sequence.args[0]
123
if isinstance(sequence, ExprNodes.SimpleCallNode):
124
function = sequence.function
125
if sequence.self is None and function.is_name:
126
entry = iterator_scope.lookup(function.name)
127
if not entry or entry.is_builtin:
128
if function.name in ('range', 'xrange'):
130
for arg in sequence.args[:2]:
131
self.mark_assignment(target, arg)
132
if len(sequence.args) > 2:
133
self.mark_assignment(
135
ExprNodes.binop_node(node.pos,
140
# A for-loop basically translates to subsequent calls to
141
# __getitem__(), so using an IndexNode here allows us to
142
# naturally infer the base type of pointers, C arrays,
143
# Python strings, etc., while correctly falling back to an
144
# object type when the base type cannot be handled.
145
self.mark_assignment(target, ExprNodes.IndexNode(
148
index=ExprNodes.IntNode(target.pos, value='PY_SSIZE_T_MAX',
149
type=PyrexTypes.c_py_ssize_t_type)))
151
self.visitchildren(node)
154
def visit_ForFromStatNode(self, node):
155
self.mark_assignment(node.target, node.bound1)
156
if node.step is not None:
157
self.mark_assignment(node.target,
158
ExprNodes.binop_node(node.pos,
162
self.visitchildren(node)
165
def visit_WhileStatNode(self, node):
166
self.visitchildren(node)
169
def visit_ExceptClauseNode(self, node):
170
if node.target is not None:
171
self.mark_assignment(node.target, object_expr)
172
self.visitchildren(node)
175
def visit_FromCImportStatNode(self, node):
176
return node # Can't be assigned to...
178
def visit_FromImportStatNode(self, node):
179
for name, target in node.items:
181
self.mark_assignment(target, object_expr)
182
self.visitchildren(node)
185
def visit_DefNode(self, node):
186
# use fake expressions with the right result type
188
self.mark_assignment(
189
node.star_arg, TypedExprNode(Builtin.tuple_type, node.pos))
190
if node.starstar_arg:
191
self.mark_assignment(
192
node.starstar_arg, TypedExprNode(Builtin.dict_type, node.pos))
193
EnvTransform.visit_FuncDefNode(self, node)
196
def visit_DelStatNode(self, node):
197
for arg in node.args:
198
self.mark_assignment(arg, arg)
199
self.visitchildren(node)
202
def visit_ParallelStatNode(self, node):
203
if self.parallel_block_stack:
204
node.parent = self.parallel_block_stack[-1]
211
node.is_parallel = True
213
node.is_parallel = (node.parent.is_prange or not
214
node.parent.is_parallel)
215
nested = node.parent.is_prange
217
node.is_parallel = True
218
# Note: nested with parallel() blocks are handled by
219
# ParallelRangeTransform!
220
# nested = node.parent
221
nested = node.parent and node.parent.is_prange
223
self.parallel_block_stack.append(node)
225
nested = nested or len(self.parallel_block_stack) > 2
226
if not self.parallel_errors and nested and not node.is_prange:
227
error(node.pos, "Only prange() may be nested")
228
self.parallel_errors = True
231
self.visitchildren(node, attrs=('body', 'target', 'args'))
233
self.parallel_block_stack.pop()
235
node.else_clause = self.visit(node.else_clause)
237
self.visitchildren(node)
238
self.parallel_block_stack.pop()
240
self.parallel_errors = False
243
def visit_YieldExprNode(self, node):
244
if self.parallel_block_stack:
245
error(node.pos, "'%s' not allowed in parallel sections" % node.expr_keyword)
248
def visit_ReturnStatNode(self, node):
249
node.in_parallel = bool(self.parallel_block_stack)
253
class MarkOverflowingArithmetic(CythonTransform):
255
# It may be possible to integrate this with the above for
256
# performance improvements (though likely not worth it).
258
might_overflow = False
260
def __call__(self, root):
262
self.env = root.scope
263
return super().__call__(root)
265
def visit_safe_node(self, node):
266
self.might_overflow, saved = False, self.might_overflow
267
self.visitchildren(node)
268
self.might_overflow = saved
271
def visit_neutral_node(self, node):
272
self.visitchildren(node)
275
def visit_dangerous_node(self, node):
276
self.might_overflow, saved = True, self.might_overflow
277
self.visitchildren(node)
278
self.might_overflow = saved
281
def visit_FuncDefNode(self, node):
282
self.env_stack.append(self.env)
283
self.env = node.local_scope
284
self.visit_safe_node(node)
285
self.env = self.env_stack.pop()
288
def visit_NameNode(self, node):
289
if self.might_overflow:
290
entry = node.entry or self.env.lookup(node.name)
292
entry.might_overflow = True
295
def visit_BinopNode(self, node):
296
if node.operator in '&|^':
297
return self.visit_neutral_node(node)
299
return self.visit_dangerous_node(node)
301
def visit_SimpleCallNode(self, node):
302
if node.function.is_name and node.function.name == 'abs':
303
# Overflows for minimum value of fixed size ints.
304
return self.visit_dangerous_node(node)
306
return self.visit_neutral_node(node)
308
visit_UnopNode = visit_neutral_node
310
visit_UnaryMinusNode = visit_dangerous_node
312
visit_InPlaceAssignmentNode = visit_dangerous_node
314
visit_Node = visit_safe_node
316
def visit_assignment(self, lhs, rhs):
317
if (isinstance(rhs, ExprNodes.IntNode)
318
and isinstance(lhs, ExprNodes.NameNode)
319
and Utils.long_literal(rhs.value)):
320
entry = lhs.entry or self.env.lookup(lhs.name)
322
entry.might_overflow = True
324
def visit_SingleAssignmentNode(self, node):
325
self.visit_assignment(node.lhs, node.rhs)
326
self.visitchildren(node)
329
def visit_CascadedAssignmentNode(self, node):
330
for lhs in node.lhs_list:
331
self.visit_assignment(lhs, node.rhs)
332
self.visitchildren(node)
335
class PyObjectTypeInferer:
337
If it's not declared, it's a PyObject.
339
def infer_types(self, scope):
341
Given a dict of entries, map all unspecified types to a specified type.
343
for name, entry in scope.entries.items():
344
if entry.type is unspecified_type:
345
entry.type = py_object_type
347
class SimpleAssignmentTypeInferer:
349
Very basic type inference.
351
Note: in order to support cross-closure type inference, this must be
352
applies to nested scopes in top-down order.
354
def set_entry_type(self, entry, entry_type, scope):
355
for e in entry.all_entries():
357
if e.type.is_memoryviewslice:
358
# memoryview slices crash if they don't get initialized
359
e.init = e.type.default_value
360
if e.type.is_cpp_class:
361
if scope.directives['cpp_locals']:
362
e.make_cpp_optional()
364
e.type.check_nullary_constructor(entry.pos)
366
def infer_types(self, scope):
367
enabled = scope.directives['infer_types']
368
verbose = scope.directives['infer_types.verbose']
371
spanning_type = aggressive_spanning_type
372
elif enabled is None: # safe mode
373
spanning_type = safe_spanning_type
375
for entry in scope.entries.values():
376
if entry.type is unspecified_type:
377
self.set_entry_type(entry, py_object_type, scope)
382
assmts_resolved = set()
386
for name, entry in scope.entries.items():
387
for assmt in entry.cf_assignments:
388
names = assmt.type_dependencies()
389
assmt_to_names[assmt] = names
392
assmts.update(node.cf_state)
393
dependencies[assmt] = assmts
394
if entry.type is unspecified_type:
395
assignments.update(entry.cf_assignments)
397
assmts_resolved.update(entry.cf_assignments)
399
def infer_name_node_type(node):
400
types = [assmt.inferred_type for assmt in node.cf_state]
402
node_type = py_object_type
405
node_type = spanning_type(
406
types, entry.might_overflow, scope)
407
node.inferred_type = node_type
409
def infer_name_node_type_partial(node):
410
types = [assmt.inferred_type for assmt in node.cf_state
411
if assmt.inferred_type is not None]
415
return spanning_type(types, entry.might_overflow, scope)
417
def inferred_types(entry):
419
has_pyobjects = False
421
for assmt in entry.cf_assignments:
422
if assmt.rhs.is_none:
425
rhs_type = assmt.inferred_type
426
if rhs_type and rhs_type.is_pyobject:
428
types.append(rhs_type)
429
# Ignore None assignments as long as there are concrete Python type assignments.
430
# but include them if None is the only assigned Python object.
431
if has_none and not has_pyobjects:
432
types.append(py_object_type)
435
def resolve_assignments(assignments):
437
for assmt in assignments:
438
deps = dependencies[assmt]
439
# All assignments are resolved
440
if assmts_resolved.issuperset(deps):
441
for node in assmt_to_names[assmt]:
442
infer_name_node_type(node)
444
inferred_type = assmt.infer_type()
445
assmts_resolved.add(assmt)
447
assignments.difference_update(resolved)
450
def partial_infer(assmt):
452
for node in assmt_to_names[assmt]:
453
partial_type = infer_name_node_type_partial(node)
454
if partial_type is None:
456
partial_types.append((node, partial_type))
457
for node, partial_type in partial_types:
458
node.inferred_type = partial_type
462
partial_assmts = set()
463
def resolve_partial(assignments):
464
# try to handle circular references
466
for assmt in assignments:
467
if assmt in partial_assmts:
469
if partial_infer(assmt):
471
assmts_resolved.add(assmt)
472
partial_assmts.update(partials)
477
if not resolve_assignments(assignments):
478
if not resolve_partial(assignments):
482
for entry in scope.entries.values():
483
if entry.type is not unspecified_type:
485
entry_type = py_object_type
486
if assmts_resolved.issuperset(entry.cf_assignments):
487
types = inferred_types(entry)
488
if types and all(types):
489
entry_type = spanning_type(
490
types, entry.might_overflow, scope)
492
self.set_entry_type(entry, entry_type, scope)
496
for entry in inferred:
497
for assmt in entry.cf_assignments:
499
types = inferred_types(entry)
500
new_type = spanning_type(types, entry.might_overflow, scope)
501
if new_type != entry.type:
502
self.set_entry_type(entry, new_type, scope)
511
for entry in inferred:
512
message(entry.pos, "inferred '%s' to be of type '%s'" % (
513
entry.name, entry.type))
516
def find_spanning_type(type1, type2):
519
elif type1 is PyrexTypes.c_bint_type or type2 is PyrexTypes.c_bint_type:
520
# type inference can break the coercion back to a Python bool
521
# if it returns an arbitrary int type here
522
return py_object_type
524
result_type = PyrexTypes.spanning_type(type1, type2)
525
if result_type in (PyrexTypes.c_double_type, PyrexTypes.c_float_type,
527
# Python's float type is just a C double, so it's safe to
528
# use the C type instead
529
return PyrexTypes.c_double_type
532
def simply_type(result_type):
533
result_type = PyrexTypes.remove_cv_ref(result_type, remove_fakeref=True)
534
if result_type.is_array:
535
result_type = PyrexTypes.c_ptr_type(result_type.base_type)
538
def aggressive_spanning_type(types, might_overflow, scope):
539
return simply_type(reduce(find_spanning_type, types))
541
def safe_spanning_type(types, might_overflow, scope):
542
result_type = simply_type(reduce(find_spanning_type, types))
543
if result_type.is_pyobject:
545
elif (result_type is PyrexTypes.c_double_type or
546
result_type is PyrexTypes.c_float_type):
547
# Python's float type is just a C double, so it's safe to use
548
# the C type instead. Similarly if given a C float, it leads to
549
# a small loss of precision vs Python but is otherwise the same
551
elif result_type is PyrexTypes.c_bint_type:
552
# find_spanning_type() only returns 'bint' for clean boolean
553
# operations without other int types, so this is safe, too
555
elif result_type.is_pythran_expr:
557
elif result_type.is_ptr:
558
# Any pointer except (signed|unsigned|) char* can't implicitly
559
# become a PyObject, and inferring char* is now accepted, too.
561
elif result_type.is_cpp_class:
562
# These can't implicitly become Python objects either.
564
elif result_type.is_struct:
565
# Though we have struct -> object for some structs, this is uncommonly
566
# used, won't arise in pure Python, and there shouldn't be side
567
# effects, so I'm declaring this safe.
569
elif result_type.is_memoryviewslice:
571
elif result_type is PyrexTypes.soft_complex_type:
573
elif result_type == PyrexTypes.c_double_complex_type:
575
elif (result_type.is_int or result_type.is_enum) and not might_overflow:
577
elif (not result_type.can_coerce_to_pyobject(scope)
578
and not result_type.is_error):
580
return py_object_type
583
def get_type_inferer():
584
return SimpleAssignmentTypeInferer()