deepspeed
218 строк · 9.7 Кб
1// Copyright (c) Microsoft Corporation.
2// SPDX-License-Identifier: Apache-2.0
3
4// DeepSpeed Team
5
6#include <ATen/cuda/CUDAContext.h>
7#include <torch/extension.h>
8#include <type_traits>
9#include "gemm_kernel_utils.h"
10#include "kernel_backward.h"
11#include "transform/bias_broadcast.h"
12
13constexpr auto kBlockSizeI = 64;
14constexpr auto kBlockSizeJ = 64;
15
16template <typename arch,
17typename scalar_t,
18typename torch_scalar_t,
19template <typename, typename, typename>
20class Broadcast1_,
21template <typename, typename, typename>
22class Broadcast2_>
23typename std::enable_if<!CheckArch<arch, scalar_t>::value>::type attention_back_impl_template(
24torch::Tensor& go,
25torch::Tensor& q,
26torch::Tensor& k,
27torch::Tensor& v,
28torch::Tensor& o,
29torch::Tensor& lse,
30torch::Tensor& delta,
31torch::Tensor& bias1,
32torch::Tensor& bias2,
33torch::Tensor& gq,
34torch::Tensor& gk,
35torch::Tensor& gv,
36torch::Tensor& gb1,
37torch::Tensor& gb2)
38{
39EVOFORMER_CHECK(false, "Unsupported GPU and data type combination")
40}
41
42template <typename arch,
43typename scalar_t,
44typename torch_scalar_t,
45template <typename, typename, typename>
46class Broadcast1_,
47template <typename, typename, typename>
48class Broadcast2_>
49typename std::enable_if<CheckArch<arch, scalar_t>::value>::type attention_back_impl_template(
50torch::Tensor& go,
51torch::Tensor& q,
52torch::Tensor& k,
53torch::Tensor& v,
54torch::Tensor& o,
55torch::Tensor& lse,
56torch::Tensor& delta,
57torch::Tensor& bias1,
58torch::Tensor& bias2,
59torch::Tensor& gq,
60torch::Tensor& gk,
61torch::Tensor& gv,
62torch::Tensor& gb1,
63torch::Tensor& gb2)
64{
65constexpr bool kPreload_ = arch::kMinComputeCapability >= 80;
66using Kernel = AttentionBackwardKernel<arch,
67scalar_t, // scalar_t
68true, // kIsAligned_
69false, // kApplyDropout_
70kPreload_, // kPreload_
71kBlockSizeI, // kBlockSizeI_,
72kBlockSizeJ, // kBlockSizeJ_,
7364, // kMaxK
74Broadcast1_,
75Broadcast2_>;
76int head_size = q.size(-1);
77int head_number = q.size(-2);
78int seq_length = q.size(-3);
79auto q_view = q.view({-1, seq_length, head_number, head_size});
80auto k_view = k.view({-1, seq_length, head_number, head_size});
81auto v_view = v.view({-1, seq_length, head_number, head_size});
82auto o_view = o.view({-1, seq_length, head_number, head_size});
83auto do_view = go.view({-1, seq_length, head_number, head_size});
84auto dk_view = gk.view({-1, seq_length, head_number, head_size});
85auto dv_view = gv.view({-1, seq_length, head_number, head_size});
86auto dq_view = gq.view({-1, seq_length, head_number, head_size});
87auto q_ptr = reinterpret_cast<scalar_t*>(q.data_ptr<torch_scalar_t>());
88auto k_ptr = reinterpret_cast<scalar_t*>(k.data_ptr<torch_scalar_t>());
89auto v_ptr = reinterpret_cast<scalar_t*>(v.data_ptr<torch_scalar_t>());
90auto o_ptr = reinterpret_cast<scalar_t*>(o.data_ptr<torch_scalar_t>());
91auto do_ptr = reinterpret_cast<scalar_t*>(go.data_ptr<torch_scalar_t>());
92auto dk_ptr = reinterpret_cast<scalar_t*>(gk.data_ptr<torch_scalar_t>());
93auto dv_ptr = reinterpret_cast<scalar_t*>(gv.data_ptr<torch_scalar_t>());
94auto dq_ptr = reinterpret_cast<scalar_t*>(gq.data_ptr<torch_scalar_t>());
95auto db1_ptr = gb1.size(0) > 0 ? reinterpret_cast<float*>(gb1.data_ptr<float>()) : nullptr;
96auto db2_ptr = gb2.size(0) > 0 ? reinterpret_cast<float*>(gb2.data_ptr<float>()) : nullptr;
97auto lse_ptr = reinterpret_cast<float*>(lse.data_ptr<float>());
98auto delta_ptr = reinterpret_cast<float*>(delta.data_ptr<float>());
99auto bias1_ptr = reinterpret_cast<scalar_t*>(bias1.data_ptr<torch_scalar_t>());
100auto bias2_ptr = reinterpret_cast<scalar_t*>(bias2.data_ptr<torch_scalar_t>());
101static_assert(Kernel::kKernelComputesDelta, "Kernel must compute delta");
102
103typename Kernel::Params p;
104p.query_ptr = q_ptr;
105p.key_ptr = k_ptr;
106p.value_ptr = v_ptr;
107p.logsumexp_ptr = lse_ptr;
108p.output_ptr = o_ptr;
109p.grad_output_ptr = do_ptr;
110p.delta_ptr = delta_ptr;
111p.grad_query_ptr = dq_ptr;
112p.grad_key_ptr = dk_ptr;
113p.grad_value_ptr = dv_ptr;
114
115p.grad_bias1_ptr = db1_ptr;
116p.grad_bias2_ptr = db2_ptr;
117p.B = q.size(0);
118p.N = q.size(1);
119p.bias1_ptr = bias1.size(0) ? bias1_ptr : nullptr;
120p.bias2_ptr = bias2.size(0) ? bias2_ptr : nullptr;
121
122p.scale = 1.0f / sqrtf(head_size);
123
124p.head_dim = head_size;
125p.head_dim_value = head_size;
126p.num_queries = seq_length;
127p.num_keys = seq_length;
128p.num_heads = head_number;
129
130p.q_strideM = q_view.stride(-3);
131p.k_strideM = k_view.stride(-3);
132p.v_strideM = v_view.stride(-3);
133p.gO_strideM = do_view.stride(-3);
134p.o_strideH = o_view.stride(-2);
135p.q_strideH = q_view.stride(-2);
136p.k_strideH = k_view.stride(-2);
137p.v_strideH = v_view.stride(-2);
138p.o_strideB = o_view.stride(-4);
139p.q_strideB = q_view.stride(-4);
140p.k_strideB = k_view.stride(-4);
141p.v_strideB = v_view.stride(-4);
142p.lse_strideB = lse.stride(-3);
143p.lse_strideH = lse.stride(-2);
144p.delta_strideB = delta.stride(-3);
145p.delta_strideH = delta.stride(-2);
146p.num_batches = q_view.size(-4);
147
148p.gO_strideB = do_view.stride(-4);
149p.gQ_strideB = dq_view.stride(-4);
150p.gK_strideB = dk_view.stride(-4);
151p.gV_strideB = dv_view.stride(-4);
152p.gO_strideH = do_view.stride(-2);
153p.gQ_strideH = dq_view.stride(-2);
154p.gK_strideH = dk_view.stride(-2);
155p.gV_strideH = dv_view.stride(-2);
156
157torch::Tensor workspace = torch::empty(p.workspace_size() / 4, lse.options());
158p.workspace = workspace.data_ptr<float>();
159
160auto kernel_fn = attention_kernel_backward_batched_impl<Kernel>;
161size_t smem_bytes = sizeof(typename Kernel::SharedStorage);
162cudaFuncSetAttribute(kernel_fn, cudaFuncAttributeMaxDynamicSharedMemorySize, int(smem_bytes));
163if (!Kernel::check_supported(p)) { throw std::runtime_error("Unsupported parameters"); }
164kernel_fn<<<p.getBlocksGrid(), p.getThreadsGrid(), smem_bytes>>>(p);
165}
166
167#define CODE(scalar_t, torch_scalar_t) \
168do { \
169if (bias1.size(0) == 0 && bias2.size(0) == 0) { \
170attention_back_impl_template<ArchTag, \
171scalar_t, \
172torch_scalar_t, \
173BroadcastNoLoad, \
174BroadcastNoLoad>( \
175go, q, k, v, o, lse, delta, bias1, bias2, gq, gk, gv, gb1, gb2); \
176} else if (bias1.size(0) > 0 && bias2.size(0) > 0) { \
177attention_back_impl_template<ArchTag, \
178scalar_t, \
179torch_scalar_t, \
180BroadcastA, \
181BroadcastB>( \
182go, q, k, v, o, lse, delta, bias1, bias2, gq, gk, gv, gb1, gb2); \
183} else if (bias1.size(0) > 0) { \
184attention_back_impl_template<ArchTag, \
185scalar_t, \
186torch_scalar_t, \
187BroadcastA, \
188BroadcastNoLoad>( \
189go, q, k, v, o, lse, delta, bias1, bias2, gq, gk, gv, gb1, gb2); \
190} else { \
191attention_back_impl_template<ArchTag, \
192scalar_t, \
193torch_scalar_t, \
194BroadcastNoLoad, \
195BroadcastB>( \
196go, q, k, v, o, lse, delta, bias1, bias2, gq, gk, gv, gb1, gb2); \
197} \
198} while (0)
199
200void attention_back_impl(torch::Tensor& go,
201torch::Tensor& q,
202torch::Tensor& k,
203torch::Tensor& v,
204torch::Tensor& o,
205torch::Tensor& lse,
206torch::Tensor& delta,
207torch::Tensor& bias1,
208torch::Tensor& bias2,
209torch::Tensor& gq,
210torch::Tensor& gk,
211torch::Tensor& gv,
212torch::Tensor& gb1,
213torch::Tensor& gb2)
214{
215cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
216DISPATCH_ARCHTAG(prop->major * 10 + prop->minor,
217DISPATCH_TYPES(q, { CODE(scalar_t, torch_scalar_t); }));
218}
219