1
#include <c10/util/BFloat16.h>
2
#include <c10/util/irange.h>
3
#include <torch/csrc/utils/byte_order.h>
14
static inline void swapBytes16(void* ptr) {
15
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
17
memcpy(&output, ptr, sizeof(uint16_t));
18
#if defined(_MSC_VER) && !defined(_DEBUG)
19
output = _byteswap_ushort(output);
20
#elif defined(__llvm__) || defined(__GNUC__) && !defined(__ICC)
21
output = __builtin_bswap16(output);
23
uint16_t Hi = output >> 8;
24
uint16_t Lo = output << 8;
27
memcpy(ptr, &output, sizeof(uint16_t));
30
static inline void swapBytes32(void* ptr) {
31
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
33
memcpy(&output, ptr, sizeof(uint32_t));
34
#if defined(_MSC_VER) && !defined(_DEBUG)
35
output = _byteswap_ulong(output);
36
#elif defined(__llvm__) || defined(__GNUC__) && !defined(__ICC)
37
output = __builtin_bswap32(output);
39
uint32_t Byte0 = output & 0x000000FF;
40
uint32_t Byte1 = output & 0x0000FF00;
41
uint32_t Byte2 = output & 0x00FF0000;
42
uint32_t Byte3 = output & 0xFF000000;
43
output = (Byte0 << 24) | (Byte1 << 8) | (Byte2 >> 8) | (Byte3 >> 24);
45
memcpy(ptr, &output, sizeof(uint32_t));
48
static inline void swapBytes64(void* ptr) {
49
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
51
memcpy(&output, ptr, sizeof(uint64_t));
53
output = _byteswap_uint64(output);
54
#elif defined(__llvm__) || defined(__GNUC__) && !defined(__ICC)
55
output = __builtin_bswap64(output);
57
uint64_t Byte0 = output & 0x00000000000000FF;
58
uint64_t Byte1 = output & 0x000000000000FF00;
59
uint64_t Byte2 = output & 0x0000000000FF0000;
60
uint64_t Byte3 = output & 0x00000000FF000000;
61
uint64_t Byte4 = output & 0x000000FF00000000;
62
uint64_t Byte5 = output & 0x0000FF0000000000;
63
uint64_t Byte6 = output & 0x00FF000000000000;
64
uint64_t Byte7 = output & 0xFF00000000000000;
65
output = (Byte0 << (7 * 8)) | (Byte1 << (5 * 8)) | (Byte2 << (3 * 8)) |
66
(Byte3 << (1 * 8)) | (Byte7 >> (7 * 8)) | (Byte6 >> (5 * 8)) |
67
(Byte5 >> (3 * 8)) | (Byte4 >> (1 * 8));
69
memcpy(ptr, &output, sizeof(uint64_t));
72
static inline uint16_t decodeUInt16(const uint8_t* data) {
73
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
75
memcpy(&output, data, sizeof(uint16_t));
79
static inline uint16_t decodeUInt16ByteSwapped(const uint8_t* data) {
80
uint16_t output = decodeUInt16(data);
85
static inline uint32_t decodeUInt32(const uint8_t* data) {
86
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
88
memcpy(&output, data, sizeof(uint32_t));
92
static inline uint32_t decodeUInt32ByteSwapped(const uint8_t* data) {
93
uint32_t output = decodeUInt32(data);
98
static inline uint64_t decodeUInt64(const uint8_t* data) {
99
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
101
memcpy(&output, data, sizeof(uint64_t));
105
static inline uint64_t decodeUInt64ByteSwapped(const uint8_t* data) {
106
uint64_t output = decodeUInt64(data);
107
swapBytes64(&output);
111
} // anonymous namespace
116
THPByteOrder THP_nativeByteOrder() {
118
return *(uint8_t*)&x ? THP_LITTLE_ENDIAN : THP_BIG_ENDIAN;
121
void THP_decodeInt16Buffer(
126
for (const auto i : c10::irange(len)) {
127
dst[i] = (int16_t)(do_byte_swap ? decodeUInt16ByteSwapped(src)
128
: decodeUInt16(src));
129
src += sizeof(int16_t);
133
void THP_decodeInt32Buffer(
138
for (const auto i : c10::irange(len)) {
139
dst[i] = (int32_t)(do_byte_swap ? decodeUInt32ByteSwapped(src)
140
: decodeUInt32(src));
141
src += sizeof(int32_t);
145
void THP_decodeInt64Buffer(
150
for (const auto i : c10::irange(len)) {
151
dst[i] = (int64_t)(do_byte_swap ? decodeUInt64ByteSwapped(src)
152
: decodeUInt64(src));
153
src += sizeof(int64_t);
157
void THP_decodeHalfBuffer(
162
for (const auto i : c10::irange(len)) {
163
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
168
x = (do_byte_swap ? decodeUInt16ByteSwapped(src) : decodeUInt16(src));
170
src += sizeof(uint16_t);
174
void THP_decodeBFloat16Buffer(
179
for (const auto i : c10::irange(len)) {
181
(do_byte_swap ? decodeUInt16ByteSwapped(src) : decodeUInt16(src));
182
std::memcpy(&dst[i], &x, sizeof(dst[i]));
183
src += sizeof(uint16_t);
187
void THP_decodeBoolBuffer(
192
for (const auto i : c10::irange(len)) {
193
dst[i] = (int)src[i] != 0 ? true : false;
197
void THP_decodeFloatBuffer(
202
for (const auto i : c10::irange(len)) {
203
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
208
x = (do_byte_swap ? decodeUInt32ByteSwapped(src) : decodeUInt32(src));
210
src += sizeof(float);
214
void THP_decodeDoubleBuffer(
219
for (const auto i : c10::irange(len)) {
220
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
225
x = (do_byte_swap ? decodeUInt64ByteSwapped(src) : decodeUInt64(src));
227
src += sizeof(double);
231
void THP_decodeComplexFloatBuffer(
232
c10::complex<float>* dst,
236
for (const auto i : c10::irange(len)) {
237
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
242
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
248
x = (do_byte_swap ? decodeUInt32ByteSwapped(src) : decodeUInt32(src));
249
src += sizeof(float);
250
y = (do_byte_swap ? decodeUInt32ByteSwapped(src) : decodeUInt32(src));
251
src += sizeof(float);
253
dst[i] = c10::complex<float>(re, im);
257
void THP_decodeComplexDoubleBuffer(
258
c10::complex<double>* dst,
262
for (const auto i : c10::irange(len)) {
263
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
268
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
273
static_assert(sizeof(uint64_t) == sizeof(double));
275
x = (do_byte_swap ? decodeUInt64ByteSwapped(src) : decodeUInt64(src));
276
src += sizeof(double);
277
y = (do_byte_swap ? decodeUInt64ByteSwapped(src) : decodeUInt64(src));
278
src += sizeof(double);
280
dst[i] = c10::complex<double>(re, im);
284
void THP_decodeInt16Buffer(
289
THP_decodeInt16Buffer(dst, src, (order != THP_nativeByteOrder()), len);
292
void THP_decodeInt32Buffer(
297
THP_decodeInt32Buffer(dst, src, (order != THP_nativeByteOrder()), len);
300
void THP_decodeInt64Buffer(
305
THP_decodeInt64Buffer(dst, src, (order != THP_nativeByteOrder()), len);
308
void THP_decodeHalfBuffer(
313
THP_decodeHalfBuffer(dst, src, (order != THP_nativeByteOrder()), len);
316
void THP_decodeBFloat16Buffer(
321
THP_decodeBFloat16Buffer(dst, src, (order != THP_nativeByteOrder()), len);
324
void THP_decodeBoolBuffer(
329
THP_decodeBoolBuffer(dst, src, (order != THP_nativeByteOrder()), len);
332
void THP_decodeFloatBuffer(
337
THP_decodeFloatBuffer(dst, src, (order != THP_nativeByteOrder()), len);
340
void THP_decodeDoubleBuffer(
345
THP_decodeDoubleBuffer(dst, src, (order != THP_nativeByteOrder()), len);
348
void THP_decodeComplexFloatBuffer(
349
c10::complex<float>* dst,
353
THP_decodeComplexFloatBuffer(dst, src, (order != THP_nativeByteOrder()), len);
356
void THP_decodeComplexDoubleBuffer(
357
c10::complex<double>* dst,
361
THP_decodeComplexDoubleBuffer(
362
dst, src, (order != THP_nativeByteOrder()), len);
365
void THP_encodeInt16Buffer(
370
memcpy(dst, src, sizeof(int16_t) * len);
371
if (order != THP_nativeByteOrder()) {
372
for (const auto i : c10::irange(len)) {
375
dst += sizeof(int16_t);
380
void THP_encodeInt32Buffer(
385
memcpy(dst, src, sizeof(int32_t) * len);
386
if (order != THP_nativeByteOrder()) {
387
for (const auto i : c10::irange(len)) {
390
dst += sizeof(int32_t);
395
void THP_encodeInt64Buffer(
400
memcpy(dst, src, sizeof(int64_t) * len);
401
if (order != THP_nativeByteOrder()) {
402
for (const auto i : c10::irange(len)) {
405
dst += sizeof(int64_t);
410
void THP_encodeFloatBuffer(
415
memcpy(dst, src, sizeof(float) * len);
416
if (order != THP_nativeByteOrder()) {
417
for (const auto i : c10::irange(len)) {
420
dst += sizeof(float);
425
void THP_encodeDoubleBuffer(
430
memcpy(dst, src, sizeof(double) * len);
431
if (order != THP_nativeByteOrder()) {
432
for (const auto i : c10::irange(len)) {
435
dst += sizeof(double);
441
std::vector<T> complex_to_float(const c10::complex<T>* src, size_t len) {
442
std::vector<T> new_src;
443
new_src.reserve(2 * len);
444
for (const auto i : c10::irange(len)) {
446
new_src.emplace_back(elem.real());
447
new_src.emplace_back(elem.imag());
452
void THP_encodeComplexFloatBuffer(
454
const c10::complex<float>* src,
457
auto new_src = complex_to_float(src, len);
458
memcpy(dst, static_cast<void*>(&new_src), 2 * sizeof(float) * len);
459
if (order != THP_nativeByteOrder()) {
460
for (const auto i : c10::irange(2 * len)) {
461
(void)i; // Suppress unused variable warning
463
dst += sizeof(float);
468
void THP_encodeComplexDoubleBuffer(
470
const c10::complex<double>* src,
473
auto new_src = complex_to_float(src, len);
474
memcpy(dst, static_cast<void*>(&new_src), 2 * sizeof(double) * len);
475
if (order != THP_nativeByteOrder()) {
476
for (const auto i : c10::irange(2 * len)) {
477
(void)i; // Suppress unused variable warning
479
dst += sizeof(double);