pytorch

Форк
0
/
ufunc.py 
551 строка · 17.4 Кб
1
from __future__ import annotations
2

3
from dataclasses import dataclass
4
from typing import Sequence, TYPE_CHECKING
5

6
import torchgen.api.ufunc as ufunc
7
from torchgen.api.translate import translate
8
from torchgen.api.types import (
9
    BaseCType,
10
    Binding,
11
    CType,
12
    Expr,
13
    NamedCType,
14
    opmath_t,
15
    scalar_t,
16
    StructuredImplSignature,
17
    VectorizedCType,
18
)
19
from torchgen.context import with_native_function
20
from torchgen.model import (
21
    Argument,
22
    BaseTy,
23
    BaseType,
24
    DispatchKey,
25
    NativeFunctionsGroup,
26
    ScalarType,
27
    UfuncKey,
28
)
29
from torchgen.utils import OrderedSet
30

31

32
if TYPE_CHECKING:
33
    from torchgen.api.ufunc import UfunctorBindings
34

35

36
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
37
#
38
#                                  CUDA STUFF
39
#
40
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
41

42
# NB: not bothering to generate dispatch stub forward declaration in header,
43
# we can just paste it whereever necessary
44

45
# TODO: use BackendIndex
46
# dispatch_key: DispatchKey  # only CPU/CUDA right now
47

48

49
# Represents functors for implementing CUDA ufuncs.
50
# Functors are templated by scalar_t because when USERS instantiate functors
51
# they are templated.  A functor looks something like this:
52
#
53
#   template <typename scalar_t>
54
#   struct CUDAFunctorOnSelf_add {
55
#     using opmath_t = at::opmath_type<scalar_t>;
56
#     opmath_t other_;
57
#     opmath_t alpha_;
58
#     CUDAFunctorOnSelf_add(opmath_t other, opmath_t alpha)
59
#         : other_(other), alpha_(alpha) {}
60
#     __device__ scalar_t operator()(scalar_t self) {
61
#       return ufunc::add(static_cast<opmath_t>(self), other_, alpha_);
62
#     }
63
#   };
64
#
65
@dataclass(frozen=True)
66
class UfunctorSignature:
67
    g: NativeFunctionsGroup
68
    scalar_tensor_idx: int | None
69
    name: str
70

71
    def arguments(self) -> UfunctorBindings:
72
        return ufunc.ufunctor_arguments(
73
            self.g, scalar_tensor_idx=self.scalar_tensor_idx, scalar_t=scalar_t
74
        )
75

76
    def fields(self) -> list[Binding]:
77
        # fields are renamed to have a trailing underscore, as is conventional
78
        return [b.rename(f"{b.name}_") for b in self.arguments().ctor]
79

80
    def returns_type(self) -> CType:
81
        # TODO: don't hardcode; return type will be inferred based on tags on
82
        # the native function
83
        return BaseCType(scalar_t)
84

85
    def decl_fields(self) -> str:
86
        return "\n".join(f"{f.type} {f.name};" for f in self.fields())
87

88
    def inline_defn_ctor(self) -> str:
89
        args_str = ", ".join(a.decl() for a in self.arguments().ctor)
90
        # NB: hypothetically could do this with translate but the
91
        # transition here is very regular
92
        init_str = ", ".join(f"{a.name}_({a.name})" for a in self.arguments().ctor)
93
        return f"{self.name}({args_str}) : {init_str} {{}}"
94

95
    def decl_apply(self) -> str:
96
        args_str = ", ".join(a.decl() for a in self.arguments().apply)
97
        return f"{self.returns_type().cpp_type()} operator()({args_str}) const"
98

99

100
@dataclass(frozen=True)
101
class UfuncSignature:
102
    g: NativeFunctionsGroup
103
    name: str
104
    compute_t: CType
105

106
    def arguments(self) -> list[Binding]:
107
        return ufunc.ufunc_arguments(self.g, compute_t=self.compute_t)
108

109
    def call(self, ctx: Sequence[Binding | Expr]) -> str:
110
        return f"{self.name}({', '.join(a.expr for a in translate(ctx, self.arguments()))})"
111

112

113
# steps:
114
#   1. take the functional signature
115
#   2. use api.ufunc to convert it to template signature.  this establishes
116
#      the type of the template function
117
#   3. use api.ufunc (II) to generate a split struct / operator() signature.
118
#      this establish context in which we call the template signature
119
#
120
# StructuredImplSignature context
121
#   ~> functor constructor sig
122
#
123
# Functor constructor context
124
#   ~> functor fields sig
125
#
126
# Functor apply context (functor fields + functor apply sig)
127
#   ~> template sig
128
#
129

