cython

Форк
0
/
Optimize.py 
5183 строки · 219.5 Кб
1
import string
2
import cython
3
cython.declare(UtilityCode=object, EncodedString=object, bytes_literal=object, encoded_string=object,
4
               Nodes=object, ExprNodes=object, PyrexTypes=object, Builtin=object,
5
               UtilNodes=object, _py_int_types=object,
6
               re=object, copy=object, codecs=object, itertools=object, attrgetter=object)
7

8

9
import re
10
import copy
11
import codecs
12
import itertools
13
from operator import attrgetter
14

15
from . import TypeSlots
16
from .ExprNodes import UnicodeNode, not_a_constant
17

18
_py_string_types = (bytes, str)
19

20

21
from . import Nodes
22
from . import ExprNodes
23
from . import PyrexTypes
24
from . import Visitor
25
from . import Builtin
26
from . import UtilNodes
27
from . import Options
28

29
from .Code import UtilityCode, TempitaUtilityCode
30
from .StringEncoding import EncodedString, bytes_literal, encoded_string
31
from .Errors import error, warning
32
from .ParseTreeTransforms import SkipDeclarations
33
from .. import Utils
34

35
from functools import reduce
36

37

38
def load_c_utility(name):
39
    return UtilityCode.load_cached(name, "Optimize.c")
40

41

42
def unwrap_coerced_node(node, coercion_nodes=(ExprNodes.CoerceToPyTypeNode, ExprNodes.CoerceFromPyTypeNode)):
43
    if isinstance(node, coercion_nodes):
44
        return node.arg
45
    return node
46

47

48
def unwrap_node(node):
49
    while isinstance(node, UtilNodes.ResultRefNode):
50
        node = node.expression
51
    return node
52

53

54
def is_common_value(a, b):
55
    a = unwrap_node(a)
56
    b = unwrap_node(b)
57
    if isinstance(a, ExprNodes.NameNode) and isinstance(b, ExprNodes.NameNode):
58
        return a.name == b.name
59
    if isinstance(a, ExprNodes.AttributeNode) and isinstance(b, ExprNodes.AttributeNode):
60
        return not a.is_py_attr and is_common_value(a.obj, b.obj) and a.attribute == b.attribute
61
    return False
62

63

64
def filter_none_node(node):
65
    if node is not None and node.constant_result is None:
66
        return None
67
    return node
68

69

70
class _YieldNodeCollector(Visitor.TreeVisitor):
71
    """
72
    YieldExprNode finder for generator expressions.
73
    """
74
    def __init__(self):
75
        Visitor.TreeVisitor.__init__(self)
76
        self.yield_stat_nodes = {}
77
        self.yield_nodes = []
78

79
    visit_Node = Visitor.TreeVisitor.visitchildren
80

81
    def visit_YieldExprNode(self, node):
82
        self.yield_nodes.append(node)
83
        self.visitchildren(node)
84

85
    def visit_ExprStatNode(self, node):
86
        self.visitchildren(node)
87
        if node.expr in self.yield_nodes:
88
            self.yield_stat_nodes[node.expr] = node
89

90
    # everything below these nodes is out of scope:
91

92
    def visit_GeneratorExpressionNode(self, node):
93
        pass
94

95
    def visit_LambdaNode(self, node):
96
        pass
97

98
    def visit_FuncDefNode(self, node):
99
        pass
100

101

102
def _find_single_yield_expression(node):
103
    yield_statements = _find_yield_statements(node)
104
    if len(yield_statements) != 1:
105
        return None, None
106
    return yield_statements[0]
107

108

109
def _find_yield_statements(node):
110
    collector = _YieldNodeCollector()
111
    collector.visitchildren(node)
112
    try:
113
        yield_statements = [
114
            (yield_node.arg, collector.yield_stat_nodes[yield_node])
115
            for yield_node in collector.yield_nodes
116
        ]
117
    except KeyError:
118
        # found YieldExprNode without ExprStatNode (i.e. a non-statement usage of 'yield')
119
        yield_statements = []
120
    return yield_statements
121

122

123
class IterationTransform(Visitor.EnvTransform):
124
    """Transform some common for-in loop patterns into efficient C loops:
125

126
    - for-in-dict loop becomes a while loop calling PyDict_Next()
127
    - for-in-enumerate is replaced by an external counter variable
128
    - for-in-range loop becomes a plain C for loop
129
    """
130
    def visit_PrimaryCmpNode(self, node):
131
        if node.is_ptr_contains():
132

133
            # for t in operand2:
134
            #     if operand1 == t:
135
            #         res = True
136
            #         break
137
            # else:
138
            #     res = False
139

140
            pos = node.pos
141
            result_ref = UtilNodes.ResultRefNode(node)
142
            if node.operand2.is_subscript:
143
                base_type = node.operand2.base.type.base_type
144
            else:
145
                base_type = node.operand2.type.base_type
146
            target_handle = UtilNodes.TempHandle(base_type)
147
            target = target_handle.ref(pos)
148
            cmp_node = ExprNodes.PrimaryCmpNode(
149
                pos, operator='==', operand1=node.operand1, operand2=target)
150
            if_body = Nodes.StatListNode(
151
                pos,
152
                stats = [Nodes.SingleAssignmentNode(pos, lhs=result_ref, rhs=ExprNodes.BoolNode(pos, value=1)),
153
                         Nodes.BreakStatNode(pos)])
154
            if_node = Nodes.IfStatNode(
155
                pos,
156
                if_clauses=[Nodes.IfClauseNode(pos, condition=cmp_node, body=if_body)],
157
                else_clause=None)
158
            for_loop = UtilNodes.TempsBlockNode(
159
                pos,
160
                temps = [target_handle],
161
                body = Nodes.ForInStatNode(
162
                    pos,
163
                    target=target,
164
                    iterator=ExprNodes.IteratorNode(node.operand2.pos, sequence=node.operand2),
165
                    body=if_node,
166
                    else_clause=Nodes.SingleAssignmentNode(pos, lhs=result_ref, rhs=ExprNodes.BoolNode(pos, value=0))))
167
            for_loop = for_loop.analyse_expressions(self.current_env())
168
            for_loop = self.visit(for_loop)
169
            new_node = UtilNodes.TempResultFromStatNode(result_ref, for_loop)
170

171
            if node.operator == 'not_in':
172
                new_node = ExprNodes.NotNode(pos, operand=new_node)
173
            return new_node
174

175
        else:
176
            self.visitchildren(node)
177
            return node
178

179
    def visit_ForInStatNode(self, node):
180
        self.visitchildren(node)
181
        return self._optimise_for_loop(node, node.iterator.sequence)
182

183
    def _optimise_for_loop(self, node, iterable, reversed=False):
184
        annotation_type = None
185
        if (iterable.is_name or iterable.is_attribute) and iterable.entry and iterable.entry.annotation:
186
            annotation = iterable.entry.annotation.expr
187
            if annotation.is_subscript:
188
                annotation = annotation.base  # container base type
189

190
        if Builtin.dict_type in (iterable.type, annotation_type):
191
            # like iterating over dict.keys()
192
            if reversed:
193
                # CPython raises an error here: not a sequence
194
                return node
195
            return self._transform_dict_iteration(
196
                node, dict_obj=iterable, method=None, keys=True, values=False)
197

198
        if (Builtin.set_type in (iterable.type, annotation_type) or
199
                Builtin.frozenset_type in (iterable.type, annotation_type)):
200
            if reversed:
201
                # CPython raises an error here: not a sequence
202
                return node
203
            return self._transform_set_iteration(node, iterable)
204

205
        # C array (slice) iteration?
206
        if iterable.type.is_ptr or iterable.type.is_array:
207
            return self._transform_carray_iteration(node, iterable, reversed=reversed)
208
        if iterable.type is Builtin.bytes_type:
209
            return self._transform_bytes_iteration(node, iterable, reversed=reversed)
210
        if iterable.type is Builtin.unicode_type:
211
            return self._transform_unicode_iteration(node, iterable, reversed=reversed)
212
        # in principle _transform_indexable_iteration would work on most of the above, and
213
        # also tuple and list. However, it probably isn't quite as optimized
214
        if iterable.type is Builtin.bytearray_type:
215
            return self._transform_indexable_iteration(node, iterable, is_mutable=True, reversed=reversed)
216
        if isinstance(iterable, ExprNodes.CoerceToPyTypeNode) and iterable.arg.type.is_memoryviewslice:
217
            return self._transform_indexable_iteration(node, iterable.arg, is_mutable=False, reversed=reversed)
218

219
        # the rest is based on function calls
220
        if not isinstance(iterable, ExprNodes.SimpleCallNode):
221
            return node
222

223
        if iterable.args is None:
224
            arg_count = iterable.arg_tuple and len(iterable.arg_tuple.args) or 0
225
        else:
226
            arg_count = len(iterable.args)
227
            if arg_count and iterable.self is not None:
228
                arg_count -= 1
229

230
        function = iterable.function
231
        # dict iteration?
232
        if function.is_attribute and not reversed and not arg_count:
233
            base_obj = iterable.self or function.obj
234
            method = function.attribute
235
            # in Py3, items() is equivalent to Py2's iteritems()
236
            is_safe_iter = self.global_scope().context.language_level >= 3
237

238
            if not is_safe_iter and method in ('keys', 'values', 'items'):
239
                # try to reduce this to the corresponding .iter*() methods
240
                if isinstance(base_obj, ExprNodes.CallNode):
241
                    inner_function = base_obj.function
242
                    if (inner_function.is_name and inner_function.name == 'dict'
243
                            and inner_function.entry
244
                            and inner_function.entry.is_builtin):
245
                        # e.g. dict(something).items() => safe to use .iter*()
246
                        is_safe_iter = True
247

248
            keys = values = False
249
            if method == 'iterkeys' or (is_safe_iter and method == 'keys'):
250
                keys = True
251
            elif method == 'itervalues' or (is_safe_iter and method == 'values'):
252
                values = True
253
            elif method == 'iteritems' or (is_safe_iter and method == 'items'):
254
                keys = values = True
255

256
            if keys or values:
257
                return self._transform_dict_iteration(
258
                    node, base_obj, method, keys, values)
259

260
        # enumerate/reversed ?
261
        if iterable.self is None and function.is_name and \
262
               function.entry and function.entry.is_builtin:
263
            if function.name == 'enumerate':
264
                if reversed:
265
                    # CPython raises an error here: not a sequence
266
                    return node
267
                return self._transform_enumerate_iteration(node, iterable)
268
            elif function.name == 'reversed':
269
                if reversed:
270
                    # CPython raises an error here: not a sequence
271
                    return node
272
                return self._transform_reversed_iteration(node, iterable)
273

274
        # range() iteration?
275
        if Options.convert_range and 1 <= arg_count <= 3 and (
276
                iterable.self is None and
277
                function.is_name and function.name in ('range', 'xrange') and
278
                function.entry and function.entry.is_builtin):
279
            if node.target.type.is_int or node.target.type.is_enum:
280
                return self._transform_range_iteration(node, iterable, reversed=reversed)
281
            if node.target.type.is_pyobject:
282
                # Assume that small integer ranges (C long >= 32bit) are best handled in C as well.
283
                for arg in (iterable.arg_tuple.args if iterable.args is None else iterable.args):
284
                    if isinstance(arg, ExprNodes.IntNode):
285
                        if arg.has_constant_result() and -2**30 <= arg.constant_result < 2**30:
286
                            continue
287
                    break
288
                else:
289
                    return self._transform_range_iteration(node, iterable, reversed=reversed)
290

291
        return node
292

293
    def _transform_reversed_iteration(self, node, reversed_function):
294
        args = reversed_function.arg_tuple.args
295
        if len(args) == 0:
296
            error(reversed_function.pos,
297
                  "reversed() requires an iterable argument")
298
            return node
299
        elif len(args) > 1:
300
            error(reversed_function.pos,
301
                  "reversed() takes exactly 1 argument")
302
            return node
303
        arg = args[0]
304

305
        # reversed(list/tuple) ?
306
        if arg.type in (Builtin.tuple_type, Builtin.list_type):
307
            node.iterator.sequence = arg.as_none_safe_node("'NoneType' object is not iterable")
308
            node.iterator.reversed = True
309
            return node
310

311
        return self._optimise_for_loop(node, arg, reversed=True)
312

313
    def _transform_indexable_iteration(self, node, slice_node, is_mutable, reversed=False):
314
        """In principle can handle any iterable that Cython has a len() for and knows how to index"""
315
        unpack_temp_node = UtilNodes.LetRefNode(
316
            slice_node.as_none_safe_node("'NoneType' is not iterable"),
317
            may_hold_none=False, is_temp=True
318
            )
319

320
        start_node = ExprNodes.IntNode(
321
            node.pos, value='0', constant_result=0, type=PyrexTypes.c_py_ssize_t_type)
322
        def make_length_call():
323
            # helper function since we need to create this node for a couple of places
324
            builtin_len = ExprNodes.NameNode(node.pos, name="len",
325
                                             entry=Builtin.builtin_scope.lookup("len"))
326
            return ExprNodes.SimpleCallNode(node.pos,
327
                                    function=builtin_len,
328
                                    args=[unpack_temp_node]
329
                                    )
330
        length_temp = UtilNodes.LetRefNode(make_length_call(), type=PyrexTypes.c_py_ssize_t_type, is_temp=True)
331
        end_node = length_temp
332

333
        if reversed:
334
            relation1, relation2 = '>', '>='
335
            start_node, end_node = end_node, start_node
336
        else:
337
            relation1, relation2 = '<=', '<'
338

339
        counter_ref = UtilNodes.LetRefNode(pos=node.pos, type=PyrexTypes.c_py_ssize_t_type)
340

341
        target_value = ExprNodes.IndexNode(slice_node.pos, base=unpack_temp_node,
342
                                           index=counter_ref)
343

344
        target_assign = Nodes.SingleAssignmentNode(
345
            pos = node.target.pos,
346
            lhs = node.target,
347
            rhs = target_value)
348

349
        # analyse with boundscheck and wraparound
350
        # off (because we're confident we know the size)
351
        env = self.current_env()
352
        new_directives = Options.copy_inherited_directives(env.directives, boundscheck=False, wraparound=False)
353
        target_assign = Nodes.CompilerDirectivesNode(
354
            target_assign.pos,
355
            directives=new_directives,
356
            body=target_assign,
357
        )
358

359
        body = Nodes.StatListNode(
360
            node.pos,
361
            stats = [target_assign])  # exclude node.body for now to not reanalyse it
362
        if is_mutable:
363
            # We need to be slightly careful here that we are actually modifying the loop
364
            # bounds and not a temp copy of it. Setting is_temp=True on length_temp seems
365
            # to ensure this.
366
            # If this starts to fail then we could insert an "if out_of_bounds: break" instead
367
            loop_length_reassign = Nodes.SingleAssignmentNode(node.pos,
368
                                                        lhs = length_temp,
369
                                                        rhs = make_length_call())
370
            body.stats.append(loop_length_reassign)
371

372
        loop_node = Nodes.ForFromStatNode(
373
            node.pos,
374
            bound1=start_node, relation1=relation1,
375
            target=counter_ref,
376
            relation2=relation2, bound2=end_node,
377
            step=None, body=body,
378
            else_clause=node.else_clause,
379
            from_range=True)
380

381
        ret = UtilNodes.LetNode(
382
                    unpack_temp_node,
383
                    UtilNodes.LetNode(
384
                        length_temp,
385
                        # TempResultFromStatNode provides the framework where the "counter_ref"
386
                        # temp is set up and can be assigned to. However, we don't need the
387
                        # result it returns so wrap it in an ExprStatNode.
388
                        Nodes.ExprStatNode(node.pos,
389
                            expr=UtilNodes.TempResultFromStatNode(
390
                                    counter_ref,
391
                                    loop_node
392
                            )
393
                        )
394
                    )
395
                ).analyse_expressions(env)
396
        body.stats.insert(1, node.body)
397
        return ret
398

399
    PyBytes_AS_STRING_func_type = PyrexTypes.CFuncType(
400
        PyrexTypes.c_char_ptr_type, [
401
            PyrexTypes.CFuncTypeArg("s", Builtin.bytes_type, None)
402
            ], exception_value="NULL")
403

404
    PyBytes_GET_SIZE_func_type = PyrexTypes.CFuncType(
405
        PyrexTypes.c_py_ssize_t_type, [
406
            PyrexTypes.CFuncTypeArg("s", Builtin.bytes_type, None)
407
        ],
408
        exception_value=-1)
409

410
    def _transform_bytes_iteration(self, node, slice_node, reversed=False):
411
        target_type = node.target.type
412
        if not target_type.is_int and target_type is not Builtin.bytes_type:
413
            # bytes iteration returns bytes objects in Py2, but
414
            # integers in Py3
415
            return node
416

417
        unpack_temp_node = UtilNodes.LetRefNode(
418
            slice_node.as_none_safe_node("'NoneType' is not iterable"))
419

420
        slice_base_node = ExprNodes.PythonCapiCallNode(
421
            slice_node.pos, "__Pyx_PyBytes_AsWritableString",
422
            self.PyBytes_AS_STRING_func_type,
423
            args = [unpack_temp_node],
424
            is_temp = 1,
425
            # TypeConversions utility code is always included
426
            )
427
        len_node = ExprNodes.PythonCapiCallNode(
428
            slice_node.pos, "__Pyx_PyBytes_GET_SIZE",
429
            self.PyBytes_GET_SIZE_func_type,
430
            args = [unpack_temp_node],
431
            is_temp = 1,
432
            )
433

434
        return UtilNodes.LetNode(
435
            unpack_temp_node,
436
            self._transform_carray_iteration(
437
                node,
438
                ExprNodes.SliceIndexNode(
439
                    slice_node.pos,
440
                    base = slice_base_node,
441
                    start = None,
442
                    step = None,
443
                    stop = len_node,
444
                    type = slice_base_node.type,
445
                    is_temp = 1,
446
                    ),
447
                reversed = reversed))
448

449
    PyUnicode_READ_func_type = PyrexTypes.CFuncType(
450
        PyrexTypes.c_py_ucs4_type, [
451
            PyrexTypes.CFuncTypeArg("kind", PyrexTypes.c_int_type, None),
452
            PyrexTypes.CFuncTypeArg("data", PyrexTypes.c_void_ptr_type, None),
453
            PyrexTypes.CFuncTypeArg("index", PyrexTypes.c_py_ssize_t_type, None)
454
        ])
455

456
    init_unicode_iteration_func_type = PyrexTypes.CFuncType(
457
        PyrexTypes.c_int_type, [
458
            PyrexTypes.CFuncTypeArg("s", PyrexTypes.py_object_type, None),
459
            PyrexTypes.CFuncTypeArg("length", PyrexTypes.c_py_ssize_t_ptr_type, None),
460
            PyrexTypes.CFuncTypeArg("data", PyrexTypes.c_void_ptr_ptr_type, None),
461
            PyrexTypes.CFuncTypeArg("kind", PyrexTypes.c_int_ptr_type, None)
462
        ],
463
        exception_value=-1)
464

465
    def _transform_unicode_iteration(self, node, slice_node, reversed=False):
466
        if slice_node.is_literal:
467
            # try to reduce to byte iteration for plain Latin-1 strings
468
            try:
469
                bytes_value = bytes_literal(slice_node.value.encode('latin1'), 'iso8859-1')
470
            except UnicodeEncodeError:
471
                pass
472
            else:
473
                bytes_slice = ExprNodes.SliceIndexNode(
474
                    slice_node.pos,
475
                    base=ExprNodes.BytesNode(
476
                        slice_node.pos, value=bytes_value,
477
                        constant_result=bytes_value,
478
                        type=PyrexTypes.c_const_char_ptr_type).coerce_to(
479
                            PyrexTypes.c_const_uchar_ptr_type, self.current_env()),
480
                    start=None,
481
                    stop=ExprNodes.IntNode(
482
                        slice_node.pos, value=str(len(bytes_value)),
483
                        constant_result=len(bytes_value),
484
                        type=PyrexTypes.c_py_ssize_t_type),
485
                    type=Builtin.unicode_type,  # hint for Python conversion
486
                )
487
                return self._transform_carray_iteration(node, bytes_slice, reversed)
488

489
        unpack_temp_node = UtilNodes.LetRefNode(
490
            slice_node.as_none_safe_node("'NoneType' is not iterable"))
491

492
        start_node = ExprNodes.IntNode(
493
            node.pos, value='0', constant_result=0, type=PyrexTypes.c_py_ssize_t_type)
494
        length_temp = UtilNodes.TempHandle(PyrexTypes.c_py_ssize_t_type)
495
        end_node = length_temp.ref(node.pos)
496
        if reversed:
497
            relation1, relation2 = '>', '>='
498
            start_node, end_node = end_node, start_node
499
        else:
500
            relation1, relation2 = '<=', '<'
501

502
        kind_temp = UtilNodes.TempHandle(PyrexTypes.c_int_type)
503
        data_temp = UtilNodes.TempHandle(PyrexTypes.c_void_ptr_type)
504
        counter_temp = UtilNodes.TempHandle(PyrexTypes.c_py_ssize_t_type)
505

506
        target_value = ExprNodes.PythonCapiCallNode(
507
            slice_node.pos, "__Pyx_PyUnicode_READ",
508
            self.PyUnicode_READ_func_type,
509
            args = [kind_temp.ref(slice_node.pos),
510
                    data_temp.ref(slice_node.pos),
511
                    counter_temp.ref(node.target.pos)],
512
            is_temp = False,
513
            )
514
        if target_value.type != node.target.type:
515
            target_value = target_value.coerce_to(node.target.type,
516
                                                  self.current_env())
517
        target_assign = Nodes.SingleAssignmentNode(
518
            pos = node.target.pos,
519
            lhs = node.target,
520
            rhs = target_value)
521
        body = Nodes.StatListNode(
522
            node.pos,
523
            stats = [target_assign, node.body])
524

525
        loop_node = Nodes.ForFromStatNode(
526
            node.pos,
527
            bound1=start_node, relation1=relation1,
528
            target=counter_temp.ref(node.target.pos),
529
            relation2=relation2, bound2=end_node,
530
            step=None, body=body,
531
            else_clause=node.else_clause,
532
            from_range=True)
533

534
        setup_node = Nodes.ExprStatNode(
535
            node.pos,
536
            expr = ExprNodes.PythonCapiCallNode(
537
                slice_node.pos, "__Pyx_init_unicode_iteration",
538
                self.init_unicode_iteration_func_type,
539
                args = [unpack_temp_node,
540
                        ExprNodes.AmpersandNode(slice_node.pos, operand=length_temp.ref(slice_node.pos),
541
                                                type=PyrexTypes.c_py_ssize_t_ptr_type),
542
                        ExprNodes.AmpersandNode(slice_node.pos, operand=data_temp.ref(slice_node.pos),
543
                                                type=PyrexTypes.c_void_ptr_ptr_type),
544
                        ExprNodes.AmpersandNode(slice_node.pos, operand=kind_temp.ref(slice_node.pos),
545
                                                type=PyrexTypes.c_int_ptr_type),
546
                        ],
547
                is_temp = True,
548
                result_is_used = False,
549
                utility_code=UtilityCode.load_cached("unicode_iter", "Optimize.c"),
550
                ))
551
        return UtilNodes.LetNode(
552
            unpack_temp_node,
553
            UtilNodes.TempsBlockNode(
554
                node.pos, temps=[counter_temp, length_temp, data_temp, kind_temp],
555
                body=Nodes.StatListNode(node.pos, stats=[setup_node, loop_node])))
556

557
    def _transform_carray_iteration(self, node, slice_node, reversed=False):
558
        neg_step = False
559
        if isinstance(slice_node, ExprNodes.SliceIndexNode):
560
            slice_base = slice_node.base
561
            start = filter_none_node(slice_node.start)
562
            stop = filter_none_node(slice_node.stop)
563
            step = None
564
            if not stop:
565
                if not slice_base.type.is_pyobject:
566
                    error(slice_node.pos, "C array iteration requires known end index")
567
                return node
568

569
        elif slice_node.is_subscript:
570
            assert isinstance(slice_node.index, ExprNodes.SliceNode)
571
            slice_base = slice_node.base
572
            index = slice_node.index
573
            start = filter_none_node(index.start)
574
            stop = filter_none_node(index.stop)
575
            step = filter_none_node(index.step)
576
            if step:
577
                if not isinstance(step.constant_result, int) \
578
                       or step.constant_result == 0 \
579
                       or step.constant_result > 0 and not stop \
580
                       or step.constant_result < 0 and not start:
581
                    if not slice_base.type.is_pyobject:
582
                        error(step.pos, "C array iteration requires known step size and end index")
583
                    return node
584
                else:
585
                    # step sign is handled internally by ForFromStatNode
586
                    step_value = step.constant_result
587
                    if reversed:
588
                        step_value = -step_value
589
                    neg_step = step_value < 0
590
                    step = ExprNodes.IntNode(step.pos, type=PyrexTypes.c_py_ssize_t_type,
591
                                             value=str(abs(step_value)),
592
                                             constant_result=abs(step_value))
593

594
        elif slice_node.type.is_array:
595
            if slice_node.type.size is None:
596
                error(slice_node.pos, "C array iteration requires known end index")
597
                return node
598
            slice_base = slice_node
599
            start = None
600
            stop = ExprNodes.IntNode(
601
                slice_node.pos, value=str(slice_node.type.size),
602
                type=PyrexTypes.c_py_ssize_t_type, constant_result=slice_node.type.size)
603
            step = None
604

605
        else:
606
            if not slice_node.type.is_pyobject:
607
                error(slice_node.pos, "C array iteration requires known end index")
608
            return node
609

610
        if start:
611
            start = start.coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_env())
612
        if stop:
613
            stop = stop.coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_env())
614
        if stop is None:
615
            if neg_step:
616
                stop = ExprNodes.IntNode(
617
                    slice_node.pos, value='-1', type=PyrexTypes.c_py_ssize_t_type, constant_result=-1)
618
            else:
619
                error(slice_node.pos, "C array iteration requires known step size and end index")
620
                return node
621

622
        if reversed:
623
            if not start:
624
                start = ExprNodes.IntNode(slice_node.pos, value="0",  constant_result=0,
625
                                          type=PyrexTypes.c_py_ssize_t_type)
626
            # if step was provided, it was already negated above
627
            start, stop = stop, start
628

629
        ptr_type = slice_base.type
630
        if ptr_type.is_array:
631
            ptr_type = ptr_type.element_ptr_type()
632
        carray_ptr = slice_base.coerce_to_simple(self.current_env())
633

634
        if start and start.constant_result != 0:
635
            start_ptr_node = ExprNodes.AddNode(
636
                start.pos,
637
                operand1=carray_ptr,
638
                operator='+',
639
                operand2=start,
640
                type=ptr_type)
641
        else:
642
            start_ptr_node = carray_ptr
643

644
        if stop and stop.constant_result != 0:
645
            stop_ptr_node = ExprNodes.AddNode(
646
                stop.pos,
647
                operand1=ExprNodes.CloneNode(carray_ptr),
648
                operator='+',
649
                operand2=stop,
650
                type=ptr_type
651
                ).coerce_to_simple(self.current_env())
652
        else:
653
            stop_ptr_node = ExprNodes.CloneNode(carray_ptr)
654

655
        counter = UtilNodes.TempHandle(ptr_type)
656
        counter_temp = counter.ref(node.target.pos)
657

658
        if slice_base.type.is_string and node.target.type.is_pyobject:
659
            # special case: char* -> bytes/unicode
660
            if slice_node.type is Builtin.unicode_type:
661
                target_value = ExprNodes.CastNode(
662
                    ExprNodes.DereferenceNode(
663
                        node.target.pos, operand=counter_temp,
664
                        type=ptr_type.base_type),
665
                    PyrexTypes.c_py_ucs4_type).coerce_to(
666
                        node.target.type, self.current_env())
667
            else:
668
                # char* -> bytes coercion requires slicing, not indexing
669
                target_value = ExprNodes.SliceIndexNode(
670
                    node.target.pos,
671
                    start=ExprNodes.IntNode(node.target.pos, value='0',
672
                                            constant_result=0,
673
                                            type=PyrexTypes.c_int_type),
674
                    stop=ExprNodes.IntNode(node.target.pos, value='1',
675
                                           constant_result=1,
676
                                           type=PyrexTypes.c_int_type),
677
                    base=counter_temp,
678
                    type=Builtin.bytes_type,
679
                    is_temp=1)
680
        elif node.target.type.is_ptr and not node.target.type.assignable_from(ptr_type.base_type):
681
            # Allow iteration with pointer target to avoid copy.
682
            target_value = counter_temp
683
        else:
684
            # TODO: can this safely be replaced with DereferenceNode() as above?
685
            target_value = ExprNodes.IndexNode(
686
                node.target.pos,
687
                index=ExprNodes.IntNode(node.target.pos, value='0',
688
                                        constant_result=0,
689
                                        type=PyrexTypes.c_int_type),
690
                base=counter_temp,
691
                type=ptr_type.base_type)
692

693
        if target_value.type != node.target.type:
694
            target_value = target_value.coerce_to(node.target.type,
695
                                                  self.current_env())
696

