paddlenlp

Форк
0
/
topp_sampling.cu 
663 строки · 22.0 Кб
1
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2
//
3
// Licensed under the Apache License, Version 2.0 (the "License");
4
// you may not use this file except in compliance with the License.
5
// You may obtain a copy of the License at
6
//
7
//     http://www.apache.org/licenses/LICENSE-2.0
8
//
9
// Unless required by applicable law or agreed to in writing, software
10
// distributed under the License is distributed on an "AS IS" BASIS,
11
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
// See the License for the specific language governing permissions and
13
// limitations under the License.
14

15
#include <curand_kernel.h>
16
#include <cuda_fp16.h>
17
#include "cub/cub.cuh"
18
#include "paddle/extension.h"
19

20
#define CHECK_INPUT(x) PD_CHECK(x.is_gpu(), #x " must be a GPU Tensor.")
21

22
#define FINAL_MASK 0xFFFFFFFF
23

24
#define FIXED_BLOCK_DIM_BASE(dim, ...) \
25
  case (dim): {                        \
26
    constexpr auto kBlockDim = (dim);  \
27
    __VA_ARGS__;                       \
28
  } break
29

30

31
#define FIXED_BLOCK_DIM(...)                 \
32
  FIXED_BLOCK_DIM_BASE(1024, ##__VA_ARGS__); \
33
  FIXED_BLOCK_DIM_BASE(512, ##__VA_ARGS__);  \
34
  FIXED_BLOCK_DIM_BASE(256, ##__VA_ARGS__);  \
35
  FIXED_BLOCK_DIM_BASE(128, ##__VA_ARGS__);  \
36
  FIXED_BLOCK_DIM_BASE(64, ##__VA_ARGS__);   \
37
  FIXED_BLOCK_DIM_BASE(32, ##__VA_ARGS__)
38

39
template <paddle::DataType D>
40
class PDTraits;
41

42
template <>
43
class PDTraits<paddle::DataType::FLOAT32> {
44
public:
45
  typedef float DataType;
46
  typedef float data_t;
47
};
48

49
template <>
50
class PDTraits<paddle::DataType::FLOAT16> {
51
public:
52
  typedef half DataType;
53
  typedef paddle::float16 data_t;
54
};
55

56
struct SegmentOffsetIter {
57
    explicit SegmentOffsetIter(int num_cols) : num_cols_(num_cols) {}
58

59
    __host__ __device__ __forceinline__ int operator()(int idx) const {
60
        return idx * num_cols_;
61
    }
62

63
    int num_cols_;
64
};
65

66
template <typename T>
67
struct Pair {
68
  __device__ __forceinline__ Pair() {}
69
  __device__ __forceinline__ Pair(T value, int id) : v(value), id(id) {}
70

71
  __device__ __forceinline__ void set(T value, int id) {
72
    v = value;
73
    id = id;
74
  }
75

76
  __device__ __forceinline__ void operator=(const Pair<T>& in) {
77
    v = in.v;
78
    id = in.id;
79
  }
80

81
  __device__ __forceinline__ bool operator<(const T value) const {
82
    return ((float)v < (float)value);
83
  }
84

85
  __device__ __forceinline__ bool operator>(const T value) const {
86
    return ((float)v > (float)value);
87
  }
88
  __device__ __forceinline__ bool operator<(const Pair<T>& in) const {
89
    return ((float)v < (float)in.v) || (((float)v == (float)in.v) && (id > in.id));
90
  }
91

92
  __device__ __forceinline__ bool operator>(const Pair<T>& in) const {
93
    return ((float)v > (float)in.v) || (((float)v == (float)in.v) && (id < in.id));
94
  }
95

96
  T v;
97
  int id;
98
};
99

100
inline int div_up(int a, int n)
101
{
102
    return (a + n - 1) / n;
103
}
104

105
__global__ void setup_kernel(curandState_t *state, const uint64_t seed, const int bs) {
106
  int idx = blockIdx.x * blockDim.x + threadIdx.x;
107
  for (int i = idx; i < bs; i += gridDim.x * blockDim.x) {
108
    curand_init(seed, 0, i, &state[i]);
109
  }
110
}
111

112
template <typename T>
113
__device__ __forceinline__ void AddTo(Pair<T> topk[],
114
                                      const Pair<T>& p,
115
                                      int beam_size) {
116
  for (int k = beam_size - 2; k >= 0; k--) {
117
    if (topk[k] < p) {
118
    topk[k + 1] = topk[k];
119
    } else {
120
    topk[k + 1] = p;
121
    return;
122
    }
123
  }
124
  topk[0] = p;
125
}
126

127
template <typename T, int BlockSize>
128
__device__ __forceinline__ void GetTopK(Pair<T> topk[],
129
                                        const T* src,
130
                                        int idx,
131
                                        int dim,
132
                                        int beam_size) {
133
  while (idx < dim) {
134
    if (topk[beam_size - 1] < src[idx]) {
135
    Pair<T> tmp(src[idx], idx);
136
    AddTo<T>(topk, tmp, beam_size);
137
    }
138
    idx += BlockSize;
139
  }
140
}
141

142
template <typename T, int BlockSize>
143
__device__ __forceinline__ void GetTopK(Pair<T> topk[],
144
                                        const T* src,
145
                                        int idx,
146
                                        int dim,
147
                                        const Pair<T>& max,
148
                                        int beam_size) {
149
  while (idx < dim) {
150
    if (topk[beam_size - 1] < src[idx]) {
151
        Pair<T> tmp(src[idx], idx);
152
        if (tmp < max) {
153
            AddTo<T>(topk, tmp, beam_size);
154
        }
155
    }
156
    idx += BlockSize;
157
  }
158
}
159

160
template <typename T, int MaxLength, int BlockSize>
161
__device__ __forceinline__ void ThreadGetTopK(Pair<T> topk[],
162
                                              int* beam,
163
                                              int beam_size,
164
                                              const T* src,
165
                                              bool* firstStep,
166
                                              bool* is_empty,
167
                                              Pair<T>* max,
168
                                              int dim,
169
                                              const int tid) {
170
  if (*beam > 0) {
171
    int length = (*beam) < beam_size ? *beam : beam_size;
172
    if (*firstStep) {
173
      *firstStep = false;
174
      GetTopK<T, BlockSize>(topk, src, tid, dim, length);
175
    } else {
176
      for (int k = 0; k < MaxLength; k++) {
177
        if (k < MaxLength - (*beam)) {
178
          topk[k] = topk[k + *beam];
179
        } else {
180
            topk[k].set(std::numeric_limits<T>::min(), -1);
181
        }
182
      }
183
      if (!(*is_empty)) {
184
        GetTopK<T, BlockSize>(
185
            topk + MaxLength - *beam, src, tid, dim, *max, length);
186
      }
187
    }
188

189
    *max = topk[MaxLength - 1];
190
    if ((*max).id == -1) *is_empty = true;
191
    *beam = 0;
192
  }
193
}
194

195
template <typename T>
196
__forceinline__ __device__ Pair<T> WarpReduce(Pair<T> input) {
197
#pragma unroll
198
    for (int offset = 16; offset > 0; offset >>= 1) {
199
        T tmp_val = __shfl_down_sync(FINAL_MASK, input.v, static_cast<unsigned>(offset), 32);
200
        int tmp_id = __shfl_down_sync(FINAL_MASK, input.id, static_cast<unsigned>(offset), 32);
201
        if ((float)input.v < (float)tmp_val) {
202
            input.v = tmp_val;
203
            input.id = tmp_id;
204
        }
205
    }
206
    return input;
207
}
208

209
template <typename T, int MaxLength, int BlockSize>
210
__device__ __forceinline__ void BlockReduce(Pair<T> shared_max[],
211
                                            Pair<T> topk[],
212
                                            Pair<T> beam_max[],
213
                                            int* beam,
214
                                            int* k,
215
                                            int *count,
216
                                            const int tid,
217
                                            const int wid,
218
                                            const int lane) {
219
  while (true) {
220
    __syncthreads();
221
    Pair<T> input_now = topk[0];
222
    input_now = WarpReduce(input_now);
223

224
    if (lane == 0) {
225
      shared_max[wid] = input_now;
226
    }
227
    __syncthreads();
228
    input_now = (tid < BlockSize / 32)
229
                    ? shared_max[lane]
230
                    : Pair<T>(std::numeric_limits<T>::min(), -1);
231
    if (wid == 0) {
232
      input_now = WarpReduce(input_now);
233
      if (lane == 0) shared_max[0] = input_now;
234
    }
235
    __syncthreads();
236
    if (tid == 0) {
237
      beam_max[*count] = shared_max[0]; 
238
      (*count)++;
239
    }
240
    int tid_max = shared_max[0].id % BlockSize;
241
    if (tid == tid_max) {
242
      (*beam)++;
243
    }
244
    if (--(*k) == 0) break;
245
    __syncthreads();
246

247
    if (tid == tid_max) {
248
        if (*beam < MaxLength) {
249
            topk[0] = topk[*beam];
250
        }
251
    }
252

253
    if (MaxLength < 5) {
254
      if (*beam >= MaxLength) break;
255
    } else {
256
      unsigned mask = 0u;
257
      mask = __ballot_sync(FINAL_MASK, true);
258
      if (tid_max / 32 == wid) {
259
        if (__shfl_down_sync(FINAL_MASK, *beam, tid_max % 32, 32) ==
260
            MaxLength)
261
          break;
262
      }
263
    }
264
  }
265
}
266

267
template <typename T, int MaxLength, int TopPBeamTopK, int BlockSize>
268
__global__ void KeMatrixTopPBeamTopK(const T* src,
269
                                     T *top_ps,
270
                                     int64_t *out_id, // topk id
271
                                     T *out_val, // topk val
272
                                     int vocab_size,
273
                                     curandState_t *state,
274
                                     int *count_iter,
275
                                     int *count_iter_begin) {
276
    const int tid = threadIdx.x;
277
    const int wid = tid / 32;
278
    const int lane = tid % 32;
279
    const int bid = blockIdx.x;
280

281
    int top_num = TopPBeamTopK;
282
    float top_p_num = (float)top_ps[bid];
283

284
    __shared__ Pair<T> shared_max[BlockSize / 32];
285
    __shared__ Pair<T> beam_max[TopPBeamTopK];
286

287
    Pair<T> topk[MaxLength];
288
    int beam = MaxLength;
289
    Pair<T> max;
290
    bool is_empty = false;
291
    bool firststep = true;
292
    __shared__ int count;
293

294
    if (tid == 0) {
295
        count = 0;
296
    }
297

298
    for (int j = 0; j < MaxLength; j++) {
299
        topk[j].set(std::numeric_limits<T>::min(), -1);
300
    }
301

302
    while (top_num) {
303
        ThreadGetTopK<T, MaxLength, BlockSize>(topk,
304
                                               &beam,
305
                                               TopPBeamTopK,
306
                                               src + bid * vocab_size,
307
                                               &firststep,
308
                                               &is_empty,
309
                                               &max,
310
                                               vocab_size,
311
                                               tid);
312
        BlockReduce<T, MaxLength, BlockSize>(shared_max,
313
                                             topk,
314
                                             beam_max,
315
                                             &beam,
316
                                             &top_num,
317
                                             &count,
318
                                             tid,
319
                                             wid,
320
                                             lane);
321
    }
322
    if (tid == 0) {
323
        count_iter_begin[bid] = count_iter[bid];
324
        float rand_top_p = curand_uniform(state + bid) * top_p_num;
325
        top_ps[bid] = (T)rand_top_p;
326
        float sum_prob = 0.0f;
327
#pragma unroll
328
        for(int i = 0; i < TopPBeamTopK; i++) {
329
            sum_prob += (float)(beam_max[i].v);
330
            if(sum_prob >= rand_top_p) {
331
                count_iter_begin[bid] += 1;
332
                out_id[bid] = (int64_t)beam_max[i].id;
333
                out_val[bid] = beam_max[i].v;
334
                break;
335
            }
336
        }
337
    }
338
}
339

340
__global__ void SetCountIter(int *count_iter, int num) {
341
    int tid = threadIdx.x;
342
    int bid = blockIdx.x;
343
    int idx = bid * blockDim.x + tid;
344
    for (int i = idx; i < num; i += gridDim.x * blockDim.x) {
345
        count_iter[i] = i;
346
    }
347
}
348

349
template <typename T>
350
__global__ void FillIndex(T* indices, T num_rows, T num_cols) {
351
  int col_id = threadIdx.x;
352
  int row_id = blockIdx.x;
353

354
  for (T j = row_id; j < num_rows; j += gridDim.x) {
355
    for (T i = col_id; i < num_cols; i += blockDim.x) {
356
      indices[j * num_cols + i] = i;
357
    }
358
  }
359
}
360

361
struct BlockPrefixCallbackOp {
362
    // Running prefix
363
    float running_total;
364
    // Constructor
365
    __device__ BlockPrefixCallbackOp(float running_total): running_total(running_total) {}
366
    // Callback operator to be entered by the first warp of threads in the block.
367
    // Thread-0 is responsible for returning a value for seeding the block-wide scan.
368
    __device__ float operator()(float block_aggregate)
369
    {
370
        float old_prefix = running_total;
371
        running_total += block_aggregate;
372
        return old_prefix;
373
    }
374
};
375

376
template <typename T, int BLOCK_SIZE>
377
__global__ void topp_sampling(T *sorted_probs,
378
                              int64_t *sorted_id,
379
                              T *out_val,
380
                              int64_t *out_id,
381
                              const T *top_ps,
382
                              int p_num,
383
                              int vocab_size,
384
                              int *count_iter,
385
                              int *count_iter_begin) {
386
    __shared__ int stop_shared;
387
    __shared__ float rand_p;
388
    const int tid = threadIdx.x;
389
    const int bid = blockIdx.x;
390
    constexpr int WARP_SIZE = 32;
391
    constexpr int NUM_WARPS = BLOCK_SIZE / WARP_SIZE;
392
    const int lane_id = tid % WARP_SIZE;
393
    const int warp_id = tid / WARP_SIZE;
394
    const float p_t = (float)top_ps[bid];
395
    if (tid == 0) {
396
        stop_shared = 0;
397
        rand_p = p_t;
398
    }
399
    if (count_iter_begin[bid] == count_iter[bid + 1]) {
400
        // topk
401
        return;
402
    }
403

404
    typedef cub::BlockScan<float, BLOCK_SIZE>  BlockScan;
405
    __shared__ typename BlockScan::TempStorage temp_storage;
406
    __shared__ uint32_t selected_shared[NUM_WARPS];
407

408
    // Initialize running total
409
    BlockPrefixCallbackOp prefix_op(0);
410

411
    if (lane_id == 0) {
412
        selected_shared[warp_id] = 0;
413
    }
414
    __syncthreads();
415

416
    int offset = bid * vocab_size;
417
    int end = ((vocab_size + BLOCK_SIZE - 1) / BLOCK_SIZE) * BLOCK_SIZE;
418
    int i_activate = 0;
419
    float thread_offset = 0;
420
    for (int i = tid; i < end; i += BLOCK_SIZE) {
421
        float thread_count = (i < vocab_size) ? (float)sorted_probs[offset + i] : 0.f;
422
        BlockScan(temp_storage).InclusiveSum(thread_count, thread_offset, prefix_op);
423
    
424
        uint32_t activate_mask = __ballot_sync(FINAL_MASK, rand_p <= thread_offset);
425
        
426
        i_activate = i;
427
        if (activate_mask != 0) {
428
            if (lane_id == 0) {
429
                atomicAdd(&stop_shared, 1);
430
                selected_shared[warp_id] = activate_mask;
431
            }
432
        }
433
        __syncthreads();
434
        if(stop_shared > 0) {
435
            break;
436
        }
437
    }
438

439
    bool skip = (selected_shared[warp_id] > 0) ? false : true;
440
    for (int i=0; i < warp_id; i++) {
441
        if(selected_shared[i] != 0) {
442
            skip = true;
443
        }
444
    }
445
    if (!skip) {
446
        int active_lane_id = WARP_SIZE - __popc(selected_shared[warp_id]); // first not 0
447
        if (lane_id == active_lane_id) {
448
            // printf("active_lane_id: %d, i_activate: %d.\n", active_lane_id, i_activate);
449
            // for (int i=0; i < active_lane_id; i++) {
450
            //   printf("p %d, value: %f\n", i, (float)(sorted_probs[offset + i]));
451
            // }
452
            out_id[bid] = sorted_id[offset + i_activate];
453
            out_val[bid] = sorted_probs[offset + i_activate];
454
        }
455
    }
456
}
457

458
int GetBlockSize(int vocab_size) {
459
    if (vocab_size > 512) {
460
        return 1024;
461
    } else if (vocab_size > 256) {
462
        return 512;
463
    } else if (vocab_size > 128) {
464
        return 256;
465
    } else if (vocab_size > 64) {
466
        return 128;
467
    } else {
468
        return 64;
469
    }
470
}
471

472
template <typename T>
473
__global__ void print_kernel(T *input, int size) {
474
  printf("[");
475
  for (int i=0; i < size; i++) {
476
    if (i != size-1) {
477
      printf("%f, ", (float)input[i]);
478
    } else {
479
      printf("%f]\n", (float)input[i]);
480
    }
481
  }
482
}
483

484
template <paddle::DataType D>
485
std::vector<paddle::Tensor> top_p_sampling_kernel(const paddle::Tensor& x, const paddle::Tensor& top_ps, int random_seed) {
486
    typedef PDTraits<D> traits_;
487
    typedef typename traits_::DataType DataType_;
488
    typedef typename traits_::data_t data_t;
489
    std::vector<int64_t> shape = x.shape();
490
    auto cu_stream = x.stream();
491

492
    int bs = shape[0];
493
    int p_num = top_ps.numel();
494
    PD_CHECK(bs == p_num, "PD_CHECK returns ", false, ", expected bs == p_num.");
495
    int vocab_size = shape[1];
496
    auto topp_ids = paddle::full({bs, 1}, 1, paddle::DataType::INT64, x.place());
497
    auto topp_probs = paddle::full({bs, 1}, 1, x.dtype(), x.place());
498
    auto inds_input = paddle::full({bs, vocab_size}, 1, paddle::DataType::INT64, x.place());
499
    auto sorted_out = paddle::full({bs, vocab_size}, 1, x.dtype(), x.place());
500
    auto sorted_id = paddle::full({bs, vocab_size}, 1, paddle::DataType::INT64, x.place());
501
    
502

503
    int BlockSize = GetBlockSize(vocab_size);
504
    switch (BlockSize) {
505
        FIXED_BLOCK_DIM(FillIndex<int64_t><<<bs, kBlockDim, 0, cu_stream>>>(inds_input.data<int64_t>(), bs, vocab_size));
506
        default:
507
            PD_THROW("the input data shape has error in the FillIndex kernel.");
508
    }
509

510
    
511
    static int count = 0;
512
    static curandState_t* dev_curand_states;
513
    if (count == 0) {
514
#if CUDA_VERSION >= 11020
515
      cudaMallocAsync(&dev_curand_states, bs * sizeof(curandState_t), cu_stream);
516
#else
517
      cudaMalloc(&dev_curand_states, bs * sizeof(curandState_t));
518
#endif
519
    }
520
    srand((unsigned int)(time(NULL)));
521
    setup_kernel<<<1, 256, 0, cu_stream>>>(dev_curand_states, rand() % random_seed, bs);
522
    PD_CHECK(bs == p_num, "PD_CHECK returns ", false, ", expected bs == p_num.");
523

524
    auto count_iter = paddle::empty({bs + 1}, paddle::DataType::INT32, x.place());
525
    auto count_iter_begin = paddle::empty({bs}, paddle::DataType::INT32, x.place());
526
    SetCountIter<<<1, 256, 0, cu_stream>>>(count_iter.data<int>(), bs + 1);
527

528
    constexpr int TopKMaxLength = 1;
529
    constexpr int TopPBeamTopK = 1;
530
    switch (BlockSize) {
531
        FIXED_BLOCK_DIM(
532
            KeMatrixTopPBeamTopK<DataType_, TopKMaxLength, TopPBeamTopK, kBlockDim><<<bs, kBlockDim, 0, cu_stream>>>(
533
                reinterpret_cast<DataType_*>(const_cast<data_t*>(x.data<data_t>())),
534
                reinterpret_cast<DataType_*>(const_cast<data_t*>(top_ps.data<data_t>())),
535
                topp_ids.data<int64_t>(),
536
                reinterpret_cast<DataType_*>(topp_probs.data<data_t>()),
537
                vocab_size,
538
                dev_curand_states,
539
                count_iter.data<int>(),
540
                count_iter_begin.data<int>()));
541
        default:
542
            PD_THROW("the input data shape has error in the topp_beam_topk kernel.");
543
    }
544
//     if (count % random_seed == random_seed - 1) {
545
// #if CUDA_VERSION >= 11020
546
//       cudaFreeAsync(dev_curand_states, cu_stream);
547
// #else
548
//       cudaFree(dev_curand_states);
549
// #endif
550
//     }
551
    count++;
552

553
    size_t temp_storage_bytes = 0;
554

555
    cub::TransformInputIterator<int, SegmentOffsetIter, int*>
556
        segment_offsets_t_begin(count_iter_begin.data<int>(),
557
                                SegmentOffsetIter(vocab_size));
558

559
    cub::TransformInputIterator<int, SegmentOffsetIter, int*>
560
        segment_offsets_t_end(count_iter.data<int>(),
561
                              SegmentOffsetIter(vocab_size));
562
    
563
    DataType_ *x_ptr = reinterpret_cast<DataType_*>(const_cast<data_t*>(x.data<data_t>()));
564
    DataType_ *sorted_out_ptr = reinterpret_cast<DataType_*>(const_cast<data_t*>(sorted_out.data<data_t>()));
565
    int64_t *in_id_ptr = inds_input.data<int64_t>();
566
    int64_t *out_id_ptr = sorted_id.data<int64_t>();
567

568
    cub::DeviceSegmentedRadixSort::SortPairsDescending(nullptr,
569
                                                       temp_storage_bytes,
570
                                                       x_ptr,
571
                                                       sorted_out_ptr,
572
                                                       in_id_ptr,
573
                                                       out_id_ptr,
574
                                                       vocab_size * bs,
575
                                                       bs,
576
                                                       segment_offsets_t_begin,
577
                                                       segment_offsets_t_end + 1,
578
                                                       0,
579
                                                       sizeof(data_t) * 8,
580
                                                       cu_stream);
581

582
    temp_storage_bytes = div_up(temp_storage_bytes, 256) * 256;
583
    int64_t temp_size = temp_storage_bytes;
584
    auto temp_storage = paddle::empty({temp_size}, paddle::DataType::UINT8, x.place());
585

586
    cub::DeviceSegmentedRadixSort::SortPairsDescending(
587
        temp_storage.data<uint8_t>(),
588
        temp_storage_bytes,
589
        x_ptr,
590
        sorted_out_ptr,
591
        in_id_ptr,
592
        out_id_ptr,
593
        vocab_size * bs,
594
        bs,
595
        segment_offsets_t_begin,
596
        segment_offsets_t_end + 1,
597
        0,
598
        sizeof(data_t) * 8,
599
        cu_stream);
600

601
    switch (BlockSize) {
602
      FIXED_BLOCK_DIM(
603
          topp_sampling<DataType_, kBlockDim><<<bs, kBlockDim, 0, cu_stream>>>(
604
              sorted_out_ptr,
605
              out_id_ptr,
606
              reinterpret_cast<DataType_*>(topp_probs.data<data_t>()),
607
              topp_ids.data<int64_t>(),
608
              reinterpret_cast<DataType_*>(const_cast<data_t*>(top_ps.data<data_t>())),
609
              p_num,
610
              vocab_size,
611
              count_iter.data<int>(),
612
              count_iter_begin.data<int>()));
613
      default:
614
          PD_THROW("the input data shape has error in the topp_sampling kernel.");
615
    }
616
    return {topp_probs, topp_ids};
617
}
618

619

620
std::vector<paddle::Tensor> TopPSampling(const paddle::Tensor& x, const paddle::Tensor& top_ps, int random_seed) {
621
    switch (x.type()) {
622
        case paddle::DataType::FLOAT16: {
623
            return top_p_sampling_kernel<paddle::DataType::FLOAT16>(
624
                x,
625
                top_ps,
626
                random_seed
627
            );
628
        }
629
        case paddle::DataType::FLOAT32: {
630
            return top_p_sampling_kernel<paddle::DataType::FLOAT32>(
631
                x,
632
                top_ps,
633
                random_seed
634
            );
635
        }
636
        default: {
637
            PD_THROW(
638
                "NOT supported data type. "
639
                "Only float16 and float32 are supported. ");
640
            break;
641
        }
642
    }
643
}
644

645
std::vector<std::vector<int64_t>> TopPSamplingInferShape(const std::vector<int64_t>& x_shape,
646
                                                         const std::vector<int64_t>& top_ps_shape) {
647
    std::vector<int64_t> out_probs_shape = {x_shape[0], 1};                                                          
648
    std::vector<int64_t> out_ids_shape = {x_shape[0], 1};
649
    return {out_probs_shape, out_ids_shape};
650
}
651

652
std::vector<paddle::DataType> TopPSamplingInferDtype(const paddle::DataType& x_dtype,
653
                                                     const paddle::DataType& top_ps_dtype) {
654
    return {x_dtype, paddle::DataType::INT64};
655
}
656

657
PD_BUILD_OP(topp_sampling)
658
    .Inputs({"x", "top_ps"})
659
    .Outputs({"topp_probs", "topp_ids"})
660
    .Attrs({"random_seed: int"})
661
    .SetKernelFn(PD_KERNEL(TopPSampling))
662
    .SetInferShapeFn(PD_INFER_SHAPE(TopPSamplingInferShape))
663
    .SetInferDtypeFn(PD_INFER_DTYPE(TopPSamplingInferDtype));

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.