130

131
def eligible_for_binary_scalar_specialization(g: NativeFunctionsGroup) -> bool:
132
    num_tensors = sum(
133
        1 for a in g.functional.func.arguments.flat_non_out if a.type.is_tensor_like()
134
    )
135
    return num_tensors == 2
136

137

138
def compute_ufunc_cuda_functors(
139
    g: NativeFunctionsGroup,
140
) -> tuple[dict[ScalarType, dict[UfuncKey, UfunctorSignature]], str]:
141
    # First, build the functors.
142
    ufunctor_sigs: dict[ScalarType, dict[UfuncKey, UfunctorSignature]] = {}
143
    ufunctors: list[str] = []
144
    loops = g.out.ufunc_inner_loop
145
    scalar_tensor_idx_lookup = {
146
        UfuncKey.CUDAFunctorOnSelf: 1,
147
        UfuncKey.CUDAFunctorOnOther: 0,
148
        UfuncKey.CUDAFunctor: None,
149
    }
150
    if eligible_for_binary_scalar_specialization(g):
151
        keys = [
152
            UfuncKey.CUDAFunctorOnSelf,
153
            UfuncKey.CUDAFunctorOnOther,
154
            UfuncKey.CUDAFunctor,
155
        ]
156
    else:
157
        keys = [UfuncKey.CUDAFunctor]
158
        for k in [UfuncKey.CUDAFunctorOnSelf, UfuncKey.CUDAFunctorOnOther]:
159
            assert k not in loops, f"cannot use {k} on non-binary function"
160
    for k in keys:
161
        # If the key was directly defined, skip functor codegen; we assume the
162
        # user already done it for us
163
        if k in loops:
164
            ufunctor_sig = UfunctorSignature(
165
                g, scalar_tensor_idx=scalar_tensor_idx_lookup[k], name=loops[k].name
166
            )
167
            for dtype in loops[k].supported_dtypes:
168
                ufunctor_sigs.setdefault(dtype, {})[k] = ufunctor_sig
169
            continue
170

171
        # Note [ScalarOnly and Generic must match names for CUDA]
172
        # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
173
        # Otherwise, look in ANY of the generic entries.  For simplicity of
174
        # codegen, both ScalarOnly and Generic are defined, the ufunc name
175
        # must match  (if they didn't match, we'd have to generate distinct
176
        # functors per dtype, which is awful, so we're not going to do it unless
177
        # someone really forces us to)
178
        ufunc_name = None
179
        supported_dtypes: OrderedSet[ScalarType] = OrderedSet()
180
        for lk in [UfuncKey.ScalarOnly, UfuncKey.Generic]:
181
            if lk not in loops:
182
                continue
183
            if ufunc_name is None:
184
                ufunc_name = loops[lk].name
185
            else:
186
                # See Note [ScalarOnly and Generic must match names for CUDA]
187
                assert (
188
                    ufunc_name == loops[lk].name
189
                ), "ScalarOnly and Generic must have same ufunc name"
190
            supported_dtypes |= loops[lk].supported_dtypes
191
        assert ufunc_name is not None
192

193
        name = f"{k}_{ufunc_name}"
194
        ufunctor_sig = UfunctorSignature(
195
            g, scalar_tensor_idx=scalar_tensor_idx_lookup[k], name=name
196
        )
197
        for dtype in supported_dtypes:
198
            ufunctor_sigs.setdefault(dtype, {})[k] = ufunctor_sig
199

200
        ufunc_sig = UfuncSignature(
201
            g, name=f"ufunc::{ufunc_name}", compute_t=BaseCType(opmath_t)
202
        )
203
        apply_ctx = ufunctor_sig.fields() + ufunctor_sig.arguments().apply
204
        ufunctors.append(
205
            f"""
206
template <typename scalar_t>
207
struct {ufunctor_sig.name} {{
208
  using opmath_t = at::opmath_type<scalar_t>;
209
  {ufunctor_sig.decl_fields()}
210
  {ufunctor_sig.inline_defn_ctor()}
211
  __device__ {ufunctor_sig.decl_apply()} {{
212
    return {ufunc_sig.call(apply_ctx)};
213
  }}
214
}};
215
"""
216
        )
217

218
    return ufunctor_sigs, "\n".join(ufunctors)
219

220

221
@dataclass(frozen=True)
222
class BinaryScalarSpecializationConfig:
223
    scalar_idx: int