697
        target_assign = Nodes.SingleAssignmentNode(
698
            pos = node.target.pos,
699
            lhs = node.target,
700
            rhs = target_value)
701

702
        body = Nodes.StatListNode(
703
            node.pos,
704
            stats = [target_assign, node.body])
705

706
        relation1, relation2 = self._find_for_from_node_relations(neg_step, reversed)
707

708
        for_node = Nodes.ForFromStatNode(
709
            node.pos,
710
            bound1=start_ptr_node, relation1=relation1,
711
            target=counter_temp,
712
            relation2=relation2, bound2=stop_ptr_node,
713
            step=step, body=body,
714
            else_clause=node.else_clause,
715
            from_range=True)
716

717
        return UtilNodes.TempsBlockNode(
718
            node.pos, temps=[counter],
719
            body=for_node)
720

721
    def _transform_enumerate_iteration(self, node, enumerate_function):
722
        args = enumerate_function.arg_tuple.args
723
        if len(args) == 0:
724
            error(enumerate_function.pos,
725
                  "enumerate() requires an iterable argument")
726
            return node
727
        elif len(args) > 2:
728
            error(enumerate_function.pos,
729
                  "enumerate() takes at most 2 arguments")
730
            return node
731

732
        if not node.target.is_sequence_constructor:
733
            # leave this untouched for now
734
            return node
735
        targets = node.target.args
736
        if len(targets) != 2:
737
            # leave this untouched for now
738
            return node
739

740
        enumerate_target, iterable_target = targets
741
        counter_type = enumerate_target.type
742

743
        if not counter_type.is_pyobject and not counter_type.is_int:
744
            # nothing we can do here, I guess
745
            return node
746

747
        if len(args) == 2:
748
            start = unwrap_coerced_node(args[1]).coerce_to(counter_type, self.current_env())
749
        else:
750
            start = ExprNodes.IntNode(enumerate_function.pos,
751
                                      value='0',
752
                                      type=counter_type,
753
                                      constant_result=0)
754
        temp = UtilNodes.LetRefNode(start)
755

756
        inc_expression = ExprNodes.AddNode(
757
            enumerate_function.pos,
758
            operand1 = temp,
759
            operand2 = ExprNodes.IntNode(node.pos, value='1',
760
                                         type=counter_type,
761
                                         constant_result=1),
762
            operator = '+',
763
            type = counter_type,
764
            #inplace = True,   # not worth using in-place operation for Py ints
765
            is_temp = counter_type.is_pyobject
766
            )
767

768
        loop_body = [
769
            Nodes.SingleAssignmentNode(
770
                pos = enumerate_target.pos,
771
                lhs = enumerate_target,
772
                rhs = temp),
773
            Nodes.SingleAssignmentNode(
774
                pos = enumerate_target.pos,
775
                lhs = temp,
776
                rhs = inc_expression)
777
            ]
778

779
        if isinstance(node.body, Nodes.StatListNode):
780
            node.body.stats = loop_body + node.body.stats
781
        else:
782
            loop_body.append(node.body)
783
            node.body = Nodes.StatListNode(
784
                node.body.pos,
785
                stats = loop_body)
786

787
        node.target = iterable_target
788
        node.item = node.item.coerce_to(iterable_target.type, self.current_env())
789
        node.iterator.sequence = args[0]
790

791
        # recurse into loop to check for further optimisations
792
        return UtilNodes.LetNode(temp, self._optimise_for_loop(node, node.iterator.sequence))
793

794
    def _find_for_from_node_relations(self, neg_step_value, reversed):
795
        if reversed:
796
            if neg_step_value:
797
                return '<', '<='
798
            else:
799
                return '>', '>='
800
        else:
801
            if neg_step_value:
802
                return '>=', '>'
803
            else:
804
                return '<=', '<'
805

806
    def _transform_range_iteration(self, node, range_function, reversed=False):
807
        args = range_function.arg_tuple.args
808
        if len(args) < 3:
809
            step_pos = range_function.pos
810
            step_value = 1
811
            step = ExprNodes.IntNode(step_pos, value='1', constant_result=1)
812
        else:
813
            step = args[2]
814
            step_pos = step.pos
815
            if not isinstance(step.constant_result, int):
816
                # cannot determine step direction
817
                return node
818
            step_value = step.constant_result
819
            if step_value == 0:
820
                # will lead to an error elsewhere
821
                return node
822
            step = ExprNodes.IntNode(step_pos, value=str(step_value),
823
                                     constant_result=step_value)
824

825
        if len(args) == 1:
826
            bound1 = ExprNodes.IntNode(range_function.pos, value='0',
827
                                       constant_result=0)
828
            bound2 = args[0].coerce_to_index(self.current_env())
829
        else:
830
            bound1 = args[0].coerce_to_index(self.current_env())
831
            bound2 = args[1].coerce_to_index(self.current_env())
832

833
        relation1, relation2 = self._find_for_from_node_relations(step_value < 0, reversed)
834

835
        bound2_ref_node = None
836
        if reversed:
837
            bound1, bound2 = bound2, bound1
838
            abs_step = abs(step_value)
839
            if abs_step != 1:
840
                if (isinstance(bound1.constant_result, int) and
841
                        isinstance(bound2.constant_result, int)):
842
                    # calculate final bounds now
843
                    if step_value < 0:
844
                        begin_value = bound2.constant_result
845
                        end_value = bound1.constant_result
