cython

Форк
0
/
fused_types.pyx 
705 строк · 19.3 Кб
1
# mode: run
2
# ticket: 1772
3
# cython: language_level=3str
4

5
cimport cython
6
from cython.view cimport array
7

8
from cython cimport integral
9
from cpython cimport Py_INCREF
10

11
from Cython import Shadow as pure_cython
12
ctypedef char * string_t
13

14
# floating = cython.fused_type(float, double) floating
15
# integral = cython.fused_type(int, long) integral
16
ctypedef cython.floating floating
17
fused_type1 = cython.fused_type(int, long, float, double, string_t)
18
fused_type2 = cython.fused_type(string_t)
19
ctypedef fused_type1 *composed_t
20
other_t = cython.fused_type(int, double)
21
ctypedef double *p_double
22
ctypedef int *p_int
23
fused_type3 = cython.fused_type(int, double)
24
fused_composite = cython.fused_type(fused_type2, fused_type3)
25
just_float = cython.fused_type(float)
26

27
ctypedef int inttypedef
28
ctypedef double doubletypedef
29
fused_with_typedef = cython.fused_type(inttypedef, doubletypedef)
30

31
ctypedef float const_inttypedef  # misleading name
32
fused_misleading_name = cython.fused_type(const_inttypedef, char)
33

34

35
def test_pure():
36
    """
37
    >>> test_pure()
38
    10
39
    """
40
    mytype = pure_cython.typedef(pure_cython.fused_type(int, complex))
41
    print(mytype(10))
42

43

44
cdef cdef_func_with_fused_args(fused_type1 x, fused_type1 y, fused_type2 z):
45
    if fused_type1 is string_t:
46
        print(x.decode('ascii'), y.decode('ascii'), z.decode('ascii'))
47
    else:
48
        print(x, y, z.decode('ascii'))
49

50
    return x + y
51

52
def test_cdef_func_with_fused_args():
53
    """
54
    >>> test_cdef_func_with_fused_args()
55
    spam ham eggs
56
    spamham
57
    10 20 butter
58
    30
59
    4.2 8.6 bunny
60
    12.8
61
    """
62
    print(cdef_func_with_fused_args(b'spam', b'ham', b'eggs').decode('ascii'))
63
    print(cdef_func_with_fused_args(10, 20, b'butter'))
64
    print(cdef_func_with_fused_args(4.2, 8.6, b'bunny'))
65

66
cdef fused_type1 fused_with_pointer(fused_type1 *array):
67
    for i in range(5):
68
        if fused_type1 is string_t:
69
            print(array[i].decode('ascii'))
70
        else:
71
            print(array[i])
72

73
    obj = array[0] + array[1] + array[2] + array[3] + array[4]
74
    # if cython.typeof(fused_type1) is string_t:
75
    Py_INCREF(obj)
76
    return obj
77

78
def test_fused_with_pointer():
79
    """
80
    >>> test_fused_with_pointer()
81
    0
82
    1
83
    2
84
    3
85
    4
86
    10
87
    <BLANKLINE>
88
    0
89
    1
90
    2
91
    3
92
    4
93
    10
94
    <BLANKLINE>
95
    0.0
96
    1.0
97
    2.0
98
    3.0
99
    4.0
100
    10.0
101
    <BLANKLINE>
102
    humpty
103
    dumpty
104
    fall
105
    splatch
106
    breakfast
107
    humptydumptyfallsplatchbreakfast
108
    """
109
    cdef int[5] int_array
110
    cdef long[5] long_array
111
    cdef float[5] float_array
112
    cdef string_t[5] string_array
113

114
    cdef char *s
115

116
    strings = [b"humpty", b"dumpty", b"fall", b"splatch", b"breakfast"]
117

118
    for i in range(5):
119
        int_array[i] = i
120
        long_array[i] = i
121
        float_array[i] = i
122
        s = strings[i]
123
        string_array[i] = s
124

125
    print(fused_with_pointer(int_array))
126
    print()
127
    print(fused_with_pointer(long_array))
128
    print()
129
    print(fused_with_pointer(float_array))
130
    print()
131
    print(fused_with_pointer(string_array).decode('ascii'))
132