224
    ctor_tensor: str
225
    ufunc_key: UfuncKey
226

227

228
BinaryScalarSpecializationConfigs = [
229
    BinaryScalarSpecializationConfig(
230
        scalar_idx=0,
231
        ctor_tensor="self",
232
        ufunc_key=UfuncKey.CUDAFunctorOnOther,
233
    ),
234
    BinaryScalarSpecializationConfig(
235
        scalar_idx=1,
236
        ctor_tensor="other",
237
        ufunc_key=UfuncKey.CUDAFunctorOnSelf,
238
    ),
239
]
240

241

242
def compute_ufunc_cuda_dtype_body(
243
    g: NativeFunctionsGroup,
244
    dtype: ScalarType,
245
    inner_loops: dict[UfuncKey, UfunctorSignature],
246
    parent_ctx: Sequence[Binding],
247
) -> str:
248
    body = "using opmath_t = at::opmath_type<scalar_t>;"
249
    body += "if (false) {}\n"  # for ease of codegen
250
    for config in BinaryScalarSpecializationConfigs:
251
        if config.ufunc_key not in inner_loops:
252
            continue
253
        ufunctor_sig = inner_loops[config.ufunc_key]
254
        scalar_idx = config.scalar_idx + 1
255
        # Make a copy and at the same time widen the type (not permissible
256
        # without copy; we don't want to mutate the input argument anyway)
257
        ctx: list[Expr | Binding] = list(parent_ctx)
258
        ctx.append(
259
            Expr(
260
                expr=f"iter.scalar_value<opmath_t>({scalar_idx})",
261
                type=NamedCType(config.ctor_tensor, BaseCType(opmath_t)),
262
            )
263
        )
264
        ufunctor_ctor_exprs_str = ", ".join(
265
            a.expr for a in translate(ctx, ufunctor_sig.arguments().ctor)
266
        )
267

268
        # NB: ufunctor must be allocated before iter.remove_operand is called,
269
        # as it relies on iter
270
        body += f"""\
271
else if (iter.is_cpu_scalar({scalar_idx})) {{
272
  {ufunctor_sig.name}<scalar_t> ufunctor({ufunctor_ctor_exprs_str});
273
  iter.remove_operand({scalar_idx});
274
  gpu_kernel(iter, ufunctor);
275
}}"""
276

277
    ufunctor_sig = inner_loops[UfuncKey.CUDAFunctor]
278
    ufunctor_ctor_exprs_str = ", ".join(
279
        a.expr for a in translate(parent_ctx, ufunctor_sig.arguments().ctor)
280
    )
281
    body += f"""
282
else {{
283
  gpu_kernel(iter, {ufunctor_sig.name}<scalar_t>({ufunctor_ctor_exprs_str}));
284
}}
285
    """
286
    return body
287

288

289
@with_native_function
290
def compute_ufunc_cuda(g: NativeFunctionsGroup) -> str:
291
    # First, build the functors, indexing them by dtype
292
    ufunctor_sigs, ufunctors = compute_ufunc_cuda_functors(g)
293

294
    # Next, build the conditionals
295
    sig = StructuredImplSignature(g, ufunc.kernel_name(g, DispatchKey.CUDA))
296
    dtype_cases = []
297
    for dtype, inner_ufunc_sigs in ufunctor_sigs.items():
298
        dtype_cases.append(
299
            f"""
300
AT_DISPATCH_CASE(at::ScalarType::{dtype},
301
  [&]() {{
302
    {compute_ufunc_cuda_dtype_body(g, dtype, inner_ufunc_sigs, sig.arguments())}
303
  }}
304
)
305
"""
306
        )
307

308
    dtype_cases_str = "\n".join(dtype_cases)
309

310
    stub_sig = StubSignature(g)
311

312
    return f"""
313
{ufunctors}
314

315
{stub_sig.type_defn()};
316
{stub_sig.dispatch_decl()};
317

318
{stub_sig.kernel_defn()} {{
319
  AT_DISPATCH_SWITCH(iter.common_dtype(), "{sig.name}",
320
    {dtype_cases_str}
321
  );
322
}}
323
REGISTER_DISPATCH({stub_sig.name}, &{stub_sig.kernel_name});
324

325
{sig.defn()} {{
326
  {stub_sig.direct_call(sig.arguments())};
327
}}
328
"""
329

330

331
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
332
#
333
#                                   CPU STUFF
334
#
335
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
336

337

338
@dataclass(frozen=True)
339
class StubSignature:
340
    g: NativeFunctionsGroup
