pytorch

Форк
0
/
ufunc.py 
209 строк · 6.5 Кб
1
from __future__ import annotations
2

3
from dataclasses import dataclass
4

5
import torchgen.api.types as api_types
6
from torchgen.api import cpp, structured
7
from torchgen.api.types import (
8
    ArgName,
9
    BaseCppType,
10
    BaseCType,
11
    Binding,
12
    ConstRefCType,
13
    CType,
14
    NamedCType,
15
    scalarT,
16
)
17
from torchgen.model import (
18
    Argument,
19
    BaseTy,
20
    BaseType,
21
    DispatchKey,
22
    FunctionSchema,
23
    NativeFunctionsGroup,
24
    Type,
25
)
26

27

28
def schema_kernel_name(func: FunctionSchema, dispatch_key: DispatchKey) -> str:
29
    assert func.is_out_fn(), "ufunc.kernel_name should only be invoked on out schemas"
30
    return f"ufunc_{func.name.name}_{dispatch_key}"
31

32

33
def kernel_name(g: NativeFunctionsGroup, dispatch_key: DispatchKey) -> str:
34
    return schema_kernel_name(g.out.func, dispatch_key)
35

36

37
# Tensors are omitted (as they are stored in TensorIterator), everything else is
38
# passed along  (technically, we can pass tensors along too, it just wastes
39
# argument registers)
40
#
41
# NB: used for CPU only
42
def dispatchstub_type(t: Type, *, binds: ArgName) -> NamedCType | None:
43
    # Dispatch stubs are always plain ints
44
    r = cpp.valuetype_type(t, binds=binds, symint=False)
45
    if r is not None:
46
        return r
47

48
    if t == BaseType(BaseTy.Scalar):
49
        return NamedCType(binds, ConstRefCType(BaseCType(scalarT)))
50
    elif t == BaseType(BaseTy.Tensor):
51
        return None
52
    else:
53
        raise AssertionError(f"unrecognized type {repr(t)}")
54

55

56
def opmath_type(scalar_t: BaseCppType) -> BaseCppType:
57
    if scalar_t == api_types.scalar_t:
58
        return api_types.opmath_t
59
    raise NotImplementedError
60

61

62
# NB: Tensors in constructor are stored in opmath_t, not scalar_t
63
# because Tensor in constructor = its a scalar tensor partially applied =
64
# it can be higher precision and we want to compute in that higher precision
65
#
66
# NB: CUDA only
67
def ufunctor_ctor_type(t: Type, *, binds: ArgName, scalar_t: BaseCppType) -> NamedCType:
68
    r = cpp.valuetype_type(t, binds=binds, symint=False)
69
    if r is not None:
70
        return r
71

72
    if t == BaseType(BaseTy.Scalar):
73
        return NamedCType(binds, BaseCType(opmath_type(scalar_t)))
74
    elif t == BaseType(BaseTy.Tensor):
75
        return NamedCType(binds, BaseCType(opmath_type(scalar_t)))
76
    else:
77
        raise AssertionError(f"unrecognized type {repr(t)}")
78

79

80
# Only Tensors ever get passed directly to operator()
81
#
82
# NB: CUDA only
83
# (Actually, this works for CPU too)
84
def ufunctor_apply_type(
85
    t: Type, *, binds: ArgName, scalar_t: BaseCppType
86
) -> NamedCType:
87
    if t == BaseType(BaseTy.Tensor):
88
        return NamedCType(binds, BaseCType(scalar_t))
89
    else:
90
        raise AssertionError(f"unrecognized type {repr(t)}")
91

92

93
# The actual ufunc template function the user writes.  Everything here
94
# is done in the computation type.  compute_t is opmath_t in CUDA and scalar_t
95
# in CPU
96
def ufunc_type(t: Type, *, binds: ArgName, compute_t: CType) -> NamedCType:
97
    r = cpp.valuetype_type(t, binds=binds, symint=False)
98
    if r is not None:
99
        return r
100

101
    if t == BaseType(BaseTy.Scalar):
102
        return NamedCType(binds, compute_t)
103
    elif t == BaseType(BaseTy.Tensor):
104
        return NamedCType(binds, compute_t)
105
    else:
106
        raise AssertionError(f"unrecognized type {repr(t)}")
107

108