133
cdef fused_type1* fused_pointer_except_null(fused_type1* x) except NULL:
134
    if fused_type1 is string_t:
135
        assert(bool(x[0]))
136
    else:
137
        assert(x[0] < 10)
138
    return x
139

140
def test_fused_pointer_except_null(value):
141
    """
142
    >>> test_fused_pointer_except_null(1)
143
    1
144
    >>> test_fused_pointer_except_null(2.0)
145
    2.0
146
    >>> test_fused_pointer_except_null(b'foo')
147
    foo
148
    >>> test_fused_pointer_except_null(16)
149
    Traceback (most recent call last):
150
    AssertionError
151
    >>> test_fused_pointer_except_null(15.1)
152
    Traceback (most recent call last):
153
    AssertionError
154
    >>> test_fused_pointer_except_null(b'')
155
    Traceback (most recent call last):
156
    AssertionError
157
    """
158
    if isinstance(value, int):
159
        test_int = cython.declare(cython.int, value)
160
        print(fused_pointer_except_null(&test_int)[0])
161
    elif isinstance(value, float):
162
        test_float = cython.declare(cython.float, value)
163
        print(fused_pointer_except_null(&test_float)[0])
164
    elif isinstance(value, bytes):
165
        test_str = cython.declare(string_t, value)
166
        print(fused_pointer_except_null(&test_str)[0].decode('ascii'))
167

168
include "../testsupport/cythonarrayutil.pxi"
169

170
cpdef cython.integral test_fused_memoryviews(cython.integral[:, ::1] a):
171
    """
172
    >>> import cython
173
    >>> a = create_array((3, 5), mode="c")
174
    >>> test_fused_memoryviews[cython.int](a)
175
    7
176
    """
177
    return a[1, 2]
178

179
ctypedef int[:, ::1] memview_int
180
ctypedef long[:, ::1] memview_long
181
memview_t = cython.fused_type(memview_int, memview_long)
182

183
def test_fused_memoryview_def(memview_t a):
184
    """
185
    >>> a = create_array((3, 5), mode="c")
186
    >>> test_fused_memoryview_def["memview_int"](a)
187
    7
188
    """
189
    return a[1, 2]
190

191
cdef test_specialize(fused_type1 x, fused_type1 *y, composed_t z, other_t *a):
192
    cdef fused_type1 result
193

194
    if composed_t is p_double:
195
        print("double pointer")
196

197
    if fused_type1 in floating:
198
        result = x + y[0] + z[0] + a[0]
199
        return result
200

201
def test_specializations():
202
    """
203
    >>> test_specializations()
204
    double pointer
205
    double pointer
206
    double pointer
207
    double pointer
208
    double pointer
209
    """
210
    cdef object (*f)(double, double *, double *, int *)
211

212
    cdef double somedouble = 2.2
213
    cdef double otherdouble = 3.3
214
    cdef int someint = 4
215

216
    cdef p_double somedouble_p = &somedouble
217
    cdef p_double otherdouble_p = &otherdouble
218
    cdef p_int someint_p = &someint
219

220
    f = test_specialize
221
    assert f(1.1, somedouble_p, otherdouble_p, someint_p) == 10.6
222

223
    f = <object (*)(double, double *, double *, int *)> test_specialize
224
    assert f(1.1, somedouble_p, otherdouble_p, someint_p) == 10.6
225

226
    assert (<object (*)(double, double *, double *, int *)>
227
            test_specialize)(1.1, somedouble_p, otherdouble_p, someint_p) == 10.6
228

229
    f = test_specialize[double, int]
230
    assert f(1.1, somedouble_p, otherdouble_p, someint_p) == 10.6
231

232
    assert test_specialize[double, int](1.1, somedouble_p, otherdouble_p, someint_p) == 10.6
233

234
    # The following cases are not supported
235
    # f = test_specialize[double][p_int]
236
    # print f(1.1, somedouble_p, otherdouble_p)
237
    # print
238

239
    # print test_specialize[double][p_int](1.1, somedouble_p, otherdouble_p)
240
    # print
241

242
    # print test_specialize[double](1.1, somedouble_p, otherdouble_p)
243
    # print
244

245
cdef opt_args(integral x, floating y = 4.0):
246
    print(x, y)
