cython

Форк
0
/
UFuncs.py 
311 строк · 10.6 Кб
1
from . import (
2
    Nodes,
3
    ExprNodes,
4
    FusedNode,
5
    Naming,
6
)
7
from .Errors import error
8
from . import PyrexTypes
9
from .UtilityCode import CythonUtilityCode
10
from .Code import TempitaUtilityCode, UtilityCode
11
from .Visitor import TreeVisitor
12
from . import Symtab
13

14

15
class _FindCFuncDefNode(TreeVisitor):
16
    """
17
    Finds the CFuncDefNode in the tree
18

19
    The assumption is that there's only one CFuncDefNode
20
    """
21

22
    found_node = None
23

24
    def visit_Node(self, node):
25
        if self.found_node:
26
            return
27
        else:
28
            self.visitchildren(node)
29

30
    def visit_CFuncDefNode(self, node):
31
        self.found_node = node
32

33
    def __call__(self, tree):
34
        self.visit(tree)
35
        return self.found_node
36

37

38
def get_cfunc_from_tree(tree):
39
    return _FindCFuncDefNode()(tree)
40

41

42
class _ArgumentInfo:
43
    """
44
    Everything related to defining an input/output argument for a ufunc
45

46
    type  - PyrexType
47
    type_constant  - str such as "NPY_INT8" representing numpy dtype constants
48
    injected_typename - str representing a name that can be used to look up the type
49
                        in Cython code
50
    """
51

52
    def __init__(self, type, type_constant, injected_typename):
53
        self.type = type
54
        self.type_constant = type_constant
55
        self.injected_typename = injected_typename
56

57

58
class UFuncConversion:
59
    def __init__(self, node):
60
        self.node = node
61
        self.global_scope = node.local_scope.global_scope()
62

63
        self.injected_typename = "ufunc_typename"
64
        while self.node.entry.cname.startswith(self.injected_typename):
65
            self.injected_typename += "_"
66
        self.injected_types = []
67
        self.in_definitions = self.get_in_type_info()
68
        self.out_definitions = self.get_out_type_info()
69

70
    def _handle_typedef_type_constant(self, type_, macro_name):
71
        decl = type_.empty_declaration_code()
72
        substituted_cname = decl.strip().replace('_', '__').replace(' ', '_')
73
        context = dict(
74
            type_substituted_cname=substituted_cname,
75
            macro_name=macro_name,
76
            type_cname=decl,
77
        )
78
        self.global_scope.use_utility_code(
79
            TempitaUtilityCode.load(
80
                'UFuncTypedef',
81
                'UFuncs_C.c',
82
                context=context
83
            ))
84
        return f"__Pyx_typedef_ufunc_{substituted_cname}"
85

86
    def _get_type_constant(self, pos, type_):
87
        base_type = type_
88
        if base_type.is_typedef:
89
            base_type = base_type.typedef_base_type
90
        base_type = PyrexTypes.remove_cv_ref(base_type)
91
        if base_type is PyrexTypes.c_bint_type:
92
            # TODO - this would be nice but not obvious it works
93
            error(pos, "Type '%s' cannot be used as a ufunc argument" % type_)
94
            return
95
        if type_.is_complex:
96
            return self._handle_typedef_type_constant(
97
                    type_,
98
                    "__PYX_GET_NPY_COMPLEX_TYPE")
99
        elif type_.is_int:
100
            signed = ""
101
            if type_.signed == PyrexTypes.SIGNED:
102
                signed = "S"
103
            elif type_.signed == PyrexTypes.UNSIGNED:
104
                signed = "U"
105
            return self._handle_typedef_type_constant(
106
                type_,
107
                f"__PYX_GET_NPY_{signed}INT_TYPE")
108
        elif type_.is_float:
109
            return self._handle_typedef_type_constant(
110
                type_,
111
                "__PYX_GET_NPY_FLOAT_TYPE")
112
        elif type_.is_pyobject:
113
            return "NPY_OBJECT"
114
        # TODO possible NPY_BOOL to bint but it needs a cast?