846
                        bound1_value = begin_value - abs_step * ((begin_value - end_value - 1) // abs_step) - 1
847
                    else:
848
                        begin_value = bound1.constant_result
849
                        end_value = bound2.constant_result
850
                        bound1_value = end_value + abs_step * ((begin_value - end_value - 1) // abs_step) + 1
851

852
                    bound1 = ExprNodes.IntNode(
853
                        bound1.pos, value=str(bound1_value), constant_result=bound1_value,
854
                        type=PyrexTypes.spanning_type(bound1.type, bound2.type))
855
                else:
856
                    # evaluate the same expression as above at runtime
857
                    bound2_ref_node = UtilNodes.LetRefNode(bound2)
858
                    bound1 = self._build_range_step_calculation(
859
                        bound1, bound2_ref_node, step, step_value)
860

861
        if step_value < 0:
862
            step_value = -step_value
863
        step.value = str(step_value)
864
        step.constant_result = step_value
865
        step = step.coerce_to_index(self.current_env())
866

867
        if not bound2.is_literal:
868
            # stop bound must be immutable => keep it in a temp var
869
            bound2_is_temp = True
870
            bound2 = bound2_ref_node or UtilNodes.LetRefNode(bound2)
871
        else:
872
            bound2_is_temp = False
873

874
        for_node = Nodes.ForFromStatNode(
875
            node.pos,
876
            target=node.target,
877
            bound1=bound1, relation1=relation1,
878
            relation2=relation2, bound2=bound2,
879
            step=step, body=node.body,
880
            else_clause=node.else_clause,
881
            from_range=True)
882
        for_node.set_up_loop(self.current_env())
883

884
        if bound2_is_temp:
885
            for_node = UtilNodes.LetNode(bound2, for_node)
886

887
        return for_node
888

889
    def _build_range_step_calculation(self, bound1, bound2_ref_node, step, step_value):
890
        abs_step = abs(step_value)
891
        spanning_type = PyrexTypes.spanning_type(bound1.type, bound2_ref_node.type)
892
        if step.type.is_int and abs_step < 0x7FFF:
893
            # Avoid loss of integer precision warnings.
894
            spanning_step_type = PyrexTypes.spanning_type(spanning_type, PyrexTypes.c_int_type)
895
        else:
896
            spanning_step_type = PyrexTypes.spanning_type(spanning_type, step.type)
897
        if step_value < 0:
898
            begin_value = bound2_ref_node
899
            end_value = bound1
900
            final_op = '-'
901
        else:
902
            begin_value = bound1
903
            end_value = bound2_ref_node
904
            final_op = '+'
905

906
        step_calculation_node = ExprNodes.binop_node(
907
            bound1.pos,
908
            operand1=ExprNodes.binop_node(
909
                bound1.pos,
910
                operand1=bound2_ref_node,
911
                operator=final_op,  # +/-
912
                operand2=ExprNodes.MulNode(
913
                    bound1.pos,
914
                    operand1=ExprNodes.IntNode(
915
                        bound1.pos,
916
                        value=str(abs_step),
917
                        constant_result=abs_step,
918
                        type=spanning_step_type),
919
                    operator='*',
920
                    operand2=ExprNodes.DivNode(
921
                        bound1.pos,
922
                        operand1=ExprNodes.SubNode(
923
                            bound1.pos,
924
                            operand1=ExprNodes.SubNode(
925
                                bound1.pos,
926
                                operand1=begin_value,
927
                                operator='-',
928
                                operand2=end_value,
929
                                type=spanning_type),
930
                            operator='-',
931
                            operand2=ExprNodes.IntNode(
932
                                bound1.pos,
933
                                value='1',
934
                                constant_result=1),
935
                            type=spanning_step_type),
936
                        operator='//',
937
                        operand2=ExprNodes.IntNode(
938
                            bound1.pos,
939
                            value=str(abs_step),
940
                            constant_result=abs_step,
941
                            type=spanning_step_type),
942
                        type=spanning_step_type),
943
                    type=spanning_step_type),
944
                type=spanning_step_type),
945
            operator=final_op,  # +/-
946
            operand2=ExprNodes.IntNode(
947
                bound1.pos,
948
                value='1',
949
                constant_result=1),
950
            type=spanning_type)
951
        return step_calculation_node
952

953
    def _transform_dict_iteration(self, node, dict_obj, method, keys, values):
954
        temps = []
955
        temp = UtilNodes.TempHandle(PyrexTypes.py_object_type)
956
        temps.append(temp)
957
        dict_temp = temp.ref(dict_obj.pos)
958
        temp = UtilNodes.TempHandle(PyrexTypes.c_py_ssize_t_type)
959
        temps.append(temp)
960
        pos_temp = temp.ref(node.pos)
961

962
        key_target = value_target = tuple_target = None
963
        if keys and values:
964
            if node.target.is_sequence_constructor:
965
                if len(node.target.args) == 2:
966
                    key_target, value_target = node.target.args
967
                else:
968
                    # unusual case that may or may not lead to an error
969
                    return node
970
            else:
971
                tuple_target = node.target
972
        elif keys:
973
            key_target = node.target
974
        else:
975
            value_target = node.target
976

977
        if isinstance(node.body, Nodes.StatListNode):
978
            body = node.body
979
        else:
980
            body = Nodes.StatListNode(pos = node.body.pos,
981
                                      stats = [node.body])
982

983
        # keep original length to guard against dict modification
984
        dict_len_temp = UtilNodes.TempHandle(PyrexTypes.c_py_ssize_t_type)
985
        temps.append(dict_len_temp)
986
        dict_len_temp_addr = ExprNodes.AmpersandNode(
987
            node.pos, operand=dict_len_temp.ref(dict_obj.pos),
988
            type=PyrexTypes.c_ptr_type(dict_len_temp.type))
989
        temp = UtilNodes.TempHandle(PyrexTypes.c_int_type)
990
        temps.append(temp)
991
        is_dict_temp = temp.ref(node.pos)
992
        is_dict_temp_addr = ExprNodes.AmpersandNode(
993
            node.pos, operand=is_dict_temp,
994
            type=PyrexTypes.c_ptr_type(temp.type))
995

996
        iter_next_node = Nodes.DictIterationNextNode(
997
            dict_temp, dict_len_temp.ref(dict_obj.pos), pos_temp,
998
            key_target, value_target, tuple_target,
999
            is_dict_temp)
1000
        iter_next_node = iter_next_node.analyse_expressions(self.current_env())
1001
        body.stats[0:0] = [iter_next_node]
1002

1003
        if method:
1004
            method_node = ExprNodes.IdentifierStringNode(dict_obj.pos, value=method)
1005
            dict_obj = dict_obj.as_none_safe_node(
1006
                "'NoneType' object has no attribute '%{}s'".format('.30' if len(method) <= 30 else ''),
1007
                error = "PyExc_AttributeError",
1008
                format_args = [method])
1009
        else:
1010
            method_node = ExprNodes.NullNode(dict_obj.pos)
1011
            dict_obj = dict_obj.as_none_safe_node("'NoneType' object is not iterable")
1012

1013
        def flag_node(value):
1014
            value = value and 1 or 0
1015
            return ExprNodes.IntNode(node.pos, value=str(value), constant_result=value)
1016

1017
        result_code = [
1018
            Nodes.SingleAssignmentNode(
1019
                node.pos,
1020
                lhs = pos_temp,
1021
                rhs = ExprNodes.IntNode(node.pos, value='0',
1022
                                        constant_result=0)),
1023
            Nodes.SingleAssignmentNode(
1024
                dict_obj.pos,
1025
                lhs = dict_temp,
1026
                rhs = ExprNodes.PythonCapiCallNode(
1027
                    dict_obj.pos,
1028
                    "__Pyx_dict_iterator",
1029
                    self.PyDict_Iterator_func_type,
1030
                    utility_code = UtilityCode.load_cached("dict_iter", "Optimize.c"),
1031
                    args = [dict_obj, flag_node(dict_obj.type is Builtin.dict_type),
1032
                            method_node, dict_len_temp_addr, is_dict_temp_addr,
1033
                            ],
1034
                    is_temp=True,
1035
                )),
1036
            Nodes.WhileStatNode(
1037
                node.pos,
1038
                condition = None,
1039
                body = body,
1040
                else_clause = node.else_clause
1041
                )
1042
            ]
1043

1044
        return UtilNodes.TempsBlockNode(
1045
            node.pos, temps=temps,
1046
            body=Nodes.StatListNode(
1047
                node.pos,
1048
                stats = result_code
1049
                ))
1050

1051
    PyDict_Iterator_func_type = PyrexTypes.CFuncType(
1052
        PyrexTypes.py_object_type, [
1053
            PyrexTypes.CFuncTypeArg("dict",  PyrexTypes.py_object_type, None),
1054
            PyrexTypes.CFuncTypeArg("is_dict",  PyrexTypes.c_int_type, None),
1055
            PyrexTypes.CFuncTypeArg("method_name",  PyrexTypes.py_object_type, None),
1056
            PyrexTypes.CFuncTypeArg("p_orig_length",  PyrexTypes.c_py_ssize_t_ptr_type, None),
1057
            PyrexTypes.CFuncTypeArg("p_is_dict",  PyrexTypes.c_int_ptr_type, None),
1058
            ])
1059

1060
    PySet_Iterator_func_type = PyrexTypes.CFuncType(
1061
        PyrexTypes.py_object_type, [
1062
            PyrexTypes.CFuncTypeArg("set",  PyrexTypes.py_object_type, None),
1063
            PyrexTypes.CFuncTypeArg("is_set",  PyrexTypes.c_int_type, None),
1064
            PyrexTypes.CFuncTypeArg("p_orig_length",  PyrexTypes.c_py_ssize_t_ptr_type, None),
1065
            PyrexTypes.CFuncTypeArg("p_is_set",  PyrexTypes.c_int_ptr_type, None),
1066
            ])
1067

1068
    def _transform_set_iteration(self, node, set_obj):
1069
        temps = []
1070
        temp = UtilNodes.TempHandle(PyrexTypes.py_object_type)
1071
        temps.append(temp)
1072
        set_temp = temp.ref(set_obj.pos)
1073
        temp = UtilNodes.TempHandle(PyrexTypes.c_py_ssize_t_type)
1074
        temps.append(temp)
1075
        pos_temp = temp.ref(node.pos)
1076

1077
        if isinstance(node.body, Nodes.StatListNode):
1078
            body = node.body
1079
        else:
1080
            body = Nodes.StatListNode(pos = node.body.pos,
1081
                                      stats = [node.body])
1082

1083
        # keep original length to guard against set modification
1084
        set_len_temp = UtilNodes.TempHandle(PyrexTypes.c_py_ssize_t_type)
1085
        temps.append(set_len_temp)
1086
        set_len_temp_addr = ExprNodes.AmpersandNode(
1087
            node.pos, operand=set_len_temp.ref(set_obj.pos),
1088
            type=PyrexTypes.c_ptr_type(set_len_temp.type))
1089
        temp = UtilNodes.TempHandle(PyrexTypes.c_int_type)
1090
        temps.append(temp)
1091
        is_set_temp = temp.ref(node.pos)
1092
        is_set_temp_addr = ExprNodes.AmpersandNode(
1093
            node.pos, operand=is_set_temp,
1094
            type=PyrexTypes.c_ptr_type(temp.type))
1095

1096
        value_target = node.target
1097
        iter_next_node = Nodes.SetIterationNextNode(
1098
            set_temp, set_len_temp.ref(set_obj.pos), pos_temp, value_target, is_set_temp)
1099
        iter_next_node = iter_next_node.analyse_expressions(self.current_env())
1100
        body.stats[0:0] = [iter_next_node]
1101

1102
        def flag_node(value):
1103
            value = value and 1 or 0
1104
            return ExprNodes.IntNode(node.pos, value=str(value), constant_result=value)
1105

1106
        result_code = [
1107
            Nodes.SingleAssignmentNode(
1108
                node.pos,
1109
                lhs=pos_temp,
1110
                rhs=ExprNodes.IntNode(node.pos, value='0', constant_result=0)),
1111
            Nodes.SingleAssignmentNode(
1112
                set_obj.pos,
1113
                lhs=set_temp,
1114
                rhs=ExprNodes.PythonCapiCallNode(
1115
                    set_obj.pos,
1116
                    "__Pyx_set_iterator",
1117
                    self.PySet_Iterator_func_type,
1118
                    utility_code=UtilityCode.load_cached("set_iter", "Optimize.c"),
1119
                    args=[set_obj, flag_node(set_obj.type is Builtin.set_type),
1120
                          set_len_temp_addr, is_set_temp_addr,
1121
                          ],
1122
                    is_temp=True,
1123
                )),
1124
            Nodes.WhileStatNode(
1125
                node.pos,
1126
                condition=None,
1127
                body=body,
1128
                else_clause=node.else_clause,
1129
                )
1130
            ]
1131

1132
        return UtilNodes.TempsBlockNode(
1133
            node.pos, temps=temps,
1134
            body=Nodes.StatListNode(
1135
                node.pos,
1136
                stats = result_code
1137
                ))
1138

1139

1140
class SwitchTransform(Visitor.EnvTransform):
1141
    """
1142
    This transformation tries to turn long if statements into C switch statements.
1143
    The requirement is that every clause be an (or of) var == value, where the var
1144
    is common among all clauses and both var and value are ints.
1145
    """
1146
    NO_MATCH = (None, None, None)
1147

1148
    def extract_conditions(self, cond, allow_not_in):
1149
        while True:
1150
            if isinstance(cond, (ExprNodes.CoerceToTempNode,
1151
                                 ExprNodes.CoerceToBooleanNode)):
1152
                cond = cond.arg
1153
            elif isinstance(cond, ExprNodes.BoolBinopResultNode):
1154
                cond = cond.arg.arg
1155
            elif isinstance(cond, UtilNodes.EvalWithTempExprNode):
1156
                # this is what we get from the FlattenInListTransform
1157
                cond = cond.subexpression
1158
            elif isinstance(cond, ExprNodes.TypecastNode):
1159
                cond = cond.operand
1160
            else:
1161
                break
1162

1163
        if isinstance(cond, ExprNodes.PrimaryCmpNode):
1164
            if cond.cascade is not None:
1165
                return self.NO_MATCH
1166
            elif cond.is_c_string_contains() and \
1167
                   isinstance(cond.operand2, (ExprNodes.UnicodeNode, ExprNodes.BytesNode)):
1168
                not_in = cond.operator == 'not_in'
1169
                if not_in and not allow_not_in:
1170
                    return self.NO_MATCH
1171
                if isinstance(cond.operand2, ExprNodes.UnicodeNode) and \
1172
                       cond.operand2.contains_surrogates():
1173
                    # dealing with surrogates leads to different
1174
                    # behaviour on wide and narrow Unicode
1175
                    # platforms => refuse to optimise this case
1176
                    return self.NO_MATCH
1177
                return not_in, cond.operand1, self.extract_in_string_conditions(cond.operand2)
1178
            elif not cond.is_python_comparison():
1179
                if cond.operator == '==':
1180
                    not_in = False
1181
                elif allow_not_in and cond.operator == '!=':
1182
                    not_in = True
1183
                else:
1184
                    return self.NO_MATCH
1185
                # this looks somewhat silly, but it does the right
1186
                # checks for NameNode and AttributeNode
1187
                if is_common_value(cond.operand1, cond.operand1):
1188
                    if cond.operand2.is_literal:
1189
                        return not_in, cond.operand1, [cond.operand2]
1190
                    elif getattr(cond.operand2, 'entry', None) \
1191
                             and cond.operand2.entry.is_const:
1192
                        return not_in, cond.operand1, [cond.operand2]
1193
                if is_common_value(cond.operand2, cond.operand2):
1194
                    if cond.operand1.is_literal:
1195
                        return not_in, cond.operand2, [cond.operand1]
1196
                    elif getattr(cond.operand1, 'entry', None) \
1197
                             and cond.operand1.entry.is_const:
1198
                        return not_in, cond.operand2, [cond.operand1]
1199
        elif isinstance(cond, ExprNodes.BoolBinopNode):
1200
            if cond.operator == 'or' or (allow_not_in and cond.operator == 'and'):
1201
                allow_not_in = (cond.operator == 'and')
1202
                not_in_1, t1, c1 = self.extract_conditions(cond.operand1, allow_not_in)
1203
                not_in_2, t2, c2 = self.extract_conditions(cond.operand2, allow_not_in)
1204
                if t1 is not None and not_in_1 == not_in_2 and is_common_value(t1, t2):
1205
                    if (not not_in_1) or allow_not_in:
1206
                        return not_in_1, t1, c1+c2
1207
        return self.NO_MATCH
1208

1209
    def extract_in_string_conditions(self, string_literal):
1210
        if isinstance(string_literal, ExprNodes.UnicodeNode):
1211
            charvals = list(map(ord, set(string_literal.value)))
1212
            charvals.sort()
1213
            return [ ExprNodes.IntNode(string_literal.pos, value=str(charval),
1214
                                       constant_result=charval)
1215
                     for charval in charvals ]
1216
        else:
1217
            # this is a bit tricky as Py3's bytes type returns
1218
            # integers on iteration, whereas Py2 returns 1-char byte
1219
            # strings
1220
            characters = string_literal.value
1221
            characters = list({ characters[i:i+1] for i in range(len(characters)) })
1222
            characters.sort()
1223
            return [ ExprNodes.CharNode(string_literal.pos, value=charval,
1224
                                        constant_result=charval)
1225
                     for charval in characters ]
1226

1227
    def extract_common_conditions(self, common_var, condition, allow_not_in):
1228
        not_in, var, conditions = self.extract_conditions(condition, allow_not_in)
1229
        if var is None:
1230
            return self.NO_MATCH
1231
        elif common_var is not None and not is_common_value(var, common_var):
1232
            return self.NO_MATCH
1233
        elif not (var.type.is_int or var.type.is_enum) or any(
1234
                [not (cond.type.is_int or cond.type.is_enum) for cond in conditions]):
1235
            return self.NO_MATCH
1236
        return not_in, var, conditions
1237

1238
    def has_duplicate_values(self, condition_values):
1239
        # duplicated values don't work in a switch statement
1240
        seen = set()
1241
        for value in condition_values:
1242
            if value.has_constant_result():
1243
                if value.constant_result in seen:
1244
                    return True
1245
                seen.add(value.constant_result)
1246
            else:
1247
                # this isn't completely safe as we don't know the
1248
                # final C value, but this is about the best we can do
1249
                try:
1250
                    value_entry = value.entry
1251
                    if ((value_entry.type.is_enum or value_entry.type.is_cpp_enum)
1252
                            and value_entry.enum_int_value is not None):
1253
                        value_for_seen = value_entry.enum_int_value
1254
                    else:
1255
                        value_for_seen = value_entry.cname
1256
                except AttributeError:
1257
                    return True  # play safe
1258
                if value_for_seen in seen:
1259
                    return True
1260
                seen.add(value_for_seen)
1261
        return False
1262

1263
    def visit_IfStatNode(self, node):
1264
        if not self.current_directives.get('optimize.use_switch'):
1265
            self.visitchildren(node)
1266
            return node
1267

1268
        common_var = None
1269
        cases = []
1270
        for if_clause in node.if_clauses:
1271
            _, common_var, conditions = self.extract_common_conditions(
1272
                common_var, if_clause.condition, False)
1273
            if common_var is None:
1274
                self.visitchildren(node)
1275
                return node
1276
            cases.append(Nodes.SwitchCaseNode(pos=if_clause.pos,
1277
                                              conditions=conditions,
1278
                                              body=if_clause.body))
1279

1280
        condition_values = [
1281
            cond for case in cases for cond in case.conditions]
1282
        if len(condition_values) < 2:
1283
            self.visitchildren(node)
1284
            return node
1285
        if self.has_duplicate_values(condition_values):
1286
            self.visitchildren(node)
1287
            return node
1288

1289
        # Recurse into body subtrees that we left untouched so far.
1290
        self.visitchildren(node, 'else_clause')
1291
        for case in cases:
1292
            self.visitchildren(case, 'body')
1293

1294
        common_var = unwrap_node(common_var)
1295
        switch_node = Nodes.SwitchStatNode(pos=node.pos,
1296
                                           test=common_var,
1297
                                           cases=cases,
1298
                                           else_clause=node.else_clause)
1299
        return switch_node
1300

1301
    def visit_CondExprNode(self, node):
1302
        if not self.current_directives.get('optimize.use_switch'):
1303
            self.visitchildren(node)
1304
            return node
1305

1306
        not_in, common_var, conditions = self.extract_common_conditions(
1307
            None, node.test, True)
1308
        if common_var is None \
1309
                or len(conditions) < 2 \
1310
                or self.has_duplicate_values(conditions):
1311
            self.visitchildren(node)
1312
            return node
1313

1314
        return self.build_simple_switch_statement(
1315
            node, common_var, conditions, not_in,
1316
            node.true_val, node.false_val)
1317

1318
    def visit_BoolBinopNode(self, node):
1319
        if not self.current_directives.get('optimize.use_switch'):
1320
            self.visitchildren(node)
1321
            return node
1322

1323
        not_in, common_var, conditions = self.extract_common_conditions(
1324
            None, node, True)
1325
        if common_var is None \
1326
                or len(conditions) < 2 \
1327
                or self.has_duplicate_values(conditions):
1328
            self.visitchildren(node)
1329
            node.wrap_operands(self.current_env())  # in case we changed the operands
1330
            return node
1331

1332
        return self.build_simple_switch_statement(
1333
            node, common_var, conditions, not_in,
1334
            ExprNodes.BoolNode(node.pos, value=True, constant_result=True),
1335
            ExprNodes.BoolNode(node.pos, value=False, constant_result=False))
1336

1337
    def visit_PrimaryCmpNode(self, node):
1338
        if not self.current_directives.get('optimize.use_switch'):
1339
            self.visitchildren(node)
1340
            return node
1341

1342
        not_in, common_var, conditions = self.extract_common_conditions(
1343
            None, node, True)
1344
        if common_var is None \
1345
                or len(conditions) < 2 \
1346
                or self.has_duplicate_values(conditions):
1347
            self.visitchildren(node)
1348
            return node
1349

1350
        return self.build_simple_switch_statement(
1351
            node, common_var, conditions, not_in,
1352
            ExprNodes.BoolNode(node.pos, value=True, constant_result=True),
1353
            ExprNodes.BoolNode(node.pos, value=False, constant_result=False))
1354

1355
    def build_simple_switch_statement(self, node, common_var, conditions,
1356
                                      not_in, true_val, false_val):
1357
        result_ref = UtilNodes.ResultRefNode(node)
1358
        true_body = Nodes.SingleAssignmentNode(
1359
            node.pos,
1360
            lhs=result_ref,
1361
            rhs=true_val.coerce_to(node.type, self.current_env()),
1362
            first=True)
1363
        false_body = Nodes.SingleAssignmentNode(
1364
            node.pos,
1365
            lhs=result_ref,
1366
            rhs=false_val.coerce_to(node.type, self.current_env()),
1367
            first=True)
1368

1369
        if not_in:
1370
            true_body, false_body = false_body, true_body
1371

1372
        cases = [Nodes.SwitchCaseNode(pos = node.pos,
1373
                                      conditions = conditions,
1374
                                      body = true_body)]
1375

1376
        common_var = unwrap_node(common_var)
1377
        switch_node = Nodes.SwitchStatNode(pos = node.pos,
1378
                                           test = common_var,
1379
                                           cases = cases,
1380
                                           else_clause = false_body)
1381
        replacement = UtilNodes.TempResultFromStatNode(result_ref, switch_node)
1382
        return replacement
1383

1384
    def visit_EvalWithTempExprNode(self, node):
1385
        if not self.current_directives.get('optimize.use_switch'):
1386
            self.visitchildren(node)
1387
            return node
1388

1389
        # drop unused expression temp from FlattenInListTransform
1390
        orig_expr = node.subexpression
1391
        temp_ref = node.lazy_temp
1392
        self.visitchildren(node)
1393
        if node.subexpression is not orig_expr:
1394
            # node was restructured => check if temp is still used
1395
            if not Visitor.tree_contains(node.subexpression, temp_ref):
1396
                return node.subexpression
1397
        return node
1398

1399
    visit_Node = Visitor.VisitorTransform.recurse_to_children
1400

1401

1402
class FlattenInListTransform(Visitor.VisitorTransform, SkipDeclarations):
1403
    """
1404
    This transformation flattens "x in [val1, ..., valn]" into a sequential list
1405
    of comparisons.
1406
    """
1407

1408
    def visit_PrimaryCmpNode(self, node):
1409
        self.visitchildren(node)
1410
        if node.cascade is not None:
1411
            return node
1412
        elif node.operator == 'in':
1413
            conjunction = 'or'
1414
            eq_or_neq = '=='
1415
        elif node.operator == 'not_in':
1416
            conjunction = 'and'
1417
            eq_or_neq = '!='
1418
        else:
1419
            return node
1420

1421
        if not isinstance(node.operand2, (ExprNodes.TupleNode,
1422
                                          ExprNodes.ListNode,
1423
                                          ExprNodes.SetNode)):
1424
            return node
1425

1426
        lhs = node.operand1
1427
        args = node.operand2.args
1428
        if len(args) == 0:
1429
            # note: lhs may have side effects, but ".is_simple()" may not work yet before type analysis.
1430
            if lhs.try_is_simple():
1431
                constant_result = node.operator == 'not_in'
1432
                return ExprNodes.BoolNode(node.pos, value=constant_result, constant_result=constant_result)
1433
            return node
1434

1435
        if any([arg.is_starred for arg in args]):
1436
            # Starred arguments do not directly translate to comparisons or "in" tests.
1437
            return node
1438

1439
        lhs = UtilNodes.ResultRefNode(lhs)
1440

1441
        conds = []
1442
        temps = []
1443
        for arg in args:
1444
            # Trial optimisation to avoid redundant temp assignments.
1445
            if not arg.try_is_simple():
1446
                # must evaluate all non-simple RHS before doing the comparisons
1447
                arg = UtilNodes.LetRefNode(arg)
1448
                temps.append(arg)
1449
            cond = ExprNodes.PrimaryCmpNode(
1450
                                pos = node.pos,
1451
                                operand1 = lhs,
1452
                                operator = eq_or_neq,
1453
                                operand2 = arg,
1454
                                cascade = None)
1455
            conds.append(ExprNodes.TypecastNode(
1456
                                pos = node.pos,
1457
                                operand = cond,
1458
                                type = PyrexTypes.c_bint_type))
1459
        def concat(left, right):
1460
            return ExprNodes.BoolBinopNode(
1461
                                pos = node.pos,
1462
                                operator = conjunction,
1463
                                operand1 = left,
1464
                                operand2 = right)
1465

1466
        condition = reduce(concat, conds)
1467
        new_node = UtilNodes.EvalWithTempExprNode(lhs, condition)
1468
        for temp in temps[::-1]:
1469
            new_node = UtilNodes.EvalWithTempExprNode(temp, new_node)
1470
        return new_node
1471

1472
    visit_Node = Visitor.VisitorTransform.recurse_to_children
1473

1474

1475
class DropRefcountingTransform(Visitor.VisitorTransform):
1476
    """Drop ref-counting in safe places.
1477
    """
1478
    visit_Node = Visitor.VisitorTransform.recurse_to_children
1479

1480
    def visit_ParallelAssignmentNode(self, node):
1481
        """
1482
        Parallel swap assignments like 'a,b = b,a' are safe.
1483
        """
1484
        left_names, right_names = [], []
1485
        left_indices, right_indices = [], []
1486
        temps = []
1487

1488
        for stat in node.stats:
1489
            if isinstance(stat, Nodes.SingleAssignmentNode):
1490
                if not self._extract_operand(stat.lhs, left_names,
1491
                                             left_indices, temps):
1492
                    return node
1493
                if not self._extract_operand(stat.rhs, right_names,
1494
                                             right_indices, temps):
1495
                    return node
1496
            elif isinstance(stat, Nodes.CascadedAssignmentNode):
1497
                # FIXME
1498
                return node
1499
            else:
1500
                return node
1501

1502
        if left_names or right_names:
1503
            # lhs/rhs names must be a non-redundant permutation
1504
            lnames = [ path for path, n in left_names ]
1505
            rnames = [ path for path, n in right_names ]
1506
            if set(lnames) != set(rnames):
1507
                return node
1508
            if len(set(lnames)) != len(right_names):
1509
                return node
1510

1511
        if left_indices or right_indices:
1512
            # base name and index of index nodes must be a
1513
            # non-redundant permutation
1514
            lindices = []
1515
            for lhs_node in left_indices:
1516
                index_id = self._extract_index_id(lhs_node)
1517
                if not index_id:
1518
                    return node
1519
                lindices.append(index_id)
1520
            rindices = []
1521
            for rhs_node in right_indices:
1522
                index_id = self._extract_index_id(rhs_node)
1523
                if not index_id:
1524
                    return node
1525
                rindices.append(index_id)
1526

1527
            if set(lindices) != set(rindices):
1528
                return node
1529
            if len(set(lindices)) != len(right_indices):
1530
                return node
1531

1532
            # really supporting IndexNode requires support in
1533
            # __Pyx_GetItemInt(), so let's stop short for now
1534
            return node
1535

1536
        temp_args = [t.arg for t in temps]
1537
        for temp in temps:
1538
            temp.use_managed_ref = False
1539

1540
        for _, name_node in left_names + right_names:
1541
            if name_node not in temp_args:
1542
                name_node.use_managed_ref = False
1543

1544
        for index_node in left_indices + right_indices:
1545
            index_node.use_managed_ref = False
1546

1547
        return node
1548

1549
    def _extract_operand(self, node, names, indices, temps):
1550
        node = unwrap_node(node)
1551
        if not node.type.is_pyobject:
1552
            return False
1553
        if isinstance(node, ExprNodes.CoerceToTempNode):
1554
            temps.append(node)
1555
            node = node.arg
1556
        name_path = []
1557
        obj_node = node
1558
        while obj_node.is_attribute:
1559
            if obj_node.is_py_attr:
1560
                return False
1561
            name_path.append(obj_node.member)
1562
            obj_node = obj_node.obj
1563
        if obj_node.is_name:
1564
            name_path.append(obj_node.name)
1565
            names.append( ('.'.join(name_path[::-1]), node) )
1566
        elif node.is_subscript:
1567
            if node.base.type != Builtin.list_type:
1568
                return False
1569
            if not node.index.type.is_int:
1570
                return False
1571
            if not node.base.is_name:
1572
                return False
1573
            indices.append(node)
1574
        else:
1575
            return False
1576
        return True
1577

1578
    def _extract_index_id(self, index_node):
1579
        base = index_node.base
1580
        index = index_node.index
1581
        if isinstance(index, ExprNodes.NameNode):
1582
            index_val = index.name
1583
        elif isinstance(index, ExprNodes.ConstNode):
1584
            # FIXME:
1585
            return None
1586
        else:
1587
            return None
1588
        return (base.name, index_val)
1589

1590

1591
class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
1592
    """Optimize some common calls to builtin types *before* the type
1593
    analysis phase and *after* the declarations analysis phase.
1594

1595
    This transform cannot make use of any argument types, but it can
1596
    restructure the tree in a way that the type analysis phase can
1597
    respond to.
1598

1599
    Introducing C function calls here may not be a good idea.  Move
1600
    them to the OptimizeBuiltinCalls transform instead, which runs
1601
    after type analysis.
1602
    """
1603
    # only intercept on call nodes
1604
    visit_Node = Visitor.VisitorTransform.recurse_to_children
1605

1606
    def visit_SimpleCallNode(self, node):
1607
        self.visitchildren(node)
1608
        function = node.function
1609
        if not self._function_is_builtin_name(function):
1610
            return node
1611
        return self._dispatch_to_handler(node, function, node.args)
1612

1613
    def visit_GeneralCallNode(self, node):
1614
        self.visitchildren(node)
1615
        function = node.function
1616
        if not self._function_is_builtin_name(function):
1617
            return node
1618
        arg_tuple = node.positional_args
1619
        if not isinstance(arg_tuple, ExprNodes.TupleNode):
1620
            return node
1621
        args = arg_tuple.args
1622
        return self._dispatch_to_handler(
1623
            node, function, args, node.keyword_args)
1624

1625
    def _function_is_builtin_name(self, function):
1626
        if not function.is_name:
1627
            return False
1628
        env = self.current_env()
1629
        entry = env.lookup(function.name)
1630
        if entry is not env.builtin_scope().lookup_here(function.name):
1631
            return False
1632
        # if entry is None, it's at least an undeclared name, so likely builtin
1633
        return True
1634

1635
    def _dispatch_to_handler(self, node, function, args, kwargs=None):
1636
        if kwargs is None:
1637
            handler_name = '_handle_simple_function_%s' % function.name
1638
        else:
1639
            handler_name = '_handle_general_function_%s' % function.name
1640
        handle_call = getattr(self, handler_name, None)
1641
        if handle_call is not None:
1642
            if kwargs is None:
1643
                return handle_call(node, args)
1644
            else:
1645
                return handle_call(node, args, kwargs)
1646
        return node
1647

1648
    def _inject_capi_function(self, node, cname, func_type, utility_code=None):
1649
        node.function = ExprNodes.PythonCapiFunctionNode(
1650
            node.function.pos, node.function.name, cname, func_type,
1651
            utility_code = utility_code)
1652

1653
    def _error_wrong_arg_count(self, function_name, node, args, expected=None):
1654
        if not expected:  # None or 0
1655
            arg_str = ''
1656
        elif isinstance(expected, str) or expected > 1:
1657
            arg_str = '...'
1658
        elif expected == 1:
1659
            arg_str = 'x'
1660
        else:
1661
            arg_str = ''
1662
        if expected is not None:
1663
            expected_str = 'expected %s, ' % expected
1664
        else:
1665
            expected_str = ''
1666
        error(node.pos, "%s(%s) called with wrong number of args, %sfound %d" % (
1667
            function_name, arg_str, expected_str, len(args)))
1668

1669
    # specific handlers for simple call nodes
1670

1671
    def _handle_simple_function_float(self, node, pos_args):
1672
        if not pos_args:
1673
            return ExprNodes.FloatNode(node.pos, value='0.0')
1674
        if len(pos_args) > 1:
1675
            self._error_wrong_arg_count('float', node, pos_args, 1)
1676
        arg_type = getattr(pos_args[0], 'type', None)
1677
        if arg_type in (PyrexTypes.c_double_type, Builtin.float_type):
1678
            return pos_args[0]
1679
        return node
1680

1681
    def _handle_simple_function_slice(self, node, pos_args):
1682
        arg_count = len(pos_args)
1683
        start = step = None
1684
        if arg_count == 1:
1685
            stop, = pos_args
1686
        elif arg_count == 2:
1687
            start, stop = pos_args
1688
        elif arg_count == 3:
1689
            start, stop, step = pos_args
1690
        else:
1691
            self._error_wrong_arg_count('slice', node, pos_args)
1692
            return node
1693
        return ExprNodes.SliceNode(
1694
            node.pos,
1695
            start=start or ExprNodes.NoneNode(node.pos),
1696
            stop=stop,
1697
            step=step or ExprNodes.NoneNode(node.pos))
1698

1699
    def _handle_simple_function_ord(self, node, pos_args):
1700
        """Unpack ord('X').
1701
        """
1702
        if len(pos_args) != 1:
1703
            return node
1704
        arg = pos_args[0]
1705
        if isinstance(arg, (ExprNodes.UnicodeNode, ExprNodes.BytesNode)):
1706
            if len(arg.value) == 1:
1707
                return ExprNodes.IntNode(
1708
                    arg.pos, type=PyrexTypes.c_long_type,
1709
                    value=str(ord(arg.value)),
1710
                    constant_result=ord(arg.value)
1711
                )
1712
        return node
1713

1714
    # sequence processing
1715

1716
    def _handle_simple_function_all(self, node, pos_args):
1717
        """Transform
1718

1719
        _result = all(p(x) for L in LL for x in L)
1720

1721
        into
1722

1723
        for L in LL:
1724
            for x in L:
1725
                if not p(x):
1726
                    return False
1727
        else:
1728
            return True
1729
        """
1730
        return self._transform_any_all(node, pos_args, False)
1731

1732
    def _handle_simple_function_any(self, node, pos_args):
1733
        """Transform
1734

1735
        _result = any(p(x) for L in LL for x in L)
1736

1737
        into
1738

1739
        for L in LL:
1740
            for x in L:
1741
                if p(x):
1742
                    return True
1743
        else:
1744
            return False
1745
        """
1746
        return self._transform_any_all(node, pos_args, True)
1747

1748
    def _transform_any_all(self, node, pos_args, is_any):
1749
        if len(pos_args) != 1:
1750
            return node
1751
        if not isinstance(pos_args[0], ExprNodes.GeneratorExpressionNode):
1752
            return node
1753
        gen_expr_node = pos_args[0]
1754
        generator_body = gen_expr_node.def_node.gbody
1755
        loop_node = generator_body.body
1756
        yield_expression, yield_stat_node = _find_single_yield_expression(loop_node)
1757
        if yield_expression is None:
1758
            return node
1759

1760
        if is_any:
1761
            condition = yield_expression
1762
        else:
1763
            condition = ExprNodes.NotNode(yield_expression.pos, operand=yield_expression)
1764

1765
        test_node = Nodes.IfStatNode(
1766
            yield_expression.pos, else_clause=None, if_clauses=[
1767
                Nodes.IfClauseNode(
1768
                    yield_expression.pos,
1769
                    condition=condition,
1770
                    body=Nodes.ReturnStatNode(
1771
                        node.pos,
1772
                        value=ExprNodes.BoolNode(yield_expression.pos, value=is_any, constant_result=is_any))
1773
                )]
1774
        )
1775
        loop_node.else_clause = Nodes.ReturnStatNode(
1776
            node.pos,
1777
            value=ExprNodes.BoolNode(yield_expression.pos, value=not is_any, constant_result=not is_any))
1778

1779
        Visitor.recursively_replace_node(gen_expr_node, yield_stat_node, test_node)
1780

1781
        return ExprNodes.InlinedGeneratorExpressionNode(
1782
            gen_expr_node.pos, gen=gen_expr_node, orig_func='any' if is_any else 'all')
1783

1784
    PySequence_List_func_type = PyrexTypes.CFuncType(
1785
        Builtin.list_type,
1786
        [PyrexTypes.CFuncTypeArg("it", PyrexTypes.py_object_type, None)])
1787

1788
    def _handle_simple_function_sorted(self, node, pos_args):
1789
        """Transform sorted(genexpr) and sorted([listcomp]) into
1790
        [listcomp].sort().  CPython just reads the iterable into a
1791
        list and calls .sort() on it.  Expanding the iterable in a
1792
        listcomp is still faster and the result can be sorted in
1793
        place.
1794
        """
1795
        if len(pos_args) != 1:
1796
            return node
1797

1798
        arg = pos_args[0]
1799
        if isinstance(arg, ExprNodes.ComprehensionNode) and arg.type is Builtin.list_type:
1800
            list_node = arg
1801
            loop_node = list_node.loop
1802

1803
        elif isinstance(arg, ExprNodes.GeneratorExpressionNode):
1804
            gen_expr_node = arg
1805
            loop_node = gen_expr_node.loop
1806
            yield_statements = _find_yield_statements(loop_node)
1807
            if not yield_statements:
1808
                return node
1809

1810
            list_node = ExprNodes.InlinedGeneratorExpressionNode(
1811
                node.pos, gen_expr_node, orig_func='sorted',
1812
                comprehension_type=Builtin.list_type)
1813

1814
            for yield_expression, yield_stat_node in yield_statements:
1815
                append_node = ExprNodes.ComprehensionAppendNode(
1816
                    yield_expression.pos,
1817
                    expr=yield_expression,
1818
                    target=list_node.target)
1819
                Visitor.recursively_replace_node(gen_expr_node, yield_stat_node, append_node)
1820

1821
        elif arg.is_sequence_constructor:
1822
            # sorted([a, b, c]) or sorted((a, b, c)).  The result is always a list,
1823
            # so starting off with a fresh one is more efficient.
1824
            list_node = loop_node = arg.as_list()
1825

1826
        else:
1827
            # Interestingly, PySequence_List works on a lot of non-sequence
1828
            # things as well.
1829
            list_node = loop_node = ExprNodes.PythonCapiCallNode(
1830
                node.pos,
1831
                "__Pyx_PySequence_ListKeepNew"
1832
                    if arg.is_temp and arg.type in (PyrexTypes.py_object_type, Builtin.list_type)
1833
                    else "PySequence_List",
1834
                self.PySequence_List_func_type,
1835
                args=pos_args, is_temp=True)
1836

1837
        result_node = UtilNodes.ResultRefNode(
1838
            pos=loop_node.pos, type=Builtin.list_type, may_hold_none=False)
1839
        list_assign_node = Nodes.SingleAssignmentNode(
1840
            node.pos, lhs=result_node, rhs=list_node, first=True)
1841

1842
        sort_method = ExprNodes.AttributeNode(
1843
            node.pos, obj=result_node, attribute=EncodedString('sort'),
1844
            # entry ? type ?
1845
            needs_none_check=False)
1846
        sort_node = Nodes.ExprStatNode(
1847
            node.pos, expr=ExprNodes.SimpleCallNode(
1848
                node.pos, function=sort_method, args=[]))
1849

1850
        sort_node.analyse_declarations(self.current_env())
1851

1852
        return UtilNodes.TempResultFromStatNode(
1853
            result_node,
1854
            Nodes.StatListNode(node.pos, stats=[list_assign_node, sort_node]))
1855

1856
    def __handle_simple_function_sum(self, node, pos_args):
1857
        """Transform sum(genexpr) into an equivalent inlined aggregation loop.
1858
        """
1859
        if len(pos_args) not in (1,2):
1860
            return node
1861
        if not isinstance(pos_args[0], (ExprNodes.GeneratorExpressionNode,
1862
                                        ExprNodes.ComprehensionNode)):
1863
            return node
1864
        gen_expr_node = pos_args[0]
1865
        loop_node = gen_expr_node.loop
1866

1867
        if isinstance(gen_expr_node, ExprNodes.GeneratorExpressionNode):
1868
            yield_expression, yield_stat_node = _find_single_yield_expression(loop_node)
1869
            # FIXME: currently nonfunctional
1870
            yield_expression = None
1871
            if yield_expression is None:
1872
                return node
1873
        else:  # ComprehensionNode
1874
            yield_stat_node = gen_expr_node.append
1875
            yield_expression = yield_stat_node.expr
1876
            try:
1877
                if not yield_expression.is_literal or not yield_expression.type.is_int:
1878
                    return node
1879
            except AttributeError:
1880
                return node  # in case we don't have a type yet
1881
            # special case: old Py2 backwards compatible "sum([int_const for ...])"
1882
            # can safely be unpacked into a genexpr
1883

1884
        if len(pos_args) == 1:
1885
            start = ExprNodes.IntNode(node.pos, value='0', constant_result=0)
1886
        else:
1887
            start = pos_args[1]
1888

1889
        result_ref = UtilNodes.ResultRefNode(pos=node.pos, type=PyrexTypes.py_object_type)
1890
        add_node = Nodes.SingleAssignmentNode(
1891
            yield_expression.pos,
1892
            lhs = result_ref,
1893
            rhs = ExprNodes.binop_node(node.pos, '+', result_ref, yield_expression)
1894
            )
1895

1896
        Visitor.recursively_replace_node(gen_expr_node, yield_stat_node, add_node)
1897

1898
        exec_code = Nodes.StatListNode(
1899
            node.pos,
1900
            stats = [
1901
                Nodes.SingleAssignmentNode(
1902
                    start.pos,
1903
                    lhs = UtilNodes.ResultRefNode(pos=node.pos, expression=result_ref),
1904
                    rhs = start,
1905
                    first = True),
1906
                loop_node
1907
                ])
1908

1909
        return ExprNodes.InlinedGeneratorExpressionNode(
1910
            gen_expr_node.pos, loop = exec_code, result_node = result_ref,
1911
            expr_scope = gen_expr_node.expr_scope, orig_func = 'sum',
1912
            has_local_scope = gen_expr_node.has_local_scope)
1913

1914
    def _handle_simple_function_min(self, node, pos_args):
1915
        return self._optimise_min_max(node, pos_args, '<')
1916

1917
    def _handle_simple_function_max(self, node, pos_args):
1918
        return self._optimise_min_max(node, pos_args, '>')
1919

1920
    def _optimise_min_max(self, node, args, operator):
1921
        """Replace min(a,b,...) and max(a,b,...) by explicit comparison code.
1922
        """
1923
        if len(args) <= 1:
1924
            if len(args) == 1 and args[0].is_sequence_constructor:
1925
                args = args[0].args
1926
            if len(args) <= 1:
1927
                # leave this to Python
1928
                return node
1929

1930
        cascaded_nodes = list(map(UtilNodes.ResultRefNode, args[1:]))
1931

1932
        last_result = args[0]
1933
        for arg_node in cascaded_nodes:
1934
            result_ref = UtilNodes.ResultRefNode(last_result)
1935
            last_result = ExprNodes.CondExprNode(
1936
                arg_node.pos,
1937
                true_val = arg_node,
1938
                false_val = result_ref,
1939
                test = ExprNodes.PrimaryCmpNode(
1940
                    arg_node.pos,
1941
                    operand1 = arg_node,
1942
                    operator = operator,
1943
                    operand2 = result_ref,
1944
                    )
1945
                )
1946
            last_result = UtilNodes.EvalWithTempExprNode(result_ref, last_result)
1947

1948
        for ref_node in cascaded_nodes[::-1]:
1949
            last_result = UtilNodes.EvalWithTempExprNode(ref_node, last_result)
1950

1951
        return last_result
1952

1953
    # builtin type creation
1954

1955
    def _DISABLED_handle_simple_function_tuple(self, node, pos_args):
1956
        if not pos_args:
1957
            return ExprNodes.TupleNode(node.pos, args=[], constant_result=())
1958
        # This is a bit special - for iterables (including genexps),
1959
        # Python actually overallocates and resizes a newly created
1960
        # tuple incrementally while reading items, which we can't
1961
        # easily do without explicit node support. Instead, we read
1962
        # the items into a list and then copy them into a tuple of the
1963
        # final size.  This takes up to twice as much memory, but will
1964
        # have to do until we have real support for genexps.
1965
        result = self._transform_list_set_genexpr(node, pos_args, Builtin.list_type)
1966
        if result is not node:
1967
            return ExprNodes.AsTupleNode(node.pos, arg=result)
1968
        return node
1969

1970
    def _handle_simple_function_frozenset(self, node, pos_args):
1971
        """Replace frozenset([...]) by frozenset((...)) as tuples are more efficient.
1972
        """
1973
        if len(pos_args) != 1:
1974
            return node
1975
        if pos_args[0].is_sequence_constructor and not pos_args[0].args:
1976
            del pos_args[0]
1977
        elif isinstance(pos_args[0], ExprNodes.ListNode):
1978
            pos_args[0] = pos_args[0].as_tuple()
1979
        return node
1980

1981
    def _handle_simple_function_list(self, node, pos_args):
1982
        if not pos_args:
1983
            return ExprNodes.ListNode(node.pos, args=[], constant_result=[])
1984
        return self._transform_list_set_genexpr(node, pos_args, Builtin.list_type)
1985

1986
    def _handle_simple_function_set(self, node, pos_args):
1987
        if not pos_args:
1988
            return ExprNodes.SetNode(node.pos, args=[], constant_result=set())
1989
        return self._transform_list_set_genexpr(node, pos_args, Builtin.set_type)
1990

1991
    def _transform_list_set_genexpr(self, node, pos_args, target_type):
1992
        """Replace set(genexpr) and list(genexpr) by an inlined comprehension.
1993
        """
1994
        if len(pos_args) > 1:
1995
            return node
1996
        if not isinstance(pos_args[0], ExprNodes.GeneratorExpressionNode):
1997
            return node
1998
        gen_expr_node = pos_args[0]
1999
        loop_node = gen_expr_node.loop
2000

2001
        yield_statements = _find_yield_statements(loop_node)
2002
        if not yield_statements:
2003
            return node
2004

2005
        result_node = ExprNodes.InlinedGeneratorExpressionNode(
2006
            node.pos, gen_expr_node,
2007
            orig_func='set' if target_type is Builtin.set_type else 'list',
2008
            comprehension_type=target_type)
2009

2010
        for yield_expression, yield_stat_node in yield_statements:
2011
            append_node = ExprNodes.ComprehensionAppendNode(
2012
                yield_expression.pos,
2013
                expr=yield_expression,
2014
                target=result_node.target)
2015
            Visitor.recursively_replace_node(gen_expr_node, yield_stat_node, append_node)
2016

2017
        return result_node
2018

2019
    def _handle_simple_function_dict(self, node, pos_args):
2020
        """Replace dict( (a,b) for ... ) by an inlined { a:b for ... }
2021
        """
2022
        if len(pos_args) == 0:
2023
            return ExprNodes.DictNode(node.pos, key_value_pairs=[], constant_result={})
2024
        if len(pos_args) > 1:
2025
            return node
2026
        if not isinstance(pos_args[0], ExprNodes.GeneratorExpressionNode):
2027
            return node
2028
        gen_expr_node = pos_args[0]
2029
        loop_node = gen_expr_node.loop
2030

2031
        yield_statements = _find_yield_statements(loop_node)
2032
        if not yield_statements:
2033
            return node
2034

2035
        for yield_expression, _ in yield_statements:
2036
            if not isinstance(yield_expression, ExprNodes.TupleNode):
2037
                return node
2038
            if len(yield_expression.args) != 2:
2039
                return node
2040

2041
        result_node = ExprNodes.InlinedGeneratorExpressionNode(
2042
            node.pos, gen_expr_node, orig_func='dict',
2043
            comprehension_type=Builtin.dict_type)
2044

2045
        for yield_expression, yield_stat_node in yield_statements:
2046
            append_node = ExprNodes.DictComprehensionAppendNode(
2047
                yield_expression.pos,
2048
                key_expr=yield_expression.args[0],
2049
                value_expr=yield_expression.args[1],
2050
                target=result_node.target)
2051
            Visitor.recursively_replace_node(gen_expr_node, yield_stat_node, append_node)
2052

2053
        return result_node
2054

2055
    # specific handlers for general call nodes
2056

2057
    def _handle_general_function_dict(self, node, pos_args, kwargs):
2058
        """Replace dict(a=b,c=d,...) by the underlying keyword dict
2059
        construction which is done anyway.
2060
        """
2061
        if len(pos_args) > 0:
2062
            return node
2063
        if not isinstance(kwargs, ExprNodes.DictNode):
2064
            return node
2065
        return kwargs
2066

2067

2068
class InlineDefNodeCalls(Visitor.NodeRefCleanupMixin, Visitor.EnvTransform):
2069
    visit_Node = Visitor.VisitorTransform.recurse_to_children
2070

2071
    def get_constant_value_node(self, name_node):
2072
        if name_node.cf_state is None:
2073
            return None
2074
        if name_node.cf_state.cf_is_null:
2075
            return None
2076
        entry = self.current_env().lookup(name_node.name)
2077
        if not entry or (not entry.cf_assignments
2078
                         or len(entry.cf_assignments) != 1):
2079
            # not just a single assignment in all closures
2080
            return None
2081
        return entry.cf_assignments[0].rhs
2082

2083
    def visit_SimpleCallNode(self, node):
2084
        self.visitchildren(node)
2085
        if not self.current_directives.get('optimize.inline_defnode_calls'):
2086
            return node
2087
        function_name = node.function
2088
        if not function_name.is_name:
2089
            return node
2090
        function = self.get_constant_value_node(function_name)
2091
        if not isinstance(function, ExprNodes.PyCFunctionNode):
2092
            return node
2093
        inlined = ExprNodes.InlinedDefNodeCallNode(
2094
            node.pos, function_name=function_name,
2095
            function=function, args=node.args,
2096
            generator_arg_tag=node.generator_arg_tag)
2097
        if inlined.can_be_inlined():
2098
            return self.replace(node, inlined)
2099
        return node
2100

2101

2102
class OptimizeBuiltinCalls(Visitor.NodeRefCleanupMixin,
2103
                           Visitor.MethodDispatcherTransform):
2104
    """Optimize some common methods calls and instantiation patterns
2105
    for builtin types *after* the type analysis phase.
2106

2107
    Running after type analysis, this transform can only perform
2108
    function replacements that do not alter the function return type
2109
    in a way that was not anticipated by the type analysis.
2110
    """
2111
    ### cleanup to avoid redundant coercions to/from Python types
2112

2113
    def visit_PyTypeTestNode(self, node):
2114
        """Flatten redundant type checks after tree changes.
2115
        """
2116
        self.visitchildren(node)
2117
        return node.reanalyse()
2118

2119
    def _visit_TypecastNode(self, node):
2120
        # disabled - the user may have had a reason to put a type
2121
        # cast, even if it looks redundant to Cython
2122
        """
2123
        Drop redundant type casts.
2124
        """
2125
        self.visitchildren(node)
2126
        if node.type == node.operand.type:
2127
            return node.operand
2128
        return node
2129

2130
    def visit_ExprStatNode(self, node):
2131
        """
2132
        Drop dead code and useless coercions.
2133
        """
2134
        self.visitchildren(node)
2135
        if isinstance(node.expr, ExprNodes.CoerceToPyTypeNode):
2136
            node.expr = node.expr.arg
2137
        expr = node.expr
2138
        if expr is None or expr.is_none or expr.is_literal:
2139
            # Expression was removed or is dead code => remove ExprStatNode as well.
2140
            return None
2141
        if expr.is_name and expr.entry and (expr.entry.is_local or expr.entry.is_arg):
2142
            # Ignore dead references to local variables etc.
2143
            return None
2144
        return node
2145

2146
    def visit_CoerceToBooleanNode(self, node):
2147
        """Drop redundant conversion nodes after tree changes.
2148
        """
2149
        self.visitchildren(node)
2150
        arg = node.arg
2151
        if isinstance(arg, ExprNodes.PyTypeTestNode):
2152
            arg = arg.arg
2153
        if isinstance(arg, ExprNodes.CoerceToPyTypeNode):
2154
            if arg.type in (PyrexTypes.py_object_type, Builtin.bool_type):
2155
                return arg.arg.coerce_to_boolean(self.current_env())
2156
        return node
2157

2158
    PyNumber_Float_func_type = PyrexTypes.CFuncType(
2159
        PyrexTypes.py_object_type, [
2160
            PyrexTypes.CFuncTypeArg("o", PyrexTypes.py_object_type, None)
2161
            ])
2162

2163
    def visit_CoerceToPyTypeNode(self, node):
2164
        """Drop redundant conversion nodes after tree changes."""
2165
        self.visitchildren(node)
2166
        arg = node.arg
2167
        if isinstance(arg, ExprNodes.CoerceFromPyTypeNode):
2168
            arg = arg.arg
2169
        if isinstance(arg, ExprNodes.PythonCapiCallNode):
2170
            if arg.function.name == 'float' and len(arg.args) == 1:
2171
                # undo redundant Py->C->Py coercion
2172
                func_arg = arg.args[0]
2173
                if func_arg.type is Builtin.float_type:
2174
                    return func_arg.as_none_safe_node("float() argument must be a string or a number, not 'NoneType'")
2175
                elif func_arg.type.is_pyobject and arg.function.cname == "__Pyx_PyObject_AsDouble":
2176
                    return ExprNodes.PythonCapiCallNode(
2177
                        node.pos, '__Pyx_PyNumber_Float', self.PyNumber_Float_func_type,
2178
                        args=[func_arg],
2179
                        py_name='float',
2180
                        is_temp=node.is_temp,
2181
                        utility_code = UtilityCode.load_cached("pynumber_float", "TypeConversion.c"),
2182
                        result_is_used=node.result_is_used,
2183
                    ).coerce_to(node.type, self.current_env())
2184
        return node
2185

2186
    def visit_CoerceFromPyTypeNode(self, node):
2187
        """Drop redundant conversion nodes after tree changes.
2188

2189
        Also, optimise away calls to Python's builtin int() and
2190
        float() if the result is going to be coerced back into a C
2191
        type anyway.
2192
        """
2193
        self.visitchildren(node)
2194
        arg = node.arg
2195
        if not arg.type.is_pyobject:
2196
            # no Python conversion left at all, just do a C coercion instead
2197
            if node.type != arg.type:
2198
                arg = arg.coerce_to(node.type, self.current_env())
2199
            return arg
2200
        if isinstance(arg, ExprNodes.PyTypeTestNode):
2201
            arg = arg.arg
2202
        if arg.is_literal:
2203
            if (node.type.is_int and isinstance(arg, ExprNodes.IntNode) or
2204
                    node.type.is_float and isinstance(arg, ExprNodes.FloatNode) or
2205
                    node.type.is_int and isinstance(arg, ExprNodes.BoolNode)):
2206
                return arg.coerce_to(node.type, self.current_env())
2207
        elif isinstance(arg, ExprNodes.CoerceToPyTypeNode):
2208
            if arg.type is PyrexTypes.py_object_type:
2209
                if node.type.assignable_from(arg.arg.type):
2210
                    # completely redundant C->Py->C coercion
2211
                    return arg.arg.coerce_to(node.type, self.current_env())
2212
            elif arg.type is Builtin.unicode_type:
2213
                if arg.arg.type.is_unicode_char and node.type.is_unicode_char:
2214
                    return arg.arg.coerce_to(node.type, self.current_env())
2215
        elif isinstance(arg, ExprNodes.SimpleCallNode):
2216
            if node.type.is_int or node.type.is_float:
2217
                return self._optimise_numeric_cast_call(node, arg)
2218
        elif arg.is_subscript:
2219
            index_node = arg.index
2220
            if isinstance(index_node, ExprNodes.CoerceToPyTypeNode):
2221
                index_node = index_node.arg
2222
            if index_node.type.is_int:
2223
                return self._optimise_int_indexing(node, arg, index_node)
2224
        return node
2225

2226
    PyBytes_GetItemInt_func_type = PyrexTypes.CFuncType(
2227
        PyrexTypes.c_char_type, [
2228
            PyrexTypes.CFuncTypeArg("bytes", Builtin.bytes_type, None),
2229
            PyrexTypes.CFuncTypeArg("index", PyrexTypes.c_py_ssize_t_type, None),
2230
            PyrexTypes.CFuncTypeArg("check_bounds", PyrexTypes.c_int_type, None),
2231
            ],
2232
        exception_value = "((char)-1)",
2233
        exception_check = True)
2234

2235
    def _optimise_int_indexing(self, coerce_node, arg, index_node):
2236
        env = self.current_env()
2237
        bound_check_bool = env.directives['boundscheck'] and 1 or 0
2238
        if arg.base.type is Builtin.bytes_type:
2239
            if coerce_node.type in (PyrexTypes.c_char_type, PyrexTypes.c_uchar_type):
2240
                # bytes[index] -> char
2241
                bound_check_node = ExprNodes.IntNode(
2242
                    coerce_node.pos, value=str(bound_check_bool),
2243
                    constant_result=bound_check_bool)
2244
                node = ExprNodes.PythonCapiCallNode(
2245
                    coerce_node.pos, "__Pyx_PyBytes_GetItemInt",
2246
                    self.PyBytes_GetItemInt_func_type,
2247
                    args=[
2248
                        arg.base.as_none_safe_node("'NoneType' object is not subscriptable"),
2249
                        index_node.coerce_to(PyrexTypes.c_py_ssize_t_type, env),
2250
                        bound_check_node,
2251
                        ],
2252
                    is_temp=True,
2253
                    utility_code=UtilityCode.load_cached(
2254
                        'bytes_index', 'StringTools.c'))
2255
                if coerce_node.type is not PyrexTypes.c_char_type:
2256
                    node = node.coerce_to(coerce_node.type, env)
2257
                return node
2258
        return coerce_node
2259

2260
    float_float_func_types = {
2261
        float_type: PyrexTypes.CFuncType(
2262
            float_type, [
2263
                PyrexTypes.CFuncTypeArg("arg", float_type, None)
2264
            ])
2265
        for float_type in (PyrexTypes.c_float_type, PyrexTypes.c_double_type, PyrexTypes.c_longdouble_type)
2266
    }
2267

2268
    def _optimise_numeric_cast_call(self, node, arg):
2269
        function = arg.function
2270
        args = None
2271
        if isinstance(arg, ExprNodes.PythonCapiCallNode):
2272
            args = arg.args
2273
        elif isinstance(function, ExprNodes.NameNode):
2274
            if function.type.is_builtin_type and isinstance(arg.arg_tuple, ExprNodes.TupleNode):
2275
                args = arg.arg_tuple.args
2276

2277
        if args is None or len(args) != 1:
2278
            return node
2279
        func_arg = args[0]
2280
        if isinstance(func_arg, ExprNodes.CoerceToPyTypeNode):
2281
            func_arg = func_arg.arg
2282
        elif func_arg.type.is_pyobject:
2283
            # play it safe: Python conversion might work on all sorts of things
2284
            return node
2285

2286
        if function.name == 'int':
2287
            if func_arg.type.is_int or node.type.is_int:
2288
                if func_arg.type == node.type:
2289
                    return func_arg
2290
                elif func_arg.type in (PyrexTypes.c_py_ucs4_type, PyrexTypes.c_py_unicode_type):
2291
                    # need to parse (<Py_UCS4>'1') as digit 1
2292
                    return self._pyucs4_to_number(node, function.name, func_arg)
2293
                elif node.type.assignable_from(func_arg.type) or func_arg.type.is_float:
2294
                    return ExprNodes.TypecastNode(node.pos, operand=func_arg, type=node.type)
2295
            elif func_arg.type.is_float and node.type.is_numeric:
2296
                if func_arg.type.math_h_modifier == 'l':
2297
                    # Work around missing Cygwin definition.
2298
                    truncl = '__Pyx_truncl'
2299
                else:
2300
                    truncl = 'trunc' + func_arg.type.math_h_modifier
2301
                return ExprNodes.PythonCapiCallNode(
2302
                    node.pos, truncl,
2303
                    func_type=self.float_float_func_types[func_arg.type],
2304
                    args=[func_arg],
2305
                    py_name='int',
2306
                    is_temp=node.is_temp,
2307
                    result_is_used=node.result_is_used,
2308
                ).coerce_to(node.type, self.current_env())
2309
        elif function.name == 'float':
2310
            if func_arg.type.is_float or node.type.is_float:
2311
                if func_arg.type == node.type:
2312
                    return func_arg
2313
                elif func_arg.type in (PyrexTypes.c_py_ucs4_type, PyrexTypes.c_py_unicode_type):
2314
                    # need to parse (<Py_UCS4>'1') as digit 1
2315
                    return self._pyucs4_to_number(node, function.name, func_arg)
2316
                elif node.type.assignable_from(func_arg.type) or func_arg.type.is_float:
2317
                    return ExprNodes.TypecastNode(
2318
                        node.pos, operand=func_arg, type=node.type)
2319
        return node
2320

2321
    pyucs4_int_func_type = PyrexTypes.CFuncType(
2322
        PyrexTypes.c_int_type, [
2323
            PyrexTypes.CFuncTypeArg("arg", PyrexTypes.c_py_ucs4_type, None)
2324
        ],
2325
        exception_value=-1)
2326

2327
    pyucs4_double_func_type = PyrexTypes.CFuncType(
2328
        PyrexTypes.c_double_type, [
2329
            PyrexTypes.CFuncTypeArg("arg", PyrexTypes.c_py_ucs4_type, None)
2330
        ],
2331
        exception_value=-1.0)
2332

2333
    def _pyucs4_to_number(self, node, py_type_name, func_arg):
2334
        assert py_type_name in ("int", "float")
2335
        return ExprNodes.PythonCapiCallNode(
2336
            node.pos, "__Pyx_int_from_UCS4" if py_type_name == "int" else "__Pyx_double_from_UCS4",
2337
            func_type=self.pyucs4_int_func_type if py_type_name == "int" else self.pyucs4_double_func_type,
2338
            args=[func_arg],
2339
            py_name=py_type_name,
2340
            is_temp=node.is_temp,
2341
            result_is_used=node.result_is_used,
2342
            utility_code=UtilityCode.load_cached("int_pyucs4" if py_type_name == "int" else "float_pyucs4", "Builtins.c"),
2343
        ).coerce_to(node.type, self.current_env())
2344

2345
    def _error_wrong_arg_count(self, function_name, node, args, expected=None):
2346
        if not expected:  # None or 0
2347
            arg_str = ''
2348
        elif isinstance(expected, str) or expected > 1:
2349
            arg_str = '...'
2350
        elif expected == 1:
2351
            arg_str = 'x'
2352
        else:
2353
            arg_str = ''
2354
        if expected is not None:
2355
            expected_str = 'expected %s, ' % expected
2356
        else:
2357
            expected_str = ''
2358
        error(node.pos, "%s(%s) called with wrong number of args, %sfound %d" % (
2359
            function_name, arg_str, expected_str, len(args)))
2360

2361
    ### generic fallbacks
2362

2363
    def _handle_function(self, node, function_name, function, arg_list, kwargs):
2364
        return node
2365

2366
    def _handle_method(self, node, type_name, attr_name, function,
2367
                       arg_list, is_unbound_method, kwargs):
2368
        """
2369
        Try to inject C-API calls for unbound method calls to builtin types.
2370
        While the method declarations in Builtin.py already handle this, we
2371
        can additionally resolve bound and unbound methods here that were
2372
        assigned to variables ahead of time.
2373
        """
2374
        if kwargs:
2375
            return node
2376
        if not function or not function.is_attribute or not function.obj.is_name:
2377
            # cannot track unbound method calls over more than one indirection as
2378
            # the names might have been reassigned in the meantime
2379
            return node
2380
        type_entry = self.current_env().lookup(type_name)
2381
        if not type_entry:
2382
            return node
2383
        method = ExprNodes.AttributeNode(
2384
            node.function.pos,
2385
            obj=ExprNodes.NameNode(
2386
                function.pos,
2387
                name=type_name,
2388
                entry=type_entry,
2389
                type=type_entry.type),
2390
            attribute=attr_name,
2391
            is_called=True).analyse_as_type_attribute(self.current_env())
2392
        if method is None:
2393
            return self._optimise_generic_builtin_method_call(
2394
                node, attr_name, function, arg_list, is_unbound_method)
2395
        args = node.args
2396
        if args is None and node.arg_tuple:
2397
            args = node.arg_tuple.args
2398
        call_node = ExprNodes.SimpleCallNode(
2399
            node.pos,
2400
            function=method,
2401
            args=args)
2402
        if not is_unbound_method:
2403
            call_node.self = function.obj
2404
        call_node.analyse_c_function_call(self.current_env())
2405
        call_node.analysed = True
2406
        return call_node.coerce_to(node.type, self.current_env())
2407

2408
    ### builtin types
2409

2410
    def _optimise_generic_builtin_method_call(self, node, attr_name, function, arg_list, is_unbound_method):
2411
        """
2412
        Try to inject an unbound method call for a call to a method of a known builtin type.
2413
        This enables caching the underlying C function of the method at runtime.
2414
        """
2415
        arg_count = len(arg_list)
2416
        if is_unbound_method or arg_count >= 3 or not (function.is_attribute and function.is_py_attr):
2417
            return node
2418
        if not function.obj.type.is_builtin_type:
2419
            return node
2420
        if function.obj.type is Builtin.type_type:
2421
            # allows different actual types => unsafe
2422
            return node
2423
        return ExprNodes.CachedBuiltinMethodCallNode(
2424
            node, function.obj, attr_name, arg_list)
2425

2426
    PyObject_Unicode_func_type = PyrexTypes.CFuncType(
2427
        Builtin.unicode_type, [
2428
            PyrexTypes.CFuncTypeArg("obj", PyrexTypes.py_object_type, None)
2429
            ])
2430

2431
    def _handle_simple_function_unicode(self, node, function, pos_args):
2432
        """Optimise single argument calls to unicode().
2433
        """
2434
        if len(pos_args) != 1:
2435
            if len(pos_args) == 0:
2436
                return ExprNodes.UnicodeNode(node.pos, value=EncodedString())
2437
            return node
2438
        arg = pos_args[0]
2439
        if arg.type is Builtin.unicode_type:
2440
            if not arg.may_be_none():
2441
                return arg
2442
            cname = "__Pyx_PyUnicode_Unicode"
2443
            utility_code = UtilityCode.load_cached('PyUnicode_Unicode', 'StringTools.c')
2444
        else:
2445
            cname = "__Pyx_PyObject_Unicode"
2446
            utility_code = UtilityCode.load_cached('PyObject_Unicode', 'StringTools.c')
2447
        return ExprNodes.PythonCapiCallNode(
2448
            node.pos, cname, self.PyObject_Unicode_func_type,
2449
            args=pos_args,
2450
            is_temp=node.is_temp,
2451
            utility_code=utility_code,
2452
            py_name="unicode")
2453

2454
    _handle_simple_function_str = _handle_simple_function_unicode
2455

2456
    def visit_FormattedValueNode(self, node):
2457
        """Simplify or avoid plain string formatting of a unicode value.
2458
        This seems misplaced here, but plain unicode formatting is essentially
2459
        a call to the unicode() builtin, which is optimised right above.
2460
        """
2461
        self.visitchildren(node)
2462
        if node.value.type is Builtin.unicode_type and not node.c_format_spec and not node.format_spec:
2463
            if not node.conversion_char or node.conversion_char == 's':
2464
                # value is definitely a unicode string and we don't format it any special
2465
                return self._handle_simple_function_unicode(node, None, [node.value])
2466
        return node
2467

2468
    PyDict_Copy_func_type = PyrexTypes.CFuncType(
2469
        Builtin.dict_type, [
2470
            PyrexTypes.CFuncTypeArg("dict", Builtin.dict_type, None)
2471
            ])
2472

2473
    def _handle_simple_function_dict(self, node, function, pos_args):
2474
        """Replace dict(some_dict) by PyDict_Copy(some_dict).
2475
        """
2476
        if len(pos_args) != 1:
2477
            return node
2478
        arg = pos_args[0]
2479
        if arg.type is Builtin.dict_type:
2480
            arg = arg.as_none_safe_node("'NoneType' is not iterable")
2481
            return ExprNodes.PythonCapiCallNode(
2482
                node.pos, "PyDict_Copy", self.PyDict_Copy_func_type,
2483
                args = [arg],
2484
                is_temp = node.is_temp
2485
                )
2486
        return node
2487

2488
    PySequence_List_func_type = PyrexTypes.CFuncType(
2489
        Builtin.list_type,
2490
        [PyrexTypes.CFuncTypeArg("it", PyrexTypes.py_object_type, None)])
2491

2492
    def _handle_simple_function_list(self, node, function, pos_args):
2493
        """Turn list(ob) into PySequence_List(ob).
2494
        """
2495
        if len(pos_args) != 1:
2496
            return node
2497
        arg = pos_args[0]
2498
        return ExprNodes.PythonCapiCallNode(
2499
            node.pos,
2500
            "__Pyx_PySequence_ListKeepNew"
2501
                if node.is_temp and arg.is_temp and arg.type in (PyrexTypes.py_object_type, Builtin.list_type)
2502
                else "PySequence_List",
2503
            self.PySequence_List_func_type,
2504
            args=pos_args,
2505
            is_temp=node.is_temp,
2506
        )
2507

2508
    PyList_AsTuple_func_type = PyrexTypes.CFuncType(
2509
        Builtin.tuple_type, [
2510
            PyrexTypes.CFuncTypeArg("list", Builtin.list_type, None)
2511
            ])
2512

2513
    def _handle_simple_function_tuple(self, node, function, pos_args):
2514
        """Replace tuple([...]) by PyList_AsTuple or PySequence_Tuple.
2515
        """
2516
        if len(pos_args) != 1 or not node.is_temp:
2517
            return node
2518
        arg = pos_args[0]
2519
        if arg.type is Builtin.tuple_type and not arg.may_be_none():
2520
            return arg
2521
        if arg.type is Builtin.list_type:
2522
            pos_args[0] = arg.as_none_safe_node(
2523
                "'NoneType' object is not iterable")
2524

2525
            return ExprNodes.PythonCapiCallNode(
2526
                node.pos, "PyList_AsTuple", self.PyList_AsTuple_func_type,
2527
                args=pos_args, is_temp=node.is_temp)
2528
        else:
2529
            return ExprNodes.AsTupleNode(node.pos, arg=arg, type=Builtin.tuple_type)
2530

2531
    PySet_New_func_type = PyrexTypes.CFuncType(
2532
        Builtin.set_type, [
2533
            PyrexTypes.CFuncTypeArg("it", PyrexTypes.py_object_type, None)
2534
        ])
2535

2536
    def _handle_simple_function_set(self, node, function, pos_args):
2537
        if len(pos_args) != 1:
2538
            return node
2539
        if pos_args[0].is_sequence_constructor:
2540
            # We can optimise set([x,y,z]) safely into a set literal,
2541
            # but only if we create all items before adding them -
2542
            # adding an item may raise an exception if it is not
2543
            # hashable, but creating the later items may have
2544
            # side-effects.
2545
            args = []
2546
            temps = []
2547
            for arg in pos_args[0].args:
2548
                if not arg.is_simple():
2549
                    arg = UtilNodes.LetRefNode(arg)
2550
                    temps.append(arg)
2551
                args.append(arg)
2552
            result = ExprNodes.SetNode(node.pos, is_temp=1, args=args)
2553
            self.replace(node, result)
2554
            for temp in temps[::-1]:
2555
                result = UtilNodes.EvalWithTempExprNode(temp, result)
2556
            return result
2557
        else:
2558
            # PySet_New(it) is better than a generic Python call to set(it)
2559
            return self.replace(node, ExprNodes.PythonCapiCallNode(
2560
                node.pos, "PySet_New",
2561
                self.PySet_New_func_type,
2562
                args=pos_args,
2563
                is_temp=node.is_temp,
2564
                py_name="set"))
2565

2566
    PyFrozenSet_New_func_type = PyrexTypes.CFuncType(
2567
        Builtin.frozenset_type, [
2568
            PyrexTypes.CFuncTypeArg("it", PyrexTypes.py_object_type, None)
2569
        ])
2570

2571
    def _handle_simple_function_frozenset(self, node, function, pos_args):
2572
        if not pos_args:
2573
            pos_args = [ExprNodes.NullNode(node.pos)]
2574
        elif len(pos_args) > 1:
2575
            return node
2576
        elif pos_args[0].type is Builtin.frozenset_type and not pos_args[0].may_be_none():
2577
            return pos_args[0]
2578
        # PyFrozenSet_New(it) is better than a generic Python call to frozenset(it)
2579
        return ExprNodes.PythonCapiCallNode(
2580
            node.pos, "__Pyx_PyFrozenSet_New",
2581
            self.PyFrozenSet_New_func_type,
2582
            args=pos_args,
2583
            is_temp=node.is_temp,
2584
            utility_code=UtilityCode.load_cached('pyfrozenset_new', 'Builtins.c'),
2585
            py_name="frozenset")
2586

2587
    PyObject_AsDouble_func_type = PyrexTypes.CFuncType(
2588
        PyrexTypes.c_double_type, [
2589
            PyrexTypes.CFuncTypeArg("obj", PyrexTypes.py_object_type, None),
2590
            ],
2591
        exception_value = "((double)-1)",
2592
        exception_check = True)
2593

2594
    def _handle_simple_function_float(self, node, function, pos_args):
2595
        """Transform float() into either a C type cast or a faster C
2596
        function call.
2597
        """
2598
        # Note: this requires the float() function to be typed as
2599
        # returning a C 'double'
2600
        if len(pos_args) == 0:
2601
            return ExprNodes.FloatNode(
2602
                node, value="0.0", constant_result=0.0
2603
                ).coerce_to(Builtin.float_type, self.current_env())
2604
        elif len(pos_args) != 1:
2605
            self._error_wrong_arg_count('float', node, pos_args, '0 or 1')
2606
            return node
2607

2608
        func_arg = pos_args[0]
2609
        if isinstance(func_arg, ExprNodes.CoerceToPyTypeNode):
2610
            func_arg = func_arg.arg
2611
        if func_arg.type is PyrexTypes.c_double_type:
2612
            return func_arg
2613
        elif func_arg.type in (PyrexTypes.c_py_ucs4_type, PyrexTypes.c_py_unicode_type):
2614
            # need to parse (<Py_UCS4>'1') as digit 1
2615
            return self._pyucs4_to_number(node, function.name, func_arg)
2616
        elif node.type.assignable_from(func_arg.type) or func_arg.type.is_numeric:
2617
            return ExprNodes.TypecastNode(
2618
                node.pos, operand=func_arg, type=node.type)
2619

2620
        arg = pos_args[0].as_none_safe_node(
2621
            "float() argument must be a string or a number, not 'NoneType'")
2622

2623
        if func_arg.type is Builtin.bytes_type:
2624
            cfunc_name = "__Pyx_PyBytes_AsDouble"
2625
            utility_code_name = 'pybytes_as_double'
2626
        elif func_arg.type is Builtin.bytearray_type:
2627
            cfunc_name = "__Pyx_PyByteArray_AsDouble"
2628
            utility_code_name = 'pybytes_as_double'
2629
        elif func_arg.type is Builtin.unicode_type:
2630
            cfunc_name = "__Pyx_PyUnicode_AsDouble"
2631
            utility_code_name = 'pyunicode_as_double'
2632
        elif func_arg.type is Builtin.int_type:
2633
            cfunc_name = "PyLong_AsDouble"
2634
            utility_code_name = None
2635
        else:
2636
            arg = pos_args[0]  # no need for an additional None check
2637
            cfunc_name = "__Pyx_PyObject_AsDouble"
2638
            utility_code_name = 'pyobject_as_double'
2639

2640
        return ExprNodes.PythonCapiCallNode(
2641
            node.pos, cfunc_name,
2642
            self.PyObject_AsDouble_func_type,
2643
            args = [arg],
2644
            is_temp = node.is_temp,
2645
            utility_code = load_c_utility(utility_code_name) if utility_code_name else None,
2646
            py_name = "float")
2647

2648
    PyNumber_Int_func_type = PyrexTypes.CFuncType(
2649
        Builtin.int_type, [
2650
            PyrexTypes.CFuncTypeArg("o", PyrexTypes.py_object_type, None)
2651
            ])
2652

2653
    PyLong_FromDouble_func_type = PyrexTypes.CFuncType(
2654
        Builtin.int_type, [
2655
            PyrexTypes.CFuncTypeArg("value", PyrexTypes.c_double_type, None)
2656
            ])
2657

2658
    def _handle_simple_function_int(self, node, function, pos_args):
2659
        """Transform int() into a faster C function call.
2660
        """
2661
        if len(pos_args) == 0:
2662
            return ExprNodes.IntNode(node.pos, value="0", constant_result=0,
2663
                                     type=Builtin.int_type)
2664
        elif len(pos_args) != 1:
2665
            return node  # int(x, base)
2666
        func_arg = pos_args[0]
2667
        if isinstance(func_arg, ExprNodes.CoerceToPyTypeNode):
2668
            if func_arg.arg.type.is_float:
2669
                return ExprNodes.PythonCapiCallNode(
2670
                    node.pos, "PyLong_FromDouble", self.PyLong_FromDouble_func_type,
2671
                    args=[func_arg.arg], is_temp=True, py_name='int',
2672
                )
2673
            else:
2674
                return node  # handled in visit_CoerceFromPyTypeNode()
2675
        if func_arg.type.is_pyobject and node.type.is_pyobject:
2676
            return ExprNodes.PythonCapiCallNode(
2677
                node.pos, "__Pyx_PyNumber_Int", self.PyNumber_Int_func_type,
2678
                args=pos_args, is_temp=True, py_name='int')
2679
        return node
2680

2681
    def _handle_simple_function_bool(self, node, function, pos_args):
2682
        """Transform bool(x) into a type coercion to a boolean.
2683
        """
2684
        if len(pos_args) == 0:
2685
            return ExprNodes.BoolNode(
2686
                node.pos, value=False, constant_result=False
2687
                ).coerce_to(Builtin.bool_type, self.current_env())
2688
        elif len(pos_args) != 1:
2689
            self._error_wrong_arg_count('bool', node, pos_args, '0 or 1')
2690
            return node
2691
        else:
2692
            # => !!<bint>(x)  to make sure it's exactly 0 or 1
2693
            operand = pos_args[0].coerce_to_boolean(self.current_env())
2694
            operand = ExprNodes.NotNode(node.pos, operand = operand)
2695
            operand = ExprNodes.NotNode(node.pos, operand = operand)
2696
            # coerce back to Python object as that's the result we are expecting
2697
            return operand.coerce_to_pyobject(self.current_env())
2698

2699
    PyMemoryView_FromObject_func_type = PyrexTypes.CFuncType(
2700
        Builtin.memoryview_type, [
2701
            PyrexTypes.CFuncTypeArg("value", PyrexTypes.py_object_type, None)
2702
            ])
2703

2704
    PyMemoryView_FromBuffer_func_type = PyrexTypes.CFuncType(
2705
        Builtin.memoryview_type, [
2706
            PyrexTypes.CFuncTypeArg("value", Builtin.py_buffer_type, None)
2707
            ])
2708

2709
    def _handle_simple_function_memoryview(self, node, function, pos_args):
2710
        if len(pos_args) != 1:
2711
            self._error_wrong_arg_count('memoryview', node, pos_args, '1')
2712
            return node
2713
        else:
2714
            if pos_args[0].type.is_pyobject:
2715
                return ExprNodes.PythonCapiCallNode(
2716
                    node.pos, "PyMemoryView_FromObject",
2717
                    self.PyMemoryView_FromObject_func_type,
2718
                    args = [pos_args[0]],
2719
                    is_temp = node.is_temp,
2720
                    py_name = "memoryview")
2721
            elif pos_args[0].type.is_ptr and pos_args[0].base_type is Builtin.py_buffer_type:
2722
                # TODO - this currently doesn't work because the buffer fails a
2723
                # "can coerce to python object" test earlier. But it'd be nice to support
2724
                return ExprNodes.PythonCapiCallNode(
2725
                    node.pos, "PyMemoryView_FromBuffer",
2726
                    self.PyMemoryView_FromBuffer_func_type,
2727
                    args = [pos_args[0]],
2728
                    is_temp = node.is_temp,
2729
                    py_name = "memoryview")
2730
        return node
2731

2732

2733
    ### builtin functions
2734

2735
    Pyx_ssize_strlen_func_type = PyrexTypes.CFuncType(
2736
        PyrexTypes.c_py_ssize_t_type, [
2737
            PyrexTypes.CFuncTypeArg("bytes", PyrexTypes.c_const_char_ptr_type, None)
2738
        ],
2739
        exception_value=-1)
2740

2741
    Pyx_Py_UNICODE_strlen_func_type = PyrexTypes.CFuncType(
2742
        PyrexTypes.c_py_ssize_t_type, [
2743
            PyrexTypes.CFuncTypeArg("unicode", PyrexTypes.c_const_py_unicode_ptr_type, None)
2744
        ],
2745
        exception_value=-1)
2746

2747
    PyObject_Size_func_type = PyrexTypes.CFuncType(
2748
        PyrexTypes.c_py_ssize_t_type, [
2749
            PyrexTypes.CFuncTypeArg("obj", PyrexTypes.py_object_type, None)
2750
        ],
2751
        exception_value=-1)
2752

2753
    _map_to_capi_len_function = {
2754
        Builtin.unicode_type:    "__Pyx_PyUnicode_GET_LENGTH",
2755
        Builtin.bytes_type:      "__Pyx_PyBytes_GET_SIZE",
2756
        Builtin.bytearray_type:  '__Pyx_PyByteArray_GET_SIZE',
2757
        Builtin.list_type:       "__Pyx_PyList_GET_SIZE",
2758
        Builtin.tuple_type:      "__Pyx_PyTuple_GET_SIZE",
2759
        Builtin.set_type:        "__Pyx_PySet_GET_SIZE",
2760
        Builtin.frozenset_type:  "__Pyx_PySet_GET_SIZE",
2761
        Builtin.dict_type:       "PyDict_Size",
2762
    }.get
2763

2764
    _ext_types_with_pysize = {"cpython.array.array"}
2765

2766
    def _handle_simple_function_len(self, node, function, pos_args):
2767
        """Replace len(char*) by the equivalent call to strlen(),
2768
        len(Py_UNICODE) by the equivalent Py_UNICODE_strlen() and
2769
        len(known_builtin_type) by an equivalent C-API call.
2770
        """
2771
        if len(pos_args) != 1:
2772
            self._error_wrong_arg_count('len', node, pos_args, 1)
2773
            return node
2774
        arg = pos_args[0]
2775
        if isinstance(arg, ExprNodes.CoerceToPyTypeNode):
2776
            arg = arg.arg
2777
        if arg.type.is_string:
2778
            new_node = ExprNodes.PythonCapiCallNode(
2779
                node.pos, "__Pyx_ssize_strlen", self.Pyx_ssize_strlen_func_type,
2780
                args = [arg],
2781
                is_temp = node.is_temp)
2782
        elif arg.type.is_pyunicode_ptr:
2783
            new_node = ExprNodes.PythonCapiCallNode(
2784
                node.pos, "__Pyx_Py_UNICODE_ssize_strlen", self.Pyx_Py_UNICODE_strlen_func_type,
2785
                args = [arg],
2786
                is_temp = node.is_temp,
2787
                utility_code = UtilityCode.load_cached("ssize_pyunicode_strlen", "StringTools.c"))
2788
        elif arg.type.is_memoryviewslice:
2789
            func_type = PyrexTypes.CFuncType(
2790
                PyrexTypes.c_py_ssize_t_type, [
2791
                    PyrexTypes.CFuncTypeArg("memoryviewslice", arg.type, None)
2792
                ], nogil=True)
2793
            new_node = ExprNodes.PythonCapiCallNode(
2794
                node.pos, "__Pyx_MemoryView_Len", func_type,
2795
                args=[arg], is_temp=node.is_temp)
2796
        elif arg.type.is_pyobject:
2797
            cfunc_name = self._map_to_capi_len_function(arg.type)
2798
            if cfunc_name is None:
2799
                arg_type = arg.type
2800
                if ((arg_type.is_extension_type or arg_type.is_builtin_type)
2801
                        and arg_type.entry.qualified_name in self._ext_types_with_pysize):
2802
                    cfunc_name = 'Py_SIZE'
2803
                else:
2804
                    return node
2805
            arg = arg.as_none_safe_node(
2806
                "object of type 'NoneType' has no len()")
2807
            new_node = ExprNodes.PythonCapiCallNode(
2808
                node.pos, cfunc_name, self.PyObject_Size_func_type,
2809
                args=[arg], is_temp=node.is_temp)
2810
        elif arg.type.is_unicode_char:
2811
            return ExprNodes.IntNode(node.pos, value='1', constant_result=1,
2812
                                     type=node.type)
2813
        else:
2814
            return node
2815
        if node.type not in (PyrexTypes.c_size_t_type, PyrexTypes.c_py_ssize_t_type):
2816
            new_node = new_node.coerce_to(node.type, self.current_env())
2817
        return new_node
2818

2819
    Pyx_Type_func_type = PyrexTypes.CFuncType(
2820
        Builtin.type_type, [
2821
            PyrexTypes.CFuncTypeArg("object", PyrexTypes.py_object_type, None)
2822
            ])
2823

2824
    def _handle_simple_function_type(self, node, function, pos_args):
2825
        """Replace type(o) by a macro call to Py_TYPE(o).
2826
        """
2827
        if len(pos_args) != 1:
2828
            return node
2829
        node = ExprNodes.PythonCapiCallNode(
2830
            node.pos, "Py_TYPE", self.Pyx_Type_func_type,
2831
            args = pos_args,
2832
            is_temp = False)
2833
        return ExprNodes.CastNode(node, PyrexTypes.py_object_type)
2834

2835
    Py_type_check_func_type = PyrexTypes.CFuncType(
2836
        PyrexTypes.c_bint_type, [
2837
            PyrexTypes.CFuncTypeArg("arg", PyrexTypes.py_object_type, None)
2838
            ])
2839

2840
    def _handle_simple_function_isinstance(self, node, function, pos_args):
2841
        """Replace isinstance() checks against builtin types by the
2842
        corresponding C-API call.
2843
        """
2844
        if len(pos_args) != 2:
2845
            return node
2846
        arg, types = pos_args
2847
        temps = []
2848
        if isinstance(types, ExprNodes.TupleNode):
2849
            types = types.args
2850
            if len(types) == 1 and not types[0].type is Builtin.type_type:
2851
                return node  # nothing to improve here
2852
            if arg.is_attribute or not arg.is_simple():
2853
                arg = UtilNodes.ResultRefNode(arg)
2854
                temps.append(arg)
2855
        elif types.type is Builtin.type_type:
2856
            types = [types]
2857
        else:
2858
            return node
2859

2860
        tests = []
2861
        test_nodes = []
2862
        env = self.current_env()
2863
        for test_type_node in types:
2864
            builtin_type = None
2865
            if test_type_node.is_name:
2866
                if test_type_node.entry:
2867
                    entry = env.lookup(test_type_node.entry.name)
2868
                    if entry and entry.type and entry.type.is_builtin_type:
2869
                        builtin_type = entry.type
2870
            if builtin_type is Builtin.type_type:
2871
                # all types have type "type", but there's only one 'type'
2872
                if entry.name != 'type' or not (
2873
                        entry.scope and entry.scope.is_builtin_scope):
2874
                    builtin_type = None
2875
            if builtin_type is not None:
2876
                type_check_function = entry.type.type_check_function(exact=False)
2877
                if type_check_function in tests:
2878
                    continue
2879
                tests.append(type_check_function)
2880
                type_check_args = [arg]
2881
            elif test_type_node.type is Builtin.type_type:
2882
                type_check_function = '__Pyx_TypeCheck'
2883
                type_check_args = [arg, test_type_node]
2884
            else:
2885
                if not test_type_node.is_literal:
2886
                    test_type_node = UtilNodes.ResultRefNode(test_type_node)
2887
                    temps.append(test_type_node)
2888
                type_check_function = 'PyObject_IsInstance'
2889
                type_check_args = [arg, test_type_node]
2890
            test_nodes.append(
2891
                ExprNodes.PythonCapiCallNode(
2892
                    test_type_node.pos, type_check_function, self.Py_type_check_func_type,
2893
                    args=type_check_args,
2894
                    is_temp=True,
2895
                ))
2896

2897
        def join_with_or(a, b, make_binop_node=ExprNodes.binop_node):
2898
            or_node = make_binop_node(node.pos, 'or', a, b)
2899
            or_node.type = PyrexTypes.c_bint_type
2900
            or_node.wrap_operands(env)
2901
            return or_node
2902

2903
        test_node = reduce(join_with_or, test_nodes).coerce_to(node.type, env)
2904
        for temp in temps[::-1]:
2905
            test_node = UtilNodes.EvalWithTempExprNode(temp, test_node)
2906
        return test_node
2907

2908
    def _handle_simple_function_ord(self, node, function, pos_args):
2909
        """Unpack ord(Py_UNICODE) and ord('X').
2910
        """
2911
        if len(pos_args) != 1:
2912
            return node
2913
        arg = pos_args[0]
2914
        if isinstance(arg, ExprNodes.CoerceToPyTypeNode):
2915
            if arg.arg.type.is_unicode_char:
2916
                return ExprNodes.TypecastNode(
2917
                    arg.pos, operand=arg.arg, type=PyrexTypes.c_long_type
2918
                    ).coerce_to(node.type, self.current_env())
2919
        elif isinstance(arg, (ExprNodes.UnicodeNode, ExprNodes.BytesNode)):
2920
            if len(arg.value) == 1:
2921
                return ExprNodes.IntNode(
2922
                    arg.pos, type=PyrexTypes.c_int_type,
2923
                    value=str(ord(arg.value)),
2924
                    constant_result=ord(arg.value)
2925
                    ).coerce_to(node.type, self.current_env())
2926
        return node
2927

2928
    ### special methods
2929

2930
    Pyx_tp_new_func_type = PyrexTypes.CFuncType(
2931
        PyrexTypes.py_object_type, [
2932
            PyrexTypes.CFuncTypeArg("type",   PyrexTypes.py_object_type, None),
2933
            PyrexTypes.CFuncTypeArg("args",   Builtin.tuple_type, None),
2934
            ])
2935

2936
    Pyx_tp_new_kwargs_func_type = PyrexTypes.CFuncType(
2937
        PyrexTypes.py_object_type, [
2938
            PyrexTypes.CFuncTypeArg("type",   PyrexTypes.py_object_type, None),
2939
            PyrexTypes.CFuncTypeArg("args",   Builtin.tuple_type, None),
2940
            PyrexTypes.CFuncTypeArg("kwargs", Builtin.dict_type, None),
2941
        ])
2942

2943
    def _handle_any_slot__new__(self, node, function, args,
2944
                                is_unbound_method, kwargs=None):
2945
        """Replace 'exttype.__new__(exttype, ...)' by a call to exttype->tp_new()
2946
        """
2947
        obj = function.obj
2948
        if not is_unbound_method or len(args) < 1:
2949
            return node
2950
        type_arg = args[0]
2951
        if not obj.is_name or not type_arg.is_name:
2952
            return node  # not a simple case
2953
        if obj.type != Builtin.type_type or type_arg.type != Builtin.type_type:
2954
            return node  # not a known type
2955
        if not type_arg.type_entry or not obj.type_entry:
2956
            if obj.name != type_arg.name:
2957
                return node
2958
            # otherwise, we know it's a type and we know it's the same
2959
            # type for both - that should do
2960
        elif type_arg.type_entry != obj.type_entry:
2961
            # different types - may or may not lead to an error at runtime
2962
            return node
2963

2964
        args_tuple = ExprNodes.TupleNode(node.pos, args=args[1:])
2965
        args_tuple = args_tuple.analyse_types(
2966
            self.current_env(), skip_children=True)
2967

2968
        if type_arg.type_entry:
2969
            ext_type = type_arg.type_entry.type
2970
            if (ext_type.is_extension_type and ext_type.typeobj_cname and
2971
                    ext_type.scope.global_scope() == self.current_env().global_scope()):
2972
                # known type in current module
2973
                tp_slot = TypeSlots.ConstructorSlot("tp_new", '__new__')
2974
                slot_func_cname = TypeSlots.get_slot_function(ext_type.scope, tp_slot)
2975
                if slot_func_cname:
2976
                    cython_scope = self.context.cython_scope
2977
                    PyTypeObjectPtr = PyrexTypes.CPtrType(
2978
                        cython_scope.lookup('PyTypeObject').type)
2979
                    pyx_tp_new_kwargs_func_type = PyrexTypes.CFuncType(
2980
                        ext_type, [
2981
                            PyrexTypes.CFuncTypeArg("type",   PyTypeObjectPtr, None),
2982
                            PyrexTypes.CFuncTypeArg("args",   PyrexTypes.py_object_type, None),
2983
                            PyrexTypes.CFuncTypeArg("kwargs", PyrexTypes.py_object_type, None),
2984
                            ])
2985

2986
                    type_arg = ExprNodes.CastNode(type_arg, PyTypeObjectPtr)
2987
                    if not kwargs:
2988
                        kwargs = ExprNodes.NullNode(node.pos, type=PyrexTypes.py_object_type)  # hack?
2989
                    return ExprNodes.PythonCapiCallNode(
2990
                        node.pos, slot_func_cname,
2991
                        pyx_tp_new_kwargs_func_type,
2992
                        args=[type_arg, args_tuple, kwargs],
2993
                        may_return_none=False,
2994
                        is_temp=True)
2995
        else:
2996
            # arbitrary variable, needs a None check for safety
2997
            type_arg = type_arg.as_none_safe_node(
2998
                "object.__new__(X): X is not a type object (NoneType)")
2999

3000
        utility_code = UtilityCode.load_cached('tp_new', 'ObjectHandling.c')
3001
        if kwargs:
3002
            return ExprNodes.PythonCapiCallNode(
3003
                node.pos, "__Pyx_tp_new_kwargs", self.Pyx_tp_new_kwargs_func_type,
3004
                args=[type_arg, args_tuple, kwargs],
3005
                utility_code=utility_code,
3006
                is_temp=node.is_temp
3007
                )
3008
        else:
3009
            return ExprNodes.PythonCapiCallNode(
3010
                node.pos, "__Pyx_tp_new", self.Pyx_tp_new_func_type,
3011
                args=[type_arg, args_tuple],
3012
                utility_code=utility_code,
3013
                is_temp=node.is_temp
3014
            )
3015

3016
    def _handle_any_slot__class__(self, node, function, args,
3017
                                is_unbound_method, kwargs=None):
3018
        # The purpose of this function is to handle calls to instance.__class__() so that
3019
        # it doesn't get handled by the __Pyx_CallUnboundCMethod0 mechanism.
3020
        # TODO: optimizations of the instance.__class__() call might be possible in future.
3021
        return node
3022

3023
    ### methods of builtin types
3024

3025
    PyObject_Append_func_type = PyrexTypes.CFuncType(
3026
        PyrexTypes.c_returncode_type, [
3027
            PyrexTypes.CFuncTypeArg("list", PyrexTypes.py_object_type, None),
3028
            PyrexTypes.CFuncTypeArg("item", PyrexTypes.py_object_type, None),
3029
            ],
3030
        exception_value=-1)
3031

3032
    def _handle_simple_method_object_append(self, node, function, args, is_unbound_method):
3033
        """Optimistic optimisation as X.append() is almost always
3034
        referring to a list.
3035
        """
3036
        if len(args) != 2 or node.result_is_used or node.function.entry:
3037
            return node
3038

3039
        return ExprNodes.PythonCapiCallNode(
3040
            node.pos, "__Pyx_PyObject_Append", self.PyObject_Append_func_type,
3041
            args=args,
3042
            may_return_none=False,
3043
            is_temp=node.is_temp,
3044
            result_is_used=False,
3045
            utility_code=load_c_utility('append')
3046
        )
3047

3048
    def _handle_simple_method_list_extend(self, node, function, args, is_unbound_method):
3049
        """Replace list.extend([...]) for short sequence literals values by sequential appends
3050
        to avoid creating an intermediate sequence argument.
3051
        """
3052
        if len(args) != 2:
3053
            return node
3054
        obj, value = args
3055
        if not value.is_sequence_constructor:
3056
            return node
3057
        items = list(value.args)
3058
        if value.mult_factor is not None or len(items) > 8:
3059
            # Appending wins for short sequences but slows down when multiple resize operations are needed.
3060
            # This seems to be a good enough limit that avoids repeated resizing.
3061
            if False and isinstance(value, ExprNodes.ListNode):
3062
                # One would expect that tuples are more efficient here, but benchmarking with
3063
                # Py3.5 and Py3.7 suggests that they are not. Probably worth revisiting at some point.
3064
                # Might be related to the usage of PySequence_FAST() in CPython's list.extend(),
3065
                # which is probably tuned more towards lists than tuples (and rightly so).
3066
                tuple_node = args[1].as_tuple().analyse_types(self.current_env(), skip_children=True)
3067
                Visitor.recursively_replace_node(node, args[1], tuple_node)
3068
            return node
3069
        wrapped_obj = self._wrap_self_arg(obj, function, is_unbound_method, 'extend')
3070
        if not items:
3071
            # Empty sequences are not likely to occur, but why waste a call to list.extend() for them?
3072
            wrapped_obj.result_is_used = node.result_is_used
3073
            return wrapped_obj
3074
        cloned_obj = obj = wrapped_obj
3075
        if len(items) > 1 and not obj.is_simple():
3076
            cloned_obj = UtilNodes.LetRefNode(obj)
3077
        # Use ListComp_Append() for all but the last item and finish with PyList_Append()
3078
        # to shrink the list storage size at the very end if necessary.
3079
        temps = []
3080
        arg = items[-1]
3081
        if not arg.is_simple():
3082
            arg = UtilNodes.LetRefNode(arg)
3083
            temps.append(arg)
3084
        new_node = ExprNodes.PythonCapiCallNode(
3085
            node.pos, "__Pyx_PyList_Append", self.PyObject_Append_func_type,
3086
            args=[cloned_obj, arg],
3087
            is_temp=True,
3088
            utility_code=load_c_utility("ListAppend"))
3089
        for arg in items[-2::-1]:
3090
            if not arg.is_simple():
3091
                arg = UtilNodes.LetRefNode(arg)
3092
                temps.append(arg)
3093
            new_node = ExprNodes.binop_node(
3094
                node.pos, '|',
3095
                ExprNodes.PythonCapiCallNode(
3096
                    node.pos, "__Pyx_ListComp_Append", self.PyObject_Append_func_type,
3097
                    args=[cloned_obj, arg], py_name="extend",
3098
                    is_temp=True,
3099
                    utility_code=load_c_utility("ListCompAppend")),
3100
                new_node,
3101
                type=PyrexTypes.c_returncode_type,
3102
            )
3103
        new_node.result_is_used = node.result_is_used
3104
        if cloned_obj is not obj:
3105
            temps.append(cloned_obj)
3106
        for temp in temps:
3107
            new_node = UtilNodes.EvalWithTempExprNode(temp, new_node)
3108
            new_node.result_is_used = node.result_is_used
3109
        return new_node
3110

3111
    PyByteArray_Append_func_type = PyrexTypes.CFuncType(
3112
        PyrexTypes.c_returncode_type, [
3113
            PyrexTypes.CFuncTypeArg("bytearray", PyrexTypes.py_object_type, None),
3114
            PyrexTypes.CFuncTypeArg("value", PyrexTypes.c_int_type, None),
3115
            ],
3116
        exception_value=-1)
3117

3118
    PyByteArray_AppendObject_func_type = PyrexTypes.CFuncType(
3119
        PyrexTypes.c_returncode_type, [
3120
            PyrexTypes.CFuncTypeArg("bytearray", PyrexTypes.py_object_type, None),
3121
            PyrexTypes.CFuncTypeArg("value", PyrexTypes.py_object_type, None),
3122
            ],
3123
        exception_value=-1)
3124

3125
    def _handle_simple_method_bytearray_append(self, node, function, args, is_unbound_method):
3126
        if len(args) != 2:
3127
            return node
3128
        func_name = "__Pyx_PyByteArray_Append"
3129
        func_type = self.PyByteArray_Append_func_type
3130

3131
        value = unwrap_coerced_node(args[1])
3132
        if value.type.is_int or isinstance(value, ExprNodes.IntNode):
3133
            value = value.coerce_to(PyrexTypes.c_int_type, self.current_env())
3134
            utility_code = UtilityCode.load_cached("ByteArrayAppend", "StringTools.c")
3135
        elif value.is_string_literal:
3136
            if not value.can_coerce_to_char_literal():
3137
                return node
3138
            value = value.coerce_to(PyrexTypes.c_char_type, self.current_env())
3139
            utility_code = UtilityCode.load_cached("ByteArrayAppend", "StringTools.c")
3140
        elif value.type.is_pyobject:
3141
            func_name = "__Pyx_PyByteArray_AppendObject"
3142
            func_type = self.PyByteArray_AppendObject_func_type
3143
            utility_code = UtilityCode.load_cached("ByteArrayAppendObject", "StringTools.c")
3144
        else:
3145
            return node
3146

3147
        new_node = ExprNodes.PythonCapiCallNode(
3148
            node.pos, func_name, func_type,
3149
            args=[args[0], value],
3150
            may_return_none=False,
3151
            is_temp=node.is_temp,
3152
            utility_code=utility_code,
3153
        )
3154
        if node.result_is_used:
3155
            new_node = new_node.coerce_to(node.type, self.current_env())
3156
        return new_node
3157

3158
    PyObject_Pop_func_type = PyrexTypes.CFuncType(
3159
        PyrexTypes.py_object_type, [
3160
            PyrexTypes.CFuncTypeArg("list", PyrexTypes.py_object_type, None),
3161
            ])
3162

3163
    PyObject_PopIndex_func_type = PyrexTypes.CFuncType(
3164
        PyrexTypes.py_object_type, [
3165
            PyrexTypes.CFuncTypeArg("list", PyrexTypes.py_object_type, None),
3166
            PyrexTypes.CFuncTypeArg("py_index", PyrexTypes.py_object_type, None),
3167
            PyrexTypes.CFuncTypeArg("c_index", PyrexTypes.c_py_ssize_t_type, None),
3168
            PyrexTypes.CFuncTypeArg("is_signed", PyrexTypes.c_int_type, None),
3169
        ],
3170
        has_varargs=True)  # to fake the additional macro args that lack a proper C type
3171

3172
    def _handle_simple_method_list_pop(self, node, function, args, is_unbound_method):
3173
        return self._handle_simple_method_object_pop(
3174
            node, function, args, is_unbound_method, is_list=True)
3175

3176
    def _handle_simple_method_object_pop(self, node, function, args, is_unbound_method, is_list=False):
3177
        """Optimistic optimisation as X.pop([n]) is almost always
3178
        referring to a list.
3179
        """
3180
        if not args:
3181
            return node
3182
        obj = args[0]
3183
        if is_list:
3184
            type_name = 'List'
3185
            obj = obj.as_none_safe_node(
3186
                "'NoneType' object has no attribute '%.30s'",
3187
                error="PyExc_AttributeError",
3188
                format_args=['pop'])
3189
        else:
3190
            type_name = 'Object'
3191
        if len(args) == 1:
3192
            return ExprNodes.PythonCapiCallNode(
3193
                node.pos, "__Pyx_Py%s_Pop" % type_name,
3194
                self.PyObject_Pop_func_type,
3195
                args=[obj],
3196
                may_return_none=True,
3197
                is_temp=node.is_temp,
3198
                utility_code=load_c_utility('pop'),
3199
            )
3200
        elif len(args) == 2:
3201
            index = unwrap_coerced_node(args[1])
3202
            py_index = ExprNodes.NoneNode(index.pos)
3203
            orig_index_type = index.type
3204
            if not index.type.is_int:
3205
                if isinstance(index, ExprNodes.IntNode):
3206
                    py_index = index.coerce_to_pyobject(self.current_env())
3207
                    index = index.coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_env())
3208
                elif is_list:
3209
                    if index.type.is_pyobject:
3210
                        py_index = index.coerce_to_simple(self.current_env())
3211
                        index = ExprNodes.CloneNode(py_index)
3212
                    index = index.coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_env())
