cython

Форк
0
/
TypeInference.py 
584 строки · 21.7 Кб
1
from .Errors import error, message
2
from . import ExprNodes
3
from . import Nodes
4
from . import Builtin
5
from . import PyrexTypes
6
from .. import Utils
7
from .PyrexTypes import py_object_type, unspecified_type
8
from .Visitor import CythonTransform, EnvTransform
9

10
from functools import reduce
11

12

13
class TypedExprNode(ExprNodes.ExprNode):
14
    # Used for declaring assignments of a specified type without a known entry.
15
    subexprs = []
16

17
    def __init__(self, type, pos=None):
18
        super().__init__(pos, type=type)
19

20
object_expr = TypedExprNode(py_object_type)
21

22

23
class MarkParallelAssignments(EnvTransform):
24
    # Collects assignments inside parallel blocks prange, with parallel.
25
    # Perhaps it's better to move it to ControlFlowAnalysis.
26

27
    # tells us whether we're in a normal loop
28
    in_loop = False
29

30
    parallel_errors = False
31

32
    def __init__(self, context):
33
        # Track the parallel block scopes (with parallel, for i in prange())
34
        self.parallel_block_stack = []
35
        super().__init__(context)
36

37
    def mark_assignment(self, lhs, rhs, inplace_op=None):
38
        if isinstance(lhs, (ExprNodes.NameNode, Nodes.PyArgDeclNode)):
39
            if lhs.entry is None:
40
                # TODO: This shouldn't happen...
41
                return
42

43
            if self.parallel_block_stack:
44
                parallel_node = self.parallel_block_stack[-1]
45
                previous_assignment = parallel_node.assignments.get(lhs.entry)
46

47
                # If there was a previous assignment to the variable, keep the
48
                # previous assignment position
49
                if previous_assignment:
50
                    pos, previous_inplace_op = previous_assignment
51

52
                    if (inplace_op and previous_inplace_op and
53
                            inplace_op != previous_inplace_op):
54
                        # x += y; x *= y
55
                        t = (inplace_op, previous_inplace_op)
56
                        error(lhs.pos,
57
                              "Reduction operator '%s' is inconsistent "
58
                              "with previous reduction operator '%s'" % t)
59
                else:
60
                    pos = lhs.pos
61

62
                parallel_node.assignments[lhs.entry] = (pos, inplace_op)
63
                parallel_node.assigned_nodes.append(lhs)
64

65
        elif isinstance(lhs, ExprNodes.SequenceNode):
66
            for i, arg in enumerate(lhs.args):
67
                if not rhs or arg.is_starred:
68
                    item_node = None
69
                else:
70
                    item_node = rhs.inferable_item_node(i)
71
                self.mark_assignment(arg, item_node)
72
        else:
73
            # Could use this info to infer cdef class attributes...
74
            pass
75

76
    def visit_WithTargetAssignmentStatNode(self, node):
77
        self.mark_assignment(node.lhs, node.with_node.enter_call)
78
        self.visitchildren(node)
79
        return node
80

81
    def visit_SingleAssignmentNode(self, node):
82
        self.mark_assignment(node.lhs, node.rhs)
83
        self.visitchildren(node)
84
        return node
85

86
    def visit_CascadedAssignmentNode(self, node):
87
        for lhs in node.lhs_list:
88
            self.mark_assignment(lhs, node.rhs)
89
        self.visitchildren(node)
90
        return node
91

92
    def visit_InPlaceAssignmentNode(self, node):
93
        self.mark_assignment(node.lhs, node.create_binop_node(), node.operator)
94
        self.visitchildren(node)
95
        return node
96

97
    def visit_ForInStatNode(self, node):
98
        # TODO: Remove redundancy with range optimization...
99
        is_special = False
100
        sequence = node.iterator.sequence
101
        target = node.target
102
        iterator_scope = node.iterator.expr_scope or self.current_env()
103
        if isinstance(sequence, ExprNodes.SimpleCallNode):
104
            function = sequence.function
105
            if sequence.self is None and function.is_name:
106
                entry = iterator_scope.lookup(function.name)
107
                if not entry or entry.is_builtin:
108
                    if function.name == 'reversed' and len(sequence.args) == 1:
109
                        sequence = sequence.args[0]
110
                    elif function.name == 'enumerate' and len(sequence.args) == 1:
111
                        if target.is_sequence_constructor and len(target.args) == 2:
112
                            iterator = sequence.args[0]
113
                            if iterator.is_name:
114
                                iterator_type = iterator.infer_type(iterator_scope)
115
                                if iterator_type.is_builtin_type:
116
                                    # assume that builtin types have a length within Py_ssize_t
117
                                    self.mark_assignment(
118
                                        target.args[0],
119
                                        ExprNodes.IntNode(target.pos, value='PY_SSIZE_T_MAX',
120
                                                          type=PyrexTypes.c_py_ssize_t_type))
121
                                    target = target.args[1]
122
                                    sequence = sequence.args[0]
123
        if isinstance(sequence, ExprNodes.SimpleCallNode):
124
            function = sequence.function
125
            if sequence.self is None and function.is_name:
126
                entry = iterator_scope.lookup(function.name)
127
                if not entry or entry.is_builtin:
128
                    if function.name in ('range', 'xrange'):
129
                        is_special = True
130
                        for arg in sequence.args[:2]:
131
                            self.mark_assignment(target, arg)
132
                        if len(sequence.args) > 2:
133
                            self.mark_assignment(
134
                                target,
135
                                ExprNodes.binop_node(node.pos,
136
                                                     '+',
137
                                                     sequence.args[0],
138
                                                     sequence.args[2]))
139
        if not is_special:
140
            # A for-loop basically translates to subsequent calls to
141
            # __getitem__(), so using an IndexNode here allows us to
142
            # naturally infer the base type of pointers, C arrays,
143
            # Python strings, etc., while correctly falling back to an
144
            # object type when the base type cannot be handled.
145
            self.mark_assignment(target, ExprNodes.IndexNode(
146
                node.pos,
147
                base=sequence,
148
                index=ExprNodes.IntNode(target.pos, value='PY_SSIZE_T_MAX',
149
                                        type=PyrexTypes.c_py_ssize_t_type)))
150

151
        self.visitchildren(node)
152
        return node
153

154
    def visit_ForFromStatNode(self, node):
155
        self.mark_assignment(node.target, node.bound1)
156
        if node.step is not None:
157
            self.mark_assignment(node.target,
158
                    ExprNodes.binop_node(node.pos,
159
                                         '+',
160
                                         node.bound1,
161
                                         node.step))
162
        self.visitchildren(node)
163
        return node
164

165
    def visit_WhileStatNode(self, node):
166
        self.visitchildren(node)
167
        return node
168

169
    def visit_ExceptClauseNode(self, node):
170
        if node.target is not None:
171
            self.mark_assignment(node.target, object_expr)
172
        self.visitchildren(node)
173
        return node
174

175
    def visit_FromCImportStatNode(self, node):
176
        return node  # Can't be assigned to...
177

178
    def visit_FromImportStatNode(self, node):
179
        for name, target in node.items:
180
            if name != "*":
181
                self.mark_assignment(target, object_expr)
182
        self.visitchildren(node)
183
        return node
184

185
    def visit_DefNode(self, node):
186
        # use fake expressions with the right result type
187
        if node.star_arg:
188
            self.mark_assignment(
189
                node.star_arg, TypedExprNode(Builtin.tuple_type, node.pos))
190
        if node.starstar_arg:
191
            self.mark_assignment(
192
                node.starstar_arg, TypedExprNode(Builtin.dict_type, node.pos))
193
        EnvTransform.visit_FuncDefNode(self, node)
194
        return node
195

196
    def visit_DelStatNode(self, node):
197
        for arg in node.args:
198
            self.mark_assignment(arg, arg)
199
        self.visitchildren(node)
200
        return node
201

202
    def visit_ParallelStatNode(self, node):
203
        if self.parallel_block_stack:
204
            node.parent = self.parallel_block_stack[-1]
205
        else:
206
            node.parent = None
207

208
        nested = False
209
        if node.is_prange:
210
            if not node.parent:
211
                node.is_parallel = True
212
            else:
213
                node.is_parallel = (node.parent.is_prange or not
214
                                    node.parent.is_parallel)
215
                nested = node.parent.is_prange
216
        else:
217
            node.is_parallel = True
218
            # Note: nested with parallel() blocks are handled by
219
            # ParallelRangeTransform!
220
            # nested = node.parent
221
            nested = node.parent and node.parent.is_prange
