pytorch

Форк
0
/
byte_order.cpp 
485 строк · 12.0 Кб
1
#include <c10/util/BFloat16.h>
2
#include <c10/util/irange.h>
3
#include <torch/csrc/utils/byte_order.h>
4

5
#include <cstring>
6
#include <vector>
7

8
#if defined(_MSC_VER)
9
#include <stdlib.h>
10
#endif
11

12
namespace {
13

14
static inline void swapBytes16(void* ptr) {
15
  // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
16
  uint16_t output;
17
  memcpy(&output, ptr, sizeof(uint16_t));
18
#if defined(_MSC_VER) && !defined(_DEBUG)
19
  output = _byteswap_ushort(output);
20
#elif defined(__llvm__) || defined(__GNUC__) && !defined(__ICC)
21
  output = __builtin_bswap16(output);
22
#else
23
  uint16_t Hi = output >> 8;
24
  uint16_t Lo = output << 8;
25
  output = Hi | Lo;
26
#endif
27
  memcpy(ptr, &output, sizeof(uint16_t));
28
}
29

30
static inline void swapBytes32(void* ptr) {
31
  // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
32
  uint32_t output;
33
  memcpy(&output, ptr, sizeof(uint32_t));
34
#if defined(_MSC_VER) && !defined(_DEBUG)
35
  output = _byteswap_ulong(output);
36
#elif defined(__llvm__) || defined(__GNUC__) && !defined(__ICC)
37
  output = __builtin_bswap32(output);
38
#else
39
  uint32_t Byte0 = output & 0x000000FF;
40
  uint32_t Byte1 = output & 0x0000FF00;
41
  uint32_t Byte2 = output & 0x00FF0000;
42
  uint32_t Byte3 = output & 0xFF000000;
43
  output = (Byte0 << 24) | (Byte1 << 8) | (Byte2 >> 8) | (Byte3 >> 24);
44
#endif
45
  memcpy(ptr, &output, sizeof(uint32_t));
46
}
47

48
static inline void swapBytes64(void* ptr) {
49
  // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
50
  uint64_t output;
51
  memcpy(&output, ptr, sizeof(uint64_t));
52
#if defined(_MSC_VER)
53
  output = _byteswap_uint64(output);
54
#elif defined(__llvm__) || defined(__GNUC__) && !defined(__ICC)
55
  output = __builtin_bswap64(output);
56
#else
57
  uint64_t Byte0 = output & 0x00000000000000FF;
58
  uint64_t Byte1 = output & 0x000000000000FF00;
59
  uint64_t Byte2 = output & 0x0000000000FF0000;
60
  uint64_t Byte3 = output & 0x00000000FF000000;
61
  uint64_t Byte4 = output & 0x000000FF00000000;
62
  uint64_t Byte5 = output & 0x0000FF0000000000;
63
  uint64_t Byte6 = output & 0x00FF000000000000;
64
  uint64_t Byte7 = output & 0xFF00000000000000;
65
  output = (Byte0 << (7 * 8)) | (Byte1 << (5 * 8)) | (Byte2 << (3 * 8)) |
66
      (Byte3 << (1 * 8)) | (Byte7 >> (7 * 8)) | (Byte6 >> (5 * 8)) |
67
      (Byte5 >> (3 * 8)) | (Byte4 >> (1 * 8));
68
#endif
69
  memcpy(ptr, &output, sizeof(uint64_t));
70
}
71

72
static inline uint16_t decodeUInt16(const uint8_t* data) {
73
  // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
74
  uint16_t output;
75
  memcpy(&output, data, sizeof(uint16_t));
76
  return output;
77
}
78

79
static inline uint16_t decodeUInt16ByteSwapped(const uint8_t* data) {
80
  uint16_t output = decodeUInt16(data);
81
  swapBytes16(&output);
82
  return output;
83
}
84

85
static inline uint32_t decodeUInt32(const uint8_t* data) {
86
  // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
87
  uint32_t output;
88
  memcpy(&output, data, sizeof(uint32_t));
89
  return output;
90
}
91

92
static inline uint32_t decodeUInt32ByteSwapped(const uint8_t* data) {
93
  uint32_t output = decodeUInt32(data);
94
  swapBytes32(&output);
95
  return output;
96
}
97

98
static inline uint64_t decodeUInt64(const uint8_t* data) {
99
  // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
100
  uint64_t output;
101
  memcpy(&output, data, sizeof(uint64_t));
102
  return output;
103
}
104

105
static inline uint64_t decodeUInt64ByteSwapped(const uint8_t* data) {
106
  uint64_t output = decodeUInt64(data);
107
  swapBytes64(&output);
108
  return output;
109
}
110

111
} // anonymous namespace
112

