1
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
3
// Licensed under the Apache License, Version 2.0 (the "License");
4
// you may not use this file except in compliance with the License.
5
// You may obtain a copy of the License at
7
// http://www.apache.org/licenses/LICENSE-2.0
9
// Unless required by applicable law or agreed to in writing, software
10
// distributed under the License is distributed on an "AS IS" BASIS,
11
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
// See the License for the specific language governing permissions and
13
// limitations under the License.
15
#include <curand_kernel.h>
18
#include "paddle/extension.h"
20
#define CHECK_INPUT(x) PD_CHECK(x.is_gpu(), #x " must be a GPU Tensor.")
22
#define FINAL_MASK 0xFFFFFFFF
24
#define FIXED_BLOCK_DIM_BASE(dim, ...) \
26
constexpr auto kBlockDim = (dim); \
31
#define FIXED_BLOCK_DIM(...) \
32
FIXED_BLOCK_DIM_BASE(1024, ##__VA_ARGS__); \
33
FIXED_BLOCK_DIM_BASE(512, ##__VA_ARGS__); \
34
FIXED_BLOCK_DIM_BASE(256, ##__VA_ARGS__); \
35
FIXED_BLOCK_DIM_BASE(128, ##__VA_ARGS__); \
36
FIXED_BLOCK_DIM_BASE(64, ##__VA_ARGS__); \
37
FIXED_BLOCK_DIM_BASE(32, ##__VA_ARGS__)
39
template <paddle::DataType D>
43
class PDTraits<paddle::DataType::FLOAT32> {
45
typedef float DataType;
50
class PDTraits<paddle::DataType::FLOAT16> {
52
typedef half DataType;
53
typedef paddle::float16 data_t;
56
struct SegmentOffsetIter {
57
explicit SegmentOffsetIter(int num_cols) : num_cols_(num_cols) {}
59
__host__ __device__ __forceinline__ int operator()(int idx) const {
60
return idx * num_cols_;
68
__device__ __forceinline__ Pair() {}
69
__device__ __forceinline__ Pair(T value, int id) : v(value), id(id) {}
71
__device__ __forceinline__ void set(T value, int id) {
76
__device__ __forceinline__ void operator=(const Pair<T>& in) {
81
__device__ __forceinline__ bool operator<(const T value) const {
82
return ((float)v < (float)value);
85
__device__ __forceinline__ bool operator>(const T value) const {
86
return ((float)v > (float)value);
88
__device__ __forceinline__ bool operator<(const Pair<T>& in) const {
89
return ((float)v < (float)in.v) || (((float)v == (float)in.v) && (id > in.id));
92
__device__ __forceinline__ bool operator>(const Pair<T>& in) const {
93
return ((float)v > (float)in.v) || (((float)v == (float)in.v) && (id < in.id));
100
inline int div_up(int a, int n)
102
return (a + n - 1) / n;
105
__global__ void setup_kernel(curandState_t *state, const uint64_t seed, const int bs) {
106
int idx = blockIdx.x * blockDim.x + threadIdx.x;
107
for (int i = idx; i < bs; i += gridDim.x * blockDim.x) {
108
curand_init(seed, 0, i, &state[i]);
113
__device__ __forceinline__ void AddTo(Pair<T> topk[],
116
for (int k = beam_size - 2; k >= 0; k--) {
118
topk[k + 1] = topk[k];
127
template <typename T, int BlockSize>
128
__device__ __forceinline__ void GetTopK(Pair<T> topk[],
134
if (topk[beam_size - 1] < src[idx]) {
135
Pair<T> tmp(src[idx], idx);
136
AddTo<T>(topk, tmp, beam_size);
142
template <typename T, int BlockSize>
143
__device__ __forceinline__ void GetTopK(Pair<T> topk[],
150
if (topk[beam_size - 1] < src[idx]) {
151
Pair<T> tmp(src[idx], idx);
153
AddTo<T>(topk, tmp, beam_size);
160
template <typename T, int MaxLength, int BlockSize>
161
__device__ __forceinline__ void ThreadGetTopK(Pair<T> topk[],
171
int length = (*beam) < beam_size ? *beam : beam_size;
174
GetTopK<T, BlockSize>(topk, src, tid, dim, length);
176
for (int k = 0; k < MaxLength; k++) {
177
if (k < MaxLength - (*beam)) {
178
topk[k] = topk[k + *beam];
180
topk[k].set(std::numeric_limits<T>::min(), -1);
184
GetTopK<T, BlockSize>(
185
topk + MaxLength - *beam, src, tid, dim, *max, length);
189
*max = topk[MaxLength - 1];
190
if ((*max).id == -1) *is_empty = true;
196
__forceinline__ __device__ Pair<T> WarpReduce(Pair<T> input) {
198
for (int offset = 16; offset > 0; offset >>= 1) {
199
T tmp_val = __shfl_down_sync(FINAL_MASK, input.v, static_cast<unsigned>(offset), 32);
200
int tmp_id = __shfl_down_sync(FINAL_MASK, input.id, static_cast<unsigned>(offset), 32);
201
if ((float)input.v < (float)tmp_val) {
209
template <typename T, int MaxLength, int BlockSize>
210
__device__ __forceinline__ void BlockReduce(Pair<T> shared_max[],
221
Pair<T> input_now = topk[0];
222
input_now = WarpReduce(input_now);
225
shared_max[wid] = input_now;
228
input_now = (tid < BlockSize / 32)
230
: Pair<T>(std::numeric_limits<T>::min(), -1);
232
input_now = WarpReduce(input_now);
233
if (lane == 0) shared_max[0] = input_now;
237
beam_max[*count] = shared_max[0];
240
int tid_max = shared_max[0].id % BlockSize;
241
if (tid == tid_max) {
244
if (--(*k) == 0) break;
247
if (tid == tid_max) {
248
if (*beam < MaxLength) {
249
topk[0] = topk[*beam];
254
if (*beam >= MaxLength) break;
257
mask = __ballot_sync(FINAL_MASK, true);
258
if (tid_max / 32 == wid) {
259
if (__shfl_down_sync(FINAL_MASK, *beam, tid_max % 32, 32) ==
267
template <typename T, int MaxLength, int TopPBeamTopK, int BlockSize>
268
__global__ void KeMatrixTopPBeamTopK(const T* src,
270
int64_t *out_id, // topk id
271
T *out_val, // topk val
273
curandState_t *state,
275
int *count_iter_begin) {
276
const int tid = threadIdx.x;
277
const int wid = tid / 32;
278
const int lane = tid % 32;
279
const int bid = blockIdx.x;
281
int top_num = TopPBeamTopK;
282
float top_p_num = (float)top_ps[bid];
284
__shared__ Pair<T> shared_max[BlockSize / 32];
285
__shared__ Pair<T> beam_max[TopPBeamTopK];
287
Pair<T> topk[MaxLength];
288
int beam = MaxLength;
290
bool is_empty = false;
291
bool firststep = true;
292
__shared__ int count;
298
for (int j = 0; j < MaxLength; j++) {
299
topk[j].set(std::numeric_limits<T>::min(), -1);
303
ThreadGetTopK<T, MaxLength, BlockSize>(topk,
306
src + bid * vocab_size,
312
BlockReduce<T, MaxLength, BlockSize>(shared_max,
323
count_iter_begin[bid] = count_iter[bid];
324
float rand_top_p = curand_uniform(state + bid) * top_p_num;
325
top_ps[bid] = (T)rand_top_p;
326
float sum_prob = 0.0f;
328
for(int i = 0; i < TopPBeamTopK; i++) {
329
sum_prob += (float)(beam_max[i].v);
330
if(sum_prob >= rand_top_p) {
331
count_iter_begin[bid] += 1;
332
out_id[bid] = (int64_t)beam_max[i].id;
333
out_val[bid] = beam_max[i].v;
340
__global__ void SetCountIter(int *count_iter, int num) {
341
int tid = threadIdx.x;
342
int bid = blockIdx.x;
343
int idx = bid * blockDim.x + tid;
344
for (int i = idx; i < num; i += gridDim.x * blockDim.x) {
350
__global__ void FillIndex(T* indices, T num_rows, T num_cols) {
351
int col_id = threadIdx.x;
352
int row_id = blockIdx.x;
354
for (T j = row_id; j < num_rows; j += gridDim.x) {
355
for (T i = col_id; i < num_cols; i += blockDim.x) {
356
indices[j * num_cols + i] = i;
361
struct BlockPrefixCallbackOp {
365
__device__ BlockPrefixCallbackOp(float running_total): running_total(running_total) {}
366
// Callback operator to be entered by the first warp of threads in the block.
367
// Thread-0 is responsible for returning a value for seeding the block-wide scan.
368
__device__ float operator()(float block_aggregate)
370
float old_prefix = running_total;
371
running_total += block_aggregate;
376
template <typename T, int BLOCK_SIZE>
377
__global__ void topp_sampling(T *sorted_probs,
385
int *count_iter_begin) {
386
__shared__ int stop_shared;
387
__shared__ float rand_p;
388
const int tid = threadIdx.x;
389
const int bid = blockIdx.x;
390
constexpr int WARP_SIZE = 32;
391
constexpr int NUM_WARPS = BLOCK_SIZE / WARP_SIZE;
392
const int lane_id = tid % WARP_SIZE;
393
const int warp_id = tid / WARP_SIZE;
394
const float p_t = (float)top_ps[bid];
399
if (count_iter_begin[bid] == count_iter[bid + 1]) {
404
typedef cub::BlockScan<float, BLOCK_SIZE> BlockScan;
405
__shared__ typename BlockScan::TempStorage temp_storage;
406
__shared__ uint32_t selected_shared[NUM_WARPS];
408
// Initialize running total
409
BlockPrefixCallbackOp prefix_op(0);
412
selected_shared[warp_id] = 0;
416
int offset = bid * vocab_size;
417
int end = ((vocab_size + BLOCK_SIZE - 1) / BLOCK_SIZE) * BLOCK_SIZE;
419
float thread_offset = 0;
420
for (int i = tid; i < end; i += BLOCK_SIZE) {
421
float thread_count = (i < vocab_size) ? (float)sorted_probs[offset + i] : 0.f;
422
BlockScan(temp_storage).InclusiveSum(thread_count, thread_offset, prefix_op);
424
uint32_t activate_mask = __ballot_sync(FINAL_MASK, rand_p <= thread_offset);
427
if (activate_mask != 0) {
429
atomicAdd(&stop_shared, 1);
430
selected_shared[warp_id] = activate_mask;
434
if(stop_shared > 0) {
439
bool skip = (selected_shared[warp_id] > 0) ? false : true;
440
for (int i=0; i < warp_id; i++) {
441
if(selected_shared[i] != 0) {
446
int active_lane_id = WARP_SIZE - __popc(selected_shared[warp_id]); // first not 0
447
if (lane_id == active_lane_id) {
448
// printf("active_lane_id: %d, i_activate: %d.\n", active_lane_id, i_activate);
449
// for (int i=0; i < active_lane_id; i++) {
450
// printf("p %d, value: %f\n", i, (float)(sorted_probs[offset + i]));
452
out_id[bid] = sorted_id[offset + i_activate];
453
out_val[bid] = sorted_probs[offset + i_activate];
458
int GetBlockSize(int vocab_size) {
459
if (vocab_size > 512) {
461
} else if (vocab_size > 256) {
463
} else if (vocab_size > 128) {
465
} else if (vocab_size > 64) {
473
__global__ void print_kernel(T *input, int size) {
475
for (int i=0; i < size; i++) {
477
printf("%f, ", (float)input[i]);
479
printf("%f]\n", (float)input[i]);
484
template <paddle::DataType D>
485
std::vector<paddle::Tensor> top_p_sampling_kernel(const paddle::Tensor& x, const paddle::Tensor& top_ps, int random_seed) {
486
typedef PDTraits<D> traits_;
487
typedef typename traits_::DataType DataType_;
488
typedef typename traits_::data_t data_t;
489
std::vector<int64_t> shape = x.shape();
490
auto cu_stream = x.stream();
493
int p_num = top_ps.numel();
494
PD_CHECK(bs == p_num, "PD_CHECK returns ", false, ", expected bs == p_num.");
495
int vocab_size = shape[1];
496
auto topp_ids = paddle::full({bs, 1}, 1, paddle::DataType::INT64, x.place());
497
auto topp_probs = paddle::full({bs, 1}, 1, x.dtype(), x.place());
498
auto inds_input = paddle::full({bs, vocab_size}, 1, paddle::DataType::INT64, x.place());
499
auto sorted_out = paddle::full({bs, vocab_size}, 1, x.dtype(), x.place());
500
auto sorted_id = paddle::full({bs, vocab_size}, 1, paddle::DataType::INT64, x.place());
503
int BlockSize = GetBlockSize(vocab_size);
505
FIXED_BLOCK_DIM(FillIndex<int64_t><<<bs, kBlockDim, 0, cu_stream>>>(inds_input.data<int64_t>(), bs, vocab_size));
507
PD_THROW("the input data shape has error in the FillIndex kernel.");
511
static int count = 0;
512
static curandState_t* dev_curand_states;
514
#if CUDA_VERSION >= 11020
515
cudaMallocAsync(&dev_curand_states, bs * sizeof(curandState_t), cu_stream);
517
cudaMalloc(&dev_curand_states, bs * sizeof(curandState_t));
520
srand((unsigned int)(time(NULL)));
521
setup_kernel<<<1, 256, 0, cu_stream>>>(dev_curand_states, rand() % random_seed, bs);
522
PD_CHECK(bs == p_num, "PD_CHECK returns ", false, ", expected bs == p_num.");
524
auto count_iter = paddle::empty({bs + 1}, paddle::DataType::INT32, x.place());
525
auto count_iter_begin = paddle::empty({bs}, paddle::DataType::INT32, x.place());
526
SetCountIter<<<1, 256, 0, cu_stream>>>(count_iter.data<int>(), bs + 1);
528
constexpr int TopKMaxLength = 1;
529
constexpr int TopPBeamTopK = 1;
532
KeMatrixTopPBeamTopK<DataType_, TopKMaxLength, TopPBeamTopK, kBlockDim><<<bs, kBlockDim, 0, cu_stream>>>(
533
reinterpret_cast<DataType_*>(const_cast<data_t*>(x.data<data_t>())),
534
reinterpret_cast<DataType_*>(const_cast<data_t*>(top_ps.data<data_t>())),
535
topp_ids.data<int64_t>(),
536
reinterpret_cast<DataType_*>(topp_probs.data<data_t>()),
539
count_iter.data<int>(),
540
count_iter_begin.data<int>()));
542
PD_THROW("the input data shape has error in the topp_beam_topk kernel.");
544
// if (count % random_seed == random_seed - 1) {
545
// #if CUDA_VERSION >= 11020
546
// cudaFreeAsync(dev_curand_states, cu_stream);
548
// cudaFree(dev_curand_states);
553
size_t temp_storage_bytes = 0;
555
cub::TransformInputIterator<int, SegmentOffsetIter, int*>
556
segment_offsets_t_begin(count_iter_begin.data<int>(),
557
SegmentOffsetIter(vocab_size));
559
cub::TransformInputIterator<int, SegmentOffsetIter, int*>
560
segment_offsets_t_end(count_iter.data<int>(),
561
SegmentOffsetIter(vocab_size));
563
DataType_ *x_ptr = reinterpret_cast<DataType_*>(const_cast<data_t*>(x.data<data_t>()));
564
DataType_ *sorted_out_ptr = reinterpret_cast<DataType_*>(const_cast<data_t*>(sorted_out.data<data_t>()));
565
int64_t *in_id_ptr = inds_input.data<int64_t>();
566
int64_t *out_id_ptr = sorted_id.data<int64_t>();
568
cub::DeviceSegmentedRadixSort::SortPairsDescending(nullptr,
576
segment_offsets_t_begin,
577
segment_offsets_t_end + 1,
582
temp_storage_bytes = div_up(temp_storage_bytes, 256) * 256;
583
int64_t temp_size = temp_storage_bytes;
584
auto temp_storage = paddle::empty({temp_size}, paddle::DataType::UINT8, x.place());
586
cub::DeviceSegmentedRadixSort::SortPairsDescending(
587
temp_storage.data<uint8_t>(),
595
segment_offsets_t_begin,
596
segment_offsets_t_end + 1,
603
topp_sampling<DataType_, kBlockDim><<<bs, kBlockDim, 0, cu_stream>>>(
606
reinterpret_cast<DataType_*>(topp_probs.data<data_t>()),
607
topp_ids.data<int64_t>(),
608
reinterpret_cast<DataType_*>(const_cast<data_t*>(top_ps.data<data_t>())),
611
count_iter.data<int>(),
612
count_iter_begin.data<int>()));
614
PD_THROW("the input data shape has error in the topp_sampling kernel.");
616
return {topp_probs, topp_ids};
620
std::vector<paddle::Tensor> TopPSampling(const paddle::Tensor& x, const paddle::Tensor& top_ps, int random_seed) {
622
case paddle::DataType::FLOAT16: {
623
return top_p_sampling_kernel<paddle::DataType::FLOAT16>(
629
case paddle::DataType::FLOAT32: {
630
return top_p_sampling_kernel<paddle::DataType::FLOAT32>(
638
"NOT supported data type. "
639
"Only float16 and float32 are supported. ");
645
std::vector<std::vector<int64_t>> TopPSamplingInferShape(const std::vector<int64_t>& x_shape,
646
const std::vector<int64_t>& top_ps_shape) {
647
std::vector<int64_t> out_probs_shape = {x_shape[0], 1};
648
std::vector<int64_t> out_ids_shape = {x_shape[0], 1};
649
return {out_probs_shape, out_ids_shape};
652
std::vector<paddle::DataType> TopPSamplingInferDtype(const paddle::DataType& x_dtype,
653
const paddle::DataType& top_ps_dtype) {
654
return {x_dtype, paddle::DataType::INT64};
657
PD_BUILD_OP(topp_sampling)
658
.Inputs({"x", "top_ps"})
659
.Outputs({"topp_probs", "topp_ids"})
660
.Attrs({"random_seed: int"})
661
.SetKernelFn(PD_KERNEL(TopPSampling))
662
.SetInferShapeFn(PD_INFER_SHAPE(TopPSamplingInferShape))
663
.SetInferDtypeFn(PD_INFER_DTYPE(TopPSamplingInferDtype));