115
        # TODO NPY_DATETIME, NPY_TIMEDELTA, NPY_STRING, NPY_UNICODE and maybe NPY_VOID might be handleable
116
        error(pos, "Type '%s' cannot be used as a ufunc argument" % type_)
117

118
    def get_in_type_info(self):
119
        definitions = []
120
        for n, arg in enumerate(self.node.args):
121
            injected_typename = f"{self.injected_typename}_in_{n}"
122
            self.injected_types.append(injected_typename)
123
            type_const = self._get_type_constant(self.node.pos, arg.type)
124
            definitions.append(_ArgumentInfo(arg.type, type_const, injected_typename))
125
        return definitions
126

127
    def get_out_type_info(self):
128
        if self.node.return_type.is_ctuple:
129
            components = self.node.return_type.components
130
        else:
131
            components = [self.node.return_type]
132
        definitions = []
133
        for n, type in enumerate(components):
134
            injected_typename = f"{self.injected_typename}_out_{n}"
135
            self.injected_types.append(injected_typename)
136
            type_const = self._get_type_constant(self.node.pos, type)
137
            definitions.append(
138
                _ArgumentInfo(type, type_const, injected_typename)
139
            )
140
        return definitions
141

142
    def generate_cy_utility_code(self):
143
        arg_types = [(a.injected_typename, a.type) for a in self.in_definitions]
144
        out_types = [(a.injected_typename, a.type) for a in self.out_definitions]
145
        context_types = dict(arg_types + out_types)
146
        self.node.entry.used = True
147

148
        ufunc_cname = self.global_scope.next_id(self.node.entry.name + "_ufunc_def")
149

150
        will_be_called_without_gil = not (any(t.is_pyobject for _, t in arg_types) or
151
            any(t.is_pyobject for _, t in out_types))
152

153
        context = dict(
154
            func_cname=ufunc_cname,
155
            in_types=arg_types,
156
            out_types=out_types,
157
            inline_func_call=self.node.entry.cname,
158
            nogil=self.node.entry.type.nogil,
159
            will_be_called_without_gil=will_be_called_without_gil,
160
            **context_types
161
        )
162

163
        ufunc_global_scope = Symtab.ModuleScope(
164
            "ufunc_module", None, self.global_scope.context
165
        )
166
        ufunc_global_scope.declare_cfunction(
167
            name=self.node.entry.cname,
168
            cname=self.node.entry.cname,
169
            type=self.node.entry.type,
170
            pos=self.node.pos,
171
            visibility="extern",
172
        )
173

174
        code = CythonUtilityCode.load(
175
            "UFuncDefinition",
176
            "UFuncs.pyx",
177
            context=context,
178
            from_scope = ufunc_global_scope,
179
            #outer_module_scope=ufunc_global_scope,
180
        )
181

182
        tree = code.get_tree(entries_only=True)
183
        return tree
184

185
    def use_generic_utility_code(self):
186
        # use the invariant C utility code
187
        self.global_scope.use_utility_code(
188
            UtilityCode.load_cached("UFuncsInit", "UFuncs_C.c")
189
        )
190
        self.global_scope.use_utility_code(
191
            UtilityCode.load_cached("UFuncTypeHandling", "UFuncs_C.c")
192
        )
193
        self.global_scope.use_utility_code(
194
            UtilityCode.load_cached("NumpyImportUFunc", "NumpyImportArray.c")
195
        )
196

197

198
def convert_to_ufunc(node):
199
    if isinstance(node, Nodes.CFuncDefNode):
200
        if node.local_scope.parent_scope.is_c_class_scope:
201
            error(node.pos, "Methods cannot currently be converted to a ufunc")
202
            return node
203
        converters = [UFuncConversion(node)]
204
        original_node = node
205
    elif isinstance(node, FusedNode.FusedCFuncDefNode) and isinstance(
206
        node.node, Nodes.CFuncDefNode
207
    ):
208
        if node.node.local_scope.parent_scope.is_c_class_scope:
209
            error(node.pos, "Methods cannot currently be converted to a ufunc")
210
            return node
211
        converters = [UFuncConversion(n) for n in node.nodes]