341

342
    @property
343
    def name(self) -> str:
344
        return f"{str(self.g.functional.func.name.name)}_stub"
345

346
    @property
347
    def kernel_name(self) -> str:
348
        return f"{str(self.g.functional.func.name.name)}_kernel"
349

350
    @property
351
    def type_name(self) -> str:
352
        return f"{str(self.g.functional.func.name.name)}_fn"
353

354
    def arguments(self) -> list[Binding]:
355
        return ufunc.stub_arguments(self.g)
356

357
    def type(self) -> str:
358
        cpp_args = self.arguments()
359
        return f"void(*)(TensorIteratorBase&, {', '.join(a.type for a in cpp_args)})"
360

361
    def dispatch_decl(self) -> str:
362
        return f"DECLARE_DISPATCH({self.type_name}, {self.name})"
363

364
    def dispatch_defn(self) -> str:
365
        return f"DEFINE_DISPATCH({self.name})"
366

367
    def kernel_defn(self) -> str:
368
        return f"void {self.kernel_name}(TensorIteratorBase& iter, {', '.join(a.defn() for a in self.arguments())})"
369

370
    def type_defn(self) -> str:
371
        return f"using {self.type_name} = {self.type()}"
372

373
    # must be called from context where this is TensorIteratorBase*
374
    def call(self, ctx: Sequence[Binding]) -> str:
375
        return f"{self.name}(device_type(), *this, {', '.join(a.expr for a in translate(ctx, self.arguments()))})"
376

377
    # used in CUDA to skip the unnecessary dynamic dispatch
378
    def direct_call(self, ctx: Sequence[Binding]) -> str:
379
        return f"{self.kernel_name}(*this, {', '.join(a.expr for a in translate(ctx, self.arguments()))})"
380

381

382
@with_native_function
383
def compute_ufunc_cpu(g: NativeFunctionsGroup) -> str:
384
    stub_sig = StubSignature(g)
385
    sig = StructuredImplSignature(g, ufunc.kernel_name(g, DispatchKey.CPU))
386

387
    return f"""
388
{stub_sig.type_defn()};
389
{stub_sig.dispatch_decl()};
390
{stub_sig.dispatch_defn()};
391

392
{sig.defn()} {{
393
  {stub_sig.call(sig.arguments())};
394
}}
395
"""
396

397

398
def compute_ufunc_cpu_dtype_body(
399
    g: NativeFunctionsGroup,
400
    dtype: ScalarType,
401
    inner_loops: dict[UfuncKey, UfuncSignature],
402
    parent_ctx: Sequence[Binding],
403
) -> str:
404
    assert UfuncKey.CPUScalar in inner_loops, f"{dtype}, {inner_loops.keys()}"
405
    assert inner_loops.keys() <= {UfuncKey.CPUScalar, UfuncKey.CPUVector}
406
    scalar_loop = inner_loops[UfuncKey.CPUScalar]
407
    vec_loop = None
408
    if UfuncKey.CPUVector in inner_loops:
409
        vec_loop = inner_loops[UfuncKey.CPUVector]
410

411
    # NB: We DON'T use translate here, because translate is
412
    # incapable of CSE'ing the scalar accesses in case it is also
413
    # used by Vectorized; also, the unpacking here is very simple
414
    # and only affects Scalar; everything else is implicitly captured
415
    # by the lambda
416

417
    # Setup scalar in scope
418
    body = []
419
    ctx = []
420
    for b in parent_ctx:
421
        if isinstance(b.argument, Argument) and b.argument.type != BaseType(
422
            BaseTy.Scalar
423
        ):
424
            continue
425
        body.append(f"auto _s_{b.name} = {b.name}.to<scalar_t>();")
426
        ctx.append(Expr(f"_s_{b.name}", NamedCType(b.nctype.name, BaseCType(scalar_t))))
427
    if vec_loop is not None:
428
        for b in parent_ctx:
429
            if isinstance(b.argument, Argument) and b.argument.type != BaseType(
430
                BaseTy.Scalar
431
            ):
432
                continue
433
            body.append(
434
                f"auto _v_{b.name} = at::vec::Vectorized<scalar_t>(_s_{b.name});"
435
            )
436
            ctx.append(
437
                Expr(
438
                    f"_v_{b.name}",
439
                    NamedCType(b.nctype.name, VectorizedCType(BaseCType(scalar_t))),
440
                )
441
            )
442

443
    # Setup lambda signature
444
    # NB: simplified version of ufunctor_arguments
