pytorch

Форк
0
129 строк · 4.5 Кб
1
#include "fp16_fma.h"
2
#include <immintrin.h>
3
#include <cmath>
4
#include <cstdint>
5

6
namespace fake_fp16 {
7

8
// Compute fp16 FMA using fp16
9
// Out = FMA (A, B, Out)
10
//
11
// Algorithm:
12
//  Do an FMA in fp64
13
//  Since fp16 has 10 bits of mantissa and fp64 has 52, zero out
14
//   42 bits.
15
//  Extract the exponent.
16
//  If the exponent ends up in the subnormal range, shift out
17
//  only 42 - (14 + exponent).
18
//  Compute the bounce value as a value that is big enough to
19
//  push all the digits except for the required ones in fp16,
20
//  the objective is to push digits to let the machine do rounding.
21
//  Add 42 or the computed number (in case of denormals) to the exponent.
22
//  For negative numbers set the highest bit of the mantissa to 1.
23
void fma_fp16(int N, const float* A, const float* B, float* Out) {
24
  constexpr int blockSize = 4;
25
  constexpr uint64_t mask = 0x7ff0000000000000;
26
  constexpr uint64_t shift_bits = 52;
27
  constexpr uint64_t offset = 1023;
28
  constexpr uint64_t dbl_threehalf = 0x3ff8000000000000;
29

30
  uint64_t expo_bouncer;
31

32
  // It can be proven than in the absence of intermediate overflow
33
  // the desired numerical result can be obtained even with the
34
  // possibility of a double rounding, as follow.
35
  //    round-to-fp16-precision(   (double)A * (double)B + (double)C  )
36
  // This statement is not proved here; but we explain how to round a fp64
37
  // number into fp16 precision using the technique of a "Bouncer"
38
  // Suppose a numerical value in fp64 has exponent value of E
39
  // If -14 <= E <= 15 (the fp16 exponent value for normalized number),
40
  // the lsb of this value in fp16 precision is 2^(E-10).
41
  // Now consider this fp64 number Bouncer which is 2^(52+(E-10)) * 3/2
42
  // The lsb of Bouncer is (by design) 2^(E-10). Because Bouncer is
43
  // is very much bigger than the fp16 value, denoted by say x,
44
  //          2^(52+(E-10)) < Bouncer + x < 2^(53+(E-10))
45
  // Thus TMP := Bouncer + x  in double precision forces x to be rounded off
46
  // at the lsb position of 2^(E-10).
47
  // Consequently, the subtraction yields the desired result
48
  //          x_fp16_precision := TMP - Bouncer;
49
  // If E < -14, we are dealing with the subnormal number range, there the lsb
50
  // of fp16 precision is FIXED at 2^(-24) (definition of fp16).
51
  // Hence the Bouncer is set at 2^(52-24) = 2^(28)
52

53
  int n = 0;
54
  for (; n + blockSize < N; n += blockSize) {
55
    __m256d mA = _mm256_cvtps_pd(_mm_loadu_ps(A + n));
56
    __m256d mB = _mm256_cvtps_pd(_mm_loadu_ps(B + n));
57
    __m256d mOut = _mm256_cvtps_pd(_mm_loadu_ps(Out + n));
58

59
    mOut = _mm256_fmadd_pd(mA, mB, mOut);
60

61
    __m256i mExpv =
62
        _mm256_and_si256(_mm256_castpd_si256(mOut), _mm256_set1_epi64x(mask));
63
    mExpv = _mm256_srli_epi64(mExpv, shift_bits);
64
    mExpv = _mm256_sub_epi64(mExpv, _mm256_set1_epi64x(offset));
65

66
    __m256i cmp = _mm256_cmpgt_epi64(_mm256_set1_epi64x(-14), mExpv);
67

68
    __m256i mExpoBouncer = _mm256_and_si256(cmp, _mm256_set1_epi64x(28));
69
    mExpoBouncer = _mm256_or_si256(
70
        mExpoBouncer,
71
        _mm256_andnot_si256(
72
            cmp, _mm256_add_epi64(_mm256_set1_epi64x(42), mExpv)));
73

74
    __m256i mBouncer = _mm256_add_epi64(
75
        _mm256_set1_epi64x(dbl_threehalf),
76
        _mm256_slli_epi64(mExpoBouncer, shift_bits));
77

78
    mOut = _mm256_sub_pd(
79
        _mm256_add_pd(_mm256_castsi256_pd(mBouncer), mOut),
80
        _mm256_castsi256_pd(mBouncer));
81

82
    _mm_storeu_ps(Out + n, _mm256_cvtpd_ps(mOut));
83
  }
84
  // Epilogue
85
  for (; n < N; n++) {
86
    typedef union {
87
      uint64_t I;
88
      double F;
89
    } flint64;
90

91
    flint64 A_, B_, Out_, Bouncer;
92
    A_.F = A[n];
93
    B_.F = B[n];
94
    Out_.F = Out[n];
95

96
    // This is FMA in FP64
97
    Out_.F = std::fma(A_.F, B_.F, Out_.F);
98

99
    // We now round Out_.F to fp16 precision using a Bouncer
100

101
    // First, figure out the exponent value E of Out_.F
102
    int64_t expv = ((Out_.I & mask) >> shift_bits) - offset;
103

104
    // Second: create the Bouncer. To do that, we
105
    // first compute its exponent and then add that exponent value
106
    // to the exponent field of the constant 3/2.
107
    if (expv < -14) {
108
      expo_bouncer = 28;
109
    } else {
110
      expo_bouncer = 42 + expv;
111
    }
112
    Bouncer.I = dbl_threehalf + (expo_bouncer << shift_bits);
113

114
    // This is rounding to fp16 precision; add and subtract Bouncer
115
    Out_.F = (Bouncer.F + Out_.F) - Bouncer.F;
116
    Out[n] = Out_.F;
117
  }
118
}
119

120
float fmafp32_avx_emulation(float v1, float v2, float v3) {
121
  __m256 v1Vec = _mm256_set1_ps(v1);
122
  __m256 v2Vec = _mm256_set1_ps(v2);
123
  __m256 v3Vec = _mm256_set1_ps(v3);
124
  __m256 resVec = _mm256_fmadd_ps(v1Vec, v2Vec, v3Vec);
125
  float *result = (float *)&resVec;
126
  return *result;
127
}
128

129
} // namespace fake_fp16
130

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.