1
from __future__ import annotations
3
from torchgen.api import cpp
4
from torchgen.api.types import (
22
from torchgen.model import (
30
TensorOptionsArguments,
33
from torchgen.utils import assert_never
44
def argumenttype_type(t: Type, *, mutable: bool, binds: ArgName) -> NamedCType:
51
r = cpp.valuetype_type(t, symint=False, binds=binds, mutable=mutable)
55
if isinstance(t, BaseType):
56
if t.name == BaseTy.Tensor:
57
return NamedCType(binds, ConstRefCType(BaseCType(tensorT)))
58
elif t.name == BaseTy.Scalar:
59
return NamedCType(binds, ConstRefCType(BaseCType(scalarT)))
61
raise AssertionError(f"base type should have been value type {t}")
62
elif isinstance(t, OptionalType):
63
if t.elem == BaseType(BaseTy.Tensor):
64
return NamedCType(binds, BaseCType(optionalTensorRefT))
65
elif t.elem == BaseType(BaseTy.Scalar):
66
return NamedCType(binds, BaseCType(optionalScalarRefT))
67
elif isinstance(t.elem, ListType) and str(t.elem.elem) == "int":
68
return NamedCType(binds, BaseCType(optionalIntArrayRefT))
69
elem = argumenttype_type(t.elem, mutable=mutable, binds=binds)
70
return NamedCType(binds, OptionalCType(elem.type))
71
elif isinstance(t, ListType):
72
if t.elem == BaseType(BaseTy.Tensor):
73
return NamedCType(binds, ConstRefCType(BaseCType(iTensorListRefT)))
74
elif t.elem == OptionalType(BaseType(BaseTy.Tensor)):
75
return NamedCType(binds, BaseCType(iOptTensorListRefT))
79
elif str(t.elem) == "int":
80
return NamedCType(binds, BaseCType(intArrayRefT))
81
elif str(t.elem) == "Dimname":
82
return NamedCType(binds, BaseCType(dimnameListT))
83
elem = argumenttype_type(t.elem, mutable=mutable, binds=binds)
84
return NamedCType(binds, ArrayRefCType(elem.type))
86
raise AssertionError(f"unrecognized type {repr(t)}")
89
def argument_type(a: Argument, *, binds: ArgName) -> NamedCType:
90
return argumenttype_type(a.type, mutable=a.is_write, binds=binds)
100
def argument(a: Argument | SelfArgument | TensorOptionsArguments) -> list[Binding]:
101
if isinstance(a, Argument):
104
nctype=argument_type(a, binds=a.name),
110
elif isinstance(a, SelfArgument):
111
return argument(a.argument)
112
elif isinstance(a, TensorOptionsArguments):
113
raise AssertionError("structured kernels don't support TensorOptions yet")
118
def impl_arguments(g: NativeFunctionsGroup) -> list[Binding]:
119
args: list[Argument | TensorOptionsArguments | SelfArgument] = []
121
if g.out.precomputed:
125
non_out_args_replaced: list[
126
Argument | TensorOptionsArguments | SelfArgument
128
for a in g.out.func.arguments.non_out:
129
if isinstance(a, Argument) and a.name in g.out.precomputed.replace:
132
non_out_args_replaced.extend(g.out.precomputed.replace[a.name])
135
non_out_args_replaced.append(a)
137
args.extend(non_out_args_replaced)
140
args.extend(g.out.precomputed.add)
142
args.extend(g.out.func.arguments.non_out)
144
args.extend(g.out.func.arguments.out)
145
return [r for arg in args for r in argument(arg)]
148
def meta_arguments(g: NativeFunctionsGroup) -> list[Binding]:
149
args: list[Argument | TensorOptionsArguments | SelfArgument] = []
150
args.extend(g.functional.func.arguments.non_out)
151
return [r for arg in args for r in argument(arg)]
154
def out_arguments(g: NativeFunctionsGroup) -> list[Binding]:
155
args: list[Argument | TensorOptionsArguments | SelfArgument] = []
156
args.extend(g.out.func.arguments.out)
157
return [r for arg in args for r in argument(arg)]