445
    scalar_bindings = []
446
    vec_bindings = []
447
    for a in g.functional.func.arguments.flat_non_out:
448
        if not a.type.is_tensor_like():
449
            continue
450
        assert a.type == BaseType(BaseTy.Tensor)
451
        scalar_bindings.append(
452
            Binding(
453
                name=a.name,
454
                nctype=NamedCType(a.name, BaseCType(scalar_t)),
455
                argument=a,
456
            )
457
        )
458
        if vec_loop is not None:
459
            vec_bindings.append(
460
                Binding(
461
                    name=a.name,
462
                    nctype=NamedCType(a.name, VectorizedCType(BaseCType(scalar_t))),
463
                    argument=a,
464
                )
465
            )
466

467
    def with_ctx(b: Sequence[Binding]) -> list[Expr | Binding]:
468
        r: list[Expr | Binding] = []
469
        r.extend(ctx)
470
        r.extend(b)
471
        return r
472

473
    body_str = "\n".join(body)
474
    if vec_loop is not None:
475
        return f"""
476
{body_str}
477
cpu_kernel_vec(iter,
478
  [=]({', '.join(b.decl() for b in scalar_bindings)}) {{ return {scalar_loop.call(with_ctx(scalar_bindings))}; }},
479
  [=]({', '.join(b.decl() for b in vec_bindings)}) {{ return {vec_loop.call(with_ctx(vec_bindings))}; }}
480
);
481
"""
482
    else:
483
        return f"""
484
{body_str}
485
cpu_kernel(iter,
486
  [=]({', '.join(b.decl() for b in scalar_bindings)}) {{ return {scalar_loop.call(with_ctx(scalar_bindings))}; }}
487
);
488
"""
489

490

491
@with_native_function
492
def compute_ufunc_cpu_kernel(g: NativeFunctionsGroup) -> str:
493
    stub_sig = StubSignature(g)
494

495
    # Reindex the ufunc by dtypes; processing generic/scalaronly as well
496
    loops = g.out.ufunc_inner_loop
497
    ufunc_sigs: dict[ScalarType, dict[UfuncKey, UfuncSignature]] = {}
498
    for k in [UfuncKey.CPUScalar, UfuncKey.CPUVector]:
499
        lks = []
500
        # ORDER MATTERS: this specifies overriding precedence
501
        if k in loops:  # should happen rarely
502
            lks.append(k)
503
        if UfuncKey.ScalarOnly in loops and k is UfuncKey.CPUScalar:
504
            lks.append(UfuncKey.ScalarOnly)
505
        if UfuncKey.Generic in loops:
506
            lks.append(UfuncKey.Generic)
507
        # TODO: don't hardcode ufunc:: namespace here, should be centralized smh
508
        for lk in lks:
509
            for dtype in loops[lk].supported_dtypes:
510
                compute_t: CType
511
                if k is UfuncKey.CPUScalar:
512
                    compute_t = BaseCType(scalar_t)
513
                elif k is UfuncKey.CPUVector:
514
                    compute_t = VectorizedCType(BaseCType(scalar_t))
515
                else:
516
                    raise AssertionError
517
                inner_ufunc_sigs = ufunc_sigs.setdefault(dtype, {})
518
                if k not in inner_ufunc_sigs:
519
                    inner_ufunc_sigs[k] = UfuncSignature(
520
                        g, name=f"ufunc::{loops[lk].name}", compute_t=compute_t
521
                    )
522

523
    # Build the conditionals
524
    dtype_cases = []
525
    for dtype, inner_ufunc_sigs in ufunc_sigs.items():
526
        dtype_cases.append(
527
            f"""
528
AT_DISPATCH_CASE(at::ScalarType::{dtype},
529
  [&]() {{
530
    {compute_ufunc_cpu_dtype_body(g, dtype, inner_ufunc_sigs, stub_sig.arguments())}
531
  }}
532
)
533
"""
534
        )
535

536
    dtype_cases_str = "\n".join(dtype_cases)
537
    return f"""
538
namespace {{
539

540
{stub_sig.kernel_defn()} {{
541
  AT_DISPATCH_SWITCH(iter.common_dtype(), "{stub_sig.name}",
542
    {dtype_cases_str}
543
  );
544
}}
545

546
}} // anonymous namespace
547

548
{stub_sig.type_defn()};
549
{stub_sig.dispatch_decl()};
550
REGISTER_DISPATCH({stub_sig.name}, &{stub_sig.kernel_name});
551
"""
552

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.