3213
                else:
3214
                    return node
3215
            elif not PyrexTypes.numeric_type_fits(index.type, PyrexTypes.c_py_ssize_t_type):
3216
                return node
3217
            elif isinstance(index, ExprNodes.IntNode):
3218
                py_index = index.coerce_to_pyobject(self.current_env())
3219
            # real type might still be larger at runtime
3220
            if not orig_index_type.is_int:
3221
                orig_index_type = index.type
3222
            if not orig_index_type.create_to_py_utility_code(self.current_env()):
3223
                return node
3224
            convert_func = orig_index_type.to_py_function
3225
            conversion_type = PyrexTypes.CFuncType(
3226
                PyrexTypes.py_object_type, [PyrexTypes.CFuncTypeArg("intval", orig_index_type, None)])
3227
            return ExprNodes.PythonCapiCallNode(
3228
                node.pos, "__Pyx_Py%s_PopIndex" % type_name,
3229
                self.PyObject_PopIndex_func_type,
3230
                args=[obj, py_index, index,
3231
                      ExprNodes.IntNode(index.pos, value=str(orig_index_type.signed and 1 or 0),
3232
                                        constant_result=orig_index_type.signed and 1 or 0,
3233
                                        type=PyrexTypes.c_int_type),
3234
                      ExprNodes.RawCNameExprNode(index.pos, PyrexTypes.c_void_type,
3235
                                                 orig_index_type.empty_declaration_code()),
3236
                      ExprNodes.RawCNameExprNode(index.pos, conversion_type, convert_func)],
3237
                may_return_none=True,
3238
                is_temp=node.is_temp,
3239
                utility_code=load_c_utility("pop_index"),
3240
            )
