deepspeed

Форк
0
/
gather_scatter.cu 
186 строк · 8.2 Кб
1
// Copyright (c) Microsoft Corporation.
2
// SPDX-License-Identifier: Apache-2.0
3

4
// DeepSpeed Team
5

6
#include "custom_cuda_layers.h"
7
#include "memory_access_utils.h"
8

9
namespace cg = cooperative_groups;
10

11
namespace td_data {
12
constexpr int granularity = 16;
13
}
14

15
template <typename T>
16
__global__ void gather_tokens_impl(T* retained_tokens,
17
                                   const T* activations,
18
                                   int32_t* gather_indices,
19
                                   int32_t sampled_tokens,
20
                                   int32_t channels,
21
                                   int32_t read_batch_stride,
22
                                   int32_t read_seq_stride,
23
                                   int32_t write_batch_stride,
24
                                   int32_t write_seq_stride)
25
{
26
    constexpr int mem_vals_t = td_data::granularity / sizeof(T);
27

28
    cg::thread_block tb = cg::this_thread_block();
29

30
    const int gather_idx = gather_indices[tb.group_index().x * sampled_tokens + tb.group_index().y];
31

32
    const int read_offset = read_batch_stride * tb.group_index().x + read_seq_stride * gather_idx;
33
    const int write_offset =
34
        write_batch_stride * tb.group_index().x + write_seq_stride * tb.group_index().y;
35

36
    for (int i = tb.thread_index().x * mem_vals_t; i < channels; i += blockDim.x * mem_vals_t) {
37
        T local_data[mem_vals_t];
38
        mem_access::load_global<td_data::granularity>(local_data, activations + read_offset + i);
39
        mem_access::store_global<td_data::granularity>(retained_tokens + write_offset + i,
40
                                                       local_data);
41
    }
42
}
43

44
template <typename T>
45
void launch_gather_tokens(T* retained_tokens,
46
                          T* activations,
47
                          int32_t* gather_indices,
48
                          int32_t batch_size,
49
                          int32_t sampled_tokens,
50
                          int32_t channels,
51
                          int32_t read_batch_stride,
52
                          int32_t read_seq_stride,
53
                          int32_t write_batch_stride,
54
                          int32_t write_seq_stride,
55
                          cudaStream_t stream)
56
{
57
    constexpr int mem_vals_t = td_data::granularity / sizeof(T);
58

59
    const int load_steps = (channels + mem_vals_t - 1) / mem_vals_t;
60
    const int threads = (load_steps >= 1024) ? 1024 : load_steps;
61

62
    dim3 block(threads);
63
    dim3 grid(batch_size, sampled_tokens);
64

65
    gather_tokens_impl<T><<<grid, block, 0, stream>>>(retained_tokens,
66
                                                      activations,
67
                                                      gather_indices,
68
                                                      sampled_tokens,
69
                                                      channels,
70
                                                      read_batch_stride,
71
                                                      read_seq_stride,
72
                                                      write_batch_stride,
73
                                                      write_seq_stride);
74
}
75

76
template void launch_gather_tokens<float>(float*,
77
                                          float*,
78
                                          int32_t*,
79
                                          int32_t,
80
                                          int32_t,
81
                                          int32_t,
82
                                          int32_t,
83
                                          int32_t,
84
                                          int32_t,
85
                                          int32_t,
86
                                          cudaStream_t);
87

88
template void launch_gather_tokens<__half>(__half*,
89
                                           __half*,
90
                                           int32_t*,
91
                                           int32_t,
92
                                           int32_t,
93
                                           int32_t,
94
                                           int32_t,
95
                                           int32_t,
96
                                           int32_t,
97
                                           int32_t,
98
                                           cudaStream_t);
99

100
template <typename T>
101
__global__ void scatter_tokens_impl(T* all_activations,
102
                                    const T* layer_activations,
103
                                    int32_t* gather_indices,
104
                                    int32_t retained_tokens,
105
                                    int32_t channels,
106
                                    int32_t read_batch_stride,
107
                                    int32_t read_seq_stride,
108
                                    int32_t write_batch_stride,
109
                                    int32_t write_seq_stride)
110
{
111
    constexpr int mem_vals_t = td_data::granularity / sizeof(T);
112

113
    cg::thread_block tb = cg::this_thread_block();
114

115
    const int gather_idx =
116
        gather_indices[tb.group_index().x * retained_tokens + tb.group_index().y];
117

118
    const int read_offset =
119
        read_batch_stride * tb.group_index().x + read_seq_stride * tb.group_index().y;
120
    const int write_offset =
121
        write_batch_stride * tb.group_index().x + write_seq_stride * gather_idx;
122

123
    for (int i = tb.thread_index().x * mem_vals_t; i < channels; i += mem_vals_t * blockDim.x) {
124
        T local_data[mem_vals_t];
125
        mem_access::load_global<td_data::granularity>(local_data,
126
                                                      layer_activations + read_offset + i);
127
        mem_access::store_global<td_data::granularity>(all_activations + write_offset + i,
128
                                                       local_data);
129
    }
130
}
131

132
template <typename T>
133
void launch_scatter_tokens(T* all_activations,
134
                           T* layer_activations,
135
                           int32_t* gather_indices,
136
                           int32_t batch_size,
137
                           int32_t sampled_tokens,
138
                           int32_t channels,
139
                           int32_t read_batch_stride,
140
                           int32_t read_seq_stride,
141
                           int32_t write_batch_stride,
142
                           int32_t write_seq_stride,
143
                           cudaStream_t stream)
144
{
145
    constexpr int mem_vals_t = td_data::granularity / sizeof(T);
146

147
    const int load_steps = (channels + mem_vals_t - 1) / mem_vals_t;
148
    const int threads = (load_steps >= 1024) ? 1024 : load_steps;
149

150
    dim3 block(threads);
151
    dim3 grid(batch_size, sampled_tokens);
152

153
    scatter_tokens_impl<T><<<grid, block, 0, stream>>>(all_activations,
154
                                                       layer_activations,
155
                                                       gather_indices,
156
                                                       sampled_tokens,
157
                                                       channels,
158
                                                       read_batch_stride,
159
                                                       read_seq_stride,
160
                                                       write_batch_stride,
161
                                                       write_seq_stride);
162
}
163

164
template void launch_scatter_tokens<float>(float*,
165
                                           float*,
166
                                           int32_t*,
167
                                           int32_t,
168
                                           int32_t,
169
                                           int32_t,
170
                                           int32_t,
171
                                           int32_t,
172
                                           int32_t,
173
                                           int32_t,
174
                                           cudaStream_t);
175

176
template void launch_scatter_tokens<__half>(__half*,
177
                                            __half*,
178
                                            int32_t*,
179
                                            int32_t,
180
                                            int32_t,
181
                                            int32_t,
182
                                            int32_t,
183
                                            int32_t,
184
                                            int32_t,
185
                                            int32_t,
186
                                            cudaStream_t);
187

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.