222

223
        self.parallel_block_stack.append(node)
224

225
        nested = nested or len(self.parallel_block_stack) > 2
226
        if not self.parallel_errors and nested and not node.is_prange:
227
            error(node.pos, "Only prange() may be nested")
228
            self.parallel_errors = True
229

230
        if node.is_prange:
231
            self.visitchildren(node, attrs=('body', 'target', 'args'))
232

233
            self.parallel_block_stack.pop()
234
            if node.else_clause:
235
                node.else_clause = self.visit(node.else_clause)
236
        else:
237
            self.visitchildren(node)
238
            self.parallel_block_stack.pop()
239

240
        self.parallel_errors = False
241
        return node
242

243
    def visit_YieldExprNode(self, node):
244
        if self.parallel_block_stack:
245
            error(node.pos, "'%s' not allowed in parallel sections" % node.expr_keyword)
246
        return node
247

248
    def visit_ReturnStatNode(self, node):
249
        node.in_parallel = bool(self.parallel_block_stack)
250
        return node
251

252

253
class MarkOverflowingArithmetic(CythonTransform):
254

255
    # It may be possible to integrate this with the above for
256
    # performance improvements (though likely not worth it).
257

258
    might_overflow = False
259

260
    def __call__(self, root):
261
        self.env_stack = []
262
        self.env = root.scope
263
        return super().__call__(root)
264

265
    def visit_safe_node(self, node):
266
        self.might_overflow, saved = False, self.might_overflow
267
        self.visitchildren(node)
268
        self.might_overflow = saved
269
        return node
270

271
    def visit_neutral_node(self, node):
272
        self.visitchildren(node)
273
        return node
274

275
    def visit_dangerous_node(self, node):
276
        self.might_overflow, saved = True, self.might_overflow
277
        self.visitchildren(node)
278
        self.might_overflow = saved
279
        return node
280

281
    def visit_FuncDefNode(self, node):
282
        self.env_stack.append(self.env)
283
        self.env = node.local_scope
284
        self.visit_safe_node(node)
285
        self.env = self.env_stack.pop()
286
        return node
287

288
    def visit_NameNode(self, node):
289
        if self.might_overflow:
290
            entry = node.entry or self.env.lookup(node.name)
291
            if entry:
292
                entry.might_overflow = True
293
        return node
294

295
    def visit_BinopNode(self, node):
296
        if node.operator in '&|^':
297
            return self.visit_neutral_node(node)
298
        else:
299
            return self.visit_dangerous_node(node)
300

301
    def visit_SimpleCallNode(self, node):
302
        if node.function.is_name and node.function.name == 'abs':
303
            # Overflows for minimum value of fixed size ints.
304
            return self.visit_dangerous_node(node)
305
        else:
306
            return self.visit_neutral_node(node)
307

308
    visit_UnopNode = visit_neutral_node
309

310
    visit_UnaryMinusNode = visit_dangerous_node
311

312
    visit_InPlaceAssignmentNode = visit_dangerous_node
313

314
    visit_Node = visit_safe_node
315

316
    def visit_assignment(self, lhs, rhs):
317
        if (isinstance(rhs, ExprNodes.IntNode)
318
                and isinstance(lhs, ExprNodes.NameNode)
319
                and Utils.long_literal(rhs.value)):
320
            entry = lhs.entry or self.env.lookup(lhs.name)
321
            if entry:
322
                entry.might_overflow = True
323

324
    def visit_SingleAssignmentNode(self, node):
325
        self.visit_assignment(node.lhs, node.rhs)
326
        self.visitchildren(node)
327
        return node
328

329
    def visit_CascadedAssignmentNode(self, node):
330
        for lhs in node.lhs_list:
331
            self.visit_assignment(lhs, node.rhs)
332
        self.visitchildren(node)
333
        return node
334

335
class PyObjectTypeInferer:
336
    """
337
    If it's not declared, it's a PyObject.
338
    """
339
    def infer_types(self, scope):
340
        """
341
        Given a dict of entries, map all unspecified types to a specified type.
342
        """
343
        for name, entry in scope.entries.items():
344
            if entry.type is unspecified_type:
345
                entry.type = py_object_type
346

347
class SimpleAssignmentTypeInferer:
348
    """
349
    Very basic type inference.
350

351
    Note: in order to support cross-closure type inference, this must be
352
    applies to nested scopes in top-down order.
353
    """