3241

3242
        return node
3243

3244
    single_param_func_type = PyrexTypes.CFuncType(
3245
        PyrexTypes.c_returncode_type, [
3246
            PyrexTypes.CFuncTypeArg("obj", PyrexTypes.py_object_type, None),
3247
            ],
3248
        exception_value=-1)
3249

3250
    def _handle_simple_method_list_sort(self, node, function, args, is_unbound_method):
3251
        """Call PyList_Sort() instead of the 0-argument l.sort().
3252
        """
3253
        if len(args) != 1:
3254
            return node
3255
        return self._substitute_method_call(
3256
            node, function, "PyList_Sort", self.single_param_func_type,
3257
            'sort', is_unbound_method, args).coerce_to(node.type, self.current_env)
3258

3259
    Pyx_PyDict_GetItem_func_type = PyrexTypes.CFuncType(
3260
        PyrexTypes.py_object_type, [
3261
            PyrexTypes.CFuncTypeArg("dict", PyrexTypes.py_object_type, None),
3262
            PyrexTypes.CFuncTypeArg("key", PyrexTypes.py_object_type, None),
3263
            PyrexTypes.CFuncTypeArg("default", PyrexTypes.py_object_type, None),
3264
            ])
3265

3266
    def _handle_simple_method_dict_get(self, node, function, args, is_unbound_method):
3267
        """Replace dict.get() by a call to PyDict_GetItem().
3268
        """
3269
        if len(args) == 2:
3270
            args.append(ExprNodes.NoneNode(node.pos))
3271
        elif len(args) != 3:
3272
            self._error_wrong_arg_count('dict.get', node, args, "2 or 3")
3273
            return node
3274

3275
        return self._substitute_method_call(
3276
            node, function,
3277
            "__Pyx_PyDict_GetItemDefault", self.Pyx_PyDict_GetItem_func_type,
3278
            'get', is_unbound_method, args,
3279
            may_return_none = True,
3280
            utility_code = load_c_utility("dict_getitem_default"))
3281

3282
    Pyx_PyDict_SetDefault_func_type = PyrexTypes.CFuncType(
3283
        PyrexTypes.py_object_type, [
3284
            PyrexTypes.CFuncTypeArg("dict", PyrexTypes.py_object_type, None),
3285
            PyrexTypes.CFuncTypeArg("key", PyrexTypes.py_object_type, None),
3286
            PyrexTypes.CFuncTypeArg("default", PyrexTypes.py_object_type, None),
3287
            PyrexTypes.CFuncTypeArg("is_safe_type", PyrexTypes.c_int_type, None),
3288
            ])
3289

3290
    def _handle_simple_method_dict_setdefault(self, node, function, args, is_unbound_method):
3291
        """Replace dict.setdefault() by calls to PyDict_GetItem() and PyDict_SetItem().
3292
        """
3293
        if len(args) == 2:
3294
            args.append(ExprNodes.NoneNode(node.pos))
3295
        elif len(args) != 3:
3296
            self._error_wrong_arg_count('dict.setdefault', node, args, "2 or 3")
3297
            return node
3298
        key_type = args[1].type
3299
        if key_type.is_builtin_type:
3300
            is_safe_type = int(key_type.name in
3301
                               'str bytes unicode float int long bool')
