deepspeed
186 строк · 8.2 Кб
1// Copyright (c) Microsoft Corporation.
2// SPDX-License-Identifier: Apache-2.0
3
4// DeepSpeed Team
5
6#include "custom_cuda_layers.h"
7#include "memory_access_utils.h"
8
9namespace cg = cooperative_groups;
10
11namespace td_data {
12constexpr int granularity = 16;
13}
14
15template <typename T>
16__global__ void gather_tokens_impl(T* retained_tokens,
17const T* activations,
18int32_t* gather_indices,
19int32_t sampled_tokens,
20int32_t channels,
21int32_t read_batch_stride,
22int32_t read_seq_stride,
23int32_t write_batch_stride,
24int32_t write_seq_stride)
25{
26constexpr int mem_vals_t = td_data::granularity / sizeof(T);
27
28cg::thread_block tb = cg::this_thread_block();
29
30const int gather_idx = gather_indices[tb.group_index().x * sampled_tokens + tb.group_index().y];
31
32const int read_offset = read_batch_stride * tb.group_index().x + read_seq_stride * gather_idx;
33const int write_offset =
34write_batch_stride * tb.group_index().x + write_seq_stride * tb.group_index().y;
35
36for (int i = tb.thread_index().x * mem_vals_t; i < channels; i += blockDim.x * mem_vals_t) {
37T local_data[mem_vals_t];
38mem_access::load_global<td_data::granularity>(local_data, activations + read_offset + i);
39mem_access::store_global<td_data::granularity>(retained_tokens + write_offset + i,
40local_data);
41}
42}
43
44template <typename T>
45void launch_gather_tokens(T* retained_tokens,
46T* activations,
47int32_t* gather_indices,
48int32_t batch_size,
49int32_t sampled_tokens,
50int32_t channels,
51int32_t read_batch_stride,
52int32_t read_seq_stride,
53int32_t write_batch_stride,
54int32_t write_seq_stride,
55cudaStream_t stream)
56{
57constexpr int mem_vals_t = td_data::granularity / sizeof(T);
58
59const int load_steps = (channels + mem_vals_t - 1) / mem_vals_t;
60const int threads = (load_steps >= 1024) ? 1024 : load_steps;
61
62dim3 block(threads);
63dim3 grid(batch_size, sampled_tokens);
64
65gather_tokens_impl<T><<<grid, block, 0, stream>>>(retained_tokens,
66activations,
67gather_indices,
68sampled_tokens,
69channels,
70read_batch_stride,
71read_seq_stride,
72write_batch_stride,
73write_seq_stride);
74}
75
76template void launch_gather_tokens<float>(float*,
77float*,
78int32_t*,
79int32_t,
80int32_t,
81int32_t,
82int32_t,
83int32_t,
84int32_t,
85int32_t,
86cudaStream_t);
87
88template void launch_gather_tokens<__half>(__half*,
89__half*,
90int32_t*,
91int32_t,
92int32_t,
93int32_t,
94int32_t,
95int32_t,
96int32_t,
97int32_t,
98cudaStream_t);
99
100template <typename T>
101__global__ void scatter_tokens_impl(T* all_activations,
102const T* layer_activations,
103int32_t* gather_indices,
104int32_t retained_tokens,
105int32_t channels,
106int32_t read_batch_stride,
107int32_t read_seq_stride,
108int32_t write_batch_stride,
109int32_t write_seq_stride)
110{
111constexpr int mem_vals_t = td_data::granularity / sizeof(T);
112
113cg::thread_block tb = cg::this_thread_block();
114
115const int gather_idx =
116gather_indices[tb.group_index().x * retained_tokens + tb.group_index().y];
117
118const int read_offset =
119read_batch_stride * tb.group_index().x + read_seq_stride * tb.group_index().y;
120const int write_offset =
121write_batch_stride * tb.group_index().x + write_seq_stride * gather_idx;
122
123for (int i = tb.thread_index().x * mem_vals_t; i < channels; i += mem_vals_t * blockDim.x) {
124T local_data[mem_vals_t];
125mem_access::load_global<td_data::granularity>(local_data,
126layer_activations + read_offset + i);
127mem_access::store_global<td_data::granularity>(all_activations + write_offset + i,
128local_data);
129}
130}
131
132template <typename T>
133void launch_scatter_tokens(T* all_activations,
134T* layer_activations,
135int32_t* gather_indices,
136int32_t batch_size,
137int32_t sampled_tokens,
138int32_t channels,
139int32_t read_batch_stride,
140int32_t read_seq_stride,
141int32_t write_batch_stride,
142int32_t write_seq_stride,
143cudaStream_t stream)
144{
145constexpr int mem_vals_t = td_data::granularity / sizeof(T);
146
147const int load_steps = (channels + mem_vals_t - 1) / mem_vals_t;
148const int threads = (load_steps >= 1024) ? 1024 : load_steps;
149
150dim3 block(threads);
151dim3 grid(batch_size, sampled_tokens);
152
153scatter_tokens_impl<T><<<grid, block, 0, stream>>>(all_activations,
154layer_activations,
155gather_indices,
156sampled_tokens,
157channels,
158read_batch_stride,
159read_seq_stride,
160write_batch_stride,
161write_seq_stride);
162}
163
164template void launch_scatter_tokens<float>(float*,
165float*,
166int32_t*,
167int32_t,
168int32_t,
169int32_t,
170int32_t,
171int32_t,
172int32_t,
173int32_t,
174cudaStream_t);
175
176template void launch_scatter_tokens<__half>(__half*,
177__half*,
178int32_t*,
179int32_t,
180int32_t,
181int32_t,
182int32_t,
183int32_t,
184int32_t,
185int32_t,
186cudaStream_t);
187