354
    def set_entry_type(self, entry, entry_type, scope):
355
        for e in entry.all_entries():
356
            e.type = entry_type
357
            if e.type.is_memoryviewslice:
358
                # memoryview slices crash if they don't get initialized
359
                e.init = e.type.default_value
360
            if e.type.is_cpp_class:
361
                if scope.directives['cpp_locals']:
362
                    e.make_cpp_optional()
363
                else:
364
                    e.type.check_nullary_constructor(entry.pos)
365

366
    def infer_types(self, scope):
367
        enabled = scope.directives['infer_types']
368
        verbose = scope.directives['infer_types.verbose']
369

370
        if enabled == True:
371
            spanning_type = aggressive_spanning_type
372
        elif enabled is None:  # safe mode
373
            spanning_type = safe_spanning_type
374
        else:
375
            for entry in scope.entries.values():
376
                if entry.type is unspecified_type:
377
                    self.set_entry_type(entry, py_object_type, scope)
378
            return
379

380
        # Set of assignments
381
        assignments = set()
382
        assmts_resolved = set()
383
        dependencies = {}
384
        assmt_to_names = {}
385

386
        for name, entry in scope.entries.items():
387
            for assmt in entry.cf_assignments:
388
                names = assmt.type_dependencies()
389
                assmt_to_names[assmt] = names
390
                assmts = set()
391
                for node in names:
392
                    assmts.update(node.cf_state)
393
                dependencies[assmt] = assmts
394
            if entry.type is unspecified_type:
395
                assignments.update(entry.cf_assignments)
396
            else:
397
                assmts_resolved.update(entry.cf_assignments)
398

399
        def infer_name_node_type(node):
400
            types = [assmt.inferred_type for assmt in node.cf_state]
401
            if not types:
402
                node_type = py_object_type
403
            else:
404
                entry = node.entry
405
                node_type = spanning_type(
406
                    types, entry.might_overflow, scope)
407
            node.inferred_type = node_type
408

409
        def infer_name_node_type_partial(node):
410
            types = [assmt.inferred_type for assmt in node.cf_state
411
                     if assmt.inferred_type is not None]
412
            if not types:
413
                return
414
            entry = node.entry
415
            return spanning_type(types, entry.might_overflow, scope)
416

417
        def inferred_types(entry):
418
            has_none = False
419
            has_pyobjects = False
420
            types = []
421
            for assmt in entry.cf_assignments:
422
                if assmt.rhs.is_none:
423
                    has_none = True
424
                else:
425
                    rhs_type = assmt.inferred_type
426
                    if rhs_type and rhs_type.is_pyobject:
427
                        has_pyobjects = True
428
                    types.append(rhs_type)
429
            # Ignore None assignments as long as there are concrete Python type assignments.
430
            # but include them if None is the only assigned Python object.
431
            if has_none and not has_pyobjects:
432
                types.append(py_object_type)
433
            return types
434

435
        def resolve_assignments(assignments):
436
            resolved = set()
437
            for assmt in assignments:
438
                deps = dependencies[assmt]
439
                # All assignments are resolved
440
                if assmts_resolved.issuperset(deps):
441
                    for node in assmt_to_names[assmt]:
442
                        infer_name_node_type(node)
443
                    # Resolve assmt
444
                    inferred_type = assmt.infer_type()
445
                    assmts_resolved.add(assmt)
446
                    resolved.add(assmt)
447
            assignments.difference_update(resolved)
448
            return resolved
449

450
        def partial_infer(assmt):
451
            partial_types = []
452
            for node in assmt_to_names[assmt]:
453
                partial_type = infer_name_node_type_partial(node)
454
                if partial_type is None:
455
                    return False
456
                partial_types.append((node, partial_type))
457
            for node, partial_type in partial_types:
458
                node.inferred_type = partial_type
459
            assmt.infer_type()
460
            return True
461

462
        partial_assmts = set()
463
        def resolve_partial(assignments):
464
            # try to handle circular references
465
            partials = set()
466
            for assmt in assignments:
467
                if assmt in partial_assmts:
468
                    continue
469
                if partial_infer(assmt):
470
                    partials.add(assmt)
471
                    assmts_resolved.add(assmt)
472
            partial_assmts.update(partials)
473
            return partials
474

475
        # Infer assignments