3302
        elif key_type is PyrexTypes.py_object_type:
3303
            is_safe_type = -1  # don't know
3304
        else:
3305
            is_safe_type = 0   # definitely not
3306
        args.append(ExprNodes.IntNode(
3307
            node.pos, value=str(is_safe_type), constant_result=is_safe_type))
3308

3309
        return self._substitute_method_call(
3310
            node, function,
3311
            "__Pyx_PyDict_SetDefault", self.Pyx_PyDict_SetDefault_func_type,
3312
            'setdefault', is_unbound_method, args,
3313
            may_return_none=True,
3314
            utility_code=load_c_utility('dict_setdefault'))
3315

3316
    PyDict_Pop_func_type = PyrexTypes.CFuncType(
3317
        PyrexTypes.py_object_type, [
3318
            PyrexTypes.CFuncTypeArg("dict", PyrexTypes.py_object_type, None),
3319
            PyrexTypes.CFuncTypeArg("key", PyrexTypes.py_object_type, None),
3320
            PyrexTypes.CFuncTypeArg("default", PyrexTypes.py_object_type, None),
3321
            ])
3322

3323
    PyDict_Pop_ignore_func_type = PyrexTypes.CFuncType(
3324
        PyrexTypes.c_int_type, [
3325
            PyrexTypes.CFuncTypeArg("dict", PyrexTypes.py_object_type, None),
3326
            PyrexTypes.CFuncTypeArg("key", PyrexTypes.py_object_type, None),
3327
            PyrexTypes.CFuncTypeArg("default", PyrexTypes.py_object_type, None),
3328
            ],
3329
            exception_value=PyrexTypes.c_int_type.exception_value,
3330
    )
3331

3332
    def _handle_simple_method_dict_pop(self, node, function, args, is_unbound_method):
3333
        """Replace dict.pop() by a call to _PyDict_Pop().
3334
        """
3335
        capi_func = "__Pyx_PyDict_Pop"
3336
        utility_code_name = 'py_dict_pop'
3337
        func_type = self.PyDict_Pop_func_type
3338

3339
        if len(args) == 2:
3340
            args.append(ExprNodes.NullNode(node.pos))
3341
        elif len(args) == 3:
3342
            if not node.result_is_used:
3343
                # special case: we can ignore the default value
3344
                capi_func = "__Pyx_PyDict_Pop_ignore"
3345
                utility_code_name = 'py_dict_pop_ignore'
3346
                func_type = self.PyDict_Pop_ignore_func_type
3347
        else:
3348
            self._error_wrong_arg_count('dict.pop', node, args, "2 or 3")
3349
            return node
3350

3351
        return self._substitute_method_call(
3352
            node, function,
3353
            capi_func, func_type,
3354
            'pop', is_unbound_method, args,
3355
            may_return_none=True,
3356
            utility_code=load_c_utility(utility_code_name))
3357

3358
    Pyx_BinopInt_func_types = {
3359
        (ctype, ret_type): PyrexTypes.CFuncType(
3360
            ret_type, [
3361
                PyrexTypes.CFuncTypeArg("op1", PyrexTypes.py_object_type, None),
3362
                PyrexTypes.CFuncTypeArg("op2", PyrexTypes.py_object_type, None),
3363
                PyrexTypes.CFuncTypeArg("cval", ctype, None),
3364
                PyrexTypes.CFuncTypeArg("inplace", PyrexTypes.c_bint_type, None),
3365
                PyrexTypes.CFuncTypeArg("zerodiv_check", PyrexTypes.c_bint_type, None),
3366
            ], exception_value=None if ret_type.is_pyobject else ret_type.exception_value)
3367
        for ctype in (PyrexTypes.c_long_type, PyrexTypes.c_double_type)
3368
        for ret_type in (PyrexTypes.py_object_type, PyrexTypes.c_bint_type)
3369
        }
3370

3371
    def _handle_simple_method_object___add__(self, node, function, args, is_unbound_method):
3372
        return self._optimise_num_binop('Add', node, function, args, is_unbound_method)
3373

3374
    def _handle_simple_method_object___sub__(self, node, function, args, is_unbound_method):
3375
        return self._optimise_num_binop('Subtract', node, function, args, is_unbound_method)
3376

3377
    def _handle_simple_method_object___mul__(self, node, function, args, is_unbound_method):
3378
        return self._optimise_num_binop('Multiply', node, function, args, is_unbound_method)
3379

3380
    def _handle_simple_method_object___eq__(self, node, function, args, is_unbound_method):
3381
        return self._optimise_num_binop('Eq', node, function, args, is_unbound_method)
3382

3383
    def _handle_simple_method_object___ne__(self, node, function, args, is_unbound_method):
3384
        return self._optimise_num_binop('Ne', node, function, args, is_unbound_method)
3385

3386
    def _handle_simple_method_object___and__(self, node, function, args, is_unbound_method):
3387
        return self._optimise_num_binop('And', node, function, args, is_unbound_method)
3388

3389
    def _handle_simple_method_object___or__(self, node, function, args, is_unbound_method):
3390
        return self._optimise_num_binop('Or', node, function, args, is_unbound_method)
3391

3392
    def _handle_simple_method_object___xor__(self, node, function, args, is_unbound_method):
3393
        return self._optimise_num_binop('Xor', node, function, args, is_unbound_method)
3394

3395
    def _handle_simple_method_object___rshift__(self, node, function, args, is_unbound_method):
3396
        if len(args) != 2 or not isinstance(args[1], ExprNodes.IntNode):
3397
            return node
3398
        if not args[1].has_constant_result() or not (1 <= args[1].constant_result <= 63):
3399
            return node
3400
        return self._optimise_num_binop('Rshift', node, function, args, is_unbound_method)
3401

3402
    def _handle_simple_method_object___lshift__(self, node, function, args, is_unbound_method):
3403
        if len(args) != 2 or not isinstance(args[1], ExprNodes.IntNode):
3404
            return node
3405
        if not args[1].has_constant_result() or not (1 <= args[1].constant_result <= 63):
3406
            return node
3407
        return self._optimise_num_binop('Lshift', node, function, args, is_unbound_method)
3408

3409
    def _handle_simple_method_object___mod__(self, node, function, args, is_unbound_method):
3410
        return self._optimise_num_div('Remainder', node, function, args, is_unbound_method)
3411

3412
    def _handle_simple_method_object___floordiv__(self, node, function, args, is_unbound_method):
3413
        return self._optimise_num_div('FloorDivide', node, function, args, is_unbound_method)
3414

3415
    def _handle_simple_method_object___truediv__(self, node, function, args, is_unbound_method):
3416
        return self._optimise_num_div('TrueDivide', node, function, args, is_unbound_method)
3417

3418
    def _handle_simple_method_object___div__(self, node, function, args, is_unbound_method):
3419
        return self._optimise_num_div('Divide', node, function, args, is_unbound_method)
3420

3421
    _handle_simple_method_int___add__ = _handle_simple_method_object___add__
3422
    _handle_simple_method_int___sub__ = _handle_simple_method_object___sub__
3423
    _handle_simple_method_int___mul__ = _handle_simple_method_object___mul__
3424
    _handle_simple_method_int___eq__ = _handle_simple_method_object___eq__
3425
    _handle_simple_method_int___ne__ = _handle_simple_method_object___ne__
3426
    _handle_simple_method_int___and__ = _handle_simple_method_object___and__
3427
    _handle_simple_method_int___or__ = _handle_simple_method_object___or__
3428
    _handle_simple_method_int___xor__ = _handle_simple_method_object___xor__
3429
    _handle_simple_method_int___rshift__ = _handle_simple_method_object___rshift__
3430
    _handle_simple_method_int___lshift__ = _handle_simple_method_object___lshift__
3431
    _handle_simple_method_int___mod__ = _handle_simple_method_object___mod__
3432
    _handle_simple_method_int___floordiv__ = _handle_simple_method_object___floordiv__
3433
    _handle_simple_method_int___truediv__ = _handle_simple_method_object___truediv__
3434

3435
    def _optimise_num_div(self, operator, node, function, args, is_unbound_method):
3436
        if len(args) != 2 or not args[1].has_constant_result() or args[1].constant_result == 0:
3437
            return node
3438
        if isinstance(args[1], ExprNodes.IntNode):
3439
            if not (-2**30 <= args[1].constant_result <= 2**30):
3440
                return node
3441
        elif isinstance(args[1], ExprNodes.FloatNode):
3442
            if not (-2**53 <= args[1].constant_result <= 2**53):
3443
                return node
3444
        else:
3445
            return node
3446
        return self._optimise_num_binop(operator, node, function, args, is_unbound_method)
3447

3448
    def _handle_simple_method_float___add__(self, node, function, args, is_unbound_method):
3449
        return self._optimise_num_binop('Add', node, function, args, is_unbound_method)
3450

3451
    def _handle_simple_method_float___sub__(self, node, function, args, is_unbound_method):
3452
        return self._optimise_num_binop('Subtract', node, function, args, is_unbound_method)
3453

3454
    def _handle_simple_method_float___truediv__(self, node, function, args, is_unbound_method):
3455
        return self._optimise_num_binop('TrueDivide', node, function, args, is_unbound_method)
3456

3457
    def _handle_simple_method_float___div__(self, node, function, args, is_unbound_method):
3458
        return self._optimise_num_binop('Divide', node, function, args, is_unbound_method)
3459

3460
    def _handle_simple_method_float___mod__(self, node, function, args, is_unbound_method):
3461
        return self._optimise_num_binop('Remainder', node, function, args, is_unbound_method)
3462

3463
    def _handle_simple_method_float___eq__(self, node, function, args, is_unbound_method):
3464
        return self._optimise_num_binop('Eq', node, function, args, is_unbound_method)
3465

3466
    def _handle_simple_method_float___ne__(self, node, function, args, is_unbound_method):
3467
        return self._optimise_num_binop('Ne', node, function, args, is_unbound_method)
3468

3469
    def _optimise_num_binop(self, operator, node, function, args, is_unbound_method):
3470
        """
3471
        Optimise math operators for (likely) float or small integer operations.
3472
        """
3473
        if getattr(node, "special_bool_cmp_function", None):
3474
            return node  # already optimized
3475

3476
        if len(args) != 2:
3477
            return node
3478

3479
        if node.type.is_pyobject:
3480
            ret_type = PyrexTypes.py_object_type
3481
        elif node.type is PyrexTypes.c_bint_type and operator in ('Eq', 'Ne'):
3482
            ret_type = PyrexTypes.c_bint_type
3483
        else:
3484
            return node
3485

3486
        result = optimise_numeric_binop(operator, node, ret_type, args[0], args[1])
3487
        if not result:
3488
            return node
3489
        func_cname, utility_code, extra_args, num_type = result
3490
        assert all([arg.type.is_pyobject for arg in args])
3491
        args = list(args) + extra_args
3492

3493
        call_node = self._substitute_method_call(
3494
            node, function,
3495
            func_cname,
3496
            self.Pyx_BinopInt_func_types[(num_type, ret_type)],
3497
            '__%s__' % operator[:3].lower(), is_unbound_method, args,
3498
            may_return_none=True,
3499
            with_none_check=False,
3500
            utility_code=utility_code)
3501

3502
        if node.type.is_pyobject and not ret_type.is_pyobject:
3503
            call_node = ExprNodes.CoerceToPyTypeNode(call_node, self.current_env(), node.type)
3504
        return call_node
3505

3506
    ### unicode type methods
3507

3508
    PyUnicode_uchar_predicate_func_type = PyrexTypes.CFuncType(
3509
        PyrexTypes.c_bint_type, [
3510
            PyrexTypes.CFuncTypeArg("uchar", PyrexTypes.c_py_ucs4_type, None),
3511
            ])
3512

3513
    def _inject_unicode_predicate(self, node, function, args, is_unbound_method):
3514
        if is_unbound_method or len(args) != 1:
3515
            return node
3516
        ustring = args[0]
3517
        if not isinstance(ustring, ExprNodes.CoerceToPyTypeNode) or \
3518
               not ustring.arg.type.is_unicode_char:
3519
            return node
3520
        uchar = ustring.arg
3521
        method_name = function.attribute
3522
        if method_name in ('istitle', 'isprintable'):
3523
            # istitle() doesn't directly map to Py_UNICODE_ISTITLE()
3524
            # isprintable() is lacking C-API support in PyPy
3525
            utility_code = UtilityCode.load_cached(
3526
                "py_unicode_%s" % method_name, "StringTools.c")
3527
            function_name = '__Pyx_Py_UNICODE_%s' % method_name.upper()
3528
        else:
3529
            utility_code = None
3530
            function_name = 'Py_UNICODE_%s' % method_name.upper()
3531
        func_call = self._substitute_method_call(
3532
            node, function,
3533
            function_name, self.PyUnicode_uchar_predicate_func_type,
3534
            method_name, is_unbound_method, [uchar],
3535
            utility_code = utility_code)
3536
        if node.type.is_pyobject:
3537
            func_call = func_call.coerce_to_pyobject(self.current_env)
3538
        return func_call
3539

3540
    _handle_simple_method_unicode_isalnum   = _inject_unicode_predicate
3541
    _handle_simple_method_unicode_isalpha   = _inject_unicode_predicate
3542
    _handle_simple_method_unicode_isdecimal = _inject_unicode_predicate
3543
    _handle_simple_method_unicode_isdigit   = _inject_unicode_predicate
3544
    _handle_simple_method_unicode_islower   = _inject_unicode_predicate
3545
    _handle_simple_method_unicode_isnumeric = _inject_unicode_predicate
3546
    _handle_simple_method_unicode_isspace   = _inject_unicode_predicate
3547
    _handle_simple_method_unicode_istitle   = _inject_unicode_predicate
3548
    _handle_simple_method_unicode_isupper   = _inject_unicode_predicate
3549
    _handle_simple_method_unicode_isprintable = _inject_unicode_predicate
3550

3551
    PyUnicode_uchar_conversion_func_type = PyrexTypes.CFuncType(
3552
        PyrexTypes.c_py_ucs4_type, [
3553
            PyrexTypes.CFuncTypeArg("uchar", PyrexTypes.c_py_ucs4_type, None),
3554
            ])
3555

3556
    # DISABLED: Return value can only be one character, which is not correct.
3557
    '''
3558
    def _inject_unicode_character_conversion(self, node, function, args, is_unbound_method):
3559
        if is_unbound_method or len(args) != 1:
3560
            return node
3561
        ustring = args[0]
3562
        if not isinstance(ustring, ExprNodes.CoerceToPyTypeNode) or \
3563
               not ustring.arg.type.is_unicode_char:
3564
            return node
3565
        uchar = ustring.arg
3566
        method_name = function.attribute
3567
        function_name = 'Py_UNICODE_TO%s' % method_name.upper()
3568
        func_call = self._substitute_method_call(
3569
            node, function,
3570
            function_name, self.PyUnicode_uchar_conversion_func_type,
3571
            method_name, is_unbound_method, [uchar])
3572
        if node.type.is_pyobject:
3573
            func_call = func_call.coerce_to_pyobject(self.current_env)
3574
        return func_call
3575

3576
    #_handle_simple_method_unicode_lower = _inject_unicode_character_conversion
3577
    #_handle_simple_method_unicode_upper = _inject_unicode_character_conversion
3578
    #_handle_simple_method_unicode_title = _inject_unicode_character_conversion
3579
    '''
3580

3581
    PyUnicode_Splitlines_func_type = PyrexTypes.CFuncType(
3582
        Builtin.list_type, [
3583
            PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None),
3584
            PyrexTypes.CFuncTypeArg("keepends", PyrexTypes.c_bint_type, None),
3585
            ])
3586

3587
    def _handle_simple_method_unicode_splitlines(self, node, function, args, is_unbound_method):
3588
        """Replace unicode.splitlines(...) by a direct call to the
3589
        corresponding C-API function.
3590
        """
3591
        if len(args) not in (1,2):
3592
            self._error_wrong_arg_count('unicode.splitlines', node, args, "1 or 2")
3593
            return node
3594
        self._inject_bint_default_argument(node, args, 1, False)
3595

3596
        return self._substitute_method_call(
3597
            node, function,
3598
            "PyUnicode_Splitlines", self.PyUnicode_Splitlines_func_type,
3599
            'splitlines', is_unbound_method, args)
3600

3601
    PyUnicode_Split_func_type = PyrexTypes.CFuncType(
3602
        Builtin.list_type, [
3603
            PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None),
3604
            PyrexTypes.CFuncTypeArg("sep", PyrexTypes.py_object_type, None),
3605
            PyrexTypes.CFuncTypeArg("maxsplit", PyrexTypes.c_py_ssize_t_type, None),
3606
            ]
3607
        )
3608

3609
    def _handle_simple_method_unicode_split(self, node, function, args, is_unbound_method):
3610
        """Replace unicode.split(...) by a direct call to the
3611
        corresponding C-API function.
3612
        """
3613
        if len(args) not in (1,2,3):
3614
            self._error_wrong_arg_count('unicode.split', node, args, "1-3")
3615
            return node
3616
        if len(args) < 2:
3617
            args.append(ExprNodes.NullNode(node.pos))
3618
        else:
3619
            self._inject_null_for_none(args, 1)
3620
        self._inject_int_default_argument(
3621
            node, args, 2, PyrexTypes.c_py_ssize_t_type, "-1")
3622

3623
        return self._substitute_method_call(
3624
            node, function,
3625
            "PyUnicode_Split", self.PyUnicode_Split_func_type,
3626
            'split', is_unbound_method, args)
3627

3628
    PyUnicode_Join_func_type = PyrexTypes.CFuncType(
3629
        Builtin.unicode_type, [
3630
            PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None),
3631
            PyrexTypes.CFuncTypeArg("seq", PyrexTypes.py_object_type, None),
3632
            ])
3633

3634
    def _handle_simple_method_unicode_join(self, node, function, args, is_unbound_method):
3635
        """
3636
        unicode.join() builds a list first => see if we can do this more efficiently
3637
        """
3638
        if len(args) != 2:
3639
            self._error_wrong_arg_count('unicode.join', node, args, "2")
3640
            return node
3641
        if isinstance(args[1], ExprNodes.GeneratorExpressionNode):
3642
            gen_expr_node = args[1]
3643
            loop_node = gen_expr_node.loop
3644

3645
            yield_statements = _find_yield_statements(loop_node)
3646
            if yield_statements:
3647
                inlined_genexpr = ExprNodes.InlinedGeneratorExpressionNode(
3648
                    node.pos, gen_expr_node, orig_func='list',
3649
                    comprehension_type=Builtin.list_type)
3650

3651
                for yield_expression, yield_stat_node in yield_statements:
3652
                    append_node = ExprNodes.ComprehensionAppendNode(
3653
                        yield_expression.pos,
3654
                        expr=yield_expression,
3655
                        target=inlined_genexpr.target)
3656

3657
                    Visitor.recursively_replace_node(gen_expr_node, yield_stat_node, append_node)
3658

3659
                args[1] = inlined_genexpr
3660

3661
        return self._substitute_method_call(
3662
            node, function,
3663
            "PyUnicode_Join", self.PyUnicode_Join_func_type,
3664
            'join', is_unbound_method, args)
3665

3666
    PyString_Tailmatch_func_type = PyrexTypes.CFuncType(
3667
        PyrexTypes.c_bint_type, [
3668
            PyrexTypes.CFuncTypeArg("str", PyrexTypes.py_object_type, None),  # bytes/str/unicode
3669
            PyrexTypes.CFuncTypeArg("substring", PyrexTypes.py_object_type, None),
3670
            PyrexTypes.CFuncTypeArg("start", PyrexTypes.c_py_ssize_t_type, None),
3671
            PyrexTypes.CFuncTypeArg("end", PyrexTypes.c_py_ssize_t_type, None),
3672
            PyrexTypes.CFuncTypeArg("direction", PyrexTypes.c_int_type, None),
3673
            ],
3674
        exception_value=-1)
3675

3676
    def _handle_simple_method_unicode_endswith(self, node, function, args, is_unbound_method):
3677
        return self._inject_tailmatch(
3678
            node, function, args, is_unbound_method, 'str', 'endswith',
3679
            unicode_tailmatch_utility_code, +1)
3680

3681
    def _handle_simple_method_unicode_startswith(self, node, function, args, is_unbound_method):
3682
        return self._inject_tailmatch(
3683
            node, function, args, is_unbound_method, 'str', 'startswith',
3684
            unicode_tailmatch_utility_code, -1)
3685

3686
    def _inject_tailmatch(self, node, function, args, is_unbound_method, type_name,
3687
                          method_name, utility_code, direction):
3688
        """Replace unicode.startswith(...) and unicode.endswith(...)
3689
        by a direct call to the corresponding C-API function.
3690
        """
3691
        if len(args) not in (2,3,4):
3692
            self._error_wrong_arg_count(f"{type_name}.{method_name}", node, args, "2-4")
3693
            return node
3694
        self._inject_int_default_argument(
3695
            node, args, 2, PyrexTypes.c_py_ssize_t_type, "0")
3696
        self._inject_int_default_argument(
3697
            node, args, 3, PyrexTypes.c_py_ssize_t_type, "PY_SSIZE_T_MAX")
3698
        args.append(ExprNodes.IntNode(
3699
            node.pos, value=str(direction), type=PyrexTypes.c_int_type))
3700

3701
        if type_name == 'str':
3702
            func_name = "__Pyx_PyUnicode_Tailmatch"
3703
        else:
3704
            func_name = f"__Pyx_Py{type_name.capitalize()}_Tailmatch"
3705

3706
        method_call = self._substitute_method_call(
3707
            node, function,
3708
            func_name, self.PyString_Tailmatch_func_type,
3709
            method_name, is_unbound_method, args,
3710
            utility_code = utility_code)
3711
        return method_call.coerce_to(Builtin.bool_type, self.current_env())
3712

3713
    PyUnicode_Find_func_type = PyrexTypes.CFuncType(
3714
        PyrexTypes.c_py_ssize_t_type, [
3715
            PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None),
3716
            PyrexTypes.CFuncTypeArg("substring", PyrexTypes.py_object_type, None),
3717
            PyrexTypes.CFuncTypeArg("start", PyrexTypes.c_py_ssize_t_type, None),
3718
            PyrexTypes.CFuncTypeArg("end", PyrexTypes.c_py_ssize_t_type, None),
3719
            PyrexTypes.CFuncTypeArg("direction", PyrexTypes.c_int_type, None),
3720
            ],
3721
        exception_value=-2)
3722

3723
    def _handle_simple_method_unicode_find(self, node, function, args, is_unbound_method):
3724
        return self._inject_unicode_find(
3725
            node, function, args, is_unbound_method, 'find', +1)
3726

3727
    def _handle_simple_method_unicode_rfind(self, node, function, args, is_unbound_method):
3728
        return self._inject_unicode_find(
3729
            node, function, args, is_unbound_method, 'rfind', -1)
3730

3731
    def _inject_unicode_find(self, node, function, args, is_unbound_method,
3732
                             method_name, direction):
3733
        """Replace unicode.find(...) and unicode.rfind(...) by a
3734
        direct call to the corresponding C-API function.
3735
        """
3736
        if len(args) not in (2,3,4):
3737
            self._error_wrong_arg_count('unicode.%s' % method_name, node, args, "2-4")
3738
            return node
3739
        self._inject_int_default_argument(
3740
            node, args, 2, PyrexTypes.c_py_ssize_t_type, "0")
3741
        self._inject_int_default_argument(
3742
            node, args, 3, PyrexTypes.c_py_ssize_t_type, "PY_SSIZE_T_MAX")
3743
        args.append(ExprNodes.IntNode(
3744
            node.pos, value=str(direction), type=PyrexTypes.c_int_type))
3745

3746
        method_call = self._substitute_method_call(
3747
            node, function, "PyUnicode_Find", self.PyUnicode_Find_func_type,
3748
            method_name, is_unbound_method, args)
3749
        return method_call.coerce_to_pyobject(self.current_env())
3750

3751
    PyUnicode_Count_func_type = PyrexTypes.CFuncType(
3752
        PyrexTypes.c_py_ssize_t_type, [
3753
            PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None),
3754
            PyrexTypes.CFuncTypeArg("substring", PyrexTypes.py_object_type, None),
3755
            PyrexTypes.CFuncTypeArg("start", PyrexTypes.c_py_ssize_t_type, None),
3756
            PyrexTypes.CFuncTypeArg("end", PyrexTypes.c_py_ssize_t_type, None),
3757
            ],
3758
        exception_value=-1)
3759

3760
    def _handle_simple_method_unicode_count(self, node, function, args, is_unbound_method):
3761
        """Replace unicode.count(...) by a direct call to the
3762
        corresponding C-API function.
3763
        """
3764
        if len(args) not in (2,3,4):
3765
            self._error_wrong_arg_count('unicode.count', node, args, "2-4")
3766
            return node
3767
        self._inject_int_default_argument(
3768
            node, args, 2, PyrexTypes.c_py_ssize_t_type, "0")
3769
        self._inject_int_default_argument(
3770
            node, args, 3, PyrexTypes.c_py_ssize_t_type, "PY_SSIZE_T_MAX")
3771

3772
        method_call = self._substitute_method_call(
3773
            node, function, "PyUnicode_Count", self.PyUnicode_Count_func_type,
3774
            'count', is_unbound_method, args)
3775
        return method_call.coerce_to_pyobject(self.current_env())
3776

3777
    PyUnicode_Replace_func_type = PyrexTypes.CFuncType(
3778
        Builtin.unicode_type, [
3779
            PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None),
3780
            PyrexTypes.CFuncTypeArg("substring", PyrexTypes.py_object_type, None),
3781
            PyrexTypes.CFuncTypeArg("replstr", PyrexTypes.py_object_type, None),
3782
            PyrexTypes.CFuncTypeArg("maxcount", PyrexTypes.c_py_ssize_t_type, None),
3783
            ])
3784

3785
    def _handle_simple_method_unicode_replace(self, node, function, args, is_unbound_method):
3786
        """Replace unicode.replace(...) by a direct call to the
3787
        corresponding C-API function.
3788
        """
3789
        if len(args) not in (3,4):
3790
            self._error_wrong_arg_count('unicode.replace', node, args, "3-4")
3791
            return node
3792
        self._inject_int_default_argument(
3793
            node, args, 3, PyrexTypes.c_py_ssize_t_type, "-1")
3794

3795
        return self._substitute_method_call(
3796
            node, function, "PyUnicode_Replace", self.PyUnicode_Replace_func_type,
3797
            'replace', is_unbound_method, args)
3798

3799
    PyUnicode_AsEncodedString_func_type = PyrexTypes.CFuncType(
3800
        Builtin.bytes_type, [
3801
            PyrexTypes.CFuncTypeArg("obj", Builtin.unicode_type, None),
3802
            PyrexTypes.CFuncTypeArg("encoding", PyrexTypes.c_const_char_ptr_type, None),
3803
            PyrexTypes.CFuncTypeArg("errors", PyrexTypes.c_const_char_ptr_type, None),
3804
            ])
3805

3806
    PyUnicode_AsXyzString_func_type = PyrexTypes.CFuncType(
3807
        Builtin.bytes_type, [
3808
            PyrexTypes.CFuncTypeArg("obj", Builtin.unicode_type, None),
3809
            ])
3810

3811
    _special_encodings = ['UTF8', 'UTF16', 'UTF-16LE', 'UTF-16BE', 'Latin1', 'ASCII',
3812
                          'unicode_escape', 'raw_unicode_escape']
3813

3814
    _special_codecs = [ (name, codecs.getencoder(name))
3815
                        for name in _special_encodings ]
3816

3817
    def _handle_simple_method_unicode_encode(self, node, function, args, is_unbound_method):
3818
        """Replace unicode.encode(...) by a direct C-API call to the
3819
        corresponding codec.
3820
        """
3821
        if len(args) < 1 or len(args) > 3:
3822
            self._error_wrong_arg_count('unicode.encode', node, args, '1-3')
3823
            return node
3824

3825
        string_node = args[0]
3826

3827
        parameters = self._unpack_encoding_and_error_mode(node.pos, args)
3828
        if parameters is None:
3829
            return node
3830
        encoding, encoding_node, error_handling, error_handling_node = parameters
3831

3832
        if string_node.has_constant_result():
3833
            # constant, so try to do the encoding at compile time
3834
            try:
3835
                value = string_node.constant_result.encode(encoding, error_handling)
3836
            except:
3837
                # well, looks like we can't
3838
                pass
3839
            else:
3840
                value = bytes_literal(value, encoding or 'UTF-8')
3841
                return ExprNodes.BytesNode(string_node.pos, value=value, type=Builtin.bytes_type)
3842

3843
        if len(args) == 1:
3844
            null_node = ExprNodes.NullNode(node.pos)
3845
            return self._substitute_method_call(
3846
                node, function, "PyUnicode_AsEncodedString",
3847
                self.PyUnicode_AsEncodedString_func_type,
3848
                'encode', is_unbound_method, [string_node, null_node, null_node])
3849

3850
        if encoding and error_handling == 'strict':
3851
            # try to find a specific encoder function
3852
            codec_name = self._find_special_codec_name(encoding)
3853
            if codec_name is not None and '-' not in codec_name:
3854
                encode_function = "PyUnicode_As%sString" % codec_name
