deepspeed

Форк
0
218 строк · 9.7 Кб
1
// Copyright (c) Microsoft Corporation.
2
// SPDX-License-Identifier: Apache-2.0
3

4
// DeepSpeed Team
5

6
#include <ATen/cuda/CUDAContext.h>
7
#include <torch/extension.h>
8
#include <type_traits>
9
#include "gemm_kernel_utils.h"
10
#include "kernel_backward.h"
11
#include "transform/bias_broadcast.h"
12

13
constexpr auto kBlockSizeI = 64;
14
constexpr auto kBlockSizeJ = 64;
15

16
template <typename arch,
17
          typename scalar_t,
18
          typename torch_scalar_t,
19
          template <typename, typename, typename>
20
          class Broadcast1_,
21
          template <typename, typename, typename>
22
          class Broadcast2_>
23
typename std::enable_if<!CheckArch<arch, scalar_t>::value>::type attention_back_impl_template(
24
    torch::Tensor& go,
25
    torch::Tensor& q,
26
    torch::Tensor& k,
27
    torch::Tensor& v,
28
    torch::Tensor& o,
29
    torch::Tensor& lse,
30
    torch::Tensor& delta,
31
    torch::Tensor& bias1,
32
    torch::Tensor& bias2,
33
    torch::Tensor& gq,
34
    torch::Tensor& gk,
35
    torch::Tensor& gv,
36
    torch::Tensor& gb1,
37
    torch::Tensor& gb2)
38
{
39
    EVOFORMER_CHECK(false, "Unsupported GPU and data type combination")
40
}
41

42
template <typename arch,
43
          typename scalar_t,
44
          typename torch_scalar_t,
45
          template <typename, typename, typename>
46
          class Broadcast1_,
47
          template <typename, typename, typename>
48
          class Broadcast2_>
49
typename std::enable_if<CheckArch<arch, scalar_t>::value>::type attention_back_impl_template(
50
    torch::Tensor& go,
51
    torch::Tensor& q,
52
    torch::Tensor& k,
53
    torch::Tensor& v,
54
    torch::Tensor& o,
55
    torch::Tensor& lse,
56
    torch::Tensor& delta,
57
    torch::Tensor& bias1,
58
    torch::Tensor& bias2,
59
    torch::Tensor& gq,
60
    torch::Tensor& gk,
61
    torch::Tensor& gv,
62
    torch::Tensor& gb1,
63
    torch::Tensor& gb2)
64
{
65
    constexpr bool kPreload_ = arch::kMinComputeCapability >= 80;
66
    using Kernel = AttentionBackwardKernel<arch,
67
                                           scalar_t,     // scalar_t
68
                                           true,         // kIsAligned_
69
                                           false,        // kApplyDropout_
70
                                           kPreload_,    // kPreload_
71
                                           kBlockSizeI,  // kBlockSizeI_,
72
                                           kBlockSizeJ,  // kBlockSizeJ_,
73
                                           64,           // kMaxK
74
                                           Broadcast1_,
75
                                           Broadcast2_>;
76
    int head_size = q.size(-1);
77
    int head_number = q.size(-2);
78
    int seq_length = q.size(-3);
79
    auto q_view = q.view({-1, seq_length, head_number, head_size});
80
    auto k_view = k.view({-1, seq_length, head_number, head_size});
81
    auto v_view = v.view({-1, seq_length, head_number, head_size});
82
    auto o_view = o.view({-1, seq_length, head_number, head_size});
83
    auto do_view = go.view({-1, seq_length, head_number, head_size});
84
    auto dk_view = gk.view({-1, seq_length, head_number, head_size});
85
    auto dv_view = gv.view({-1, seq_length, head_number, head_size});
86
    auto dq_view = gq.view({-1, seq_length, head_number, head_size});
87
    auto q_ptr = reinterpret_cast<scalar_t*>(q.data_ptr<torch_scalar_t>());
88
    auto k_ptr = reinterpret_cast<scalar_t*>(k.data_ptr<torch_scalar_t>());
89
    auto v_ptr = reinterpret_cast<scalar_t*>(v.data_ptr<torch_scalar_t>());
90
    auto o_ptr = reinterpret_cast<scalar_t*>(o.data_ptr<torch_scalar_t>());
91
    auto do_ptr = reinterpret_cast<scalar_t*>(go.data_ptr<torch_scalar_t>());
92
    auto dk_ptr = reinterpret_cast<scalar_t*>(gk.data_ptr<torch_scalar_t>());
93
    auto dv_ptr = reinterpret_cast<scalar_t*>(gv.data_ptr<torch_scalar_t>());
94
    auto dq_ptr = reinterpret_cast<scalar_t*>(gq.data_ptr<torch_scalar_t>());
95
    auto db1_ptr = gb1.size(0) > 0 ? reinterpret_cast<float*>(gb1.data_ptr<float>()) : nullptr;
96
    auto db2_ptr = gb2.size(0) > 0 ? reinterpret_cast<float*>(gb2.data_ptr<float>()) : nullptr;
97
    auto lse_ptr = reinterpret_cast<float*>(lse.data_ptr<float>());
98
    auto delta_ptr = reinterpret_cast<float*>(delta.data_ptr<float>());
99
    auto bias1_ptr = reinterpret_cast<scalar_t*>(bias1.data_ptr<torch_scalar_t>());
100
    auto bias2_ptr = reinterpret_cast<scalar_t*>(bias2.data_ptr<torch_scalar_t>());
101
    static_assert(Kernel::kKernelComputesDelta, "Kernel must compute delta");
102

103
    typename Kernel::Params p;
104
    p.query_ptr = q_ptr;
105
    p.key_ptr = k_ptr;
106
    p.value_ptr = v_ptr;
107
    p.logsumexp_ptr = lse_ptr;
108
    p.output_ptr = o_ptr;
109
    p.grad_output_ptr = do_ptr;
110
    p.delta_ptr = delta_ptr;
111
    p.grad_query_ptr = dq_ptr;
112
    p.grad_key_ptr = dk_ptr;
113
    p.grad_value_ptr = dv_ptr;
114

115
    p.grad_bias1_ptr = db1_ptr;
116
    p.grad_bias2_ptr = db2_ptr;
117
    p.B = q.size(0);
118
    p.N = q.size(1);
119
    p.bias1_ptr = bias1.size(0) ? bias1_ptr : nullptr;
120
    p.bias2_ptr = bias2.size(0) ? bias2_ptr : nullptr;
121

122
    p.scale = 1.0f / sqrtf(head_size);
123

124
    p.head_dim = head_size;
125
    p.head_dim_value = head_size;
126
    p.num_queries = seq_length;
127
    p.num_keys = seq_length;
128
    p.num_heads = head_number;
129

130
    p.q_strideM = q_view.stride(-3);
131
    p.k_strideM = k_view.stride(-3);
132
    p.v_strideM = v_view.stride(-3);
133
    p.gO_strideM = do_view.stride(-3);
134
    p.o_strideH = o_view.stride(-2);
135
    p.q_strideH = q_view.stride(-2);
136
    p.k_strideH = k_view.stride(-2);
137
    p.v_strideH = v_view.stride(-2);
138
    p.o_strideB = o_view.stride(-4);
139
    p.q_strideB = q_view.stride(-4);
140
    p.k_strideB = k_view.stride(-4);
141
    p.v_strideB = v_view.stride(-4);
142
    p.lse_strideB = lse.stride(-3);
143
    p.lse_strideH = lse.stride(-2);
144
    p.delta_strideB = delta.stride(-3);
145
    p.delta_strideH = delta.stride(-2);
146
    p.num_batches = q_view.size(-4);
147

148
    p.gO_strideB = do_view.stride(-4);
149
    p.gQ_strideB = dq_view.stride(-4);
150
    p.gK_strideB = dk_view.stride(-4);
151
    p.gV_strideB = dv_view.stride(-4);
152
    p.gO_strideH = do_view.stride(-2);
153
    p.gQ_strideH = dq_view.stride(-2);
154
    p.gK_strideH = dk_view.stride(-2);
155
    p.gV_strideH = dv_view.stride(-2);
156

157
    torch::Tensor workspace = torch::empty(p.workspace_size() / 4, lse.options());
158
    p.workspace = workspace.data_ptr<float>();
159

160
    auto kernel_fn = attention_kernel_backward_batched_impl<Kernel>;
161
    size_t smem_bytes = sizeof(typename Kernel::SharedStorage);
162
    cudaFuncSetAttribute(kernel_fn, cudaFuncAttributeMaxDynamicSharedMemorySize, int(smem_bytes));
163
    if (!Kernel::check_supported(p)) { throw std::runtime_error("Unsupported parameters"); }
164
    kernel_fn<<<p.getBlocksGrid(), p.getThreadsGrid(), smem_bytes>>>(p);
165
}
166