247

248
def test_opt_args():
249
    """
250
    >>> test_opt_args()
251
    3 4.0
252
    3 4.0
253
    3 4.0
254
    3 4.0
255
    """
256
    opt_args[int,  float](3)
257
    opt_args[int, double](3)
258
    opt_args[int,  float](3, 4.0)
259
    opt_args[int, double](3, 4.0)
260

261
class NormalClass(object):
262
    def method(self, cython.integral i):
263
        print(cython.typeof(i), i)
264

265
def test_normal_class():
266
    """
267
    >>> test_normal_class()
268
    short 10
269
    """
270
    NormalClass().method[pure_cython.short](10)
271

272
def test_normal_class_refcount():
273
    """
274
    >>> test_normal_class_refcount()
275
    short 10
276
    0
277
    """
278
    import sys
279
    x = NormalClass()
280
    c = sys.getrefcount(x)
281
    x.method[pure_cython.short](10)
282
    print(sys.getrefcount(x) - c)
283

284
def test_fused_declarations(cython.integral i, cython.floating f):
285
    """
286
    >>> test_fused_declarations[pure_cython.short, pure_cython.float](5, 6.6)
287
    short
288
    float
289
    25 43.56
290
    >>> test_fused_declarations[pure_cython.long, pure_cython.double](5, 6.6)
291
    long
292
    double
293
    25 43.56
294
    """
295
    cdef cython.integral squared_int = i * i
296
    cdef cython.floating squared_float = f * f
297

298
    assert cython.typeof(squared_int) == cython.typeof(i)
299
    assert cython.typeof(squared_float) == cython.typeof(f)
300

301
    print(cython.typeof(squared_int))
302
    print(cython.typeof(squared_float))
303
    print('%d %.2f' % (squared_int, squared_float))
304

305
def test_sizeof_fused_type(fused_type1 b):
306
    """
307
    >>> test_sizeof_fused_type[pure_cython.double](11.1)
308
    """
309
    t = sizeof(b), sizeof(fused_type1), sizeof(double)
310
    assert t[0] == t[1] == t[2], t
311

312
def get_array(itemsize, format):
313
    result = array((10,), itemsize, format)
314
    result[5] = 5.0
315
    result[6] = 6.0
316
    return result
317

318
def get_intc_array():
319
    result = array((10,), sizeof(int), 'i')
320
    result[5] = 5
321
    result[6] = 6
322
    return result
323

324
def test_fused_memslice_dtype(cython.floating[:] array):
325
    """
326
    Note: the np.ndarray dtype test is in numpy_test
327

328
    >>> import cython
329
    >>> sorted(test_fused_memslice_dtype.__signatures__)
330
    ['double', 'float']
331

332
    >>> test_fused_memslice_dtype[cython.double](get_array(8, 'd'))
333
    double[:] double[:] 5.0 6.0
334
    >>> test_fused_memslice_dtype[cython.float](get_array(4, 'f'))
335
    float[:] float[:] 5.0 6.0
336

337
    # None should evaluate to *something* (currently the first
338
    # in the list, but this shouldn't be a hard requirement)
339
    >>> test_fused_memslice_dtype(None)
340
    float[:]
341
    >>> test_fused_memslice_dtype[cython.double](None)
342
    double[:]
343
    """
344
    if array is None:
345
        print(cython.typeof(array))
346
        return
347
    cdef cython.floating[:] otherarray = array[0:100:1]
348
    print(cython.typeof(array), cython.typeof(otherarray),
349
          array[5], otherarray[6])
350
    cdef cython.floating value;
351
    cdef cython.floating[:] test_cast = <cython.floating[:1:1]>&value
352

353
def test_fused_memslice_dtype_repeated(cython.floating[:] array1, cython.floating[:] array2):
354
    """
355
    Note: the np.ndarray dtype test is in numpy_test
356

357
    >>> sorted(test_fused_memslice_dtype_repeated.__signatures__)
358
    ['double', 'float']
359

360
    >>> test_fused_memslice_dtype_repeated(get_array(8, 'd'), get_array(8, 'd'))
361
    double[:] double[:]
362
    >>> test_fused_memslice_dtype_repeated(get_array(4, 'f'), get_array(4, 'f'))
363
    float[:] float[:]
364
    >>> test_fused_memslice_dtype_repeated(get_array(8, 'd'), get_array(4, 'f'))
365
    Traceback (most recent call last):
366
    ValueError: Buffer dtype mismatch, expected 'double' but got 'float'
367
    """
