deepspeed

Форк
0
/
quant_reduce.cu 
263 строки · 11.0 Кб
1
// Copyright (c) Microsoft Corporation.
2
// SPDX-License-Identifier: Apache-2.0
3

4
// DeepSpeed Team
5

6
#include <cstdio>
7
#include "dequantization_utils.h"
8
#include "ds_kernel_utils.h"
9
#include "memory_access_utils.h"
10
#include "quantization_utils.h"
11
#include "reduction_utils.h"
12

13
using rop = reduce::ROpType;
14

15
/*
16
TODO(cmikeh2): Add implementation that better handles larger nodes. It would like make sense
17
to leverage some parallel reductions here to improve performance.
18
*/
19

20
template <int numBits, int numTensors, int totalChunks, quantize::Type quantType>
21
__global__ void __launch_bounds__(1024) dequant_reduce(int8_t* reduced_data,
22
                                                       float* reduced_scales,
23
                                                       const int8_t* input_data,
24
                                                       const float* input_scales,
25
                                                       int elems_per_out_group,
26
                                                       int elems_per_in_tensor,
27
                                                       int groups_per_in_tensor,
28
                                                       int elems_per_in_group,
29
                                                       int num_tensors)
30
{
31
    cg::thread_block tb = cg::this_thread_block();
32
    cg::thread_block_tile<hw_warp_size> warp = cg::tiled_partition<hw_warp_size>(tb);
33

34
    // NOTE(cmikeh2): This probably could be hardcoded to a larger number,
35
    // but that means even stronger restrictions on the number of elements per group
36
    // A performance analysis here might be beneficial
37
    constexpr int mem_granularity = (numBits == 8) ? 8 : 4;
38
    constexpr int elems_per_load = mem_granularity / sizeof(int8_t);  // div by 1
39
    constexpr int storage_values = 16 / sizeof(__half2);
40

41
    const int block_offset = tb.group_index().x * elems_per_out_group;
42
    const int elem_offset = tb.thread_index().x * elems_per_load;
43
    const int base_offset = block_offset + elem_offset;
44
    const int stride = tb.group_dim().x * elems_per_load;
45

46
    __half2 local_buffer[totalChunks * storage_values];
47

48
    quantize::GroupStats<quantType> stats;
49

50
#pragma unroll
51
    for (int i = 0; i < totalChunks; i++) {
52
        __half2* iteration_buffer = local_buffer + i * storage_values;
53

54
#pragma unroll
55
        for (int j = 0; j < storage_values; j++) {
56
            iteration_buffer[j] = reduce::init<rop::Add, __half2>();
57
        }
58

59
        const int iter_offset = i * stride + base_offset;
60
        const int iter_scale_idx = iter_offset / elems_per_in_group;
61
        bool do_loads = i * stride + elem_offset < elems_per_out_group;
62

63
        if (numTensors > 0) {
64
#pragma unroll
65
            for (int j = 0; j < numTensors; j++) {
66
                if (do_loads) {
67
                    int8_t load_buffer[elems_per_load];
68

69
                    mem_access::load_global<mem_granularity>(
70
                        load_buffer, input_data + j * elems_per_in_tensor + iter_offset);
71

72
                    quantize::Params<quantType, numBits> params(
73
                        input_scales + j * groups_per_in_tensor, iter_scale_idx);
74

75
                    __half2 dequant_buffer[storage_values];
76
                    dequantize::chunk<numBits, quantType>(dequant_buffer, load_buffer, params);
77

78
#pragma unroll
79
                    for (int k = 0; k < storage_values; k++) {
80
                        iteration_buffer[k] =
81
                            reduce::element<rop::Add>(iteration_buffer[k], dequant_buffer[k]);
82
                    }
83
                }
84
            }
85
        } else {
86
#pragma unroll 4
87
            for (int j = 0; j < num_tensors; j++) {
88
                if (do_loads) {
89
                    int8_t load_buffer[elems_per_load];
90

91
                    mem_access::load_global<mem_granularity>(
92
                        load_buffer, input_data + j * elems_per_in_tensor + iter_offset);
93

94
                    quantize::Params<quantType, numBits> params(
95
                        input_scales + j * groups_per_in_tensor, iter_scale_idx);
96

97
                    __half2 dequant_buffer[storage_values];
98
                    dequantize::chunk<numBits, quantType>(dequant_buffer, load_buffer, params);
99

100
#pragma unroll
101
                    for (int k = 0; k < storage_values; k++) {
102
                        iteration_buffer[k] =
103
                            reduce::element<rop::Add>(iteration_buffer[k], dequant_buffer[k]);
104
                    }
105
                }
106
            }
107
        }
108

109
#pragma unroll
110
        for (int j = 0; j < storage_values; j++) { stats.update(iteration_buffer[j]); }
111
    }
112

113
    auto params = stats.template get_params<numBits, 1024>(tb, warp);
114

115
    if (tb.thread_index().x == 0) { params.store(reduced_scales, tb.group_index().x); }
116

117
#pragma unroll
118
    for (int i = 0; i < totalChunks; i++) {
119
        const int iter_offset = i * stride + base_offset;
120
        if (i * stride + elem_offset < elems_per_out_group) {
121
            int8_t local_output[elems_per_load];
122
            quantize::_chunk<numBits, quantType>(
123
                local_output, local_buffer + i * storage_values, params);
124
            mem_access::store_global<mem_granularity>(reduced_data + iter_offset, local_output);
125
        }
126
    }
127
}
128

