cython

Форк
0
/
parallel.pyx 
167 строк · 4.4 Кб
1
# tag: run
2
# tag: openmp
3

4
cimport cython.parallel
5
from cython.parallel import prange, threadid
6
cimport openmp
7
from libc.stdlib cimport malloc, free
8

9
openmp.omp_set_nested(1)
10

11
cdef int forward(int x) nogil:
12
    return x
13

14
def test_parallel():
15
    """
16
    >>> test_parallel()
17
    """
18
    cdef int maxthreads = openmp.omp_get_max_threads()
19
    cdef int *buf = <int *> malloc(sizeof(int) * maxthreads)
20

21
    if buf == NULL:
22
        raise MemoryError
23

24
    with nogil, cython.parallel.parallel():
25
        buf[threadid()] = threadid()
26
        # Recognise threadid() also when it's used in a function argument.
27
        # See https://github.com/cython/cython/issues/3594
28
        buf[forward(cython.parallel.threadid())] = forward(threadid())
29

30
    for i in range(maxthreads):
31
        assert buf[i] == i
32

33
    free(buf)
34

35
cdef int get_num_threads() noexcept with gil:
36
    print "get_num_threads called"
37
    return 3
38

39
cdef bint check_size(int size) nogil:
40
    return size > 5
41

42
def test_num_threads(int size):
43
    """
44
    >>> test_num_threads(6)
45
    1
46
    get_num_threads called
47
    3
48
    get_num_threads called
49
    3
50
    get_num_threads called
51
    3
52
    get_num_threads called
53
    3
54
    >>> test_num_threads(4)
55
    1
56
    get_num_threads called
57
    1
58
    get_num_threads called
59
    1
60
    get_num_threads called
61
    1
62
    get_num_threads called
63
    1
64
    """
65
    cdef int dyn = openmp.omp_get_dynamic()
66
    cdef int num_threads
67
    cdef int *p = &num_threads
68

69
    openmp.omp_set_dynamic(0)
70

71
    with nogil, cython.parallel.parallel(num_threads=1):
72
        p[0] = openmp.omp_get_num_threads()
73

74
    print num_threads
75

76
    with nogil, cython.parallel.parallel(num_threads=get_num_threads(), use_threads_if=size > 5):
77
        p[0] = openmp.omp_get_num_threads()
78

79
    print num_threads
80

81
    # Checks that temporary variables are released properly
82
    with nogil, cython.parallel.parallel(num_threads=get_num_threads(), use_threads_if=check_size(size)):
83
        p[0] = openmp.omp_get_num_threads()
84

85
    print num_threads
86

87
    cdef int i
88
    # Checks that temporary variables are released properly
89
    for i in prange(1, nogil=True, num_threads=get_num_threads(), use_threads_if=check_size(size)):
90
        p[0] = openmp.omp_get_num_threads()
91
        break
92

93
    print num_threads
94

95
    num_threads = 0xbad
96
    for i in prange(1, nogil=True, num_threads=get_num_threads(), use_threads_if=size > 5):
97
        p[0] = openmp.omp_get_num_threads()
98
        break
99

100
    openmp.omp_set_dynamic(dyn)
101

102
    return num_threads
103

104
'''
105
def test_parallel_catch():
106
    """
107
    >>> test_parallel_catch()
108
    True
109
    """
110
    cdef int i, j, num_threads
111
    exceptions = []
112

113
    for i in prange(100, nogil=True, num_threads=4):
114
        num_threads = openmp.omp_get_num_threads()
115

116
        with gil:
117
            try:
118
                for j in prange(100, nogil=True):
119
                    if i + j > 60:
120
                        with gil:
121
                            raise Exception("try and catch me if you can!")
122
            except Exception, e:
123
                exceptions.append(e)
124
                break
125

126
    print len(exceptions) == num_threads
127
    assert len(exceptions) == num_threads, (len(exceptions), num_threads)
128
'''
129

130

131
cdef void parallel_exception_checked_function(int* ptr, int id) except * nogil:
132
    # requires the GIL after each call
133
    ptr[0] = id;
134

135
cdef void parallel_call_exception_checked_function_impl(int* arr, int num_threads) nogil:
136
    # Inside a nogil function, parallel can't be sure that the GIL has been released.
137
    # Therefore Cython must release the GIL itself.
138
    # Otherwise, we can experience cause lock-ups if anything inside it acquires the GIL
139
    # (since if any other thread has finished, it will be holding the GIL).
140
    #
141
    # An equivalent test with prange is in "sequential_parallel.pyx"
142
    with cython.parallel.parallel(num_threads=num_threads):
143
        parallel_exception_checked_function(arr+threadid(), threadid())
144

145

146
def test_parallel_call_exception_checked_function():
147
    """
148
    test_parallel_call_exception_checked_function()
149
    """
150
    cdef int maxthreads = openmp.omp_get_max_threads()
151
    cdef int *buf = <int *> malloc(sizeof(int) * maxthreads)
152

153
    if buf == NULL:
154
        raise MemoryError
155

156
    try:
157
        # Note we *don't* release the GIL here
158
        parallel_call_exception_checked_function_impl(buf, maxthreads)
159

160
        for i in range(maxthreads):
161
            assert buf[i] == i
162
    finally:
163
        free(buf)
164

165

166
OPENMP_PARALLEL = True
167
include "sequential_parallel.pyx"
168

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.