109
def ufunctor_ctor_argument(a: Argument, scalar_t: BaseCppType) -> Binding:
110
    return Binding(
111
        nctype=ufunctor_ctor_type(a.type, binds=a.name, scalar_t=scalar_t),
112
        name=a.name,
113
        default=None,
114
        argument=a,
115
    )
116

117

118
def ufunctor_apply_argument(a: Argument, scalar_t: BaseCppType) -> Binding:
119
    return Binding(
120
        nctype=ufunctor_apply_type(a.type, binds=a.name, scalar_t=scalar_t),
121
        name=a.name,
122
        default=None,
123
        argument=a,
124
    )
125

126

127
def ufunc_argument(a: Argument, compute_t: CType) -> Binding:
128
    return Binding(
129
        nctype=ufunc_type(a.type, binds=a.name, compute_t=compute_t),
130
        name=a.name,
131
        default=None,
132
        argument=a,
133
    )
134

135

136
@dataclass(frozen=True)
137
class UfunctorBindings:
138
    ctor: list[Binding]
139
    apply: list[Binding]
140

141

142
# ufunctors are a CUDA-only concept representing functors that take some of
143
# their arguments on a host-side constructor, and the rest in the device-side
144
# apply.  E.g.,
145
#
146
# template <typename scalar_t>
147
# struct CUDAFunctorOnSelf_add {
148
#   using opmath_t = at::opmath_type<scalar_t>;
149
#   opmath_t other_;
150
#   opmath_t alpha_;
151
#   CUDAFunctorOnSelf_add(opmath_t other, opmath_t alpha) : other_(other), alpha_(alpha) {}
152
#   __device__ scalar_t operator()(scalar_t self) {
153
#     return ufunc::add(static_cast<opmath_t>(self), other_, alpha_);
154
#   }
155
# };
156
#
157
# The ctor refers to the constructor CUDAFunctorOnSelf_add, while apply refers
158
# to the operator() definition
159
def ufunctor_arguments(
160
    g: NativeFunctionsGroup, *, scalar_tensor_idx: int | None, scalar_t: BaseCppType
161
) -> UfunctorBindings:
162
    ctor = []
163
    apply = []
164
    for a in g.functional.func.arguments.flat_non_out:
165
        if a.type.is_tensor_like():
166
            if scalar_tensor_idx == 0:
167
                # put it in the ctor anyway
168
                ctor.append(ufunctor_ctor_argument(a, scalar_t=scalar_t))
169
                scalar_tensor_idx = None
170
            else:
171
                if scalar_tensor_idx is not None:
172
                    scalar_tensor_idx -= 1
173
                apply.append(ufunctor_apply_argument(a, scalar_t=scalar_t))
174
        else:
175
            ctor.append(ufunctor_ctor_argument(a, scalar_t=scalar_t))
176
    assert scalar_tensor_idx is None
177
    return UfunctorBindings(ctor=ctor, apply=apply)
178

179

180
# ufuncs are the inner loop template functions that you wrote in ufunc/add.h
181
# which do the actual computation in question.  E.g.,
182
#
183
# template <typename T>
184
# C10_HOST_DEVICE T add(T self, T other, T alpha) __ubsan_ignore_undefined__ {
185
#   return self + alpha * other;
186
# }
187
#
188
# In this file, we refer to T as compute_t which is bound by caller
189
def ufunc_arguments(g: NativeFunctionsGroup, *, compute_t: CType) -> list[Binding]:
190
    return [
191
        ufunc_argument(a, compute_t=compute_t)
192
        for a in g.functional.func.arguments.flat_non_out
193
    ]
194

195

196
# Stubs are the DispatchStub trampolines that CPU kernels use to get to their
197
# vectorized versions.  E.g.,
198
#
199
# using structured_binary_fn_alpha = void(*)(TensorIteratorBase&, const Scalar& alpha);
200
# DECLARE_DISPATCH(structured_binary_fn_alpha, add_stub);
201
def stub_arguments(g: NativeFunctionsGroup) -> list[Binding]:
202
    # stubs drop all tensor arguments (they are implicit in the TensorIterator
203
    # argument and keep everything else)
204
    return [
205
        r
206
        for a in g.out.func.arguments.flat_non_out
207
        if not a.type.is_tensor_like()
208
        for r in structured.argument(a)
209
    ]
210

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.