deepspeed
263 строки · 11.0 Кб
1// Copyright (c) Microsoft Corporation.
2// SPDX-License-Identifier: Apache-2.0
3
4// DeepSpeed Team
5
6#include <cstdio>
7#include "dequantization_utils.h"
8#include "ds_kernel_utils.h"
9#include "memory_access_utils.h"
10#include "quantization_utils.h"
11#include "reduction_utils.h"
12
13using rop = reduce::ROpType;
14
15/*
16TODO(cmikeh2): Add implementation that better handles larger nodes. It would like make sense
17to leverage some parallel reductions here to improve performance.
18*/
19
20template <int numBits, int numTensors, int totalChunks, quantize::Type quantType>
21__global__ void __launch_bounds__(1024) dequant_reduce(int8_t* reduced_data,
22float* reduced_scales,
23const int8_t* input_data,
24const float* input_scales,
25int elems_per_out_group,
26int elems_per_in_tensor,
27int groups_per_in_tensor,
28int elems_per_in_group,
29int num_tensors)
30{
31cg::thread_block tb = cg::this_thread_block();
32cg::thread_block_tile<hw_warp_size> warp = cg::tiled_partition<hw_warp_size>(tb);
33
34// NOTE(cmikeh2): This probably could be hardcoded to a larger number,
35// but that means even stronger restrictions on the number of elements per group
36// A performance analysis here might be beneficial
37constexpr int mem_granularity = (numBits == 8) ? 8 : 4;
38constexpr int elems_per_load = mem_granularity / sizeof(int8_t); // div by 1
39constexpr int storage_values = 16 / sizeof(__half2);
40
41const int block_offset = tb.group_index().x * elems_per_out_group;
42const int elem_offset = tb.thread_index().x * elems_per_load;
43const int base_offset = block_offset + elem_offset;
44const int stride = tb.group_dim().x * elems_per_load;
45
46__half2 local_buffer[totalChunks * storage_values];
47
48quantize::GroupStats<quantType> stats;
49
50#pragma unroll
51for (int i = 0; i < totalChunks; i++) {
52__half2* iteration_buffer = local_buffer + i * storage_values;
53
54#pragma unroll
55for (int j = 0; j < storage_values; j++) {
56iteration_buffer[j] = reduce::init<rop::Add, __half2>();
57}
58
59const int iter_offset = i * stride + base_offset;
60const int iter_scale_idx = iter_offset / elems_per_in_group;
61bool do_loads = i * stride + elem_offset < elems_per_out_group;
62
63if (numTensors > 0) {
64#pragma unroll
65for (int j = 0; j < numTensors; j++) {
66if (do_loads) {
67int8_t load_buffer[elems_per_load];
68
69mem_access::load_global<mem_granularity>(
70load_buffer, input_data + j * elems_per_in_tensor + iter_offset);
71
72quantize::Params<quantType, numBits> params(
73input_scales + j * groups_per_in_tensor, iter_scale_idx);
74
75__half2 dequant_buffer[storage_values];
76dequantize::chunk<numBits, quantType>(dequant_buffer, load_buffer, params);
77
78#pragma unroll
79for (int k = 0; k < storage_values; k++) {
80iteration_buffer[k] =
81reduce::element<rop::Add>(iteration_buffer[k], dequant_buffer[k]);
82}
83}
84}
85} else {
86#pragma unroll 4
87for (int j = 0; j < num_tensors; j++) {
88if (do_loads) {
89int8_t load_buffer[elems_per_load];
90
91mem_access::load_global<mem_granularity>(
92load_buffer, input_data + j * elems_per_in_tensor + iter_offset);
93
94quantize::Params<quantType, numBits> params(
95input_scales + j * groups_per_in_tensor, iter_scale_idx);
96
97__half2 dequant_buffer[storage_values];
98dequantize::chunk<numBits, quantType>(dequant_buffer, load_buffer, params);
99
100#pragma unroll
101for (int k = 0; k < storage_values; k++) {
102iteration_buffer[k] =
103reduce::element<rop::Add>(iteration_buffer[k], dequant_buffer[k]);
104}
105}
106}
107}
108
109#pragma unroll
110for (int j = 0; j < storage_values; j++) { stats.update(iteration_buffer[j]); }
111}
112
113auto params = stats.template get_params<numBits, 1024>(tb, warp);
114
115if (tb.thread_index().x == 0) { params.store(reduced_scales, tb.group_index().x); }
116
117#pragma unroll
118for (int i = 0; i < totalChunks; i++) {
119const int iter_offset = i * stride + base_offset;
120if (i * stride + elem_offset < elems_per_out_group) {
121int8_t local_output[elems_per_load];
122quantize::_chunk<numBits, quantType>(
123local_output, local_buffer + i * storage_values, params);
124mem_access::store_global<mem_granularity>(reduced_data + iter_offset, local_output);
125}
126}
127}
128
129template <int Power>
130int32_t pow2_round(int32_t raw_value)
131{
132return (((raw_value - 1) >> Power) + 1) << Power;
133}
134
135#define LAUNCH_DEQUANT_REDUCE(num_chunks) \
136dequant_reduce<numBits, numTensors, num_chunks, quantType> \
137<<<grid, block, 0, stream>>>(reduced_data, \
138reduced_scales, \
139input_data, \
140input_scales, \
141elems_per_out_group, \
142elems_per_in_tensor, \
143groups_per_in_tensor, \
144elems_per_in_group, \
145num_tensors);
146
147template <int numBits, int numTensors, quantize::Type quantType>
148void launch_dequant_reduce_impl(int8_t* reduced_data,
149float* reduced_scales,
150const int8_t* input_data,
151const float* input_scales,
152int out_groups,
153int elems_per_out_group,
154int elems_per_in_tensor,
155int groups_per_in_tensor,
156int elems_per_in_group,
157int num_tensors,
158cudaStream_t stream)
159{
160// This is a coincidence. This is derived by 8 halves per 16 bytes with 2-way packing for int4
161constexpr int elems_per_thread = numBits;
162const int one_step_threads =
163next_pow2((elems_per_out_group + elems_per_thread - 1) / (elems_per_thread));
164// TODO(cmikeh2): Tune this
165const int threads = (one_step_threads < 1024) ? one_step_threads : 1024;
166
167dim3 block(threads);
168dim3 grid(out_groups);
169
170const int elems_per_step = threads * elems_per_thread;
171const int unroll_raw = (elems_per_out_group + elems_per_step - 1) / elems_per_step;
172
173const int unroll = (unroll_raw >= 4) ? pow2_round<1>(unroll_raw) : unroll_raw;
174
175if (unroll == 1) {
176// 0-4096 elems
177LAUNCH_DEQUANT_REDUCE(1);
178} else if (unroll == 2) {
179// 4097-8192 etc...
180LAUNCH_DEQUANT_REDUCE(2);
181} else if (unroll == 3) {
182LAUNCH_DEQUANT_REDUCE(3);
183} else if (unroll == 4) {
184LAUNCH_DEQUANT_REDUCE(4);
185} else if (unroll == 6) {
186LAUNCH_DEQUANT_REDUCE(6);
187} else if (unroll == 8) {
188LAUNCH_DEQUANT_REDUCE(8);
189} else if (unroll == 10) {
190LAUNCH_DEQUANT_REDUCE(10);
191} else if (unroll == 12) {
192// 48k limit
193LAUNCH_DEQUANT_REDUCE(12);
194} else {
195assert(false);
196}
197}
198
199#define LAUNCH_DEQUANT_REDUCE_IMPL(NUM_BITS, NUM_GPUS, QUANT_TYPE) \
200launch_dequant_reduce_impl<NUM_BITS, NUM_GPUS, QUANT_TYPE>(reduced_data, \
201reduced_scales, \
202input_data, \
203input_scales, \
204out_groups, \
205elems_per_out_group, \
206elems_per_in_tensor, \
207groups_per_in_tensor, \
208elems_per_in_group, \
209num_gpus, \
210stream);
211
212void launch_dequant_reduce(int8_t* reduced_data,
213float* reduced_scales,
214const int8_t* input_data,
215const float* input_scales,
216int num_gpus,
217int num_bits,
218quantize::Type quant_type,
219int out_groups,
220int elems_per_out_group,
221int elems_per_in_tensor,
222int groups_per_in_tensor,
223int elems_per_in_group,
224cudaStream_t stream)
225{
226if (quant_type == quantize::Type::Symmetric) {
227if (num_bits == 4) {
228if (num_gpus == 8) {
229LAUNCH_DEQUANT_REDUCE_IMPL(4, 8, quantize::Type::Symmetric);
230} else if (num_gpus == 16) {
231LAUNCH_DEQUANT_REDUCE_IMPL(4, 16, quantize::Type::Symmetric);
232} else {
233LAUNCH_DEQUANT_REDUCE_IMPL(4, -1, quantize::Type::Symmetric);
234}
235} else if (num_bits == 8) {
236if (num_gpus == 8) {
237LAUNCH_DEQUANT_REDUCE_IMPL(8, 8, quantize::Type::Symmetric);
238} else if (num_gpus == 16) {
239LAUNCH_DEQUANT_REDUCE_IMPL(8, 16, quantize::Type::Symmetric);
240} else {
241LAUNCH_DEQUANT_REDUCE_IMPL(8, -1, quantize::Type::Symmetric);
242}
243}
244} else if (quant_type == quantize::Type::Asymmetric) {
245if (num_bits == 4) {
246if (num_gpus == 8) {
247LAUNCH_DEQUANT_REDUCE_IMPL(4, 8, quantize::Type::Asymmetric);
248} else if (num_gpus == 16) {
249LAUNCH_DEQUANT_REDUCE_IMPL(4, 16, quantize::Type::Asymmetric);
250} else {
251LAUNCH_DEQUANT_REDUCE_IMPL(4, -1, quantize::Type::Asymmetric);
252}
253} else if (num_bits == 8) {
254if (num_gpus == 8) {
255LAUNCH_DEQUANT_REDUCE_IMPL(8, 8, quantize::Type::Asymmetric);
256} else if (num_gpus == 16) {
257LAUNCH_DEQUANT_REDUCE_IMPL(8, 16, quantize::Type::Asymmetric);
258} else {
259LAUNCH_DEQUANT_REDUCE_IMPL(8, -1, quantize::Type::Asymmetric);
260}
261}
262}
263}
264