212
        original_node = node.node
213
    else:
214
        error(node.pos, "Only C functions can be converted to a ufunc")
215
        return node
216

217
    if not converters:
218
        return  # this path probably shouldn't happen
219

220
    del converters[0].global_scope.entries[original_node.entry.name]
221
    # the generic utility code is generic, so there's no reason to do it multiple times
222
    converters[0].use_generic_utility_code()
223
    return [node] + _generate_stats_from_converters(converters, original_node)
224

225

226
def generate_ufunc_initialization(converters, cfunc_nodes, original_node):
227
    global_scope = converters[0].global_scope
228
    ufunc_funcs_name = global_scope.next_id(Naming.pyrex_prefix + "funcs")
229
    ufunc_types_name = global_scope.next_id(Naming.pyrex_prefix + "types")
230
    ufunc_data_name = global_scope.next_id(Naming.pyrex_prefix + "data")
231
    type_constants = []
232
    narg_in = None
233
    narg_out = None
234
    for c in converters:
235
        in_const = [d.type_constant for d in c.in_definitions]
236
        if narg_in is not None:
237
            assert narg_in == len(in_const)
238
        else:
239
            narg_in = len(in_const)
240
        type_constants.extend(in_const)
241
        out_const = [d.type_constant for d in c.out_definitions]
242
        if narg_out is not None:
243
            assert narg_out == len(out_const)
244
        else:
245
            narg_out = len(out_const)
246
        type_constants.extend(out_const)
247

248
    func_cnames = [cfnode.entry.cname for cfnode in cfunc_nodes]
249

250
    context = dict(
251
        ufunc_funcs_name=ufunc_funcs_name,
252
        func_cnames=func_cnames,
253
        ufunc_types_name=ufunc_types_name,
254
        type_constants=type_constants,
255
        ufunc_data_name=ufunc_data_name,
256
    )
257
    global_scope.use_utility_code(
258
        TempitaUtilityCode.load("UFuncConsts", "UFuncs_C.c", context=context)
259
    )
260

261
    pos = original_node.pos
262
    func_name = original_node.entry.name
263
    docstr = original_node.doc
264

265
    args_to_func = '%s(), %s, %s(), %s, %s, %s, PyUFunc_None, "%s", %s, 0' % (
266
        ufunc_funcs_name,
267
        ufunc_data_name,
268
        ufunc_types_name,
269
        len(func_cnames),
270
        narg_in,
271
        narg_out,
272
        func_name,
273
        docstr.as_c_string_literal() if docstr else "NULL",
274
    )
275

276
    call_node = ExprNodes.PythonCapiCallNode(
277
        pos,
278
        function_name="PyUFunc_FromFuncAndData",
279
        # use a dummy type because it's honestly too fiddly
280
        func_type=PyrexTypes.CFuncType(
281
            PyrexTypes.py_object_type,
282
            [PyrexTypes.CFuncTypeArg("dummy", PyrexTypes.c_void_ptr_type, None)],
283
        ),
284
        args=[
285
            ExprNodes.ConstNode(
286
                pos, type=PyrexTypes.c_void_ptr_type, value=args_to_func
287
            )
288
        ],
289
    )
290
    lhs_entry = global_scope.declare_var(func_name, PyrexTypes.py_object_type, pos)
291
    assgn_node = Nodes.SingleAssignmentNode(
292
        pos,
293
        lhs=ExprNodes.NameNode(
294
            pos, name=func_name, type=PyrexTypes.py_object_type, entry=lhs_entry
295
        ),
296
        rhs=call_node,
297
    )
298
    return assgn_node
299

300

301
def _generate_stats_from_converters(converters, node):
302
    stats = []
303
    for converter in converters:
304
        tree = converter.generate_cy_utility_code()
305
        ufunc_node = get_cfunc_from_tree(tree)
306
        # merge in any utility code
307
        converter.global_scope.utility_code_list.extend(tree.scope.utility_code_list)
308
        stats.append(ufunc_node)
309

310
    stats.append(generate_ufunc_initialization(converters, stats, node))
311
    return stats
312

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.