167
#define CODE(scalar_t, torch_scalar_t)                                           \
168
    do {                                                                         \
169
        if (bias1.size(0) == 0 && bias2.size(0) == 0) {                          \
170
            attention_back_impl_template<ArchTag,                                \
171
                                         scalar_t,                               \
172
                                         torch_scalar_t,                         \
173
                                         BroadcastNoLoad,                        \
174
                                         BroadcastNoLoad>(                       \
175
                go, q, k, v, o, lse, delta, bias1, bias2, gq, gk, gv, gb1, gb2); \
176
        } else if (bias1.size(0) > 0 && bias2.size(0) > 0) {                     \
177
            attention_back_impl_template<ArchTag,                                \
178
                                         scalar_t,                               \
179
                                         torch_scalar_t,                         \
180
                                         BroadcastA,                             \
181
                                         BroadcastB>(                            \
182
                go, q, k, v, o, lse, delta, bias1, bias2, gq, gk, gv, gb1, gb2); \
183
        } else if (bias1.size(0) > 0) {                                          \
184
            attention_back_impl_template<ArchTag,                                \
185
                                         scalar_t,                               \
186
                                         torch_scalar_t,                         \
187
                                         BroadcastA,                             \
188
                                         BroadcastNoLoad>(                       \
189
                go, q, k, v, o, lse, delta, bias1, bias2, gq, gk, gv, gb1, gb2); \
190
        } else {                                                                 \
191
            attention_back_impl_template<ArchTag,                                \
192
                                         scalar_t,                               \
193
                                         torch_scalar_t,                         \
194
                                         BroadcastNoLoad,                        \
195
                                         BroadcastB>(                            \
196
                go, q, k, v, o, lse, delta, bias1, bias2, gq, gk, gv, gb1, gb2); \
197
        }                                                                        \
198
    } while (0)
199

200
void attention_back_impl(torch::Tensor& go,
201
                         torch::Tensor& q,
202
                         torch::Tensor& k,
203
                         torch::Tensor& v,
204
                         torch::Tensor& o,
205
                         torch::Tensor& lse,
206
                         torch::Tensor& delta,
207
                         torch::Tensor& bias1,
208
                         torch::Tensor& bias2,
209
                         torch::Tensor& gq,
210
                         torch::Tensor& gk,
211
                         torch::Tensor& gv,
212
                         torch::Tensor& gb1,
213
                         torch::Tensor& gb2)
214
{
215
    cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
216
    DISPATCH_ARCHTAG(prop->major * 10 + prop->minor,
217
                     DISPATCH_TYPES(q, { CODE(scalar_t, torch_scalar_t); }));
218
}
219

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.