pytorch

Форк
0
/
fp16_fma_slow.cc 
540 строк · 14.1 Кб
1
#include <immintrin.h>
2
#include "fp16_fma.h"
3

4
namespace fp16_fma {
5

6
typedef int int16;
7
typedef char int8;
8
typedef unsigned short int bits16;
9
typedef unsigned int bits32;
10
typedef signed char Word8;
11
typedef unsigned char UWord8;
12
typedef signed short Word16;
13
typedef unsigned short UWord16;
14
typedef signed int Word32;
15
typedef unsigned int UWord32;
16
typedef long long Word64;
17
typedef unsigned long long UWord64;
18
typedef unsigned short float16;
19
typedef signed int sbits32;
20
typedef signed short int sbits16;
21

22
typedef char flag;
23

24
#define MAX_U32 (UWord32)0xffffffffL
25
#define MAX_U16 (UWord16)0xffff
26
#define BITMASK_T(typ, w) (((typ)1 << (w)) - 1)
27
#define TESTBIT(x, n) (((x) >> (n)) & 1)
28

29
#define float16_default_nan 0x7E00
30
#define float16_default_nan_pos 0x7E00
31
#define float16_default_nan_neg 0xFE00
32

33
int8 float_exception_flags = 0;
34

35
enum {
36
  float_round_nearest_even = 0,
37
  float_round_down = 1,
38
  float_round_up = 2,
39
  float_round_to_zero = 3
40
};
41

42
int8 float_rounding_mode = float_round_nearest_even;
43
enum { float_tininess_after_rounding = 0, float_tininess_before_rounding = 1 };
44
int float_detect_tininess = float_tininess_after_rounding;
45

46
inline bits16 extractFloat16Frac(float16 a) {
47
  return a & 0x3FF;
48
}
49

50
inline int16 extractFloat16Exp(float16 a) {
51
  return (a >> 10) & 0x1F;
52
}
53

54
inline flag extractFloat16Sign(float16 a) {
55
  return a >> 15;
56
}
57

58
flag float16_is_quiet_nan(float16 a) {
59
  return (0xFC00 <= (bits16)(a << 1));
60
}
61

62
flag float16_is_signaling_nan(float16 a) {
63
  return (((a >> 9) & 0x3F) == 0x3E) && (a & 0x01FF);
64
}
65

66
enum {
67
  float_flag_inexact = 1,
68
  float_flag_divbyzero = 2,
69
  float_flag_underflow = 4,
70
  float_flag_overflow = 8,
71
  float_flag_invalid = 16
72
};
73

74
void float_raise(int8 flags) {
75
  float_exception_flags |= flags;
76
}
77
int pickNaNMulAdd(
78
    flag aIsQNaN,
79
    flag aIsSNaN,
80
    flag bIsQNaN,
81
    flag bIsSNaN,
82
    flag cIsQNaN,
83
    flag cIsSNaN,
84
    flag infzero) {
85
  if (infzero) {
86
    float_raise(float_flag_invalid);
87
    return 2;
88
  }
89

90
  if (cIsSNaN || cIsQNaN) {
91
    return 2;
92
  } else if (bIsSNaN || bIsQNaN) {
93
    return 1;
94
  } else {
95
    return 0;
96
  }
97
}
98

99
inline float16 packFloat16(flag zSign, int16 zExp, bits16 zSig) {
100
  return (((bits16)zSign) << 15) + (((bits16)zExp) << 10) + zSig;
101
}
102

103
float16
104
propagateFloat16MulAddNaN(float16 a, float16 b, float16 c, flag infzero) {
105
  flag aIsQuietNaN, aIsSignalingNaN, bIsQuietNaN, bIsSignalingNaN, cIsQuietNaN,
106
      cIsSignalingNaN;
107
  int selNaN;
108

109
  aIsQuietNaN = float16_is_quiet_nan(a);
110
  aIsSignalingNaN = float16_is_signaling_nan(a);
111
  bIsQuietNaN = float16_is_quiet_nan(b);
112
  bIsSignalingNaN = float16_is_signaling_nan(b);
113
  cIsQuietNaN = float16_is_quiet_nan(c);
114
  cIsSignalingNaN = float16_is_signaling_nan(c);
115

116
  if (aIsSignalingNaN | bIsSignalingNaN | cIsSignalingNaN) {
117
    float_raise(float_flag_invalid);
118
  }
119

120
  selNaN = pickNaNMulAdd(
121
      aIsQuietNaN,
122
      aIsSignalingNaN,
123
      bIsQuietNaN,
124
      bIsSignalingNaN,
125
      cIsQuietNaN,
126
      cIsSignalingNaN,
127
      infzero);
128

129
  switch (selNaN) {
130
    case 0:
131
      return a | (1 << 9);
132
    case 1:
133
      return b | (1 << 9);
134
    case 2:
135
      return c | (1 << 9);
136
    case 3:
137
    default:
138
      return float16_default_nan;
139
  }
140
}
141

142
inline void shift32RightJamming(bits32 a, int16 count, bits32* zPtr) {
143
  bits32 z;
144

145
  if (count == 0) {
146
    z = a;
147
  } else if (count < 32) {
148
    z = (a >> count) | ((a << ((-count) & 31)) != 0);
149
  } else {
150
    z = (a != 0);
151
  }
152
  *zPtr = z;
153
}
154

155
void shift16RightJamming(bits16 a, int16 count, bits16* zPtr) {
156
  bits16 z;
157

158
  if (count == 0) {
159
    z = a;
160
  } else if (count < 16) {
161
    z = (a >> count) | (((a << ((-count) & 15)) & 0xffff) != 0);
162
  } else {
163
    z = (a != 0);
164
  }
165
  *zPtr = z;
166
}
167

168
Word8 GetRound(Word32 fcr) {
169
  Word8 res, round_mode;
170
  round_mode = fcr & 0x3; // lower 2 bits as rounding mode in FCR
171
  res = (round_mode == 3)
172
      ? 1
173
      : ((round_mode == 2)
174
             ? 2
175
             : ((round_mode == 1) ? 3 : 0)); // Translate to float_rounding_mode
176
  return res;
177
}
178

179
Word8 GetException(Word32 fsr) {
180
  Word8 res = 0;
181
  if (TESTBIT(fsr, 7) == 1)
182
    res |= 32; // float_flag_inexact
183
  if (TESTBIT(fsr, 8) == 1)
184
    res |= 16; // float_flag_underflow
185
  if (TESTBIT(fsr, 9) == 1)
186
    res |= 8; // float_flag_overflow
187
  if (TESTBIT(fsr, 10) == 1)
188
    res |= 4; // float_flag_divbyzero
189
  if (TESTBIT(fsr, 11) == 1)
190
    res |= 1; // float_flag_invalid
191
  return res;
192
}
193

194
float16 roundAndPackFloat16(flag zSign, int16 zExp, bits16 zSig) {
195
  int8 roundingMode;
196
  flag roundNearestEven;
197
  int8 roundIncrement, roundBits;
198
  flag isTiny;
199

200
  roundingMode = float_rounding_mode;
201
  roundNearestEven = (roundingMode == float_round_nearest_even);
202
  roundIncrement = 0x8;
203
  if (!roundNearestEven) {
204
    //    if ( ( ! roundNearestEven ) && ( roundingMode !=
205
    //    float_round_ties_away) ) {
206
    if (roundingMode == float_round_to_zero) {
207
      roundIncrement = 0;
208
    } else {
209
      roundIncrement = 0xF;
210
      if (zSign) {
211
        if (roundingMode == float_round_up)
212
          roundIncrement = 0;
213
      } else {
214
        if (roundingMode == float_round_down)
215
          roundIncrement = 0;
216
      }
217
    }
218
  }
219
  roundBits = zSig & 0xF;
220
  if (0x1D <= (bits16)zExp) {
221
    if ((0x1D < zExp) ||
222
        ((zExp == 0x1D) && ((sbits16)(zSig + roundIncrement) < 0))) {
223
      float_raise(float_flag_overflow | float_flag_inexact);
224
      return packFloat16(zSign, 0x1F, 0) - (roundIncrement == 0);
225
    }
226
    if (zExp < 0) {
227
      isTiny = (float_detect_tininess == float_tininess_before_rounding) ||
228
          (zExp < -1) || (zSig + roundIncrement < 0x8000);
229
      shift16RightJamming(zSig, -zExp, &zSig);
230
      zExp = 0;
231
      roundBits = zSig & 0xF;
232

233
      if (isTiny && roundBits)
234
        float_raise(float_flag_underflow);
235
    }
236
  }
237
  if (roundBits)
238
    float_exception_flags |= float_flag_inexact;
239
  zSig = (zSig + roundIncrement) >> 4;
240
  zSig &= ~(((roundBits ^ 0x8) == 0) & roundNearestEven);
241
  if (zSig == 0)
242
    zExp = 0;
243
  return packFloat16(zSign, zExp, zSig);
244
}
245

246
int8 countLeadingZeros32(bits32 a) {
247
  static const int8 countLeadingZerosHigh[] = {
248
      8, 7, 6, 6, 5, 5, 5, 5, 4, 4, 4, 4, 4, 4, 4, 4, 3, 3, 3, 3, 3, 3, 3, 3,
249
      3, 3, 3, 3, 3, 3, 3, 3, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
250
      2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1,
251
      1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
252
      1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
253
      1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
254
      0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
255
      0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
256
      0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
257
      0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
258
      0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
259
  int8 shiftCount;
260

261
  shiftCount = 0;
262
  if (a < 0x10000) {
263
    shiftCount += 16;
264
    a <<= 16;
265
  }
266
  if (a < 0x1000000) {
267
    shiftCount += 8;
268
    a <<= 8;
269
  }
270
  shiftCount += countLeadingZerosHigh[a >> 24];
271
  return shiftCount;
272
}
273

274
void normalizeFloat16Subnormal(bits16 aSig, int16* zExpPtr, bits16* zSigPtr) {
275
  int8 shiftCount;
276

277
  shiftCount = countLeadingZeros32((bits32)aSig) - 16 - 5;
278
  *zSigPtr = aSig << shiftCount;
279
  *zExpPtr = 1 - shiftCount;
280
}
281

282
float16 float16_muladd(float16 a, float16 b, float16 c, flag negate_product) {
283
  flag aSign, bSign, cSign, zSign;
284
  int16 aExp, bExp, cExp, pExp, zExp, expDiff;
285
  bits16 aSig, bSig, cSig;
286
  flag pInf, pZero, pSign;
287
  bits32 pSig32, cSig32, zSig32;
288
  bits16 pSig;
289
  int shiftcount;
290
  flag infzero;
291

292
  /* Extract the sign bit, exponent and significant  */
293
  aSig = extractFloat16Frac(a);
294
  aExp = extractFloat16Exp(a);
295
  aSign = extractFloat16Sign(a);
296

297
  bSig = extractFloat16Frac(b);
298
  bExp = extractFloat16Exp(b);
299
  bSign = extractFloat16Sign(b);
300

301
  cSig = extractFloat16Frac(c);
302
  cExp = extractFloat16Exp(c);
303
  cSign = extractFloat16Sign(c);
304

305
  /* Flag to indicate fusedMultiplyAdd(0, inf,  or fusedMultiplyAdd(inf, 0 c) */
306
  infzero =
307
      ((aExp == 0 && aSig == 0 && bExp == 0x1f && bSig == 0) ||
308
       (aExp == 0x1f && aSig == 0 && bExp == 0 && bSig == 0));
309

310
  /* CASE1: if any input is NaN =>  NaN propagate */
311

312
  /* It is implementation-defined whether the cases of (0,inf,qnan)
313
   * and (inf,0,qnan) raise InvalidOperation or not (and what QNaN
314
   * they return if they do), so we have to hand this information
315
   * off to the target-specific pick-a-NaN routine.
316
   */
317

318
  /* IEEE754 7.2 - Invalid: fusedMultiplyAdd(0, inf, c) or
319
   * fusedMultiplyAdd(inf, 0 , c) unless c is a quiet NaN; If c is a
320
   * quiet NaN then it is implementation defined whether the invalid operation
321
   * exception is signaled.
322
   */
323
  if (((aExp == 0x1f) && aSig) || ((bExp == 0x1f) && bSig) ||
324
      ((cExp == 0x1f) && cSig)) {
325
    return propagateFloat16MulAddNaN(a, b, c, infzero);
326
  }
327

328
  /* Work out the sign and type of the product */
329
  pSign = aSign ^ bSign;
330
  if (negate_product) {
331
    pSign ^= 1;
332
  }
333

334
  /* CASE2: fusedMultiplyAdd(0, inf, c) or fusedMultiplyAdd(inf,0,  c) and c is
335
   * not NaN  => raise invalid */
336
  if (infzero) {
337
    float_raise(float_flag_invalid);
338
    return float16_default_nan;
339
  }
340

341
  pInf = (aExp == 0x1f) || (bExp == 0x1f);
342
  pZero = ((aExp | aSig) == 0) || ((bExp | bSig) == 0);
343

344
  /* CASE3 and CASE4: c is inf, p is number or inf*/
345
  if (cExp == 0x1f) {
346
    if (pInf && (pSign ^ cSign)) {
347
      /* CASE3: addition of opposite-signed infinities => InvalidOperation */
348
      float_raise(float_flag_invalid);
349
      return float16_default_nan;
350
    }
351
    /* CASE4: Otherwise generate an infinity of the same sign */
352
    return packFloat16(cSign, 0x1f, 0);
353
  }
354

355
  /* CASE5: c is number and p is inf */
356
  if (pInf) {
357
    return packFloat16(pSign, 0x1f, 0);
358
  }
359

360
  /* CASE6: c is number, p is zero */
361
  if (pZero) {
362
    if (cExp == 0) {
363
      if (cSig == 0) {
364
        /* Adding two exact zeroes */
365
        if (pSign == cSign) {
366
          zSign = pSign;
367
        } else if (float_rounding_mode == float_round_down) {
368
          zSign = 1;
369
        } else {
370
          zSign = 0;
371
        }
372
        return packFloat16(zSign, 0, 0);
373
      }
374
    }
375
    /* CASE7: Zero plus something non-zero : just return the something */
376
    return c;
377
  }
378

379
  if (aExp == 0) {
380
    normalizeFloat16Subnormal(aSig, &aExp, &aSig);
381
  }
382
  if (bExp == 0) {
383
    normalizeFloat16Subnormal(bSig, &bExp, &bSig);
384
  }
385

386
  /* Calculate the actual result a * b + c */
387

388
  /* NOTE: we subtract 0x7e where float16_mul() subtracts 0x7f
389
   * because we want the true exponent, not the "one-less-than"
390
   * flavour that roundAndPackFloat16() takes.
391
   */
392
  pExp = aExp + bExp - 0xe;
393
  aSig = (aSig | 0x0400) << 4;
394
  bSig = (bSig | 0x0400) << 5;
395
  pSig32 = (bits32)aSig * bSig;
396
  if ((sbits32)(pSig32 << 1) >= 0) {
397
    pSig32 <<= 1;
398
    pExp--;
399
  }
400

401
  zSign = pSign;
402

403
  /* Now pSig32 is the significand of the multiply, with the explicit bit in
404
   * position 30.
405
   */
406
  if (cExp == 0) {
407
    if (!cSig) {
408
      /* Throw out the special case of c being an exact zero now */
409
      shift32RightJamming(pSig32, 16, &pSig32);
410
      pSig = pSig32;
411
      return roundAndPackFloat16(zSign, pExp - 1, pSig);
412
    }
413
    normalizeFloat16Subnormal(cSig, &cExp, &cSig);
414
  }
415

416
  cSig32 = (bits32)cSig << (30 - 10);
417
  cSig32 |= 0x40000000;
418
  expDiff = pExp - cExp;
419

420
  if (pSign == cSign) {
421
    /* Addition */
422
    if (expDiff > 0) {
423
      /* scale c to match p */
424
      shift32RightJamming(cSig32, expDiff, &cSig32);
425
      zExp = pExp;
426
    } else if (expDiff < 0) {
427
      /* scale p to match c */
428
      shift32RightJamming(pSig32, -expDiff, &pSig32);
429
      zExp = cExp;
430
    } else {
431
      /* no scaling needed */
432
      zExp = cExp;
433
    }
434
    /* Add significands and make sure explicit bit ends up in posn 62 */
435
    zSig32 = pSig32 + cSig32;
436
    if ((sbits32)zSig32 < 0) {
437
      shift32RightJamming(zSig32, 1, &zSig32);
438
    } else {
439
      zExp--;
440
    }
441
  } else {
442
    /* Subtraction */
443
    if (expDiff > 0) {
444
      shift32RightJamming(cSig32, expDiff, &cSig32);
445
      zSig32 = pSig32 - cSig32;
446
      zExp = pExp;
447
    } else if (expDiff < 0) {
448
      shift32RightJamming(pSig32, -expDiff, &pSig32);
449
      zSig32 = cSig32 - pSig32;
450
      zExp = cExp;
451
      zSign ^= 1;
452
    } else {
453
      zExp = pExp;
454
      if (cSig32 < pSig32) {
455
        zSig32 = pSig32 - cSig32;
456
      } else if (pSig32 < cSig32) {
457
        zSig32 = cSig32 - pSig32;
458
        zSign ^= 1;
459
      } else {
460
        /* Exact zero */
461
        zSign = 0;
462
        if (float_rounding_mode == float_round_down) {
463
          zSign ^= 1;
464
        }
465
        return packFloat16(zSign, 0, 0);
466
      }
467
    }
468
    --zExp;
469
    /* Normalize to put the explicit bit back into bit 62. */
470
    shiftcount = countLeadingZeros32(zSig32) - 1;
471
    zSig32 <<= shiftcount;
472
    zExp -= shiftcount;
473
  }
474
  shift32RightJamming(zSig32, 16, &zSig32);
475
  return roundAndPackFloat16(zSign, zExp, zSig32);
476
}
477

478
void fp_mac_h(
479
    Word16 d0,
480
    Word16 d1,
481
    Word16 d2,
482
    Word32 negate_product,
483
    Word32 fcr,
484
    Word32 fsr_i,
485
    Word16* res,
486
    Word32* fsr_o) {
487
  // Extract rounding mode from FCR/FSR to softfloat
488
  float_rounding_mode = GetRound(fcr);
489
  float_exception_flags = GetException(fsr_i);
490
  // Call softfloat lib
491
  *res = float16_muladd(d1, d2, d0, negate_product);
492
  //*fsr_o =  PutException(float_exception_flags, fsr_i);
493
}
494

495
void fma16(
496
    const Word16 input,
497
    const Word16 a,
498
    const Word16 b,
499
    const Word32 fcr,
500
    const Word32 fsr_i,
501
    Word16* result,
502
    Word32* fsr_o) {
503
  Word16 res;
504
  Word32 fsr = 0;
505
  // Call fp utility
506
  fp_mac_h(b, input, a, 0, fcr, fsr_i, &res, &fsr);
507
  // Output result
508
  *fsr_o = fsr;
509
  *result = res;
510
}
511

512
float fake_fma_fp16_slow(float v1, float v2, float v3) {
513
  uint32_t fcr_val = 0;
514
  uint32_t fsr_val = 0x00000F80;
515
  uint32_t exception_flags = 0;
516

517
  uint16_t hv1, hv2, hv3, hresult;
518
  hv1 = _cvtss_sh(v1, 0);
519
  hv2 = _cvtss_sh(v2, 0);
520
  hv3 = _cvtss_sh(v3, 0);
521

522
  fma16(
523
      *reinterpret_cast<Word16*>(&hv1),
524
      *reinterpret_cast<Word16*>(&hv2),
525
      *reinterpret_cast<Word16*>(&hv3),
526
      *reinterpret_cast<Word32*>(&fcr_val),
527
      *reinterpret_cast<Word32*>(&fsr_val),
528
      reinterpret_cast<Word16*>(&hresult),
529
      reinterpret_cast<Word32*>(&exception_flags));
530

531
  return _cvtsh_ss(hresult);
532
}
533

534
void fake_fma_fp16_slow(int N, const float* A, const float* B, float* Out) {
535
  for (int n = 0; n < N; n++) {
536
    Out[n] = fake_fma_fp16_slow(A[n], B[n], Out[n]);
537
  }
538
}
539

540
} // namespace fp16_fma
541

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.