129
template <int Power>
130
int32_t pow2_round(int32_t raw_value)
131
{
132
    return (((raw_value - 1) >> Power) + 1) << Power;
133
}
134

135
#define LAUNCH_DEQUANT_REDUCE(num_chunks)                      \
136
    dequant_reduce<numBits, numTensors, num_chunks, quantType> \
137
        <<<grid, block, 0, stream>>>(reduced_data,             \
138
                                     reduced_scales,           \
139
                                     input_data,               \
140
                                     input_scales,             \
141
                                     elems_per_out_group,      \
142
                                     elems_per_in_tensor,      \
143
                                     groups_per_in_tensor,     \
144
                                     elems_per_in_group,       \
145
                                     num_tensors);
146

147
template <int numBits, int numTensors, quantize::Type quantType>
148
void launch_dequant_reduce_impl(int8_t* reduced_data,
149
                                float* reduced_scales,
150
                                const int8_t* input_data,
151
                                const float* input_scales,
152
                                int out_groups,
153
                                int elems_per_out_group,
154
                                int elems_per_in_tensor,
155
                                int groups_per_in_tensor,
156
                                int elems_per_in_group,
157
                                int num_tensors,
158
                                cudaStream_t stream)
159
{
160
    // This is a coincidence. This is derived by 8 halves per 16 bytes with 2-way packing for int4
161
    constexpr int elems_per_thread = numBits;
162
    const int one_step_threads =
163
        next_pow2((elems_per_out_group + elems_per_thread - 1) / (elems_per_thread));
164
    // TODO(cmikeh2): Tune this
165
    const int threads = (one_step_threads < 1024) ? one_step_threads : 1024;
166

167
    dim3 block(threads);
168
    dim3 grid(out_groups);
169

170
    const int elems_per_step = threads * elems_per_thread;
171
    const int unroll_raw = (elems_per_out_group + elems_per_step - 1) / elems_per_step;
172

173
    const int unroll = (unroll_raw >= 4) ? pow2_round<1>(unroll_raw) : unroll_raw;
174

175
    if (unroll == 1) {
176
        // 0-4096 elems
177
        LAUNCH_DEQUANT_REDUCE(1);
178
    } else if (unroll == 2) {
179
        // 4097-8192 etc...
180
        LAUNCH_DEQUANT_REDUCE(2);
181
    } else if (unroll == 3) {
182
        LAUNCH_DEQUANT_REDUCE(3);
183
    } else if (unroll == 4) {
184
        LAUNCH_DEQUANT_REDUCE(4);
185
    } else if (unroll == 6) {
186
        LAUNCH_DEQUANT_REDUCE(6);
187
    } else if (unroll == 8) {
188
        LAUNCH_DEQUANT_REDUCE(8);
189
    } else if (unroll == 10) {
190
        LAUNCH_DEQUANT_REDUCE(10);
191
    } else if (unroll == 12) {
192
        // 48k limit
193
        LAUNCH_DEQUANT_REDUCE(12);
194
    } else {
195
        assert(false);
196
    }
197
}
198

