gpt-neox

Форк
0
/
fused_rotary_positional_embedding_cuda.cu 
336 строк · 15.4 Кб
1
/* coding=utf-8
2
 * Copyright (c) 2024, NVIDIA CORPORATION.  All rights reserved.
3
 *
4
 * Licensed under the Apache License, Version 2.0 (the "License");
5
 * you may not use this file except in compliance with the License.
6
 * You may obtain a copy of the License at
7
 *
8
 *     http://www.apache.org/licenses/LICENSE-2.0
9
 *
10
 * Unless required by applicable law or agreed to in writing, software
11
 * distributed under the License is distributed on an "AS IS" BASIS,
12
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
 * See the License for the specific language governing permissions and
14
 * limitations under the License.
15
 */
16

17
#include <ATen/ATen.h>
18

19
#include "fused_rotary_positional_embedding.h"
20
#include "type_shim.h"
21

22
namespace fused_rope {
23

24
torch::Tensor fwd_cuda(const torch::Tensor& input,
25
                       const torch::Tensor& freqs,
26
                       const bool transpose_output)
27
{
28
    // input sizes: (s, b, h, d)
29
    // s: sequence length
30
    // b: batch size
31
    // h: head num
32
    // d: dim of each head
33
    const int s = input.size(0);
34
    const int b = input.size(1);
35
    const int h = input.size(2);
36
    const int d = input.size(3);
37
    // input strides
38
    const int stride_s = input.stride(0);
39
    const int stride_b = input.stride(1);
40
    const int stride_h = input.stride(2);
41
    const int stride_d = input.stride(3);
42
    // freqs' shape is always (s, 1, 1, d2), so the strides are same under
43
    // different memory formats
44
    const int d2 = freqs.size(3);
45

46
    // output
47
    auto act_options = input.options().requires_grad(false);
48
    torch::Tensor output;
49
    if (transpose_output) {
50
        output = torch::empty({b, s, h, d}, act_options).transpose(0, 1);
51
    } else {
52
        output = torch::empty({s, b, h, d}, act_options);
53
    }
54
    // output strides
55
    const int o_stride_s = output.stride(0);
56
    const int o_stride_b = output.stride(1);
57
    const int o_stride_h = output.stride(2);
58
    const int o_stride_d = output.stride(3);
59

60
    DISPATCH_FLOAT_HALF_AND_BFLOAT(input.scalar_type(),
61
                                   0,
62
                                   "dispatch_fused_rope_forward",
63
                                   dispatch_fused_rope_forward(s,
64
                                                               b,
65
                                                               h,
66
                                                               d,
67
                                                               d2,
68
                                                               stride_s,
69
                                                               stride_b,
70
                                                               stride_h,
71
                                                               stride_d,
72
                                                               o_stride_s,
73
                                                               o_stride_b,
74
                                                               o_stride_h,
75
                                                               o_stride_d,
76
                                                               input.data_ptr<scalar_t_0>(),
77
                                                               freqs.data_ptr<float>(),
78
                                                               output.data_ptr<scalar_t_0>()););
79
    return output;
80
}
81

82
torch::Tensor bwd_cuda(const torch::Tensor& output_grads,
83
                       const torch::Tensor& freqs,
84
                       const bool transpose_output)
85
{
86
    // output_grads sizes: (s, b, h, d)
87
    // s: sequence length
88
    // b: batch size
89
    // h: head num
90
    // d: dim of each head
91
    const int s = output_grads.size(0);
92
    const int b = output_grads.size(1);
93
    const int h = output_grads.size(2);
94
    const int d = output_grads.size(3);
95
    // output_grads strides
96
    const int stride_s = output_grads.stride(0);
97
    const int stride_b = output_grads.stride(1);
98
    const int stride_h = output_grads.stride(2);
99
    const int stride_d = output_grads.stride(3);
100
    // freqs' shape is always (s, 1, 1, d2), so the strides are same under
101
    // different memory formats
102
    const int d2 = freqs.size(3);
103

104
    auto act_options = output_grads.options().requires_grad(false);
105
    torch::Tensor input_grads;
106
    if (transpose_output) {
107
        input_grads = torch::empty({b, s, h, d}, act_options).transpose(0, 1);
108
    } else {
109
        input_grads = torch::empty({s, b, h, d}, act_options);
110
    }
111
    const int o_stride_s = input_grads.stride(0);
112
    const int o_stride_b = input_grads.stride(1);
113
    const int o_stride_h = input_grads.stride(2);
114
    const int o_stride_d = input_grads.stride(3);
115

116
    DISPATCH_FLOAT_HALF_AND_BFLOAT(
117
        output_grads.scalar_type(),
118
        0,
119
        "dispatch_fused_rope_backward",
120
        dispatch_fused_rope_backward(s,
121
                                     b,
122
                                     h,
123
                                     d,
124
                                     d2,
125
                                     stride_s,
126
                                     stride_b,
127
                                     stride_h,
128
                                     stride_d,
129
                                     o_stride_s,
130
                                     o_stride_b,
131
                                     o_stride_h,
132
                                     o_stride_d,
133
                                     output_grads.data_ptr<scalar_t_0>(),
134
                                     freqs.data_ptr<float>(),
135
                                     input_grads.data_ptr<scalar_t_0>()););
136
    return input_grads;
137
}
138

139
#define DISPATCH_FUSED_ROPE_TYPES(TYPE1, TYPE2, NAME, ...) \
140
    switch (TYPE1) {                                       \
141
        case at::ScalarType::Float: {                      \
142
            using scalar_t_0 = float;                      \
143
            switch (TYPE2) {                               \
144
                case at::ScalarType::Float: {              \
145
                    using scalar_t_1 = float;              \
146
                    __VA_ARGS__;                           \
147
                    break;                                 \
148
                }                                          \
149
                default:                                   \
150
                    TORCH_CHECK(false,                     \
151
                                #NAME,                     \
152
                                " not supported for '",    \
153
                                toString(TYPE1),           \
154
                                "' with '",                \
155
                                toString(TYPE2),           \
156
                                "'");                      \
157
            }                                              \
158
            break;                                         \
159
        }                                                  \
160
        case at::ScalarType::Half: {                       \
161
            using scalar_t_0 = at::Half;                   \
162
            switch (TYPE2) {                               \
163
                case at::ScalarType::Float: {              \
164
                    using scalar_t_1 = float;              \
165
                    __VA_ARGS__;                           \
166
                    break;                                 \
167
                }                                          \
168
                case at::ScalarType::Half: {               \
169
                    using scalar_t_1 = at::Half;           \
170
                    __VA_ARGS__;                           \
171
                    break;                                 \
172
                }                                          \
173
                default:                                   \
174
                    TORCH_CHECK(false,                     \
175
                                #NAME,                     \
176
                                " not supported for '",    \
177
                                toString(TYPE1),           \
178
                                "' with '",                \
179
                                toString(TYPE2),           \
180
                                "'");                      \
181
            }                                              \
182
            break;                                         \
183
        }                                                  \
184
        case at::ScalarType::BFloat16: {                   \
185
            using scalar_t_0 = at::BFloat16;               \
186
            switch (TYPE2) {                               \
187
                case at::ScalarType::Float: {              \
188
                    using scalar_t_1 = float;              \
189
                    __VA_ARGS__;                           \
190
                    break;                                 \
191
                }                                          \
192
                case at::ScalarType::BFloat16: {           \
193
                    using scalar_t_1 = at::BFloat16;       \
194
                    __VA_ARGS__;                           \
195
                    break;                                 \
196
                }                                          \
197
                default:                                   \
198
                    TORCH_CHECK(false,                     \
199
                                #NAME,                     \
200
                                " not supported for '",    \
201
                                toString(TYPE1),           \
202
                                "' with '",                \
203
                                toString(TYPE2),           \
204
                                "'");                      \
205
            }                                              \
206
            break;                                         \
207
        }                                                  \
208
        default:                                           \
209
            TORCH_CHECK(false,                             \
210
                        #NAME,                             \
211
                        " not supported for '",            \
212
                        toString(TYPE1),                   \
213
                        "' with '",                        \
214
                        toString(TYPE2),                   \
215
                        "'");                              \
216
    }
217

218
torch::Tensor fwd_cached_cuda(const torch::Tensor& input,
219
                              const torch::Tensor& cos,
220
                              const torch::Tensor& sin,
221
                              const bool transpose_output)
222
{
223
    // input sizes: (s, b, h, d)
224
    // s: sequence length
225
    // b: batch size
226
    // h: head num
227
    // d: dim of each head
228
    const int s = input.size(0);
229
    const int b = input.size(1);
230
    const int h = input.size(2);
231
    const int d = input.size(3);
232
    // input strides
233
    const int stride_s = input.stride(0);
234
    const int stride_b = input.stride(1);
235
    const int stride_h = input.stride(2);
236
    const int stride_d = input.stride(3);
237
    // cos/sin's shape is always (s, 1, 1, d2), so the strides are same under
238
    // different memory formats
239
    const int d2 = cos.size(3);
240

241
    // output
242
    auto act_options = input.options().requires_grad(false);
243
    torch::Tensor output;
244
    if (transpose_output) {
245
        output = torch::empty({b, s, h, d}, act_options).transpose(0, 1);
246
    } else {
247
        output = torch::empty({s, b, h, d}, act_options);
248
    }
249
    // output strides
250
    const int o_stride_s = output.stride(0);
251
    const int o_stride_b = output.stride(1);
252
    const int o_stride_h = output.stride(2);
253
    const int o_stride_d = output.stride(3);
254

255
    DISPATCH_FUSED_ROPE_TYPES(input.scalar_type(),
256
                              cos.scalar_type(),
257
                              "dispatch_fused_rope_cached_forward",
258
                              dispatch_fused_rope_cached_forward(s,
259
                                                                 b,
260
                                                                 h,
261
                                                                 d,
262
                                                                 d2,
263
                                                                 stride_s,
264
                                                                 stride_b,
265
                                                                 stride_h,
266
                                                                 stride_d,
267
                                                                 o_stride_s,
268
                                                                 o_stride_b,
269
                                                                 o_stride_h,
270
                                                                 o_stride_d,
271
                                                                 input.data_ptr<scalar_t_0>(),
272
                                                                 cos.data_ptr<scalar_t_1>(),
273
                                                                 sin.data_ptr<scalar_t_1>(),
274
                                                                 output.data_ptr<scalar_t_0>()););
275
    return output;
276
}
277

278
torch::Tensor bwd_cached_cuda(const torch::Tensor& output_grads,
279
                              const torch::Tensor& cos,
280
                              const torch::Tensor& sin,
281
                              const bool transpose_output)
282
{
283
    // output_grads sizes: (s, b, h, d)
284
    // s: sequence length
285
    // b: batch size
286
    // h: head num
287
    // d: dim of each head
288
    const int s = output_grads.size(0);
289
    const int b = output_grads.size(1);
290
    const int h = output_grads.size(2);
291
    const int d = output_grads.size(3);
292
    // output_grads strides
293
    const int stride_s = output_grads.stride(0);
294
    const int stride_b = output_grads.stride(1);
295
    const int stride_h = output_grads.stride(2);
296
    const int stride_d = output_grads.stride(3);
297
    // cos/sin's shape is always (s, 1, 1, d2), so the strides are same under
298
    // different memory formats
299
    const int d2 = cos.size(3);
300

301
    auto act_options = output_grads.options().requires_grad(false);
302
    torch::Tensor input_grads;
303
    if (transpose_output) {
304
        input_grads = torch::empty({b, s, h, d}, act_options).transpose(0, 1);
305
    } else {
306
        input_grads = torch::empty({s, b, h, d}, act_options);
307
    }
308
    const int o_stride_s = input_grads.stride(0);
309
    const int o_stride_b = input_grads.stride(1);
310
    const int o_stride_h = input_grads.stride(2);
311
    const int o_stride_d = input_grads.stride(3);
312

313
    DISPATCH_FUSED_ROPE_TYPES(
314
        output_grads.scalar_type(),
315
        cos.scalar_type(),
316
        "dispatch_fused_rope_cached_backward",
317
        dispatch_fused_rope_cached_backward(s,
318
                                            b,
319
                                            h,
320
                                            d,
321
                                            d2,
322
                                            stride_s,
323
                                            stride_b,
324
                                            stride_h,
325
                                            stride_d,
326
                                            o_stride_s,
327
                                            o_stride_b,
328
                                            o_stride_h,
329
                                            o_stride_d,
330
                                            output_grads.data_ptr<scalar_t_0>(),
331
                                            cos.data_ptr<scalar_t_1>(),
332
                                            sin.data_ptr<scalar_t_1>(),
333
                                            input_grads.data_ptr<scalar_t_0>()););
334
    return input_grads;
335
}
336
}  // end namespace fused_rope
337

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.