368
    print(cython.typeof(array1), cython.typeof(array2))
369

370
def test_fused_memslice_dtype_repeated_2(cython.floating[:] array1, cython.floating[:] array2,
371
                                         fused_type3[:] array3):
372
    """
373
    Note: the np.ndarray dtype test is in numpy_test
374

375
    >>> sorted(test_fused_memslice_dtype_repeated_2.__signatures__)
376
    ['double|double', 'double|int', 'float|double', 'float|int']
377

378
    >>> test_fused_memslice_dtype_repeated_2(get_array(8, 'd'), get_array(8, 'd'), get_array(8, 'd'))
379
    double[:] double[:] double[:]
380
    >>> test_fused_memslice_dtype_repeated_2(get_array(8, 'd'), get_array(8, 'd'), get_intc_array())
381
    double[:] double[:] int[:]
382
    >>> test_fused_memslice_dtype_repeated_2(get_array(4, 'f'), get_array(4, 'f'), get_intc_array())
383
    float[:] float[:] int[:]
384
    """
385
    print(cython.typeof(array1), cython.typeof(array2), cython.typeof(array3))
386

387
def test_fused_const_memslice_dtype_repeated(const cython.floating[:] array1, cython.floating[:] array2):
388
    """Test fused types memory view with one being const
389

390
    >>> sorted(test_fused_const_memslice_dtype_repeated.__signatures__)
391
    ['double', 'float']
392

393
    >>> test_fused_const_memslice_dtype_repeated(get_array(8, 'd'), get_array(8, 'd'))
394
    const double[:] double[:]
395
    >>> test_fused_const_memslice_dtype_repeated(get_array(4, 'f'), get_array(4, 'f'))
396
    const float[:] float[:]
397
    >>> test_fused_const_memslice_dtype_repeated(get_array(8, 'd'), get_array(4, 'f'))
398
    Traceback (most recent call last):
399
    ValueError: Buffer dtype mismatch, expected 'double' but got 'float'
400
    """
401
    print(cython.typeof(array1), cython.typeof(array2))
402

403
def test_cython_numeric(cython.numeric arg):
404
    """
405
    Test to see whether complex numbers have their utility code declared
406
    properly.
407

408
    >>> test_cython_numeric(10.0 + 1j)
409
    double complex (10+1j)
410
    """
411
    print(cython.typeof(arg), arg)
412

413

414
cdef fused int_t:
415
    int
416

417
def test_pylong(int_t i):
418
    """
419
    >>> import cython
420
    >>> try:    long = long # Python 2
421
    ... except: long = int  # Python 3
422

423
    >>> test_pylong[int](int(0))
424
    int
425
    >>> test_pylong[cython.int](int(0))
426
    int
427
    >>> test_pylong(int(0))
428
    int
429

430
    >>> test_pylong[int](long(0))
431
    int
432
    >>> test_pylong[cython.int](long(0))
433
    int
434
    >>> test_pylong(long(0))
435
    int
436

437
    >>> test_pylong[cython.long](0)  # doctest: +ELLIPSIS
438
    Traceback (most recent call last):
439
    KeyError: ...
440
    """
441
    print(cython.typeof(i))
442

443

444
cdef fused ints_t:
445
    int
446
    long
447

448
cdef _test_index_fused_args(cython.floating f, ints_t i):
449
    print(cython.typeof(f), cython.typeof(i))
450

451
def test_index_fused_args(cython.floating f, ints_t i):
452
    """
453
    >>> import cython
454
    >>> test_index_fused_args[cython.double, cython.int](2.0, 3)
455
    double int
456
    """
457
    _test_index_fused_args[cython.floating, ints_t](f, i)
458

459
cdef _test_index_const_fused_args(const cython.floating f, const ints_t i):
460
    print((cython.typeof(f), cython.typeof(i)))
461