3855
                return self._substitute_method_call(
3856
                    node, function, encode_function,
3857
                    self.PyUnicode_AsXyzString_func_type,
3858
                    'encode', is_unbound_method, [string_node])
3859

3860
        return self._substitute_method_call(
3861
            node, function, "PyUnicode_AsEncodedString",
3862
            self.PyUnicode_AsEncodedString_func_type,
3863
            'encode', is_unbound_method,
3864
            [string_node, encoding_node, error_handling_node])
3865

3866
    PyUnicode_DecodeXyz_func_ptr_type = PyrexTypes.CPtrType(PyrexTypes.CFuncType(
3867
        Builtin.unicode_type, [
3868
            PyrexTypes.CFuncTypeArg("string", PyrexTypes.c_const_char_ptr_type, None),
3869
            PyrexTypes.CFuncTypeArg("size", PyrexTypes.c_py_ssize_t_type, None),
3870
            PyrexTypes.CFuncTypeArg("errors", PyrexTypes.c_const_char_ptr_type, None),
3871
        ]))
3872

3873
    _decode_c_string_func_type = PyrexTypes.CFuncType(
3874
        Builtin.unicode_type, [
3875
            PyrexTypes.CFuncTypeArg("string", PyrexTypes.c_const_char_ptr_type, None),
3876
            PyrexTypes.CFuncTypeArg("start", PyrexTypes.c_py_ssize_t_type, None),
3877
            PyrexTypes.CFuncTypeArg("stop", PyrexTypes.c_py_ssize_t_type, None),
3878
            PyrexTypes.CFuncTypeArg("encoding", PyrexTypes.c_const_char_ptr_type, None),
3879
            PyrexTypes.CFuncTypeArg("errors", PyrexTypes.c_const_char_ptr_type, None),
3880
            PyrexTypes.CFuncTypeArg("decode_func", PyUnicode_DecodeXyz_func_ptr_type, None),
3881
        ])
3882

3883
    _decode_bytes_func_type = PyrexTypes.CFuncType(
3884
        Builtin.unicode_type, [
3885
            PyrexTypes.CFuncTypeArg("string", PyrexTypes.py_object_type, None),
3886
            PyrexTypes.CFuncTypeArg("start", PyrexTypes.c_py_ssize_t_type, None),
3887
            PyrexTypes.CFuncTypeArg("stop", PyrexTypes.c_py_ssize_t_type, None),
3888
            PyrexTypes.CFuncTypeArg("encoding", PyrexTypes.c_const_char_ptr_type, None),
3889
            PyrexTypes.CFuncTypeArg("errors", PyrexTypes.c_const_char_ptr_type, None),
3890
            PyrexTypes.CFuncTypeArg("decode_func", PyUnicode_DecodeXyz_func_ptr_type, None),
3891
        ])
3892

3893
    _decode_cpp_string_func_type = None  # lazy init
3894

3895
    def _handle_simple_method_bytes_decode(self, node, function, args, is_unbound_method):
3896
        """Replace char*.decode() by a direct C-API call to the
3897
        corresponding codec, possibly resolving a slice on the char*.
3898
        """
3899
        if not (1 <= len(args) <= 3):
3900
            self._error_wrong_arg_count('bytes.decode', node, args, '1-3')
3901
            return node
3902

3903
        # Try to extract encoding parameters and attempt constant decode.
3904
        string_node = args[0]
3905
        parameters = self._unpack_encoding_and_error_mode(node.pos, args)
3906
        if parameters is None:
3907
            return node
3908
        encoding, encoding_node, error_handling, error_handling_node = parameters
3909

3910
        if string_node.has_constant_result():
3911
            try:
3912
                constant_result = string_node.constant_result.decode(encoding, error_handling)
3913
            except (AttributeError, ValueError, UnicodeDecodeError):
3914
                pass
3915
            else:
3916
                return UnicodeNode(
3917
                    string_node.pos,
3918
                    value=EncodedString(constant_result),
3919
                    bytes_value=string_node.constant_result,
3920
                )
3921

3922
        # normalise input nodes
3923
        start = stop = None
3924
        if isinstance(string_node, ExprNodes.SliceIndexNode):
3925
            index_node = string_node
3926
            string_node = index_node.base
3927
            start, stop = index_node.start, index_node.stop
3928
            if not start or start.constant_result == 0:
3929
                start = None
3930
        if isinstance(string_node, ExprNodes.CoerceToPyTypeNode):
3931
            string_node = string_node.arg
3932

3933
        string_type = string_node.type
3934
        if string_type in (Builtin.bytes_type, Builtin.bytearray_type):
3935
            if is_unbound_method:
3936
                string_node = string_node.as_none_safe_node(
3937
                    "descriptor '%s' requires a '%s' object but received a 'NoneType'",
3938
                    format_args=['decode', string_type.name])
3939
            else:
3940
                string_node = string_node.as_none_safe_node(
3941
                    "'NoneType' object has no attribute '%.30s'",
3942
                    error="PyExc_AttributeError",
3943
                    format_args=['decode'])
3944
        elif not string_type.is_string and not string_type.is_cpp_string:
3945
            # nothing to optimise here
3946
            return node
3947

3948
        if not start:
3949
            start = ExprNodes.IntNode(node.pos, value='0', constant_result=0)
3950
        elif not start.type.is_int:
3951
            start = start.coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_env())
3952
        if stop and not stop.type.is_int:
3953
            stop = stop.coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_env())
3954

3955
        # try to find a specific encoder function
3956
        codec_name = None
3957
        if encoding is not None:
3958
            codec_name = self._find_special_codec_name(encoding)
3959
        if codec_name is not None:
3960
            if codec_name in ('UTF16', 'UTF-16LE', 'UTF-16BE'):
3961
                codec_cname = "__Pyx_PyUnicode_Decode%s" % codec_name.replace('-', '')
3962
            else:
3963
                codec_cname = "PyUnicode_Decode%s" % codec_name
3964
            decode_function = ExprNodes.RawCNameExprNode(
3965
                node.pos, type=self.PyUnicode_DecodeXyz_func_ptr_type, cname=codec_cname)
3966
            encoding_node = ExprNodes.NullNode(node.pos)
3967
        else:
3968
            decode_function = ExprNodes.NullNode(node.pos)
3969

3970
        # build the helper function call
3971
        temps = []
3972
        if string_type.is_string:
3973
            # C string
3974
            if not stop:
3975
                # use strlen() to find the string length, just as CPython would
3976
                if not string_node.is_name:
3977
                    string_node = UtilNodes.LetRefNode(string_node)  # used twice
3978
                    temps.append(string_node)
3979
                stop = ExprNodes.PythonCapiCallNode(
3980
                    string_node.pos, "__Pyx_ssize_strlen", self.Pyx_ssize_strlen_func_type,
3981
                    args=[string_node],
3982
                    is_temp=True,
3983
                )
3984
            helper_func_type = self._decode_c_string_func_type
3985
            utility_code_name = 'decode_c_string'
3986
        elif string_type.is_cpp_string:
3987
            # C++ std::string
3988
            if not stop:
3989
                stop = ExprNodes.IntNode(node.pos, value='PY_SSIZE_T_MAX',
3990
                                         constant_result=ExprNodes.not_a_constant)
3991
            if self._decode_cpp_string_func_type is None:
3992
                # lazy init to reuse the C++ string type
3993
                self._decode_cpp_string_func_type = PyrexTypes.CFuncType(
3994
                    Builtin.unicode_type, [
3995
                        PyrexTypes.CFuncTypeArg("string", string_type, None),
3996
                        PyrexTypes.CFuncTypeArg("start", PyrexTypes.c_py_ssize_t_type, None),
3997
                        PyrexTypes.CFuncTypeArg("stop", PyrexTypes.c_py_ssize_t_type, None),
3998
                        PyrexTypes.CFuncTypeArg("encoding", PyrexTypes.c_const_char_ptr_type, None),
3999
                        PyrexTypes.CFuncTypeArg("errors", PyrexTypes.c_const_char_ptr_type, None),
4000
                        PyrexTypes.CFuncTypeArg("decode_func", self.PyUnicode_DecodeXyz_func_ptr_type, None),
4001
                    ])
4002
            helper_func_type = self._decode_cpp_string_func_type
4003
            utility_code_name = 'decode_cpp_string'
4004
        else:
4005
            # Python bytes/bytearray object
4006
            if not stop:
4007
                stop = ExprNodes.IntNode(node.pos, value='PY_SSIZE_T_MAX',
4008
                                         constant_result=ExprNodes.not_a_constant)
4009
            helper_func_type = self._decode_bytes_func_type
4010
            if string_type is Builtin.bytes_type:
4011
                utility_code_name = 'decode_bytes'
4012
            else:
4013
                utility_code_name = 'decode_bytearray'
4014

4015
        node = ExprNodes.PythonCapiCallNode(
4016
            node.pos, '__Pyx_%s' % utility_code_name, helper_func_type,
4017
            args=[string_node, start, stop, encoding_node, error_handling_node, decode_function],
4018
            is_temp=node.is_temp,
4019
            utility_code=UtilityCode.load_cached(utility_code_name, 'StringTools.c'),
4020
        )
4021

4022
        for temp in temps[::-1]:
4023
            node = UtilNodes.EvalWithTempExprNode(temp, node)
4024
        return node
4025

4026
    _handle_simple_method_bytearray_decode = _handle_simple_method_bytes_decode
4027

4028
    def _find_special_codec_name(self, encoding):
4029
        try:
4030
            requested_codec = codecs.getencoder(encoding)
4031
        except LookupError:
4032
            return None
4033
        for name, codec in self._special_codecs:
4034
            if codec == requested_codec:
4035
                if '_' in name:
4036
                    name = ''.join([s.capitalize()
4037
                                    for s in name.split('_')])
4038
                return name
4039
        return None
4040

4041
    def _unpack_encoding_and_error_mode(self, pos, args):
4042
        null_node = ExprNodes.NullNode(pos)
4043

4044
        if len(args) >= 2:
4045
            encoding, encoding_node = self._unpack_string_and_cstring_node(args[1])
4046
            if encoding_node is None:
4047
                return None
4048
        else:
4049
            encoding = None
4050
            encoding_node = null_node
4051

4052
        if len(args) == 3:
4053
            error_handling, error_handling_node = self._unpack_string_and_cstring_node(args[2])
4054
            if error_handling_node is None:
4055
                return None
4056
            if error_handling == 'strict':
4057
                error_handling_node = null_node
4058
        else:
4059
            error_handling = 'strict'
4060
            error_handling_node = null_node
4061

4062
        return (encoding, encoding_node, error_handling, error_handling_node)
4063

4064
    def _unpack_string_and_cstring_node(self, node):
4065
        if isinstance(node, ExprNodes.CoerceToPyTypeNode):
4066
            node = node.arg
4067
        if isinstance(node, ExprNodes.UnicodeNode):
4068
            encoding = node.value
4069
            node = ExprNodes.BytesNode(
4070
                node.pos, value=encoding.as_utf8_string(), type=PyrexTypes.c_const_char_ptr_type)
4071
        elif isinstance(node, ExprNodes.BytesNode):
4072
            encoding = node.value.decode('ISO-8859-1')
4073
            node = ExprNodes.BytesNode(
4074
                node.pos, value=node.value, type=PyrexTypes.c_const_char_ptr_type)
4075
        elif node.type is Builtin.bytes_type:
4076
            encoding = None
4077
            node = node.coerce_to(PyrexTypes.c_const_char_ptr_type, self.current_env())
4078
        elif node.type.is_string:
4079
            encoding = None
4080
        else:
4081
            encoding = node = None
4082
        return encoding, node
4083

4084
    def _handle_simple_method_bytes_endswith(self, node, function, args, is_unbound_method):
4085
        return self._inject_tailmatch(
4086
            node, function, args, is_unbound_method, 'bytes', 'endswith',
4087
            bytes_tailmatch_utility_code, +1)
4088

4089
    def _handle_simple_method_bytes_startswith(self, node, function, args, is_unbound_method):
4090
        return self._inject_tailmatch(
4091
            node, function, args, is_unbound_method, 'bytes', 'startswith',
4092
            bytes_tailmatch_utility_code, -1)
4093

4094
    '''   # disabled for now, enable when we consider it worth it (see StringTools.c)
4095
    def _handle_simple_method_bytearray_endswith(self, node, function, args, is_unbound_method):
4096
        return self._inject_tailmatch(
4097
            node, function, args, is_unbound_method, 'bytearray', 'endswith',
4098
            bytes_tailmatch_utility_code, +1)
4099

4100
    def _handle_simple_method_bytearray_startswith(self, node, function, args, is_unbound_method):
4101
        return self._inject_tailmatch(
4102
            node, function, args, is_unbound_method, 'bytearray', 'startswith',
4103
            bytes_tailmatch_utility_code, -1)
4104
    '''
4105

4106
    ### helpers
4107

4108
    def _substitute_method_call(self, node, function, name, func_type,
4109
                                attr_name, is_unbound_method, args=(),
4110
                                utility_code=None, is_temp=None,
4111
                                may_return_none=ExprNodes.PythonCapiCallNode.may_return_none,
4112
                                with_none_check=True):
4113
        args = list(args)
4114
        if with_none_check and args:
4115
            args[0] = self._wrap_self_arg(args[0], function, is_unbound_method, attr_name)
4116
        if is_temp is None:
4117
            is_temp = node.is_temp
4118
        return ExprNodes.PythonCapiCallNode(
4119
            node.pos, name, func_type,
4120
            args = args,
4121
            is_temp = is_temp,
4122
            utility_code = utility_code,
4123
            may_return_none = may_return_none,
4124
            result_is_used = node.result_is_used,
4125
            )
4126

4127
    def _wrap_self_arg(self, self_arg, function, is_unbound_method, attr_name):
4128
        if self_arg.is_literal:
4129
            return self_arg
4130
        if is_unbound_method:
4131
            self_arg = self_arg.as_none_safe_node(
4132
                "descriptor '%s' requires a '%s' object but received a 'NoneType'",
4133
                format_args=[attr_name, self_arg.type.name])
4134
        else:
4135
            self_arg = self_arg.as_none_safe_node(
4136
                "'NoneType' object has no attribute '%{}s'".format('.30' if len(attr_name) <= 30 else ''),
4137
                error="PyExc_AttributeError",
4138
                format_args=[attr_name])
4139
        return self_arg
4140

4141
    obj_to_obj_func_type = PyrexTypes.CFuncType(
4142
        PyrexTypes.py_object_type, [
4143
            PyrexTypes.CFuncTypeArg("obj", PyrexTypes.py_object_type, None)
4144
        ])
4145

4146
    def _inject_null_for_none(self, args, index):
4147
        if len(args) <= index:
4148
            return
4149
        arg = args[index]
4150
        args[index] = ExprNodes.NullNode(arg.pos) if arg.is_none else ExprNodes.PythonCapiCallNode(
4151
            arg.pos, "__Pyx_NoneAsNull",
4152
            self.obj_to_obj_func_type,
4153
            args=[arg.coerce_to_simple(self.current_env())],
4154
            is_temp=0,
4155
        )
4156

4157
    def _inject_int_default_argument(self, node, args, arg_index, type, default_value):
4158
        # Python usually allows passing None for range bounds,
4159
        # so we treat that as requesting the default.
4160
        assert len(args) >= arg_index
4161
        if len(args) == arg_index or args[arg_index].is_none:
4162
            args.append(ExprNodes.IntNode(node.pos, value=str(default_value),
4163
                                          type=type, constant_result=default_value))
4164
        else:
4165
            arg = args[arg_index].coerce_to(type, self.current_env())
4166
            if isinstance(arg, ExprNodes.CoerceFromPyTypeNode):
4167
                # Add a runtime check for None and map it to the default value.
4168
                arg.special_none_cvalue = str(default_value)
4169
            args[arg_index] = arg
4170

4171
    def _inject_bint_default_argument(self, node, args, arg_index, default_value):
4172
        assert len(args) >= arg_index
4173
        if len(args) == arg_index:
4174
            default_value = bool(default_value)
4175
            args.append(ExprNodes.BoolNode(node.pos, value=default_value,
4176
                                           constant_result=default_value))
4177
        else:
4178
            args[arg_index] = args[arg_index].coerce_to_boolean(self.current_env())
4179

4180

4181
def optimise_numeric_binop(operator, node, ret_type, arg0, arg1):
4182
    """
4183
    Optimise math operators for (likely) float or small integer operations.
4184
    """
4185
    # When adding IntNode/FloatNode to something else, assume other operand is also numeric.
4186
    # Prefer constants on RHS as they allows better size control for some operators.
4187
    num_nodes = (ExprNodes.IntNode, ExprNodes.FloatNode)
4188
    if isinstance(arg1, num_nodes):
4189
        if arg0.type is not PyrexTypes.py_object_type and arg0.type is not Builtin.int_type:
4190
            return None
4191
        numval = arg1
4192
        arg_order = 'ObjC'
4193
    elif isinstance(arg0, num_nodes):
4194
        if arg1.type is not PyrexTypes.py_object_type and arg1.type is not Builtin.int_type:
4195
            return None
4196
        numval = arg0
4197
        arg_order = 'CObj'
4198
    else:
4199
        return None
4200

4201
    if not numval.has_constant_result():
4202
        return None
4203

4204
    # is_float is an instance check rather that numval.type.is_float because
4205
    # it will often be a Python float type rather than a C float type
4206
    is_float = isinstance(numval, ExprNodes.FloatNode)
4207
    num_type = PyrexTypes.c_double_type if is_float else PyrexTypes.c_long_type
4208
    if is_float:
4209
        if operator not in ('Add', 'Subtract', 'Remainder', 'TrueDivide', 'Divide', 'Eq', 'Ne'):
4210
            return None
4211
    elif operator == 'Divide':
4212
        # mixed old-/new-style division is not currently optimised for integers
4213
        return None
4214
    elif abs(numval.constant_result) > 2**30:
4215
        # Cut off at an integer border that is still safe for all operations.
4216
        return None
4217

4218
    if operator in ('TrueDivide', 'FloorDivide', 'Divide', 'Remainder'):
4219
        if arg1.constant_result == 0:
4220
            # Don't optimise division by 0. :)
4221
            return None
4222

4223
    extra_args = []
4224

4225
    extra_args.append((ExprNodes.FloatNode if is_float else ExprNodes.IntNode)(
4226
        numval.pos, value=numval.value, constant_result=numval.constant_result,
4227
        type=num_type))
4228
    inplace = node.inplace if isinstance(node, ExprNodes.NumBinopNode) else False
4229
    extra_args.append(ExprNodes.BoolNode(node.pos, value=inplace, constant_result=inplace))
4230
    if is_float or operator not in ('Eq', 'Ne'):
4231
        # "PyFloatBinop" and "PyLongBinop" take an additional "check for zero division" argument.
4232
        zerodivision_check = arg_order == 'CObj' and (
4233
            not node.cdivision if isinstance(node, ExprNodes.DivNode) else False)
4234
        extra_args.append(ExprNodes.BoolNode(node.pos, value=zerodivision_check, constant_result=zerodivision_check))
4235

4236
    utility_code = TempitaUtilityCode.load_cached(
4237
        "PyFloatBinop" if is_float else "PyLongCompare" if operator in ('Eq', 'Ne') else "PyLongBinop",
4238
        "Optimize.c",
4239
        context=dict(op=operator, order=arg_order, ret_type=ret_type))
4240

4241
    func_cname = "__Pyx_Py%s_%s%s%s" % (
4242
        'Float' if is_float else 'Long',
4243
        '' if ret_type.is_pyobject else 'Bool',
4244
        operator,
4245
        arg_order)
4246

4247
    return func_cname, utility_code, extra_args, num_type
4248

4249

4250
unicode_tailmatch_utility_code = UtilityCode.load_cached('unicode_tailmatch', 'StringTools.c')
4251
bytes_tailmatch_utility_code = UtilityCode.load_cached('bytes_tailmatch', 'StringTools.c')
4252

4253

4254
class ConstantFolding(Visitor.VisitorTransform, SkipDeclarations):
4255
    """Calculate the result of constant expressions to store it in
4256
    ``expr_node.constant_result``, and replace trivial cases by their
4257
    constant result.
4258

4259
    General rules:
4260

4261
    - We calculate float constants to make them available to the
4262
      compiler, but we do not aggregate them into a single literal
4263
      node to prevent any loss of precision.
4264

4265
    - We recursively calculate constants from non-literal nodes to
4266
      make them available to the compiler, but we only aggregate
4267
      literal nodes at each step.  Non-literal nodes are never merged
4268
      into a single node.
4269
    """
4270

4271
    def __init__(self, reevaluate=False):
4272
        """
4273
        The reevaluate argument specifies whether constant values that were
4274
        previously computed should be recomputed.
4275
        """
4276
        super().__init__()
4277
        self.reevaluate = reevaluate
4278

4279
    def _calculate_const(self, node):
4280
        if (not self.reevaluate and
4281
                node.constant_result is not ExprNodes.constant_value_not_set):
4282
            return
4283

4284
        # make sure we always set the value
4285
        not_a_constant = ExprNodes.not_a_constant
4286
        node.constant_result = not_a_constant
4287

4288
        # check if all children are constant
4289
        children = self.visitchildren(node)
4290
        for child_result in children.values():
4291
            if type(child_result) is list:
4292
                for child in child_result:
4293
                    if getattr(child, 'constant_result', not_a_constant) is not_a_constant:
4294
                        return
4295
            elif getattr(child_result, 'constant_result', not_a_constant) is not_a_constant:
4296
                return
4297

4298
        # now try to calculate the real constant value
4299
        try:
4300
            node.calculate_constant_result()
4301
#            if node.constant_result is not ExprNodes.not_a_constant:
4302
#                print node.__class__.__name__, node.constant_result
4303
        except (ValueError, TypeError, KeyError, IndexError, AttributeError, ArithmeticError):
4304
            # ignore all 'normal' errors here => no constant result
4305
            pass
4306
        except Exception:
4307
            # this looks like a real error
4308
            import traceback, sys
4309
            traceback.print_exc(file=sys.stdout)
4310

4311
    NODE_TYPE_ORDER = [ExprNodes.BoolNode, ExprNodes.CharNode,
4312
                       ExprNodes.IntNode, ExprNodes.FloatNode]
4313

4314
    def _widest_node_class(self, *nodes):
4315
        try:
4316
            return self.NODE_TYPE_ORDER[
4317
                max(map(self.NODE_TYPE_ORDER.index, map(type, nodes)))]
4318
        except ValueError:
4319
            return None
4320

4321
    def _bool_node(self, node, value):
4322
        value = bool(value)
4323
        return ExprNodes.BoolNode(node.pos, value=value, constant_result=value)
4324

4325
    def visit_ExprNode(self, node):
4326
        self._calculate_const(node)
4327
        return node
4328

4329
    def visit_UnopNode(self, node):
4330
        self._calculate_const(node)
4331
        if not node.has_constant_result():
4332
            if node.operator == '!':
4333
                return self._handle_NotNode(node)
4334
            return node
4335
        if not node.operand.is_literal:
4336
            return node
4337
        if node.operator == '!':
4338
            return self._bool_node(node, node.constant_result)
4339
        elif isinstance(node.operand, ExprNodes.BoolNode):
4340
            return ExprNodes.IntNode(node.pos, value=str(int(node.constant_result)),
4341
                                     type=PyrexTypes.c_int_type,
4342
                                     constant_result=int(node.constant_result))
4343
        elif node.operator == '+':
4344
            return self._handle_UnaryPlusNode(node)
4345
        elif node.operator == '-':
4346
            return self._handle_UnaryMinusNode(node)
4347
        return node
4348

4349
    _negate_operator = {
4350
        'in': 'not_in',
4351
        'not_in': 'in',
4352
        'is': 'is_not',
4353
        'is_not': 'is'
4354
    }.get
4355

4356
    def _handle_NotNode(self, node):
4357
        operand = node.operand
4358
        if isinstance(operand, ExprNodes.PrimaryCmpNode):
4359
            operator = self._negate_operator(operand.operator)
4360
            if operator:
4361
                node = copy.copy(operand)
4362
                node.operator = operator
4363
                node = self.visit_PrimaryCmpNode(node)
4364
        return node
4365

4366
    def _handle_UnaryMinusNode(self, node):
4367
        def _negate(value):
4368
            if value.startswith('-'):
4369
                value = value[1:]
4370
            else:
4371
                value = '-' + value
4372
            return value
4373

4374
        node_type = node.operand.type
4375
        if isinstance(node.operand, ExprNodes.FloatNode):
4376
            # this is a safe operation
4377
            return ExprNodes.FloatNode(node.pos, value=_negate(node.operand.value),
4378
                                       type=node_type,
4379
                                       constant_result=node.constant_result)
4380
        if node_type.is_int and node_type.signed or \
4381
                isinstance(node.operand, ExprNodes.IntNode) and node_type.is_pyobject:
4382
            return ExprNodes.IntNode(node.pos, value=_negate(node.operand.value),
4383
                                     type=node_type,
4384
                                     longness=node.operand.longness,
4385
                                     constant_result=node.constant_result)
4386
        return node
4387

4388
    def _handle_UnaryPlusNode(self, node):
4389
        if (node.operand.has_constant_result() and
4390
                    node.constant_result == node.operand.constant_result):
4391
            return node.operand
4392
        return node
4393

4394
    def visit_BoolBinopNode(self, node):
4395
        self._calculate_const(node)
4396
        if not node.operand1.has_constant_result():
4397
            return node
4398
        if node.operand1.constant_result:
4399
            if node.operator == 'and':
4400
                return node.operand2
4401
            else:
4402
                return node.operand1
4403
        else:
4404
            if node.operator == 'and':
4405
                return node.operand1
4406
            else:
4407
                return node.operand2
4408

4409
    def visit_BinopNode(self, node):
4410
        self._calculate_const(node)
4411
        if node.constant_result is ExprNodes.not_a_constant:
4412
            return node
4413
        if isinstance(node.constant_result, float):
4414
            return node
4415
        operand1, operand2 = node.operand1, node.operand2
4416
        if not operand1.is_literal or not operand2.is_literal:
4417
            return node
4418

4419
        # now inject a new constant node with the calculated value
4420
        try:
4421
            type1, type2 = operand1.type, operand2.type
4422
            if type1 is None or type2 is None:
4423
                return node
4424
        except AttributeError:
4425
            return node
4426

4427
        if type1.is_numeric and type2.is_numeric:
4428
            widest_type = PyrexTypes.widest_numeric_type(type1, type2)
4429
        else:
4430
            widest_type = PyrexTypes.py_object_type
4431

4432
        target_class = self._widest_node_class(operand1, operand2)
4433
        if target_class is None:
4434
            return node
4435
        elif target_class is ExprNodes.BoolNode and node.operator in '+-//<<%**>>':
4436
            # C arithmetic results in at least an int type
4437
            target_class = ExprNodes.IntNode
4438
        elif target_class is ExprNodes.CharNode and node.operator in '+-//<<%**>>&|^':
4439
            # C arithmetic results in at least an int type
4440
            target_class = ExprNodes.IntNode
4441

4442
        if target_class is ExprNodes.IntNode:
4443
            unsigned = getattr(operand1, 'unsigned', '') and \
4444
                       getattr(operand2, 'unsigned', '')
4445
            longness = "LL"[:max(len(getattr(operand1, 'longness', '')),
4446
                                 len(getattr(operand2, 'longness', '')))]
4447
            value = hex(int(node.constant_result))
4448
            value = Utils.strip_py2_long_suffix(value)
4449
            new_node = ExprNodes.IntNode(pos=node.pos,
4450
                                         unsigned=unsigned, longness=longness,
4451
                                         value=value,
4452
                                         constant_result=int(node.constant_result))
4453
            # IntNode is smart about the type it chooses, so we just
4454
            # make sure we were not smarter this time
4455
            if widest_type.is_pyobject or new_node.type.is_pyobject:
4456
                new_node.type = PyrexTypes.py_object_type
4457
            else:
4458
                new_node.type = PyrexTypes.widest_numeric_type(widest_type, new_node.type)
4459
        else:
4460
            if target_class is ExprNodes.BoolNode:
4461
                node_value = node.constant_result
4462
            else:
4463
                node_value = str(node.constant_result)
4464
            new_node = target_class(pos=node.pos, type = widest_type,
4465
                                    value = node_value,
4466
                                    constant_result = node.constant_result)
4467
        return new_node
4468

4469
    def visit_AddNode(self, node):
4470
        self._calculate_const(node)
4471
        if node.constant_result is ExprNodes.not_a_constant:
4472
            return node
4473
        if node.operand1.is_string_literal and node.operand2.is_string_literal:
4474
            # some people combine string literals with a '+'
4475
            str1, str2 = node.operand1, node.operand2
4476
            if isinstance(str1, ExprNodes.UnicodeNode) and isinstance(str2, ExprNodes.UnicodeNode):
4477
                bytes_value = None
4478
                if str1.bytes_value is not None and str2.bytes_value is not None:
4479
                    if str1.bytes_value.encoding == str2.bytes_value.encoding:
4480
                        bytes_value = bytes_literal(
4481
                            str1.bytes_value + str2.bytes_value,
4482
                            str1.bytes_value.encoding)
