gpt-neox
336 строк · 15.4 Кб
1/* coding=utf-8
2* Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*/
16
17#include <ATen/ATen.h>
18
19#include "fused_rotary_positional_embedding.h"
20#include "type_shim.h"
21
22namespace fused_rope {
23
24torch::Tensor fwd_cuda(const torch::Tensor& input,
25const torch::Tensor& freqs,
26const bool transpose_output)
27{
28// input sizes: (s, b, h, d)
29// s: sequence length
30// b: batch size
31// h: head num
32// d: dim of each head
33const int s = input.size(0);
34const int b = input.size(1);
35const int h = input.size(2);
36const int d = input.size(3);
37// input strides
38const int stride_s = input.stride(0);
39const int stride_b = input.stride(1);
40const int stride_h = input.stride(2);
41const int stride_d = input.stride(3);
42// freqs' shape is always (s, 1, 1, d2), so the strides are same under
43// different memory formats
44const int d2 = freqs.size(3);
45
46// output
47auto act_options = input.options().requires_grad(false);
48torch::Tensor output;
49if (transpose_output) {
50output = torch::empty({b, s, h, d}, act_options).transpose(0, 1);
51} else {
52output = torch::empty({s, b, h, d}, act_options);
53}
54// output strides
55const int o_stride_s = output.stride(0);
56const int o_stride_b = output.stride(1);
57const int o_stride_h = output.stride(2);
58const int o_stride_d = output.stride(3);
59
60DISPATCH_FLOAT_HALF_AND_BFLOAT(input.scalar_type(),
610,
62"dispatch_fused_rope_forward",
63dispatch_fused_rope_forward(s,
64b,
65h,
66d,
67d2,
68stride_s,
69stride_b,
70stride_h,
71stride_d,
72o_stride_s,
73o_stride_b,
74o_stride_h,
75o_stride_d,
76input.data_ptr<scalar_t_0>(),
77freqs.data_ptr<float>(),
78output.data_ptr<scalar_t_0>()););
79return output;
80}
81
82torch::Tensor bwd_cuda(const torch::Tensor& output_grads,
83const torch::Tensor& freqs,
84const bool transpose_output)
85{
86// output_grads sizes: (s, b, h, d)
87// s: sequence length
88// b: batch size
89// h: head num
90// d: dim of each head
91const int s = output_grads.size(0);
92const int b = output_grads.size(1);
93const int h = output_grads.size(2);
94const int d = output_grads.size(3);
95// output_grads strides
96const int stride_s = output_grads.stride(0);
97const int stride_b = output_grads.stride(1);
98const int stride_h = output_grads.stride(2);
99const int stride_d = output_grads.stride(3);
100// freqs' shape is always (s, 1, 1, d2), so the strides are same under
101// different memory formats
102const int d2 = freqs.size(3);
103
104auto act_options = output_grads.options().requires_grad(false);
105torch::Tensor input_grads;
106if (transpose_output) {
107input_grads = torch::empty({b, s, h, d}, act_options).transpose(0, 1);
108} else {
109input_grads = torch::empty({s, b, h, d}, act_options);
110}
111const int o_stride_s = input_grads.stride(0);
112const int o_stride_b = input_grads.stride(1);
113const int o_stride_h = input_grads.stride(2);
114const int o_stride_d = input_grads.stride(3);
115
116DISPATCH_FLOAT_HALF_AND_BFLOAT(
117output_grads.scalar_type(),
1180,
119"dispatch_fused_rope_backward",
120dispatch_fused_rope_backward(s,
121b,
122h,
123d,
124d2,
125stride_s,
126stride_b,
127stride_h,
128stride_d,
129o_stride_s,
130o_stride_b,
131o_stride_h,
132o_stride_d,
133output_grads.data_ptr<scalar_t_0>(),
134freqs.data_ptr<float>(),
135input_grads.data_ptr<scalar_t_0>()););
136return input_grads;
137}
138
139#define DISPATCH_FUSED_ROPE_TYPES(TYPE1, TYPE2, NAME, ...) \
140switch (TYPE1) { \
141case at::ScalarType::Float: { \
142using scalar_t_0 = float; \
143switch (TYPE2) { \
144case at::ScalarType::Float: { \
145using scalar_t_1 = float; \
146__VA_ARGS__; \
147break; \
148} \
149default: \
150TORCH_CHECK(false, \
151#NAME, \
152" not supported for '", \
153toString(TYPE1), \
154"' with '", \
155toString(TYPE2), \
156"'"); \
157} \
158break; \
159} \
160case at::ScalarType::Half: { \
161using scalar_t_0 = at::Half; \
162switch (TYPE2) { \
163case at::ScalarType::Float: { \
164using scalar_t_1 = float; \
165__VA_ARGS__; \
166break; \
167} \
168case at::ScalarType::Half: { \
169using scalar_t_1 = at::Half; \
170__VA_ARGS__; \
171break; \
172} \
173default: \
174TORCH_CHECK(false, \
175#NAME, \
176" not supported for '", \
177toString(TYPE1), \
178"' with '", \
179toString(TYPE2), \
180"'"); \
181} \
182break; \
183} \
184case at::ScalarType::BFloat16: { \
185using scalar_t_0 = at::BFloat16; \
186switch (TYPE2) { \
187case at::ScalarType::Float: { \
188using scalar_t_1 = float; \
189__VA_ARGS__; \
190break; \
191} \
192case at::ScalarType::BFloat16: { \
193using scalar_t_1 = at::BFloat16; \
194__VA_ARGS__; \
195break; \
196} \
197default: \
198TORCH_CHECK(false, \
199#NAME, \
200" not supported for '", \
201toString(TYPE1), \
202"' with '", \
203toString(TYPE2), \
204"'"); \
205} \
206break; \
207} \
208default: \
209TORCH_CHECK(false, \
210#NAME, \
211" not supported for '", \
212toString(TYPE1), \
213"' with '", \
214toString(TYPE2), \
215"'"); \
216}
217
218torch::Tensor fwd_cached_cuda(const torch::Tensor& input,
219const torch::Tensor& cos,
220const torch::Tensor& sin,
221const bool transpose_output)
222{
223// input sizes: (s, b, h, d)
224// s: sequence length
225// b: batch size
226// h: head num
227// d: dim of each head
228const int s = input.size(0);
229const int b = input.size(1);
230const int h = input.size(2);
231const int d = input.size(3);
232// input strides
233const int stride_s = input.stride(0);
234const int stride_b = input.stride(1);
235const int stride_h = input.stride(2);
236const int stride_d = input.stride(3);
237// cos/sin's shape is always (s, 1, 1, d2), so the strides are same under
238// different memory formats
239const int d2 = cos.size(3);
240
241// output
242auto act_options = input.options().requires_grad(false);
243torch::Tensor output;
244if (transpose_output) {
245output = torch::empty({b, s, h, d}, act_options).transpose(0, 1);
246} else {
247output = torch::empty({s, b, h, d}, act_options);
248}
249// output strides
250const int o_stride_s = output.stride(0);
251const int o_stride_b = output.stride(1);
252const int o_stride_h = output.stride(2);
253const int o_stride_d = output.stride(3);
254
255DISPATCH_FUSED_ROPE_TYPES(input.scalar_type(),
256cos.scalar_type(),
257"dispatch_fused_rope_cached_forward",
258dispatch_fused_rope_cached_forward(s,
259b,
260h,
261d,
262d2,
263stride_s,
264stride_b,
265stride_h,
266stride_d,
267o_stride_s,
268o_stride_b,
269o_stride_h,
270o_stride_d,
271input.data_ptr<scalar_t_0>(),
272cos.data_ptr<scalar_t_1>(),
273sin.data_ptr<scalar_t_1>(),
274output.data_ptr<scalar_t_0>()););
275return output;
276}
277
278torch::Tensor bwd_cached_cuda(const torch::Tensor& output_grads,
279const torch::Tensor& cos,
280const torch::Tensor& sin,
281const bool transpose_output)
282{
283// output_grads sizes: (s, b, h, d)
284// s: sequence length
285// b: batch size
286// h: head num
287// d: dim of each head
288const int s = output_grads.size(0);
289const int b = output_grads.size(1);
290const int h = output_grads.size(2);
291const int d = output_grads.size(3);
292// output_grads strides
293const int stride_s = output_grads.stride(0);
294const int stride_b = output_grads.stride(1);
295const int stride_h = output_grads.stride(2);
296const int stride_d = output_grads.stride(3);
297// cos/sin's shape is always (s, 1, 1, d2), so the strides are same under
298// different memory formats
299const int d2 = cos.size(3);
300
301auto act_options = output_grads.options().requires_grad(false);
302torch::Tensor input_grads;
303if (transpose_output) {
304input_grads = torch::empty({b, s, h, d}, act_options).transpose(0, 1);
305} else {
306input_grads = torch::empty({s, b, h, d}, act_options);
307}
308const int o_stride_s = input_grads.stride(0);
309const int o_stride_b = input_grads.stride(1);
310const int o_stride_h = input_grads.stride(2);
311const int o_stride_d = input_grads.stride(3);
312
313DISPATCH_FUSED_ROPE_TYPES(
314output_grads.scalar_type(),
315cos.scalar_type(),
316"dispatch_fused_rope_cached_backward",
317dispatch_fused_rope_cached_backward(s,
318b,
319h,
320d,
321d2,
322stride_s,
323stride_b,
324stride_h,
325stride_d,
326o_stride_s,
327o_stride_b,
328o_stride_h,
329o_stride_d,
330output_grads.data_ptr<scalar_t_0>(),
331cos.data_ptr<scalar_t_1>(),
332sin.data_ptr<scalar_t_1>(),
333input_grads.data_ptr<scalar_t_0>()););
334return input_grads;
335}
336} // end namespace fused_rope
337