462
def test_index_const_fused_args(const cython.floating f, const ints_t i):
463
    """Test indexing function implementation with const fused type args
464

465
    >>> import cython
466
    >>> test_index_const_fused_args[cython.double, cython.int](2.0, 3)
467
    ('const double', 'const int')
468
    """
469
    _test_index_const_fused_args[cython.floating, ints_t](f, i)
470

471

472
def test_composite(fused_composite x):
473
    """
474
    >>> print(test_composite(b'a').decode('ascii'))
475
    a
476
    >>> test_composite(3)
477
    6
478
    >>> test_composite(3.0)
479
    6.0
480
    """
481
    if fused_composite is string_t:
482
        return x
483
    else:
484
        return 2 * x
485

486

487
cdef cdef_func_const_fused_arg(const cython.floating val,
488
                               const fused_type1 * ptr_to_const,
489
                               const (cython.floating *) const_ptr):
490
    print((val, cython.typeof(val)))
491
    print((ptr_to_const[0], cython.typeof(ptr_to_const[0])))
492
    print((const_ptr[0], cython.typeof(const_ptr[0])))
493

494
    ptr_to_const = NULL  # pointer is not const, value is const
495
    const_ptr[0] = 0.0  # pointer is const, value is not const
496

497
def test_cdef_func_with_const_fused_arg():
498
    """Test cdef function with const fused type argument
499

500
    >>> test_cdef_func_with_const_fused_arg()
501
    (0.0, 'const float')
502
    (1, 'const int')
503
    (2.0, 'float')
504
    """
505
    cdef float arg0 = 0.0
506
    cdef int arg1 = 1
507
    cdef float arg2 = 2.0
508
    cdef_func_const_fused_arg(arg0, &arg1, &arg2)
509

510

511
cdef in_check_1(just_float x):
512
    return just_float in floating
513

514
cdef in_check_2(just_float x, floating y):
515
    # the "floating" on the right-hand side of the in statement should not be specialized
516
    # - the test should still work.
517
    return just_float in floating
518

519
cdef in_check_3(floating x):
520
    # the floating on the left-hand side of the in statement should be specialized
521
    # but the one of the right-hand side should not (so that the test can still work).
522
    return floating in floating
523

524
def test_fused_in_check():
525
    """
526
    It should be possible to use fused types on in "x in ...fused_type" statements
527
    even if that type is specialized in the function.
528

529
    >>> test_fused_in_check()
530
    True
531
    True
532
    True
533
    True
534
    """
535
    print(in_check_1(1.0))
536
    print(in_check_2(1.0, 2.0))
537
    print(in_check_2[float, double](1.0, 2.0))
538
    print(in_check_3[float](1.0))
539

540

541
### see GH3642 - presence of cdef inside "unrelated" caused a type to be incorrectly inferred
542
cdef unrelated(cython.floating x):
543
    cdef cython.floating t = 1
544
    return t
545

546
cdef handle_float(float* x): return 'float'
547

548
cdef handle_double(double* x): return 'double'
549

550
def convert_to_ptr(cython.floating x):
551
    """
552
    >>> convert_to_ptr(1.0)
553
    'double'
554
    >>> convert_to_ptr['double'](1.0)
555
    'double'
556
    >>> convert_to_ptr['float'](1.0)
557
    'float'
558
    """
559
    if cython.floating is float:
560
        return handle_float(&x)
561
    elif cython.floating is double:
562
        return handle_double(&x)
563

564
def constfused_with_typedef(const fused_with_typedef[:] x):
565
    """
566
    >>> constfused_with_typedef(get_array(8, 'd'))
567
    5.0
568
    >>> constfused_with_typedef(get_intc_array())
569
    5
570
    """
571
    return x[5]
572

573
def constfused_typedef_name_clashes(const fused_with_typedef[:] x, fused_misleading_name[:] y):
574
    """
575
    This'll deliberately end up with two typedefs that generate the same name in dispatch code
576
    (and thus one needs to end up numbered to make it work).
577
    It's mainly a compile test and the runtime part is fairly token.
578

579
    >>> constfused_typedef_name_clashes(get_intc_array(), get_array(4, 'f'))
580
    (5, 5.0)
581
    """
582
    return x[5], y[5]