4483
                string_value = EncodedString(node.constant_result)
4484
                return ExprNodes.UnicodeNode(str1.pos, value=string_value, bytes_value=bytes_value)
4485
            elif isinstance(str1, ExprNodes.BytesNode) and isinstance(str2, ExprNodes.BytesNode):
4486
                if str1.value.encoding == str2.value.encoding:
4487
                    bytes_value = bytes_literal(node.constant_result, str1.value.encoding)
4488
                    return ExprNodes.BytesNode(str1.pos, value=bytes_value, constant_result=node.constant_result)
4489
            # all other combinations are rather complicated
4490
            # to get right in Py2/3: encodings, unicode escapes, ...
4491
        return self.visit_BinopNode(node)
4492

4493
    def visit_MulNode(self, node):
4494
        self._calculate_const(node)
4495
        if node.operand1.is_sequence_constructor:
4496
            return self._calculate_constant_seq(node, node.operand1, node.operand2)
4497
        if isinstance(node.operand1, ExprNodes.IntNode) and \
4498
                node.operand2.is_sequence_constructor:
4499
            return self._calculate_constant_seq(node, node.operand2, node.operand1)
4500
        if node.operand1.is_string_literal:
4501
            return self._multiply_string(node, node.operand1, node.operand2)
4502
        elif node.operand2.is_string_literal:
4503
            return self._multiply_string(node, node.operand2, node.operand1)
4504
        return self.visit_BinopNode(node)
4505

4506
    def _multiply_string(self, node, string_node, multiplier_node):
4507
        multiplier = multiplier_node.constant_result
4508
        if not isinstance(multiplier, int):
4509
            return node
4510
        if not (node.has_constant_result() and isinstance(node.constant_result, _py_string_types)):
4511
            return node
4512
        if len(node.constant_result) > 256:
4513
            # Too long for static creation, leave it to runtime.  (-> arbitrary limit)
4514
            return node
4515

4516
        if isinstance(string_node, ExprNodes.BytesNode):
4517
            build_string = bytes_literal
4518
        elif isinstance(string_node, ExprNodes.UnicodeNode):
4519
            build_string = encoded_string
4520
            if string_node.bytes_value is not None:
4521
                string_node.bytes_value = bytes_literal(
4522
                    string_node.bytes_value * multiplier,
4523
                    string_node.bytes_value.encoding)
4524
        else:
4525
            assert False, "unknown string node type: %s" % type(string_node)
4526
        string_node.value = build_string(
4527
            string_node.value * multiplier,
4528
            string_node.value.encoding)
4529
        # follow constant-folding and use unicode_value in preference
4530
        string_node.constant_result = string_node.value
4531
        return string_node
4532

4533
    def _calculate_constant_seq(self, node, sequence_node, factor):
4534
        if factor.constant_result != 1 and sequence_node.args:
4535
            if isinstance(factor.constant_result, int) and factor.constant_result <= 0:
4536
                del sequence_node.args[:]
4537
                sequence_node.mult_factor = None
4538
            elif sequence_node.mult_factor is not None:
4539
                if (isinstance(factor.constant_result, int) and
4540
                        isinstance(sequence_node.mult_factor.constant_result, int)):
4541
                    value = sequence_node.mult_factor.constant_result * factor.constant_result
4542
                    sequence_node.mult_factor = ExprNodes.IntNode(
4543
                        sequence_node.mult_factor.pos,
4544
                        value=str(value), constant_result=value)
4545
                else:
4546
                    # don't know if we can combine the factors, so don't
4547
                    return self.visit_BinopNode(node)
4548
            else:
4549
                sequence_node.mult_factor = factor
4550
        return sequence_node
4551

4552
    def visit_ModNode(self, node):
4553
        self.visitchildren(node)
4554
        if isinstance(node.operand1, ExprNodes.UnicodeNode) and isinstance(node.operand2, ExprNodes.TupleNode):
4555
            if not node.operand2.mult_factor:
4556
                fstring = self._build_fstring(node.operand1.pos, node.operand1.value, node.operand2.args)
4557
                if fstring is not None:
4558
                    return fstring
4559
        return self.visit_BinopNode(node)
4560

4561
    _parse_string_format_regex = (
4562
        '(%(?:'              # %...
4563
        '(?:[-0-9]+|[ ])?'   # width (optional) or space prefix fill character (optional)
4564
        '(?:[.][0-9]+)?'     # precision (optional)
4565
        ')?.)'               # format type (or something different for unsupported formats)
4566
    )
4567

4568
    def _build_fstring(self, pos, ustring, format_args):
4569
        # Issues formatting warnings instead of errors since we really only catch a few errors by accident.
4570
        args = iter(format_args)
4571
        substrings = []
4572
        can_be_optimised = True
4573
        for s in re.split(self._parse_string_format_regex, ustring):
4574
            if not s:
4575
                continue
4576
            if s == '%%':
4577
                substrings.append(ExprNodes.UnicodeNode(pos, value=EncodedString('%')))
4578
                continue
4579
            if s[0] != '%':
4580
                if s[-1] == '%':
4581
                    warning(pos, f"Incomplete format: '...{s[-3:]}'", level=1)
4582
                    can_be_optimised = False
4583
                substrings.append(ExprNodes.UnicodeNode(pos, value=EncodedString(s)))
4584
                continue
4585
            format_type = s[-1]
4586
            try:
4587
                arg = next(args)
4588
            except StopIteration:
4589
                warning(pos, "Too few arguments for format placeholders", level=1)
4590
                can_be_optimised = False
4591
                break
4592
            if arg.is_starred:
4593
                can_be_optimised = False
4594
                break
4595
            if format_type in 'asrfdoxX':
4596
                format_spec = s[1:]
4597
                conversion_char = None
4598
                if format_type in 'doxX' and '.' in format_spec:
4599
                    # Precision is not allowed for integers in format(), but ok in %-formatting.
4600
                    can_be_optimised = False
4601
                elif format_type in 'ars':
4602
                    format_spec = format_spec[:-1]
4603
                    conversion_char = format_type
4604
                    if format_spec.startswith('0'):
4605
                        format_spec = '>' + format_spec[1:]  # right-alignment '%05s' spells '{:>5}'
4606
                elif format_type == 'd':
4607
                    # '%d' formatting supports float, but '{obj:d}' does not => convert to int first.
4608
                    conversion_char = 'd'
4609

4610
                if format_spec.startswith('-'):
4611
                    format_spec = '<' + format_spec[1:]  # left-alignment '%-5s' spells '{:<5}'
4612

4613
                substrings.append(ExprNodes.FormattedValueNode(
4614
                    arg.pos, value=arg,
4615
                    conversion_char=conversion_char,
4616
                    format_spec=ExprNodes.UnicodeNode(pos, value=EncodedString(format_spec))
4617
                        if format_spec else None,
4618
                ))
4619
            else:
4620
                # keep it simple for now ...
4621
                can_be_optimised = False
4622
                break
4623

4624
        if not can_be_optimised:
4625
            # Print all warnings we can find before finally giving up here.
4626
            return None
4627

4628
        try:
4629
            next(args)
4630
        except StopIteration: pass
4631
        else:
4632
            warning(pos, "Too many arguments for format placeholders", level=1)
4633
            return None
4634

4635
        node = ExprNodes.JoinedStrNode(pos, values=substrings)
4636
        return self.visit_JoinedStrNode(node)
4637

4638
    def visit_FormattedValueNode(self, node):
4639
        self.visitchildren(node)
4640
        conversion_char = node.conversion_char or 's'
4641
        if node.format_spec is not None and node.format_spec.is_string_literal and not node.format_spec.value:
4642
            node.format_spec = None
4643
        if node.format_spec is None and isinstance(node.value, ExprNodes.IntNode):
4644
            value = EncodedString(node.value.value)
4645
            if value.isdigit():
4646
                return ExprNodes.UnicodeNode(node.value.pos, value=value)
4647
        if node.format_spec is None and conversion_char == 's':
4648
            if node.value.is_string_literal:
4649
                return node.value
4650
        return node
4651

4652
    def visit_JoinedStrNode(self, node):
4653
        """
4654
        Clean up after the parser by discarding empty Unicode strings and merging
4655
        substring sequences.  Empty or single-value join lists are not uncommon
4656
        because f-string format specs are always parsed into JoinedStrNodes.
4657
        """
4658
        self.visitchildren(node)
4659

4660
        values = []
4661
        for is_unode_group, substrings in itertools.groupby(node.values, key=attrgetter('is_string_literal')):
4662
            if is_unode_group:
4663
                substrings = list(substrings)
4664
                unode = substrings[0]
4665
                if len(substrings) > 1:
4666
                    value = EncodedString(''.join(value.value for value in substrings))
4667
                    unode = ExprNodes.UnicodeNode(unode.pos, value=value)
4668
                # ignore empty Unicode strings
4669
                if unode.value:
4670
                    values.append(unode)
4671
            else:
4672
                values.extend(substrings)
4673

4674
        if not values:
4675
            node = ExprNodes.UnicodeNode(node.pos, value=EncodedString(''))
4676
        elif len(values) == 1:
4677
            node = values[0]
4678
        elif len(values) == 2:
4679
            # reduce to string concatenation
4680
            node = ExprNodes.binop_node(node.pos, '+', *values)
4681
        else:
4682
            node.values = values
4683
        return node
4684

4685
    def visit_MergedDictNode(self, node):
4686
        """Unpack **args in place if we can."""
4687
        self.visitchildren(node)
4688
        args = []
4689
        items = []
4690

4691
        def add(parent, arg):
4692
            if arg.is_dict_literal:
4693
                if items and items[-1].reject_duplicates == arg.reject_duplicates:
4694
                    items[-1].key_value_pairs.extend(arg.key_value_pairs)
4695
                else:
4696
                    items.append(arg)
4697
            elif isinstance(arg, ExprNodes.MergedDictNode) and parent.reject_duplicates == arg.reject_duplicates:
4698
                for child_arg in arg.keyword_args:
4699
                    add(arg, child_arg)
4700
            else:
4701
                if items:
4702
                    args.extend(items)
4703
                    del items[:]
4704
                args.append(arg)
4705

4706
        for arg in node.keyword_args:
4707
            add(node, arg)
4708
        if items:
4709
            args.extend(items)
4710

4711
        if len(args) == 1:
4712
            arg = args[0]
4713
            if arg.is_dict_literal or isinstance(arg, ExprNodes.MergedDictNode):
4714
                return arg
4715
        node.keyword_args[:] = args
4716
        self._calculate_const(node)
4717
        return node
4718

4719
    def visit_MergedSequenceNode(self, node):
4720
        """Unpack *args in place if we can."""
4721
        self.visitchildren(node)
4722

4723
        is_set = node.type is Builtin.set_type
4724
        args = []
4725
        values = []
4726

4727
        def add(arg):
4728
            if (is_set and arg.is_set_literal) or (arg.is_sequence_constructor and not arg.mult_factor):
4729
                if values:
4730
                    values[0].args.extend(arg.args)
4731
                else:
4732
                    values.append(arg)
4733
            elif isinstance(arg, ExprNodes.MergedSequenceNode):
4734
                for child_arg in arg.args:
4735
                    add(child_arg)
4736
            else:
4737
                if values:
4738
                    args.append(values[0])
4739
                    del values[:]
4740
                args.append(arg)
4741

4742
        for arg in node.args:
4743
            add(arg)
4744
        if values:
4745
            args.append(values[0])
4746

4747
        if len(args) == 1:
4748
            arg = args[0]
4749
            if ((is_set and arg.is_set_literal) or
4750
                    (arg.is_sequence_constructor and arg.type is node.type) or
4751
                    isinstance(arg, ExprNodes.MergedSequenceNode)):
4752
                return arg
4753
        node.args[:] = args
4754
        self._calculate_const(node)
4755
        return node
4756

4757
    def visit_SequenceNode(self, node):
4758
        """Unpack *args in place if we can."""
4759
        self.visitchildren(node)
4760
        args = []
4761
        for arg in node.args:
4762
            if not arg.is_starred:
4763
                args.append(arg)
4764
            elif arg.target.is_sequence_constructor and not arg.target.mult_factor:
4765
                args.extend(arg.target.args)
4766
            else:
4767
                args.append(arg)
4768
        node.args[:] = args
4769
        self._calculate_const(node)
4770
        return node
4771

4772
    def visit_PrimaryCmpNode(self, node):
4773
        # calculate constant partial results in the comparison cascade
4774
        self.visitchildren(node, ['operand1'])
4775
        left_node = node.operand1
4776
        cmp_node = node
4777
        while cmp_node is not None:
4778
            self.visitchildren(cmp_node, ['operand2'])
4779
            right_node = cmp_node.operand2
4780
            cmp_node.constant_result = not_a_constant
4781
            if left_node.has_constant_result() and right_node.has_constant_result():
4782
                try:
4783
                    cmp_node.calculate_cascaded_constant_result(left_node.constant_result)
4784
                except (ValueError, TypeError, KeyError, IndexError, AttributeError, ArithmeticError):
4785
                    pass  # ignore all 'normal' errors here => no constant result
4786
            left_node = right_node
4787
            cmp_node = cmp_node.cascade
4788

4789
        if not node.cascade:
4790
            if node.has_constant_result():
4791
                return self._bool_node(node, node.constant_result)
4792
            return node
4793

4794
        # collect partial cascades: [[value, CmpNode...], [value, CmpNode, ...], ...]
4795
        cascades = [[node.operand1]]
4796
        final_false_result = []
4797

4798
        cmp_node = node
4799
        while cmp_node is not None:
4800
            if cmp_node.has_constant_result():
4801
                if not cmp_node.constant_result:
4802
                    # False => short-circuit
4803
                    final_false_result.append(self._bool_node(cmp_node, False))
4804
                    break
4805
                else:
4806
                    # True => discard and start new cascade
4807
                    cascades.append([cmp_node.operand2])
4808
            else:
4809
                # not constant => append to current cascade
4810
                cascades[-1].append(cmp_node)
4811
            cmp_node = cmp_node.cascade
4812

4813
        cmp_nodes = []
4814
        for cascade in cascades:
4815
            if len(cascade) < 2:
4816
                continue
4817
            cmp_node = cascade[1]
4818
            pcmp_node = ExprNodes.PrimaryCmpNode(
4819
                cmp_node.pos,
4820
                operand1=cascade[0],
4821
                operator=cmp_node.operator,
4822
                operand2=cmp_node.operand2,
4823
                constant_result=not_a_constant)
4824
            cmp_nodes.append(pcmp_node)
4825

4826
            last_cmp_node = pcmp_node
4827
            for cmp_node in cascade[2:]:
4828
                last_cmp_node.cascade = cmp_node
4829
                last_cmp_node = cmp_node
4830
            last_cmp_node.cascade = None
4831

4832
        if final_false_result:
4833
            # last cascade was constant False
4834
            cmp_nodes.append(final_false_result[0])
4835
        elif not cmp_nodes:
4836
            # only constants, but no False result
4837
            return self._bool_node(node, True)
4838
        node = cmp_nodes[0]
4839
        if len(cmp_nodes) == 1:
4840
            if node.has_constant_result():
4841
                return self._bool_node(node, node.constant_result)
4842
        else:
4843
            for cmp_node in cmp_nodes[1:]:
4844
                node = ExprNodes.BoolBinopNode(
4845
                    node.pos,
4846
                    operand1=node,
4847
                    operator='and',
4848
                    operand2=cmp_node,
4849
                    constant_result=not_a_constant)
4850
        return node
4851

4852
    def visit_CondExprNode(self, node):
4853
        self._calculate_const(node)
4854
        if not node.test.has_constant_result():
4855
            return node
4856
        if node.test.constant_result:
4857
            return node.true_val
4858
        else:
4859
            return node.false_val
4860

4861
    def visit_IfStatNode(self, node):
4862
        self.visitchildren(node)
4863
        # eliminate dead code based on constant condition results
4864
        if_clauses = []
4865
        for if_clause in node.if_clauses:
4866
            condition = if_clause.condition
4867
            if condition.has_constant_result():
4868
                if condition.constant_result:
4869
                    # always true => subsequent clauses can safely be dropped
4870
                    node.else_clause = if_clause.body
4871
                    break
4872
                # else: false => drop clause
4873
            else:
4874
                # unknown result => normal runtime evaluation
4875
                if_clauses.append(if_clause)
4876
        if if_clauses:
4877
            node.if_clauses = if_clauses
4878
            return node
4879
        elif node.else_clause:
4880
            return node.else_clause
4881
        else:
4882
            return Nodes.StatListNode(node.pos, stats=[])
4883

4884
    def visit_SliceIndexNode(self, node):
4885
        self._calculate_const(node)
4886
        # normalise start/stop values
4887
        if node.start is None or node.start.constant_result is None:
4888
            start = node.start = None
4889
        else:
4890
            start = node.start.constant_result
4891
        if node.stop is None or node.stop.constant_result is None:
4892
            stop = node.stop = None
4893
        else:
4894
            stop = node.stop.constant_result
4895
        # cut down sliced constant sequences
4896
        if node.constant_result is not not_a_constant:
4897
            base = node.base
4898
            if base.is_sequence_constructor and base.mult_factor is None:
4899
                base.args = base.args[start:stop]
4900
                return base
4901
            elif base.is_string_literal:
4902
                base = base.as_sliced_node(start, stop)
4903
                if base is not None:
4904
                    return base
4905
        return node
4906

4907
    def visit_ComprehensionNode(self, node):
4908
        self.visitchildren(node)
4909
        if isinstance(node.loop, Nodes.StatListNode) and not node.loop.stats:
4910
            # loop was pruned already => transform into literal
4911
            if node.type is Builtin.list_type:
4912
                return ExprNodes.ListNode(
4913
                    node.pos, args=[], constant_result=[])
4914
            elif node.type is Builtin.set_type:
4915
                return ExprNodes.SetNode(
4916
                    node.pos, args=[], constant_result=set())
4917
            elif node.type is Builtin.dict_type:
4918
                return ExprNodes.DictNode(
4919
                    node.pos, key_value_pairs=[], constant_result={})
4920
        return node
4921

4922
    def visit_ForInStatNode(self, node):
4923
        self.visitchildren(node)
4924
        sequence = node.iterator.sequence
4925
        if isinstance(sequence, ExprNodes.SequenceNode):
4926
            if not sequence.args:
4927
                if node.else_clause:
4928
                    return node.else_clause
4929
                else:
4930
                    # don't break list comprehensions
4931
                    return Nodes.StatListNode(node.pos, stats=[])
4932
            # iterating over a list literal? => tuples are more efficient
4933
            if isinstance(sequence, ExprNodes.ListNode):
4934
                node.iterator.sequence = sequence.as_tuple()
4935
        return node
4936

4937
    def visit_WhileStatNode(self, node):
4938
        self.visitchildren(node)
4939
        if node.condition and node.condition.has_constant_result():
4940
            if node.condition.constant_result:
4941
                node.condition = None
4942
                node.else_clause = None
4943
            else:
4944
                return node.else_clause
4945
        return node
4946

4947
    def visit_ExprStatNode(self, node):
4948
        self.visitchildren(node)
4949
        if not isinstance(node.expr, ExprNodes.ExprNode):
4950
            # ParallelRangeTransform does this ...
4951
            return node
4952
        # drop unused constant expressions
4953
        if node.expr.has_constant_result():
4954
            return None
4955
        return node
4956

4957
    def visit_GILStatNode(self, node):
4958
        self.visitchildren(node)
4959
        if node.condition is None:
4960
            return node
4961

4962
        if node.condition.has_constant_result():
4963
            # Condition is True - Modify node to be a normal
4964
            # GILStatNode with condition=None
4965
            if node.condition.constant_result:
4966
                node.condition = None
4967

4968
            # Condition is False - the body of the GILStatNode
4969
            # should run without changing the state of the gil
4970
            # return the body of the GILStatNode
4971
            else:
4972
                return node.body
4973

4974
        # If condition is not constant we keep the GILStatNode as it is.
4975
        # Either it will later become constant (e.g. a `numeric is int`
4976
        # expression in a fused type function) and then when ConstantFolding
4977
        # runs again it will be handled or a later transform (i.e. GilCheck)
4978
        # will raise an error
4979
        return node
4980

4981
    # in the future, other nodes can have their own handler method here
4982
    # that can replace them with a constant result node
4983

4984
    visit_Node = Visitor.VisitorTransform.recurse_to_children
4985

4986

4987
class FinalOptimizePhase(Visitor.EnvTransform, Visitor.NodeRefCleanupMixin):
4988
    """
4989
    This visitor handles several commuting optimizations, and is run
4990
    just before the C code generation phase.
4991

4992
    The optimizations currently implemented in this class are:
4993
        - eliminate None assignment and refcounting for first assignment.
4994
        - isinstance -> typecheck for cdef types
4995
        - eliminate checks for None and/or types that became redundant after tree changes
4996
        - eliminate useless string formatting steps
4997
        - inject branch hints for unlikely if-cases that only raise exceptions
4998
        - replace Python function calls that look like method calls by a faster PyMethodCallNode
4999
    """
5000
    in_loop = False
5001

5002
    def visit_SingleAssignmentNode(self, node):
5003
        """Avoid redundant initialisation of local variables before their
5004
        first assignment.
5005
        """
5006
        self.visitchildren(node)
5007
        if node.first:
5008
            lhs = node.lhs
5009
            lhs.lhs_of_first_assignment = True
5010
        return node
5011

5012
    def _check_optimize_method_calls(self, node):
5013
        function = node.function
5014
        env = self.current_env()
5015
        in_global_scope = (
5016
            env.is_module_scope or
5017
            env.is_c_class_scope or
5018
            (env.is_py_class_scope and env.outer_scope.is_module_scope)
5019
        )
5020
        return (node.is_temp and function.type.is_pyobject and self.current_directives.get(
5021
                "optimize.unpack_method_calls_in_pyinit"
5022
                if not self.in_loop and in_global_scope
5023
                else "optimize.unpack_method_calls"))
5024

5025
    def visit_SimpleCallNode(self, node):
5026
        """
5027
        Replace generic calls to isinstance(x, type) by a more efficient type check.
5028
        Replace likely Python method calls by a specialised PyMethodCallNode.
5029
        """
5030
        self.visitchildren(node)
5031
        function = node.function
5032
        if function.type.is_cfunction and function.is_name:
5033
            if function.name == 'isinstance' and len(node.args) == 2:
5034
                type_arg = node.args[1]
5035
                if type_arg.type.is_builtin_type and type_arg.type.name == 'type':
5036
                    cython_scope = self.context.cython_scope
5037
                    function.entry = cython_scope.lookup('PyObject_TypeCheck')
5038
                    function.type = function.entry.type
5039
                    PyTypeObjectPtr = PyrexTypes.CPtrType(cython_scope.lookup('PyTypeObject').type)
5040
                    node.args[1] = ExprNodes.CastNode(node.args[1], PyTypeObjectPtr)
5041
        else:
5042
            # optimise simple Python methods calls
5043
            if ExprNodes.PyMethodCallNode.can_be_used_for_posargs(node.arg_tuple, has_kwargs=False):
5044
                # simple call, now exclude calls to objects that are definitely not methods
5045
                if ExprNodes.PyMethodCallNode.can_be_used_for_function(function):
5046
                    if (node.self and function.is_attribute and
5047
                            isinstance(function.obj, ExprNodes.CloneNode) and function.obj.arg is node.self):
5048
                        # function self object was moved into a CloneNode => undo
5049
                        function.obj = function.obj.arg
5050
                    node = self.replace(node, ExprNodes.PyMethodCallNode.from_node(
5051
                        node, function=function, arg_tuple=node.arg_tuple, type=node.type,
5052
                        unpack=self._check_optimize_method_calls(node)))
5053
        return node
5054

5055
    def visit_GeneralCallNode(self, node):
5056
        """
5057
        Replace likely Python method calls by a specialised PyMethodCallNode.
5058
        """
5059
        self.visitchildren(node)
5060
        has_kwargs = bool(node.keyword_args)
5061
        kwds_is_dict_node = isinstance(node.keyword_args, ExprNodes.DictNode)
5062
        if not ExprNodes.PyMethodCallNode.can_be_used_for_posargs(
5063
                node.positional_args, has_kwargs=has_kwargs, kwds_is_dict_node=kwds_is_dict_node):
5064
            return node
5065
        function = node.function
5066
        if not ExprNodes.PyMethodCallNode.can_be_used_for_function(function):
5067
            return node
5068

5069
        node = self.replace(node, ExprNodes.PyMethodCallNode.from_node(
5070
            node, function=function, arg_tuple=node.positional_args, kwdict=node.keyword_args,
5071
            type=node.type, unpack=self._check_optimize_method_calls(node)))
5072
        return node
5073

5074
    def visit_NumPyMethodCallNode(self, node):
5075
        # Exclude from replacement above.
5076
        self.visitchildren(node)
5077
        return node
5078

5079
    def visit_PyTypeTestNode(self, node):
5080
        """Remove tests for alternatively allowed None values from
5081
        type tests when we know that the argument cannot be None
5082
        anyway.
5083
        """
5084
        self.visitchildren(node)
5085
        if not node.notnone:
5086
            if not node.arg.may_be_none():
5087
                node.notnone = True
5088
        return node
5089

5090
    def visit_NoneCheckNode(self, node):
5091
        """Remove None checks from expressions that definitely do not
5092
        carry a None value.
5093
        """
5094
        self.visitchildren(node)
5095
        if not node.arg.may_be_none():
5096
            return node.arg
5097
        return node
5098

5099
    def visit_LoopNode(self, node):
5100
        """Remember when we enter a loop as some expensive optimisations might still be worth it there.
5101
        """
5102
        old_val = self.in_loop
5103
        self.in_loop = True
5104
        self.visitchildren(node)
5105
        self.in_loop = old_val
5106
        return node
5107

5108
    def visit_IfStatNode(self, node):
5109
        """Assign 'unlikely' branch hints to if-clauses that only raise exceptions.
5110
        """
5111
        self.visitchildren(node)
5112
        last_non_unlikely_clause = None
5113
        for i, if_clause in enumerate(node.if_clauses):
5114
            self._set_ifclause_branch_hint(if_clause, if_clause.body)
5115
            if not if_clause.branch_hint:
5116
                last_non_unlikely_clause = if_clause
5117
        if node.else_clause and last_non_unlikely_clause:
5118
            # If the 'else' clause is 'unlikely', then set the preceding 'if' clause to 'likely' to reflect that.
5119
            self._set_ifclause_branch_hint(last_non_unlikely_clause, node.else_clause, inverse=True)
5120
        return node
5121

5122
    def _set_ifclause_branch_hint(self, clause, statements_node, inverse=False):
5123
        """Inject a branch hint if the if-clause unconditionally leads to a 'raise' statement.
5124
        """
5125
        if not statements_node.is_terminator:
5126
            return
5127
        # Allow simple statements, but no conditions, loops, etc.
5128
        non_branch_nodes = (
5129
            Nodes.ExprStatNode,
5130
            Nodes.AssignmentNode,
5131
            Nodes.AssertStatNode,
5132
            Nodes.DelStatNode,
5133
            Nodes.GlobalNode,
5134
            Nodes.NonlocalNode,
5135
        )
5136
        statements = [statements_node]
5137
        for next_node_pos, node in enumerate(statements, 1):
5138
            if isinstance(node, Nodes.GILStatNode):
5139
                statements.insert(next_node_pos, node.body)
5140
                continue
5141
            if isinstance(node, Nodes.StatListNode):
5142
                statements[next_node_pos:next_node_pos] = node.stats
5143
                continue
5144
            if not isinstance(node, non_branch_nodes):
5145
                if next_node_pos == len(statements) and isinstance(node, (Nodes.RaiseStatNode, Nodes.ReraiseStatNode)):
5146
                    # Anything that unconditionally raises exceptions at the end should be considered unlikely.
5147
                    clause.branch_hint = 'likely' if inverse else 'unlikely'
5148
                break
5149

5150

5151
class ConsolidateOverflowCheck(Visitor.CythonTransform):
5152
    """
5153
    This class facilitates the sharing of overflow checking among all nodes
5154
    of a nested arithmetic expression.  For example, given the expression
5155
    a*b + c, where a, b, and x are all possibly overflowing ints, the entire
5156
    sequence will be evaluated and the overflow bit checked only at the end.
5157
    """
5158
    overflow_bit_node = None
5159

5160
    def visit_Node(self, node):
5161
        if self.overflow_bit_node is not None:
5162
            saved = self.overflow_bit_node
5163
            self.overflow_bit_node = None
5164
            self.visitchildren(node)
5165
            self.overflow_bit_node = saved
5166
        else:
5167
            self.visitchildren(node)
5168
        return node
5169

5170
    def visit_NumBinopNode(self, node):
5171
        if node.overflow_check and node.overflow_fold:
5172
            top_level_overflow = self.overflow_bit_node is None
5173
            if top_level_overflow:
5174
                self.overflow_bit_node = node
5175
            else:
5176
                node.overflow_bit_node = self.overflow_bit_node
5177
                node.overflow_check = False
5178
            self.visitchildren(node)
5179
            if top_level_overflow:
5180
                self.overflow_bit_node = None
5181
        else:
5182
            self.visitchildren(node)
5183
        return node
5184

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.