199
#define LAUNCH_DEQUANT_REDUCE_IMPL(NUM_BITS, NUM_GPUS, QUANT_TYPE)                   \
200
    launch_dequant_reduce_impl<NUM_BITS, NUM_GPUS, QUANT_TYPE>(reduced_data,         \
201
                                                               reduced_scales,       \
202
                                                               input_data,           \
203
                                                               input_scales,         \
204
                                                               out_groups,           \
205
                                                               elems_per_out_group,  \
206
                                                               elems_per_in_tensor,  \
207
                                                               groups_per_in_tensor, \
208
                                                               elems_per_in_group,   \
209
                                                               num_gpus,             \
210
                                                               stream);
211

212
void launch_dequant_reduce(int8_t* reduced_data,
213
                           float* reduced_scales,
214
                           const int8_t* input_data,
215
                           const float* input_scales,
216
                           int num_gpus,
217
                           int num_bits,
218
                           quantize::Type quant_type,
219
                           int out_groups,
220
                           int elems_per_out_group,
221
                           int elems_per_in_tensor,
222
                           int groups_per_in_tensor,
223
                           int elems_per_in_group,
224
                           cudaStream_t stream)
225
{
226
    if (quant_type == quantize::Type::Symmetric) {
227
        if (num_bits == 4) {
228
            if (num_gpus == 8) {
229
                LAUNCH_DEQUANT_REDUCE_IMPL(4, 8, quantize::Type::Symmetric);
230
            } else if (num_gpus == 16) {
231
                LAUNCH_DEQUANT_REDUCE_IMPL(4, 16, quantize::Type::Symmetric);
232
            } else {
233
                LAUNCH_DEQUANT_REDUCE_IMPL(4, -1, quantize::Type::Symmetric);
234
            }
235
        } else if (num_bits == 8) {
236
            if (num_gpus == 8) {
237
                LAUNCH_DEQUANT_REDUCE_IMPL(8, 8, quantize::Type::Symmetric);
238
            } else if (num_gpus == 16) {
239
                LAUNCH_DEQUANT_REDUCE_IMPL(8, 16, quantize::Type::Symmetric);
240
            } else {
241
                LAUNCH_DEQUANT_REDUCE_IMPL(8, -1, quantize::Type::Symmetric);
242
            }
243
        }
244
    } else if (quant_type == quantize::Type::Asymmetric) {
245
        if (num_bits == 4) {
246
            if (num_gpus == 8) {
247
                LAUNCH_DEQUANT_REDUCE_IMPL(4, 8, quantize::Type::Asymmetric);
248
            } else if (num_gpus == 16) {
249
                LAUNCH_DEQUANT_REDUCE_IMPL(4, 16, quantize::Type::Asymmetric);
250
            } else {
251
                LAUNCH_DEQUANT_REDUCE_IMPL(4, -1, quantize::Type::Asymmetric);
252
            }
253
        } else if (num_bits == 8) {
254
            if (num_gpus == 8) {
255
                LAUNCH_DEQUANT_REDUCE_IMPL(8, 8, quantize::Type::Asymmetric);
256
            } else if (num_gpus == 16) {
257
                LAUNCH_DEQUANT_REDUCE_IMPL(8, 16, quantize::Type::Asymmetric);
258
            } else {
259
                LAUNCH_DEQUANT_REDUCE_IMPL(8, -1, quantize::Type::Asymmetric);
260
            }
261
        }
262
    }
263
}
264

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.