113
namespace torch {
114
namespace utils {
115

116
THPByteOrder THP_nativeByteOrder() {
117
  uint32_t x = 1;
118
  return *(uint8_t*)&x ? THP_LITTLE_ENDIAN : THP_BIG_ENDIAN;
119
}
120

121
void THP_decodeInt16Buffer(
122
    int16_t* dst,
123
    const uint8_t* src,
124
    bool do_byte_swap,
125
    size_t len) {
126
  for (const auto i : c10::irange(len)) {
127
    dst[i] = (int16_t)(do_byte_swap ? decodeUInt16ByteSwapped(src)
128
                                    : decodeUInt16(src));
129
    src += sizeof(int16_t);
130
  }
131
}
132

133
void THP_decodeInt32Buffer(
134
    int32_t* dst,
135
    const uint8_t* src,
136
    bool do_byte_swap,
137
    size_t len) {
138
  for (const auto i : c10::irange(len)) {
139
    dst[i] = (int32_t)(do_byte_swap ? decodeUInt32ByteSwapped(src)
140
                                    : decodeUInt32(src));
141
    src += sizeof(int32_t);
142
  }
143
}
144

145
void THP_decodeInt64Buffer(
146
    int64_t* dst,
147
    const uint8_t* src,
148
    bool do_byte_swap,
149
    size_t len) {
150
  for (const auto i : c10::irange(len)) {
151
    dst[i] = (int64_t)(do_byte_swap ? decodeUInt64ByteSwapped(src)
152
                                    : decodeUInt64(src));
153
    src += sizeof(int64_t);
154
  }
155
}
156

157
void THP_decodeHalfBuffer(
158
    c10::Half* dst,
159
    const uint8_t* src,
160
    bool do_byte_swap,
161
    size_t len) {
162
  for (const auto i : c10::irange(len)) {
163
    // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
164
    union {
165
      uint16_t x;
166
      c10::Half f;
167
    };
168
    x = (do_byte_swap ? decodeUInt16ByteSwapped(src) : decodeUInt16(src));
169
    dst[i] = f;
170
    src += sizeof(uint16_t);
171
  }
172
}
173

174
void THP_decodeBFloat16Buffer(
175
    at::BFloat16* dst,
176
    const uint8_t* src,
177
    bool do_byte_swap,
178
    size_t len) {
179
  for (const auto i : c10::irange(len)) {
180
    uint16_t x =
181
        (do_byte_swap ? decodeUInt16ByteSwapped(src) : decodeUInt16(src));
182
    std::memcpy(&dst[i], &x, sizeof(dst[i]));
183
    src += sizeof(uint16_t);
184
  }
185
}
186

187
void THP_decodeBoolBuffer(
188
    bool* dst,
189
    const uint8_t* src,
190
    bool do_byte_swap,
191
    size_t len) {
192
  for (const auto i : c10::irange(len)) {
193
    dst[i] = (int)src[i] != 0 ? true : false;
194
  }
195
}
196

197
void THP_decodeFloatBuffer(
198
    float* dst,
199
    const uint8_t* src,
200
    bool do_byte_swap,
201
    size_t len) {
202
  for (const auto i : c10::irange(len)) {
203
    // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
204
    union {
205
      uint32_t x;
206
      float f;
207
    };
208
    x = (do_byte_swap ? decodeUInt32ByteSwapped(src) : decodeUInt32(src));
209
    dst[i] = f;
210
    src += sizeof(float);
211
  }
212
}
213

214
void THP_decodeDoubleBuffer(
215
    double* dst,
216
    const uint8_t* src,
217
    bool do_byte_swap,
218
    size_t len) {
219
  for (const auto i : c10::irange(len)) {
220
    // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
221
    union {
222
      uint64_t x;
223
      double d;
224
    };
225
    x = (do_byte_swap ? decodeUInt64ByteSwapped(src) : decodeUInt64(src));
226
    dst[i] = d;
227
    src += sizeof(double);
228
  }
229
}
230

231
void THP_decodeComplexFloatBuffer(
232
    c10::complex<float>* dst,
233
    const uint8_t* src,
234
    bool do_byte_swap,
235
    size_t len) {
236
  for (const auto i : c10::irange(len)) {
237
    // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
238
    union {
239
      uint32_t x;
240
      float re;
241
    };
242
    // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
243
    union {
244
      uint32_t y;
245
      float im;
246
    };
247

248
    x = (do_byte_swap ? decodeUInt32ByteSwapped(src) : decodeUInt32(src));
249
    src += sizeof(float);
250
    y = (do_byte_swap ? decodeUInt32ByteSwapped(src) : decodeUInt32(src));
251
    src += sizeof(float);
252

253
    dst[i] = c10::complex<float>(re, im);
254
  }
255
}
256

257
void THP_decodeComplexDoubleBuffer(
258
    c10::complex<double>* dst,
259
    const uint8_t* src,
260
    bool do_byte_swap,
261
    size_t len) {
262
  for (const auto i : c10::irange(len)) {
263
    // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
264
    union {
265
      uint64_t x;
266
      double re;
267
    };
268
    // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
269
    union {
270
      uint64_t y;
271
      double im;
272
    };
273
    static_assert(sizeof(uint64_t) == sizeof(double));
274

275
    x = (do_byte_swap ? decodeUInt64ByteSwapped(src) : decodeUInt64(src));
276
    src += sizeof(double);
277
    y = (do_byte_swap ? decodeUInt64ByteSwapped(src) : decodeUInt64(src));
278
    src += sizeof(double);
279

280
    dst[i] = c10::complex<double>(re, im);
281
  }
282
}
283

284
void THP_decodeInt16Buffer(
285
    int16_t* dst,
286
    const uint8_t* src,
287
    THPByteOrder order,
288
    size_t len) {
289
  THP_decodeInt16Buffer(dst, src, (order != THP_nativeByteOrder()), len);
290
}
291

292
void THP_decodeInt32Buffer(
293
    int32_t* dst,
294
    const uint8_t* src,
295
    THPByteOrder order,
296
    size_t len) {
297
  THP_decodeInt32Buffer(dst, src, (order != THP_nativeByteOrder()), len);
298
}
299

300
void THP_decodeInt64Buffer(
301
    int64_t* dst,
302
    const uint8_t* src,
303
    THPByteOrder order,
304
    size_t len) {
305
  THP_decodeInt64Buffer(dst, src, (order != THP_nativeByteOrder()), len);
306
}
307

308
void THP_decodeHalfBuffer(
309
    c10::Half* dst,
310
    const uint8_t* src,
311
    THPByteOrder order,
312
    size_t len) {
313
  THP_decodeHalfBuffer(dst, src, (order != THP_nativeByteOrder()), len);
314
}
315

316
void THP_decodeBFloat16Buffer(
317
    at::BFloat16* dst,
318
    const uint8_t* src,
319
    THPByteOrder order,
320
    size_t len) {
321
  THP_decodeBFloat16Buffer(dst, src, (order != THP_nativeByteOrder()), len);
322
}
323

324
void THP_decodeBoolBuffer(
325
    bool* dst,
326
    const uint8_t* src,
327
    THPByteOrder order,
328
    size_t len) {
329
  THP_decodeBoolBuffer(dst, src, (order != THP_nativeByteOrder()), len);
330
}
331

332
void THP_decodeFloatBuffer(
333
    float* dst,
334
    const uint8_t* src,
335
    THPByteOrder order,
336
    size_t len) {
337
  THP_decodeFloatBuffer(dst, src, (order != THP_nativeByteOrder()), len);
338
}
339

340
void THP_decodeDoubleBuffer(
341
    double* dst,
342
    const uint8_t* src,
343
    THPByteOrder order,
344
    size_t len) {
345
  THP_decodeDoubleBuffer(dst, src, (order != THP_nativeByteOrder()), len);
346
}
347

348
void THP_decodeComplexFloatBuffer(
349
    c10::complex<float>* dst,
350
    const uint8_t* src,
351
    THPByteOrder order,
352
    size_t len) {
353
  THP_decodeComplexFloatBuffer(dst, src, (order != THP_nativeByteOrder()), len);
354
}
355

356
void THP_decodeComplexDoubleBuffer(
357
    c10::complex<double>* dst,
358
    const uint8_t* src,
359
    THPByteOrder order,
360
    size_t len) {
361
  THP_decodeComplexDoubleBuffer(
362
      dst, src, (order != THP_nativeByteOrder()), len);
363
}
364

365
void THP_encodeInt16Buffer(
366
    uint8_t* dst,
367
    const int16_t* src,
368
    THPByteOrder order,
369
    size_t len) {
370
  memcpy(dst, src, sizeof(int16_t) * len);
371
  if (order != THP_nativeByteOrder()) {
372
    for (const auto i : c10::irange(len)) {
373
      (void)i;
374
      swapBytes16(dst);
375
      dst += sizeof(int16_t);
376
    }
377
  }
378
}
379

380
void THP_encodeInt32Buffer(
381
    uint8_t* dst,
382
    const int32_t* src,
383
    THPByteOrder order,
384
    size_t len) {
385
  memcpy(dst, src, sizeof(int32_t) * len);
386
  if (order != THP_nativeByteOrder()) {
387
    for (const auto i : c10::irange(len)) {
388
      (void)i;
389
      swapBytes32(dst);
390
      dst += sizeof(int32_t);
391
    }
392
  }
393
}
394

395
void THP_encodeInt64Buffer(
396
    uint8_t* dst,
397
    const int64_t* src,
398
    THPByteOrder order,
399
    size_t len) {
400
  memcpy(dst, src, sizeof(int64_t) * len);
401
  if (order != THP_nativeByteOrder()) {
402
    for (const auto i : c10::irange(len)) {
403
      (void)i;
404
      swapBytes64(dst);
405
      dst += sizeof(int64_t);
406
    }
407
  }
408
}
409

410
void THP_encodeFloatBuffer(
411
    uint8_t* dst,
412
    const float* src,
413
    THPByteOrder order,
414
    size_t len) {
415
  memcpy(dst, src, sizeof(float) * len);
416
  if (order != THP_nativeByteOrder()) {
417
    for (const auto i : c10::irange(len)) {
418
      (void)i;
419
      swapBytes32(dst);
420
      dst += sizeof(float);
421
    }
422
  }
423
}
424

425
void THP_encodeDoubleBuffer(
426
    uint8_t* dst,
427
    const double* src,
428
    THPByteOrder order,
429
    size_t len) {
430
  memcpy(dst, src, sizeof(double) * len);
431
  if (order != THP_nativeByteOrder()) {
432
    for (const auto i : c10::irange(len)) {
433
      (void)i;
434
      swapBytes64(dst);
435
      dst += sizeof(double);
436
    }
437
  }
438
}
439

440
template <typename T>
441
std::vector<T> complex_to_float(const c10::complex<T>* src, size_t len) {
442
  std::vector<T> new_src;
443
  new_src.reserve(2 * len);
444
  for (const auto i : c10::irange(len)) {
445
    auto elem = src[i];
446
    new_src.emplace_back(elem.real());
447
    new_src.emplace_back(elem.imag());
448
  }
449
  return new_src;
450
}
451

452
void THP_encodeComplexFloatBuffer(
453
    uint8_t* dst,
454
    const c10::complex<float>* src,
455
    THPByteOrder order,
456
    size_t len) {
457
  auto new_src = complex_to_float(src, len);
458
  memcpy(dst, static_cast<void*>(&new_src), 2 * sizeof(float) * len);
459
  if (order != THP_nativeByteOrder()) {
460
    for (const auto i : c10::irange(2 * len)) {
461
      (void)i; // Suppress unused variable warning
462
      swapBytes32(dst);
463
      dst += sizeof(float);
464
    }
465
  }
466
}
467

468
void THP_encodeComplexDoubleBuffer(
469
    uint8_t* dst,
470
    const c10::complex<double>* src,
471
    THPByteOrder order,
472
    size_t len) {
473
  auto new_src = complex_to_float(src, len);
474
  memcpy(dst, static_cast<void*>(&new_src), 2 * sizeof(double) * len);
475
  if (order != THP_nativeByteOrder()) {
476
    for (const auto i : c10::irange(2 * len)) {
477
      (void)i; // Suppress unused variable warning
478
      swapBytes64(dst);
479
      dst += sizeof(double);
480
    }
481
  }
482
}
483

484
} // namespace utils
485
} // namespace torch
486

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.