583

584
cdef double get_double():
585
    return 1.0
586
cdef float get_float():
587
    return 0.0
588

589
cdef call_func_pointer(cython.floating (*f)()):
590
    return f()
591

592
def test_fused_func_pointer():
593
    """
594
    >>> test_fused_func_pointer()
595
    1.0
596
    0.0
597
    """
598
    print(call_func_pointer(get_double))
599
    print(call_func_pointer(get_float))
600

601
cdef double get_double_from_int(int i):
602
    return i
603

604
cdef call_func_pointer_with_1(cython.floating (*f)(cython.integral)):
605
    return f(1)
606

607
def test_fused_func_pointer2():
608
    """
609
    >>> test_fused_func_pointer2()
610
    1.0
611
    """
612
    print(call_func_pointer_with_1(get_double_from_int))
613

614
cdef call_function_that_calls_fused_pointer(object (*f)(cython.floating (*)(cython.integral))):
615
    if cython.floating is double and cython.integral is int:
616
        return 5*f(get_double_from_int)
617
    else:
618
        return None  # practically it's hard to make this kind of function useful...
619

620
def test_fused_func_pointer_multilevel():
621
    """
622
    >>> test_fused_func_pointer_multilevel()
623
    5.0
624
    None
625
    """
626
    print(call_function_that_calls_fused_pointer(call_func_pointer_with_1[double, int]))
627
    print(call_function_that_calls_fused_pointer(call_func_pointer_with_1[float, int]))
628

629
cdef null_default(cython.floating x, cython.floating *x_minus_1_out=NULL):
630
    # On C++ a void* can't be assigned to a regular pointer, therefore setting up
631
    # needs to avoid going through a void* temp
632
    if x_minus_1_out:
633
        x_minus_1_out[0] = x-1
634
    return x
635

636
def test_null_default():
637
    """
638
    >>> test_null_default()
639
    2.0 1.0
640
    2.0
641
    2.0 1.0
642
    2.0
643
    """
644
    cdef double xd = 2.
645
    cdef double xd_minus_1
646
    result = null_default(xd, &xd_minus_1)
647
    print(result, xd_minus_1)
648
    result = null_default(xd)
649
    print(result)
650

651
    cdef float xf = 2.
652
    cdef float xf_minus_1
653
    result = null_default(xf, &xf_minus_1)
654
    print(result, xf_minus_1)
655
    result = null_default(xf)
656
    print(result)
657

658

659
cdef cython.numeric fused_numeric_default(int a = 1, cython.numeric x = 0):
660
    return x + a
661

662
def test_fused_numeric_default(int a, x):
663
    """
664
    >>> test_fused_numeric_default(1, 0)
665
    [1, 1.0, (1+0j)]
666

667
    >>> test_fused_numeric_default(1, 2)
668
    [3, 3.0, (3+0j)]
669

670
    >>> test_fused_numeric_default(2, 0)
671
    [2, 2.0, (2+0j)]
672

673
    >>> test_fused_numeric_default(2, 1)
674
    [3, 3.0, (3+0j)]
675
    """
676
    result = []
677

678
    if a == 1 and x == 0:
679
        result.append(fused_numeric_default[int]())
680
    elif x == 0:
681
        result.append(fused_numeric_default[int](a))
682
    elif a == 1:
683
        result.append(fused_numeric_default[int](1, x))
684
    else:
685
        result.append(fused_numeric_default[int](a, x))
686

687
    if a == 1 and x == 0:
688
        result.append(fused_numeric_default[float]())
689
    elif x == 0:
690
        result.append(fused_numeric_default[float](a))
691
    elif a == 1:
692
        result.append(fused_numeric_default[float](1, x))
693
    else:
694
        result.append(fused_numeric_default[float](a, x))
695

696
    if a == 1 and x == 0:
697
        result.append(fused_numeric_default[cython.doublecomplex]())
698
    elif x == 0:
699
        result.append(fused_numeric_default[cython.doublecomplex](a))
700
    elif a == 1:
701
        result.append(fused_numeric_default[cython.doublecomplex](1, x))
702
    else:
703
        result.append(fused_numeric_default[cython.doublecomplex](a, x))
704

705
    return result
706

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.