1
from __future__ import annotations
3
from dataclasses import dataclass
4
from typing import Sequence, TYPE_CHECKING
6
import torchgen.api.ufunc as ufunc
7
from torchgen.api.translate import translate
8
from torchgen.api.types import (
16
StructuredImplSignature,
19
from torchgen.context import with_native_function
20
from torchgen.model import (
29
from torchgen.utils import OrderedSet
33
from torchgen.api.ufunc import UfunctorBindings
65
@dataclass(frozen=True)
66
class UfunctorSignature:
67
g: NativeFunctionsGroup
68
scalar_tensor_idx: int | None
71
def arguments(self) -> UfunctorBindings:
72
return ufunc.ufunctor_arguments(
73
self.g, scalar_tensor_idx=self.scalar_tensor_idx, scalar_t=scalar_t
76
def fields(self) -> list[Binding]:
78
return [b.rename(f"{b.name}_") for b in self.arguments().ctor]
80
def returns_type(self) -> CType:
83
return BaseCType(scalar_t)
85
def decl_fields(self) -> str:
86
return "\n".join(f"{f.type} {f.name};" for f in self.fields())
88
def inline_defn_ctor(self) -> str:
89
args_str = ", ".join(a.decl() for a in self.arguments().ctor)
92
init_str = ", ".join(f"{a.name}_({a.name})" for a in self.arguments().ctor)
93
return f"{self.name}({args_str}) : {init_str} {{}}"
95
def decl_apply(self) -> str:
96
args_str = ", ".join(a.decl() for a in self.arguments().apply)
97
return f"{self.returns_type().cpp_type()} operator()({args_str}) const"
100
@dataclass(frozen=True)
102
g: NativeFunctionsGroup
106
def arguments(self) -> list[Binding]:
107
return ufunc.ufunc_arguments(self.g, compute_t=self.compute_t)
109
def call(self, ctx: Sequence[Binding | Expr]) -> str:
110
return f"{self.name}({', '.join(a.expr for a in translate(ctx, self.arguments()))})"
131
def eligible_for_binary_scalar_specialization(g: NativeFunctionsGroup) -> bool:
133
1 for a in g.functional.func.arguments.flat_non_out if a.type.is_tensor_like()
135
return num_tensors == 2
138
def compute_ufunc_cuda_functors(
139
g: NativeFunctionsGroup,
140
) -> tuple[dict[ScalarType, dict[UfuncKey, UfunctorSignature]], str]:
142
ufunctor_sigs: dict[ScalarType, dict[UfuncKey, UfunctorSignature]] = {}
143
ufunctors: list[str] = []
144
loops = g.out.ufunc_inner_loop
145
scalar_tensor_idx_lookup = {
146
UfuncKey.CUDAFunctorOnSelf: 1,
147
UfuncKey.CUDAFunctorOnOther: 0,
148
UfuncKey.CUDAFunctor: None,
150
if eligible_for_binary_scalar_specialization(g):
152
UfuncKey.CUDAFunctorOnSelf,
153
UfuncKey.CUDAFunctorOnOther,
154
UfuncKey.CUDAFunctor,
157
keys = [UfuncKey.CUDAFunctor]
158
for k in [UfuncKey.CUDAFunctorOnSelf, UfuncKey.CUDAFunctorOnOther]:
159
assert k not in loops, f"cannot use {k} on non-binary function"
164
ufunctor_sig = UfunctorSignature(
165
g, scalar_tensor_idx=scalar_tensor_idx_lookup[k], name=loops[k].name
167
for dtype in loops[k].supported_dtypes:
168
ufunctor_sigs.setdefault(dtype, {})[k] = ufunctor_sig
179
supported_dtypes: OrderedSet[ScalarType] = OrderedSet()
180
for lk in [UfuncKey.ScalarOnly, UfuncKey.Generic]:
183
if ufunc_name is None:
184
ufunc_name = loops[lk].name
188
ufunc_name == loops[lk].name
189
), "ScalarOnly and Generic must have same ufunc name"
190
supported_dtypes |= loops[lk].supported_dtypes
191
assert ufunc_name is not None
193
name = f"{k}_{ufunc_name}"
194
ufunctor_sig = UfunctorSignature(
195
g, scalar_tensor_idx=scalar_tensor_idx_lookup[k], name=name
197
for dtype in supported_dtypes:
198
ufunctor_sigs.setdefault(dtype, {})[k] = ufunctor_sig
200
ufunc_sig = UfuncSignature(
201
g, name=f"ufunc::{ufunc_name}", compute_t=BaseCType(opmath_t)
203
apply_ctx = ufunctor_sig.fields() + ufunctor_sig.arguments().apply
206
template <typename scalar_t>
207
struct {ufunctor_sig.name} {{
208
using opmath_t = at::opmath_type<scalar_t>;
209
{ufunctor_sig.decl_fields()}
210
{ufunctor_sig.inline_defn_ctor()}
211
__device__ {ufunctor_sig.decl_apply()} {{
212
return {ufunc_sig.call(apply_ctx)};
218
return ufunctor_sigs, "\n".join(ufunctors)
221
@dataclass(frozen=True)
222
class BinaryScalarSpecializationConfig:
228
BinaryScalarSpecializationConfigs = [
229
BinaryScalarSpecializationConfig(
232
ufunc_key=UfuncKey.CUDAFunctorOnOther,
234
BinaryScalarSpecializationConfig(
237
ufunc_key=UfuncKey.CUDAFunctorOnSelf,
242
def compute_ufunc_cuda_dtype_body(
243
g: NativeFunctionsGroup,
245
inner_loops: dict[UfuncKey, UfunctorSignature],
246
parent_ctx: Sequence[Binding],
248
body = "using opmath_t = at::opmath_type<scalar_t>;"
249
body += "if (false) {}\n"
250
for config in BinaryScalarSpecializationConfigs:
251
if config.ufunc_key not in inner_loops:
253
ufunctor_sig = inner_loops[config.ufunc_key]
254
scalar_idx = config.scalar_idx + 1
257
ctx: list[Expr | Binding] = list(parent_ctx)
260
expr=f"iter.scalar_value<opmath_t>({scalar_idx})",
261
type=NamedCType(config.ctor_tensor, BaseCType(opmath_t)),
264
ufunctor_ctor_exprs_str = ", ".join(
265
a.expr for a in translate(ctx, ufunctor_sig.arguments().ctor)
271
else if (iter.is_cpu_scalar({scalar_idx})) {{
272
{ufunctor_sig.name}<scalar_t> ufunctor({ufunctor_ctor_exprs_str});
273
iter.remove_operand({scalar_idx});
274
gpu_kernel(iter, ufunctor);
277
ufunctor_sig = inner_loops[UfuncKey.CUDAFunctor]
278
ufunctor_ctor_exprs_str = ", ".join(
279
a.expr for a in translate(parent_ctx, ufunctor_sig.arguments().ctor)
283
gpu_kernel(iter, {ufunctor_sig.name}<scalar_t>({ufunctor_ctor_exprs_str}));
290
def compute_ufunc_cuda(g: NativeFunctionsGroup) -> str:
292
ufunctor_sigs, ufunctors = compute_ufunc_cuda_functors(g)
295
sig = StructuredImplSignature(g, ufunc.kernel_name(g, DispatchKey.CUDA))
297
for dtype, inner_ufunc_sigs in ufunctor_sigs.items():
300
AT_DISPATCH_CASE(at::ScalarType::{dtype},
302
{compute_ufunc_cuda_dtype_body(g, dtype, inner_ufunc_sigs, sig.arguments())}
308
dtype_cases_str = "\n".join(dtype_cases)
310
stub_sig = StubSignature(g)
315
{stub_sig.type_defn()};
316
{stub_sig.dispatch_decl()};
318
{stub_sig.kernel_defn()} {{
319
AT_DISPATCH_SWITCH(iter.common_dtype(), "{sig.name}",
323
REGISTER_DISPATCH({stub_sig.name}, &{stub_sig.kernel_name});
326
{stub_sig.direct_call(sig.arguments())};
338
@dataclass(frozen=True)
340
g: NativeFunctionsGroup
343
def name(self) -> str:
344
return f"{str(self.g.functional.func.name.name)}_stub"
347
def kernel_name(self) -> str:
348
return f"{str(self.g.functional.func.name.name)}_kernel"
351
def type_name(self) -> str:
352
return f"{str(self.g.functional.func.name.name)}_fn"
354
def arguments(self) -> list[Binding]:
355
return ufunc.stub_arguments(self.g)
357
def type(self) -> str:
358
cpp_args = self.arguments()
359
return f"void(*)(TensorIteratorBase&, {', '.join(a.type for a in cpp_args)})"
361
def dispatch_decl(self) -> str:
362
return f"DECLARE_DISPATCH({self.type_name}, {self.name})"
364
def dispatch_defn(self) -> str:
365
return f"DEFINE_DISPATCH({self.name})"
367
def kernel_defn(self) -> str:
368
return f"void {self.kernel_name}(TensorIteratorBase& iter, {', '.join(a.defn() for a in self.arguments())})"
370
def type_defn(self) -> str:
371
return f"using {self.type_name} = {self.type()}"
374
def call(self, ctx: Sequence[Binding]) -> str:
375
return f"{self.name}(device_type(), *this, {', '.join(a.expr for a in translate(ctx, self.arguments()))})"
378
def direct_call(self, ctx: Sequence[Binding]) -> str:
379
return f"{self.kernel_name}(*this, {', '.join(a.expr for a in translate(ctx, self.arguments()))})"
383
def compute_ufunc_cpu(g: NativeFunctionsGroup) -> str:
384
stub_sig = StubSignature(g)
385
sig = StructuredImplSignature(g, ufunc.kernel_name(g, DispatchKey.CPU))
388
{stub_sig.type_defn()};
389
{stub_sig.dispatch_decl()};
390
{stub_sig.dispatch_defn()};
393
{stub_sig.call(sig.arguments())};
398
def compute_ufunc_cpu_dtype_body(
399
g: NativeFunctionsGroup,
401
inner_loops: dict[UfuncKey, UfuncSignature],
402
parent_ctx: Sequence[Binding],
404
assert UfuncKey.CPUScalar in inner_loops, f"{dtype}, {inner_loops.keys()}"
405
assert inner_loops.keys() <= {UfuncKey.CPUScalar, UfuncKey.CPUVector}
406
scalar_loop = inner_loops[UfuncKey.CPUScalar]
408
if UfuncKey.CPUVector in inner_loops:
409
vec_loop = inner_loops[UfuncKey.CPUVector]
421
if isinstance(b.argument, Argument) and b.argument.type != BaseType(
425
body.append(f"auto _s_{b.name} = {b.name}.to<scalar_t>();")
426
ctx.append(Expr(f"_s_{b.name}", NamedCType(b.nctype.name, BaseCType(scalar_t))))
427
if vec_loop is not None:
429
if isinstance(b.argument, Argument) and b.argument.type != BaseType(
434
f"auto _v_{b.name} = at::vec::Vectorized<scalar_t>(_s_{b.name});"
439
NamedCType(b.nctype.name, VectorizedCType(BaseCType(scalar_t))),
447
for a in g.functional.func.arguments.flat_non_out:
448
if not a.type.is_tensor_like():
450
assert a.type == BaseType(BaseTy.Tensor)
451
scalar_bindings.append(
454
nctype=NamedCType(a.name, BaseCType(scalar_t)),
458
if vec_loop is not None:
462
nctype=NamedCType(a.name, VectorizedCType(BaseCType(scalar_t))),
467
def with_ctx(b: Sequence[Binding]) -> list[Expr | Binding]:
468
r: list[Expr | Binding] = []
473
body_str = "\n".join(body)
474
if vec_loop is not None:
478
[=]({', '.join(b.decl() for b in scalar_bindings)}) {{ return {scalar_loop.call(with_ctx(scalar_bindings))}; }},
479
[=]({', '.join(b.decl() for b in vec_bindings)}) {{ return {vec_loop.call(with_ctx(vec_bindings))}; }}
486
[=]({', '.join(b.decl() for b in scalar_bindings)}) {{ return {scalar_loop.call(with_ctx(scalar_bindings))}; }}
492
def compute_ufunc_cpu_kernel(g: NativeFunctionsGroup) -> str:
493
stub_sig = StubSignature(g)
496
loops = g.out.ufunc_inner_loop
497
ufunc_sigs: dict[ScalarType, dict[UfuncKey, UfuncSignature]] = {}
498
for k in [UfuncKey.CPUScalar, UfuncKey.CPUVector]:
503
if UfuncKey.ScalarOnly in loops and k is UfuncKey.CPUScalar:
504
lks.append(UfuncKey.ScalarOnly)
505
if UfuncKey.Generic in loops:
506
lks.append(UfuncKey.Generic)
509
for dtype in loops[lk].supported_dtypes:
511
if k is UfuncKey.CPUScalar:
512
compute_t = BaseCType(scalar_t)
513
elif k is UfuncKey.CPUVector:
514
compute_t = VectorizedCType(BaseCType(scalar_t))
517
inner_ufunc_sigs = ufunc_sigs.setdefault(dtype, {})
518
if k not in inner_ufunc_sigs:
519
inner_ufunc_sigs[k] = UfuncSignature(
520
g, name=f"ufunc::{loops[lk].name}", compute_t=compute_t
525
for dtype, inner_ufunc_sigs in ufunc_sigs.items():
528
AT_DISPATCH_CASE(at::ScalarType::{dtype},
530
{compute_ufunc_cpu_dtype_body(g, dtype, inner_ufunc_sigs, stub_sig.arguments())}
536
dtype_cases_str = "\n".join(dtype_cases)
540
{stub_sig.kernel_defn()} {{
541
AT_DISPATCH_SWITCH(iter.common_dtype(), "{stub_sig.name}",
546
}} // anonymous namespace
548
{stub_sig.type_defn()};
549
{stub_sig.dispatch_decl()};
550
REGISTER_DISPATCH({stub_sig.name}, &{stub_sig.kernel_name});