onnxruntime

Форк
0
/
cutlass_3.5.0.patch 
100 строк · 3.5 Кб
1
diff --git a/examples/41_fused_multi_head_attention/kernel_forward.h b/examples/41_fused_multi_head_attention/kernel_forward.h
2
index 4c80f549..5ad610c8 100644
3
--- a/examples/41_fused_multi_head_attention/kernel_forward.h
4
+++ b/examples/41_fused_multi_head_attention/kernel_forward.h
5
@@ -189,6 +189,7 @@ struct AttentionKernel {
6

7
     // Scale
8
     accum_t scale = 0.0;
9
+    accum_t softcap = 0.0;
10

11
     // Dimensions/strides
12
     int32_t head_dim = 0;
13
@@ -221,6 +222,8 @@ struct AttentionKernel {
14
     int32_t num_batches = 0;
15
     int32_t num_heads = 0;
16

17
+    bool use_smooth_softmax = false;
18
+
19
     // dropout
20
     bool use_dropout = false;
21
     unsigned long long dropout_batch_head_rng_offset = 0;
22
@@ -818,6 +821,15 @@ struct AttentionKernel {
23
         accum =
24
             cutlass::multiplies<typename MM0::Mma::FragmentC>()(p.scale, accum);
25
       }
26
+
27
+      // apply softcap if applicable
28
+      if (p.softcap > 0.0) {
29
+        accum = cutlass::multiplies<typename MM0::Mma::FragmentC>()(1.0 / p.softcap, accum);
30
+        for (int i = 0; i < accum.size(); ++i) {
31
+          accum[i] = cutlass::fast_tanh(accum[i]);
32
+        }
33
+        accum = cutlass::multiplies<typename MM0::Mma::FragmentC>()(p.softcap, accum);
34
+      }
35

36
       // apply attention bias if applicable
37
       if (kSupportsBias && p.attn_bias_ptr != nullptr) {
38
@@ -897,7 +909,8 @@ struct AttentionKernel {
39
           p.num_keys - iter_key_start,
40
           iter_key_start == 0,
41
           iteratorC_tile_offset,
42
-          kSupportsBias ? 1.0f : p.scale);
43
+          kSupportsBias ? 1.0f : p.scale,
44
+          p.use_smooth_softmax);
45

46
       // Output results to shared-memory
47
       int warp_idx_mn_0 = my_warp_id %
48
@@ -1166,7 +1179,8 @@ struct AttentionKernel {
49
       int max_col,
50
       bool is_first,
51
       typename WarpIteratorC::TensorCoord const& tile_offset,
52
-      float scaling) {
53
+      float scaling,
54
+      bool use_smooth_softmax) {
55
     /* Iterates on the accumulator and corresponding position on result matrix
56

57
     (1) Update `mi[r]` to the max value of the row `r`
58
@@ -1257,7 +1271,7 @@ struct AttentionKernel {
59
       accum_t mi_row, total_row;
60
       LambdaIterator::iterateRows(
61
           lane_offset,
62
-          [&](int accum_m) { mi_row = mi[accum_m]; },
63
+          [&](int accum_m) { mi_row = mi[accum_m];},
64
           [&](int accum_m, int accum_n, int idx) {
65
             frag[idx] =
66
                 (accum_n < max_col) ? exp2f(frag[idx] - mi_row) : accum_t(0.0);
67
@@ -1294,7 +1308,7 @@ struct AttentionKernel {
68
       for (int i = 0; i < MM0::MmaCore::WarpCount::kN; ++i) {
69
         total_row += addition_storage[id + kQueriesPerBlock * i];
70
       }
71
-      s_prime[id] = total_row;
72
+      s_prime[id] = (use_smooth_softmax && (max_col <= kKeysPerBlock)) ? total_row + exp2f(-mi[id]) : total_row;
73
     }
74
   }
75

76
diff --git a/include/cutlass/functional.h b/include/cutlass/functional.h
77
index 964d2ff3..676ba768 100644
78
--- a/include/cutlass/functional.h
79
+++ b/include/cutlass/functional.h
80
@@ -39,6 +39,7 @@
81
 #include "cutlass/numeric_types.h"
82

83
 #include <cuda_runtime.h>
84
+#include <cuda_fp16.h>
85

86
 #if defined(CUTLASS_ARCH_WMMA_ENABLED)
87
 #include <mma.h>
88
@@ -230,8 +231,12 @@ struct inverse_square_root<half_t> {
89
   CUTLASS_HOST_DEVICE
90
   half_t operator()(half_t const &lhs) const {
91
 #if defined(__CUDA_ARCH__)
92
+#if (__CUDA_ARCH__ >= 530)
93
     auto result = hrsqrt(reinterpret_cast<__half const &>(lhs));
94
     return reinterpret_cast<half_t const &>(result);
95
+#else
96
+    return half_t::convert((rsqrtf(half_t::convert(lhs))));
97
+#endif
98
 #else
99
     return half_t(1.f / std::sqrt(half_t::convert(lhs)));
100
 #endif
101

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.