1
from __future__ import annotations
3
from dataclasses import dataclass
5
import torchgen.api.types as api_types
6
from torchgen.api import cpp, structured
7
from torchgen.api.types import (
17
from torchgen.model import (
28
def schema_kernel_name(func: FunctionSchema, dispatch_key: DispatchKey) -> str:
29
assert func.is_out_fn(), "ufunc.kernel_name should only be invoked on out schemas"
30
return f"ufunc_{func.name.name}_{dispatch_key}"
33
def kernel_name(g: NativeFunctionsGroup, dispatch_key: DispatchKey) -> str:
34
return schema_kernel_name(g.out.func, dispatch_key)
42
def dispatchstub_type(t: Type, *, binds: ArgName) -> NamedCType | None:
44
r = cpp.valuetype_type(t, binds=binds, symint=False)
48
if t == BaseType(BaseTy.Scalar):
49
return NamedCType(binds, ConstRefCType(BaseCType(scalarT)))
50
elif t == BaseType(BaseTy.Tensor):
53
raise AssertionError(f"unrecognized type {repr(t)}")
56
def opmath_type(scalar_t: BaseCppType) -> BaseCppType:
57
if scalar_t == api_types.scalar_t:
58
return api_types.opmath_t
59
raise NotImplementedError
67
def ufunctor_ctor_type(t: Type, *, binds: ArgName, scalar_t: BaseCppType) -> NamedCType:
68
r = cpp.valuetype_type(t, binds=binds, symint=False)
72
if t == BaseType(BaseTy.Scalar):
73
return NamedCType(binds, BaseCType(opmath_type(scalar_t)))
74
elif t == BaseType(BaseTy.Tensor):
75
return NamedCType(binds, BaseCType(opmath_type(scalar_t)))
77
raise AssertionError(f"unrecognized type {repr(t)}")
84
def ufunctor_apply_type(
85
t: Type, *, binds: ArgName, scalar_t: BaseCppType
87
if t == BaseType(BaseTy.Tensor):
88
return NamedCType(binds, BaseCType(scalar_t))
90
raise AssertionError(f"unrecognized type {repr(t)}")
96
def ufunc_type(t: Type, *, binds: ArgName, compute_t: CType) -> NamedCType:
97
r = cpp.valuetype_type(t, binds=binds, symint=False)
101
if t == BaseType(BaseTy.Scalar):
102
return NamedCType(binds, compute_t)
103
elif t == BaseType(BaseTy.Tensor):
104
return NamedCType(binds, compute_t)
106
raise AssertionError(f"unrecognized type {repr(t)}")
109
def ufunctor_ctor_argument(a: Argument, scalar_t: BaseCppType) -> Binding:
111
nctype=ufunctor_ctor_type(a.type, binds=a.name, scalar_t=scalar_t),
118
def ufunctor_apply_argument(a: Argument, scalar_t: BaseCppType) -> Binding:
120
nctype=ufunctor_apply_type(a.type, binds=a.name, scalar_t=scalar_t),
127
def ufunc_argument(a: Argument, compute_t: CType) -> Binding:
129
nctype=ufunc_type(a.type, binds=a.name, compute_t=compute_t),
136
@dataclass(frozen=True)
137
class UfunctorBindings:
159
def ufunctor_arguments(
160
g: NativeFunctionsGroup, *, scalar_tensor_idx: int | None, scalar_t: BaseCppType
161
) -> UfunctorBindings:
164
for a in g.functional.func.arguments.flat_non_out:
165
if a.type.is_tensor_like():
166
if scalar_tensor_idx == 0:
168
ctor.append(ufunctor_ctor_argument(a, scalar_t=scalar_t))
169
scalar_tensor_idx = None
171
if scalar_tensor_idx is not None:
172
scalar_tensor_idx -= 1
173
apply.append(ufunctor_apply_argument(a, scalar_t=scalar_t))
175
ctor.append(ufunctor_ctor_argument(a, scalar_t=scalar_t))
176
assert scalar_tensor_idx is None
177
return UfunctorBindings(ctor=ctor, apply=apply)
189
def ufunc_arguments(g: NativeFunctionsGroup, *, compute_t: CType) -> list[Binding]:
191
ufunc_argument(a, compute_t=compute_t)
192
for a in g.functional.func.arguments.flat_non_out
201
def stub_arguments(g: NativeFunctionsGroup) -> list[Binding]:
206
for a in g.out.func.arguments.flat_non_out
207
if not a.type.is_tensor_like()
208
for r in structured.argument(a)