7
from .Errors import error
8
from . import PyrexTypes
9
from .UtilityCode import CythonUtilityCode
10
from .Code import TempitaUtilityCode, UtilityCode
11
from .Visitor import TreeVisitor
15
class _FindCFuncDefNode(TreeVisitor):
17
Finds the CFuncDefNode in the tree
19
The assumption is that there's only one CFuncDefNode
24
def visit_Node(self, node):
28
self.visitchildren(node)
30
def visit_CFuncDefNode(self, node):
31
self.found_node = node
33
def __call__(self, tree):
35
return self.found_node
38
def get_cfunc_from_tree(tree):
39
return _FindCFuncDefNode()(tree)
44
Everything related to defining an input/output argument for a ufunc
47
type_constant - str such as "NPY_INT8" representing numpy dtype constants
48
injected_typename - str representing a name that can be used to look up the type
52
def __init__(self, type, type_constant, injected_typename):
54
self.type_constant = type_constant
55
self.injected_typename = injected_typename
59
def __init__(self, node):
61
self.global_scope = node.local_scope.global_scope()
63
self.injected_typename = "ufunc_typename"
64
while self.node.entry.cname.startswith(self.injected_typename):
65
self.injected_typename += "_"
66
self.injected_types = []
67
self.in_definitions = self.get_in_type_info()
68
self.out_definitions = self.get_out_type_info()
70
def _handle_typedef_type_constant(self, type_, macro_name):
71
decl = type_.empty_declaration_code()
72
substituted_cname = decl.strip().replace('_', '__').replace(' ', '_')
74
type_substituted_cname=substituted_cname,
75
macro_name=macro_name,
78
self.global_scope.use_utility_code(
79
TempitaUtilityCode.load(
84
return f"__Pyx_typedef_ufunc_{substituted_cname}"
86
def _get_type_constant(self, pos, type_):
88
if base_type.is_typedef:
89
base_type = base_type.typedef_base_type
90
base_type = PyrexTypes.remove_cv_ref(base_type)
91
if base_type is PyrexTypes.c_bint_type:
92
# TODO - this would be nice but not obvious it works
93
error(pos, "Type '%s' cannot be used as a ufunc argument" % type_)
96
return self._handle_typedef_type_constant(
98
"__PYX_GET_NPY_COMPLEX_TYPE")
101
if type_.signed == PyrexTypes.SIGNED:
103
elif type_.signed == PyrexTypes.UNSIGNED:
105
return self._handle_typedef_type_constant(
107
f"__PYX_GET_NPY_{signed}INT_TYPE")
109
return self._handle_typedef_type_constant(
111
"__PYX_GET_NPY_FLOAT_TYPE")
112
elif type_.is_pyobject:
114
# TODO possible NPY_BOOL to bint but it needs a cast?
115
# TODO NPY_DATETIME, NPY_TIMEDELTA, NPY_STRING, NPY_UNICODE and maybe NPY_VOID might be handleable
116
error(pos, "Type '%s' cannot be used as a ufunc argument" % type_)
118
def get_in_type_info(self):
120
for n, arg in enumerate(self.node.args):
121
injected_typename = f"{self.injected_typename}_in_{n}"
122
self.injected_types.append(injected_typename)
123
type_const = self._get_type_constant(self.node.pos, arg.type)
124
definitions.append(_ArgumentInfo(arg.type, type_const, injected_typename))
127
def get_out_type_info(self):
128
if self.node.return_type.is_ctuple:
129
components = self.node.return_type.components
131
components = [self.node.return_type]
133
for n, type in enumerate(components):
134
injected_typename = f"{self.injected_typename}_out_{n}"
135
self.injected_types.append(injected_typename)
136
type_const = self._get_type_constant(self.node.pos, type)
138
_ArgumentInfo(type, type_const, injected_typename)
142
def generate_cy_utility_code(self):
143
arg_types = [(a.injected_typename, a.type) for a in self.in_definitions]
144
out_types = [(a.injected_typename, a.type) for a in self.out_definitions]
145
context_types = dict(arg_types + out_types)
146
self.node.entry.used = True
148
ufunc_cname = self.global_scope.next_id(self.node.entry.name + "_ufunc_def")
150
will_be_called_without_gil = not (any(t.is_pyobject for _, t in arg_types) or
151
any(t.is_pyobject for _, t in out_types))
154
func_cname=ufunc_cname,
157
inline_func_call=self.node.entry.cname,
158
nogil=self.node.entry.type.nogil,
159
will_be_called_without_gil=will_be_called_without_gil,
163
ufunc_global_scope = Symtab.ModuleScope(
164
"ufunc_module", None, self.global_scope.context
166
ufunc_global_scope.declare_cfunction(
167
name=self.node.entry.cname,
168
cname=self.node.entry.cname,
169
type=self.node.entry.type,
174
code = CythonUtilityCode.load(
178
from_scope = ufunc_global_scope,
179
#outer_module_scope=ufunc_global_scope,
182
tree = code.get_tree(entries_only=True)
185
def use_generic_utility_code(self):
186
# use the invariant C utility code
187
self.global_scope.use_utility_code(
188
UtilityCode.load_cached("UFuncsInit", "UFuncs_C.c")
190
self.global_scope.use_utility_code(
191
UtilityCode.load_cached("UFuncTypeHandling", "UFuncs_C.c")
193
self.global_scope.use_utility_code(
194
UtilityCode.load_cached("NumpyImportUFunc", "NumpyImportArray.c")
198
def convert_to_ufunc(node):
199
if isinstance(node, Nodes.CFuncDefNode):
200
if node.local_scope.parent_scope.is_c_class_scope:
201
error(node.pos, "Methods cannot currently be converted to a ufunc")
203
converters = [UFuncConversion(node)]
205
elif isinstance(node, FusedNode.FusedCFuncDefNode) and isinstance(
206
node.node, Nodes.CFuncDefNode
208
if node.node.local_scope.parent_scope.is_c_class_scope:
209
error(node.pos, "Methods cannot currently be converted to a ufunc")
211
converters = [UFuncConversion(n) for n in node.nodes]
212
original_node = node.node
214
error(node.pos, "Only C functions can be converted to a ufunc")
218
return # this path probably shouldn't happen
220
del converters[0].global_scope.entries[original_node.entry.name]
221
# the generic utility code is generic, so there's no reason to do it multiple times
222
converters[0].use_generic_utility_code()
223
return [node] + _generate_stats_from_converters(converters, original_node)
226
def generate_ufunc_initialization(converters, cfunc_nodes, original_node):
227
global_scope = converters[0].global_scope
228
ufunc_funcs_name = global_scope.next_id(Naming.pyrex_prefix + "funcs")
229
ufunc_types_name = global_scope.next_id(Naming.pyrex_prefix + "types")
230
ufunc_data_name = global_scope.next_id(Naming.pyrex_prefix + "data")
235
in_const = [d.type_constant for d in c.in_definitions]
236
if narg_in is not None:
237
assert narg_in == len(in_const)
239
narg_in = len(in_const)
240
type_constants.extend(in_const)
241
out_const = [d.type_constant for d in c.out_definitions]
242
if narg_out is not None:
243
assert narg_out == len(out_const)
245
narg_out = len(out_const)
246
type_constants.extend(out_const)
248
func_cnames = [cfnode.entry.cname for cfnode in cfunc_nodes]
251
ufunc_funcs_name=ufunc_funcs_name,
252
func_cnames=func_cnames,
253
ufunc_types_name=ufunc_types_name,
254
type_constants=type_constants,
255
ufunc_data_name=ufunc_data_name,
257
global_scope.use_utility_code(
258
TempitaUtilityCode.load("UFuncConsts", "UFuncs_C.c", context=context)
261
pos = original_node.pos
262
func_name = original_node.entry.name
263
docstr = original_node.doc
265
args_to_func = '%s(), %s, %s(), %s, %s, %s, PyUFunc_None, "%s", %s, 0' % (
273
docstr.as_c_string_literal() if docstr else "NULL",
276
call_node = ExprNodes.PythonCapiCallNode(
278
function_name="PyUFunc_FromFuncAndData",
279
# use a dummy type because it's honestly too fiddly
280
func_type=PyrexTypes.CFuncType(
281
PyrexTypes.py_object_type,
282
[PyrexTypes.CFuncTypeArg("dummy", PyrexTypes.c_void_ptr_type, None)],
286
pos, type=PyrexTypes.c_void_ptr_type, value=args_to_func
290
lhs_entry = global_scope.declare_var(func_name, PyrexTypes.py_object_type, pos)
291
assgn_node = Nodes.SingleAssignmentNode(
293
lhs=ExprNodes.NameNode(
294
pos, name=func_name, type=PyrexTypes.py_object_type, entry=lhs_entry
301
def _generate_stats_from_converters(converters, node):
303
for converter in converters:
304
tree = converter.generate_cy_utility_code()
305
ufunc_node = get_cfunc_from_tree(tree)
306
# merge in any utility code
307
converter.global_scope.utility_code_list.extend(tree.scope.utility_code_list)
308
stats.append(ufunc_node)
310
stats.append(generate_ufunc_initialization(converters, stats, node))