476
        while True:
477
            if not resolve_assignments(assignments):
478
                if not resolve_partial(assignments):
479
                    break
480
        inferred = set()
481
        # First pass
482
        for entry in scope.entries.values():
483
            if entry.type is not unspecified_type:
484
                continue
485
            entry_type = py_object_type
486
            if assmts_resolved.issuperset(entry.cf_assignments):
487
                types = inferred_types(entry)
488
                if types and all(types):
489
                    entry_type = spanning_type(
490
                        types, entry.might_overflow, scope)
491
                    inferred.add(entry)
492
            self.set_entry_type(entry, entry_type, scope)
493

494
        def reinfer():
495
            dirty = False
496
            for entry in inferred:
497
                for assmt in entry.cf_assignments:
498
                    assmt.infer_type()
499
                types = inferred_types(entry)
500
                new_type = spanning_type(types, entry.might_overflow, scope)
501
                if new_type != entry.type:
502
                    self.set_entry_type(entry, new_type, scope)
503
                    dirty = True
504
            return dirty
505

506
        # types propagation
507
        while reinfer():
508
            pass
509

510
        if verbose:
511
            for entry in inferred:
512
                message(entry.pos, "inferred '%s' to be of type '%s'" % (
513
                    entry.name, entry.type))
514

515

516
def find_spanning_type(type1, type2):
517
    if type1 is type2:
518
        result_type = type1
519
    elif type1 is PyrexTypes.c_bint_type or type2 is PyrexTypes.c_bint_type:
520
        # type inference can break the coercion back to a Python bool
521
        # if it returns an arbitrary int type here
522
        return py_object_type
523
    else:
524
        result_type = PyrexTypes.spanning_type(type1, type2)
525
    if result_type in (PyrexTypes.c_double_type, PyrexTypes.c_float_type,
526
                       Builtin.float_type):
527
        # Python's float type is just a C double, so it's safe to
528
        # use the C type instead
529
        return PyrexTypes.c_double_type
530
    return result_type
531

532
def simply_type(result_type):
533
    result_type = PyrexTypes.remove_cv_ref(result_type, remove_fakeref=True)
534
    if result_type.is_array:
535
        result_type = PyrexTypes.c_ptr_type(result_type.base_type)
536
    return result_type
537

538
def aggressive_spanning_type(types, might_overflow, scope):
539
    return simply_type(reduce(find_spanning_type, types))
540

541
def safe_spanning_type(types, might_overflow, scope):
542
    result_type = simply_type(reduce(find_spanning_type, types))
543
    if result_type.is_pyobject:
544
        return result_type
545
    elif (result_type is PyrexTypes.c_double_type or
546
            result_type is PyrexTypes.c_float_type):
547
        # Python's float type is just a C double, so it's safe to use
548
        # the C type instead. Similarly if given a C float, it leads to
549
        # a small loss of precision vs Python but is otherwise the same
550
        return result_type
551
    elif result_type is PyrexTypes.c_bint_type:
552
        # find_spanning_type() only returns 'bint' for clean boolean
553
        # operations without other int types, so this is safe, too
554
        return result_type
555
    elif result_type.is_pythran_expr:
556
        return result_type
557
    elif result_type.is_ptr:
558
        # Any pointer except (signed|unsigned|) char* can't implicitly
559
        # become a PyObject, and inferring char* is now accepted, too.
560
        return result_type
561
    elif result_type.is_cpp_class:
562
        # These can't implicitly become Python objects either.
563
        return result_type
564
    elif result_type.is_struct:
565
        # Though we have struct -> object for some structs, this is uncommonly
566
        # used, won't arise in pure Python, and there shouldn't be side
567
        # effects, so I'm declaring this safe.
568
        return result_type
569
    elif result_type.is_memoryviewslice:
570
        return result_type
571
    elif result_type is PyrexTypes.soft_complex_type:
572
        return result_type
573
    elif result_type == PyrexTypes.c_double_complex_type:
574
        return result_type
575
    elif (result_type.is_int or result_type.is_enum) and not might_overflow:
576
        return result_type
577
    elif (not result_type.can_coerce_to_pyobject(scope)
578
            and not result_type.is_error):
579
        return result_type
580
    return py_object_type
581

582

583
def get_type_inferer():
584
    return SimpleAssignmentTypeInferer()
585

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.