onnxruntime
100 строк · 3.5 Кб
1diff --git a/examples/41_fused_multi_head_attention/kernel_forward.h b/examples/41_fused_multi_head_attention/kernel_forward.h
2index 4c80f549..5ad610c8 100644
3--- a/examples/41_fused_multi_head_attention/kernel_forward.h
4+++ b/examples/41_fused_multi_head_attention/kernel_forward.h
5@@ -189,6 +189,7 @@ struct AttentionKernel {
6
7// Scale
8accum_t scale = 0.0;
9+ accum_t softcap = 0.0;
10
11// Dimensions/strides
12int32_t head_dim = 0;
13@@ -221,6 +222,8 @@ struct AttentionKernel {
14int32_t num_batches = 0;
15int32_t num_heads = 0;
16
17+ bool use_smooth_softmax = false;
18+
19// dropout
20bool use_dropout = false;
21unsigned long long dropout_batch_head_rng_offset = 0;
22@@ -818,6 +821,15 @@ struct AttentionKernel {
23accum =
24cutlass::multiplies<typename MM0::Mma::FragmentC>()(p.scale, accum);
25}
26+
27+ // apply softcap if applicable
28+ if (p.softcap > 0.0) {
29+ accum = cutlass::multiplies<typename MM0::Mma::FragmentC>()(1.0 / p.softcap, accum);
30+ for (int i = 0; i < accum.size(); ++i) {
31+ accum[i] = cutlass::fast_tanh(accum[i]);
32+ }
33+ accum = cutlass::multiplies<typename MM0::Mma::FragmentC>()(p.softcap, accum);
34+ }
35
36// apply attention bias if applicable
37if (kSupportsBias && p.attn_bias_ptr != nullptr) {
38@@ -897,7 +909,8 @@ struct AttentionKernel {
39p.num_keys - iter_key_start,
40iter_key_start == 0,
41iteratorC_tile_offset,
42- kSupportsBias ? 1.0f : p.scale);
43+ kSupportsBias ? 1.0f : p.scale,
44+ p.use_smooth_softmax);
45
46// Output results to shared-memory
47int warp_idx_mn_0 = my_warp_id %
48@@ -1166,7 +1179,8 @@ struct AttentionKernel {
49int max_col,
50bool is_first,
51typename WarpIteratorC::TensorCoord const& tile_offset,
52- float scaling) {
53+ float scaling,
54+ bool use_smooth_softmax) {
55/* Iterates on the accumulator and corresponding position on result matrix
56
57(1) Update `mi[r]` to the max value of the row `r`
58@@ -1257,7 +1271,7 @@ struct AttentionKernel {
59accum_t mi_row, total_row;
60LambdaIterator::iterateRows(
61lane_offset,
62- [&](int accum_m) { mi_row = mi[accum_m]; },
63+ [&](int accum_m) { mi_row = mi[accum_m];},
64[&](int accum_m, int accum_n, int idx) {
65frag[idx] =
66(accum_n < max_col) ? exp2f(frag[idx] - mi_row) : accum_t(0.0);
67@@ -1294,7 +1308,7 @@ struct AttentionKernel {
68for (int i = 0; i < MM0::MmaCore::WarpCount::kN; ++i) {
69total_row += addition_storage[id + kQueriesPerBlock * i];
70}
71- s_prime[id] = total_row;
72+ s_prime[id] = (use_smooth_softmax && (max_col <= kKeysPerBlock)) ? total_row + exp2f(-mi[id]) : total_row;
73}
74}
75
76diff --git a/include/cutlass/functional.h b/include/cutlass/functional.h
77index 964d2ff3..676ba768 100644
78--- a/include/cutlass/functional.h
79+++ b/include/cutlass/functional.h
80@@ -39,6 +39,7 @@
81#include "cutlass/numeric_types.h"
82
83#include <cuda_runtime.h>
84+#include <cuda_fp16.h>
85
86#if defined(CUTLASS_ARCH_WMMA_ENABLED)
87#include <mma.h>
88@@ -230,8 +231,12 @@ struct inverse_square_root<half_t> {
89CUTLASS_HOST_DEVICE
90half_t operator()(half_t const &lhs) const {
91#if defined(__CUDA_ARCH__)
92+#if (__CUDA_ARCH__ >= 530)
93auto result = hrsqrt(reinterpret_cast<__half const &>(lhs));
94return reinterpret_cast<half_t const &>(result);
95+#else
96+ return half_t::convert((rsqrtf(half_t::convert(lhs))));
97+#endif
98#else
99return half_t(1.f / std::sqrt(half_t::convert(lhs)));
100#endif
101