pytorch
1#include "fp16_fma.h"2#include <immintrin.h>3#include <cmath>4#include <cstdint>5
6namespace fake_fp16 {7
8// Compute fp16 FMA using fp16
9// Out = FMA (A, B, Out)
10//
11// Algorithm:
12// Do an FMA in fp64
13// Since fp16 has 10 bits of mantissa and fp64 has 52, zero out
14// 42 bits.
15// Extract the exponent.
16// If the exponent ends up in the subnormal range, shift out
17// only 42 - (14 + exponent).
18// Compute the bounce value as a value that is big enough to
19// push all the digits except for the required ones in fp16,
20// the objective is to push digits to let the machine do rounding.
21// Add 42 or the computed number (in case of denormals) to the exponent.
22// For negative numbers set the highest bit of the mantissa to 1.
23void fma_fp16(int N, const float* A, const float* B, float* Out) {24constexpr int blockSize = 4;25constexpr uint64_t mask = 0x7ff0000000000000;26constexpr uint64_t shift_bits = 52;27constexpr uint64_t offset = 1023;28constexpr uint64_t dbl_threehalf = 0x3ff8000000000000;29
30uint64_t expo_bouncer;31
32// It can be proven than in the absence of intermediate overflow33// the desired numerical result can be obtained even with the34// possibility of a double rounding, as follow.35// round-to-fp16-precision( (double)A * (double)B + (double)C )36// This statement is not proved here; but we explain how to round a fp6437// number into fp16 precision using the technique of a "Bouncer"38// Suppose a numerical value in fp64 has exponent value of E39// If -14 <= E <= 15 (the fp16 exponent value for normalized number),40// the lsb of this value in fp16 precision is 2^(E-10).41// Now consider this fp64 number Bouncer which is 2^(52+(E-10)) * 3/242// The lsb of Bouncer is (by design) 2^(E-10). Because Bouncer is43// is very much bigger than the fp16 value, denoted by say x,44// 2^(52+(E-10)) < Bouncer + x < 2^(53+(E-10))45// Thus TMP := Bouncer + x in double precision forces x to be rounded off46// at the lsb position of 2^(E-10).47// Consequently, the subtraction yields the desired result48// x_fp16_precision := TMP - Bouncer;49// If E < -14, we are dealing with the subnormal number range, there the lsb50// of fp16 precision is FIXED at 2^(-24) (definition of fp16).51// Hence the Bouncer is set at 2^(52-24) = 2^(28)52
53int n = 0;54for (; n + blockSize < N; n += blockSize) {55__m256d mA = _mm256_cvtps_pd(_mm_loadu_ps(A + n));56__m256d mB = _mm256_cvtps_pd(_mm_loadu_ps(B + n));57__m256d mOut = _mm256_cvtps_pd(_mm_loadu_ps(Out + n));58
59mOut = _mm256_fmadd_pd(mA, mB, mOut);60
61__m256i mExpv =62_mm256_and_si256(_mm256_castpd_si256(mOut), _mm256_set1_epi64x(mask));63mExpv = _mm256_srli_epi64(mExpv, shift_bits);64mExpv = _mm256_sub_epi64(mExpv, _mm256_set1_epi64x(offset));65
66__m256i cmp = _mm256_cmpgt_epi64(_mm256_set1_epi64x(-14), mExpv);67
68__m256i mExpoBouncer = _mm256_and_si256(cmp, _mm256_set1_epi64x(28));69mExpoBouncer = _mm256_or_si256(70mExpoBouncer,71_mm256_andnot_si256(72cmp, _mm256_add_epi64(_mm256_set1_epi64x(42), mExpv)));73
74__m256i mBouncer = _mm256_add_epi64(75_mm256_set1_epi64x(dbl_threehalf),76_mm256_slli_epi64(mExpoBouncer, shift_bits));77
78mOut = _mm256_sub_pd(79_mm256_add_pd(_mm256_castsi256_pd(mBouncer), mOut),80_mm256_castsi256_pd(mBouncer));81
82_mm_storeu_ps(Out + n, _mm256_cvtpd_ps(mOut));83}84// Epilogue85for (; n < N; n++) {86typedef union {87uint64_t I;88double F;89} flint64;90
91flint64 A_, B_, Out_, Bouncer;92A_.F = A[n];93B_.F = B[n];94Out_.F = Out[n];95
96// This is FMA in FP6497Out_.F = std::fma(A_.F, B_.F, Out_.F);98
99// We now round Out_.F to fp16 precision using a Bouncer100
101// First, figure out the exponent value E of Out_.F102int64_t expv = ((Out_.I & mask) >> shift_bits) - offset;103
104// Second: create the Bouncer. To do that, we105// first compute its exponent and then add that exponent value106// to the exponent field of the constant 3/2.107if (expv < -14) {108expo_bouncer = 28;109} else {110expo_bouncer = 42 + expv;111}112Bouncer.I = dbl_threehalf + (expo_bouncer << shift_bits);113
114// This is rounding to fp16 precision; add and subtract Bouncer115Out_.F = (Bouncer.F + Out_.F) - Bouncer.F;116Out[n] = Out_.F;117}118}
119
120float fmafp32_avx_emulation(float v1, float v2, float v3) {121__m256 v1Vec = _mm256_set1_ps(v1);122__m256 v2Vec = _mm256_set1_ps(v2);123__m256 v3Vec = _mm256_set1_ps(v3);124__m256 resVec = _mm256_fmadd_ps(v1Vec, v2Vec, v3Vec);125float *result = (float *)&resVec;126return *result;127}
128
129} // namespace fake_fp16130