pytorch

Форк
0
/
serialization.cpp 
410 строк · 12.9 Кб
1
#include <torch/csrc/python_headers.h>
2
#include <system_error>
3

4
#include <ATen/ops/from_blob.h>
5
#include <c10/core/CPUAllocator.h>
6
#include <torch/csrc/THP.h>
7
#include <torch/csrc/serialization.h>
8

9
template <class io>
10
Py_ssize_t doPartialRead(io fildes, void* buf, size_t nbytes);
11

12
template <class io>
13
Py_ssize_t doPartialWrite(io fildes, void* buf, size_t nbytes);
14

15
static Py_ssize_t doPartialPythonReadBuffered(
16
    PyObject* fildes,
17
    void* buf,
18
    size_t nbytes);
19
static Py_ssize_t doPartialPythonReadInto(
20
    PyObject* fildes,
21
    void* buf,
22
    size_t nbytes);
23
static Py_ssize_t doPartialPythonWrite(
24
    PyObject* fildes,
25
    void* buf,
26
    size_t nbytes);
27

28
template <>
29
Py_ssize_t doPartialRead<int>(int fildes, void* buf, size_t nbytes) {
30
  return read(fildes, buf, nbytes);
31
}
32

33
template <>
34
Py_ssize_t doPartialRead<PyObject*>(
35
    PyObject* fildes,
36
    void* buf,
37
    size_t nbytes) {
38
  // Try to use fildes.readinto() instead of fildes.read()
39
  // because it is more memory efficient.
40
  // TODO: Stop calling PyObject_HasAttrString() in a loop on our read loop
41
  auto has_readinto = PyObject_HasAttrString(fildes, "readinto") == 1;
42
  if (has_readinto) {
43
    return doPartialPythonReadInto(fildes, buf, nbytes);
44
  }
45
  return doPartialPythonReadBuffered(fildes, buf, nbytes);
46
}
47

48
template <>
49
Py_ssize_t doPartialWrite<int>(int fildes, void* buf, size_t nbytes) {
50
  return write(fildes, buf, nbytes);
51
}
52

53
template <>
54
Py_ssize_t doPartialWrite<PyObject*>(
55
    PyObject* fildes,
56
    void* buf,
57
    size_t nbytes) {
58
  return doPartialPythonWrite(fildes, buf, nbytes);
59
}
60

61
static inline bool isUnsupportedOperation() {
62
  THPObjectPtr io(PyImport_ImportModule("io"));
63
  if (!io)
64
    throw python_error();
65
  THPObjectPtr exception(PyObject_GetAttrString(io, "UnsupportedOperation"));
66
  if (!exception)
67
    throw python_error();
68
  return PyErr_ExceptionMatches(exception.get());
69
}
70

71
// Call Python fildes.read(nbytes) and copy it to buf.
72
static inline Py_ssize_t doPartialPythonReadBuffered(
73
    PyObject* fildes,
74
    void* buf,
75
    size_t raw_nbytes) {
76
  // If we request a large amount of data, f.read() will internally try to
77
  // allocate a buffer of that size.  This is counterproductive, because
78
  // it's not the buffer we ultimately want to write the data into.  Read
79
  // less than that and avoid allocating too much extra memory.
80
  // TODO: Maybe 260 KB is a bit small...
81
  const size_t nbytes = std::min<size_t>(raw_nbytes, 262144u); // 2^18 (~260 KB)
82

83
  THPObjectPtr r(PyObject_CallMethod(fildes, "read", "i", nbytes));
84
  if (!r)
85
    throw python_error();
86

87
  auto size = PyBytes_GET_SIZE(r.get());
88
  const void* py_buf = PyBytes_AsString(r.get());
89

90
  // we read EOF
91
  if (size == 0) {
92
    return 0;
93
  }
94

95
  // Slurp it into the buffer we actually want
96
  memcpy(buf, py_buf, size);
97

98
  return size;
99
}
100

101
// Either does fildes.readinto(buf) or fildes.write(buf)
102
static inline Py_ssize_t doPartialPythonIO(
103
    PyObject* fildes,
104
    void* buf,
105
    size_t nbytes,
106
    bool is_read) {
107
  auto rw_flag = is_read ? PyBUF_WRITE : PyBUF_READ;
108
  THPObjectPtr memview(PyMemoryView_FromMemory(
109
      reinterpret_cast<char*>(buf), static_cast<Py_ssize_t>(nbytes), rw_flag));
110
  if (!memview)
111
    throw python_error();
112

113
  std::string method = "write";
114
  if (is_read) {
115
    method = "readinto";
116
  }
117
  THPObjectPtr r(
118
      PyObject_CallMethod(fildes, method.c_str(), "O", memview.get()));
119
  if (r) {
120
    return PyLong_AsSsize_t(r.get());
121
  }
122

123
  // fildes.readinto can return UnsupportedOperation so fall back to
124
  // fildes.read.
125
  if (is_read && isUnsupportedOperation()) {
126
    PyErr_Clear();
127
    return doPartialPythonReadBuffered(fildes, buf, nbytes);
128
  }
129
  throw python_error();
130
}
131

132
// Call Python fildes.readinto(buf)
133
static Py_ssize_t doPartialPythonReadInto(
134
    PyObject* fildes,
135
    void* buf,
136
    size_t nbytes) {
137
  return doPartialPythonIO(fildes, buf, nbytes, /* is_read */ true);
138
}
139

140
// Call Python fildes.write(buf)
141
static Py_ssize_t doPartialPythonWrite(
142
    PyObject* fildes,
143
    void* buf,
144
    size_t nbytes) {
145
  return doPartialPythonIO(fildes, buf, nbytes, /* is_read */ false);
146
}
147

148
// Requires that we read EXACTLY nbytes; fails if we don't.
149
template <typename io>
150
void doRead(io fildes, void* raw_buf, size_t nbytes) {
151
  char* buf = static_cast<char*>(raw_buf);
152
  while (nbytes > 0) {
153
    errno = 0; // doPartialRead may not set errno
154
    // we read in 1GB blocks to avoid bugs on Mac OS X Lion
155
    // see https://github.com/pytorch/pytorch/issues/1031 for more details
156
    Py_ssize_t r =
157
        doPartialRead(fildes, buf, std::min<size_t>(nbytes, 1073741824));
158
    if (r < 0) {
159
      int err = errno;
160
      TORCH_INTERNAL_ASSERT(
161
          err != 0, "read(): impossible! r < 0, but no errno was set");
162
      TORCH_INTERNAL_ASSERT(
163
          err != EAGAIN,
164
          "read(): non-blocking fd ",
165
          fildes,
166
          " read EAGAIN; cowardly refusing to spin-wait");
167
      if (err == EINTR) {
168
        continue;
169
      } else {
170
        AT_ERROR("read(): fd ", fildes, " failed with ", strerror(err));
171
      }
172
    } else if (r == 0) {
173
      break;
174
    }
175
    buf += r;
176
    // This is guaranteed by POSIX, but I just want to be double-sure
177
    // to not underflow a signed integer.
178
    AT_ASSERT(static_cast<size_t>(r) <= nbytes);
179
    nbytes -= r;
180
  }
181
  if (nbytes != 0) {
182
    AT_ERROR(
183
        "unexpected EOF, expected ",
184
        nbytes,
185
        " more bytes. The file might be corrupted.");
186
  }
187
}
188

189
template <typename io>
190
void doWrite(io fildes, void* raw_buf, size_t nbytes) {
191
  char* buf = static_cast<char*>(raw_buf);
192
  while (nbytes > 0) {
193
    errno = 0; // doPartialWrite may not set errno
194
    // we write in 1GB blocks to avoid bugs on Mac OS X Lion
195
    // see https://github.com/pytorch/pytorch/issues/1031 for more details
196
    Py_ssize_t r =
197
        doPartialWrite(fildes, buf, std::min<size_t>(nbytes, 1073741824));
198
    if (r < 0) {
199
      int err = errno;
200
      TORCH_INTERNAL_ASSERT(
201
          err != 0, "write(): impossible! r < 0, but no errno was set");
202
      TORCH_INTERNAL_ASSERT(
203
          err != EAGAIN,
204
          "write(): non-blocking fd ",
205
          fildes,
206
          " read EAGAIN; cowardly refusing to spin-wait");
207
      if (err == EINTR) {
208
        continue;
209
      } else {
210
        AT_ERROR("write(): fd ", fildes, " failed with ", strerror(err));
211
      }
212
    }
213
    buf += r;
214
    AT_ASSERT(static_cast<size_t>(r) <= nbytes);
215
    nbytes -= r;
216
  }
217
}
218

219
// save_save is necessary since the old eager format saved storages as
220
// [size + data], but the v1.5 eager format removes this since size is saved in
221
// the filesize.
222
template <class io>
223
void THPStorage_writeFileRaw(
224
    c10::StorageImpl* self,
225
    io fd,
226
    bool save_size,
227
    uint64_t element_size) {
228
  c10::DeviceGuard guard(self->device());
229
  uint8_t* data{};
230
  at::Tensor cpu_tensor;
231
  size_t size_bytes = self->nbytes();
232
  size_t numel = size_bytes / element_size;
233
  if (self->device_type() == at::kCPU) {
234
    // We are using a mutable pointer here because we're ultimately
235
    // calling into a Python API that requires that, even though it
236
    // won't mutate the data.
237
    data = static_cast<uint8_t*>(self->mutable_data());
238
  } else {
239
    // Here we use a tensor.to() to impl D2H for all non-CPU device.
240
    auto device_tensor = at::from_blob(
241
        self->mutable_data(),
242
        {static_cast<int64_t>(size_bytes)},
243
        {1},
244
        nullptr,
245
        at::device(self->device()).dtype(c10::kByte),
246
        {self->device()});
247
    cpu_tensor = device_tensor.to(at::kCPU);
248
    data = (uint8_t*)cpu_tensor.data_ptr();
249
  }
250
  if (save_size) {
251
    if (torch::utils::THP_nativeByteOrder() ==
252
        torch::utils::THPByteOrder::THP_LITTLE_ENDIAN)
253
      doWrite(fd, &numel, sizeof(int64_t));
254
    else {
255
      int64_t nsize{}; // convert big endian cpu to little endian storage
256
      torch::utils::THP_encodeInt64Buffer(
257
          (uint8_t*)&nsize,
258
          (const int64_t*)&numel,
259
          torch::utils::THPByteOrder::THP_LITTLE_ENDIAN,
260
          1);
261
      doWrite(fd, &nsize, sizeof(int64_t));
262
    }
263
  }
264
  // fast track for bytes and little endian
265
  if (element_size == 1 ||
266
      torch::utils::THP_nativeByteOrder() ==
267
          torch::utils::THPByteOrder::THP_LITTLE_ENDIAN) {
268
    doWrite(fd, data, size_bytes);
269
  } else {
270
    size_t buffer_size = std::min(numel, (size_t)5000);
271
    // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
272
    std::unique_ptr<uint8_t[]> le_buffer(
273
        new uint8_t[buffer_size * element_size]);
274
    for (size_t i = 0; i < numel; i += buffer_size) {
275
      size_t to_convert = std::min(numel - i, buffer_size);
276
      // NOLINTNEXTLINE(bugprone-branch-clone)
277
      if (element_size == 2) {
278
        torch::utils::THP_encodeInt16Buffer(
279
            (uint8_t*)le_buffer.get(),
280
            (const int16_t*)data + i,
281
            torch::utils::THPByteOrder::THP_LITTLE_ENDIAN,
282
            to_convert);
283
      } else if (element_size == 4) {
284
        torch::utils::THP_encodeInt32Buffer(
285
            (uint8_t*)le_buffer.get(),
286
            (const int32_t*)data + i,
287
            torch::utils::THPByteOrder::THP_LITTLE_ENDIAN,
288
            to_convert);
289
      } else if (element_size == 8) {
290
        torch::utils::THP_encodeInt64Buffer(
291
            (uint8_t*)le_buffer.get(),
292
            (const int64_t*)data + i,
293
            torch::utils::THPByteOrder::THP_LITTLE_ENDIAN,
294
            to_convert);
295
      }
296
      doWrite(fd, le_buffer.get(), to_convert * element_size);
297
    }
298
  }
299
}
300

301
template void THPStorage_writeFileRaw<int>(
302
    c10::StorageImpl* self,
303
    int fd,
304
    bool save_size,
305
    uint64_t element_size);
306
template void THPStorage_writeFileRaw<PyObject*>(
307
    c10::StorageImpl* self,
308
    PyObject* fd,
309
    bool save_size,
310
    uint64_t element_size);
311

312
template <class io>
313
c10::intrusive_ptr<c10::StorageImpl> THPStorage_readFileRaw(
314
    io file,
315
    c10::intrusive_ptr<c10::StorageImpl> storage,
316
    uint64_t element_size) {
317
  c10::OptionalDeviceGuard guard;
318
  if (storage.defined()) {
319
    guard.reset_device(storage->device());
320
  }
321
  int64_t size{};
322
  doRead(file, &size, sizeof(int64_t));
323
  if (torch::utils::THP_nativeByteOrder() ==
324
      torch::utils::THPByteOrder::THP_BIG_ENDIAN) {
325
    int64_t tsize = size; // convert little endian storage to big endian cpu
326
    torch::utils::THP_decodeInt64Buffer(&size, (const uint8_t*)&tsize, true, 1);
327
  }
328
  size_t nbytes = element_size * size;
329
  if (!storage.defined()) {
330
    storage = c10::make_intrusive<at::StorageImpl>(
331
        c10::StorageImpl::use_byte_size_t(),
332
        nbytes,
333
        c10::GetDefaultCPUAllocator(),
334
        /*resizable=*/true);
335
  } else {
336
    size_t _storage_nbytes = storage->nbytes();
337
    TORCH_CHECK(
338
        _storage_nbytes == nbytes,
339
        "storage has wrong byte size: expected %ld got %ld",
340
        nbytes,
341
        _storage_nbytes);
342
  }
343

344
  // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
345
  std::unique_ptr<char[]> cpu_data;
346

347
  uint8_t* data{};
348
  if (storage->device_type() == at::kCPU) {
349
    data = static_cast<uint8_t*>(storage->mutable_data());
350
  } else {
351
    // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
352
    cpu_data = std::unique_ptr<char[]>(new char[nbytes]);
353
    data = (uint8_t*)cpu_data.get();
354
  }
355

356
  // fast track for bytes and little endian
357
  if (element_size == 1 ||
358
      torch::utils::THP_nativeByteOrder() ==
359
          torch::utils::THPByteOrder::THP_LITTLE_ENDIAN) {
360
    doRead(file, data, storage->nbytes());
361
  } else {
362
    int64_t buffer_size = std::min(size, (int64_t)5000);
363
    // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
364
    std::unique_ptr<uint8_t[]> le_buffer(
365
        new uint8_t[buffer_size * element_size]);
366

367
    for (int64_t i = 0; i < size; i += buffer_size) {
368
      size_t to_convert = std::min(size - i, buffer_size);
369
      doRead(file, le_buffer.get(), element_size * to_convert);
370

371
      // NOLINTNEXTLINE(bugprone-branch-clone)
372
      if (element_size == 2) {
373
        torch::utils::THP_decodeInt16Buffer(
374
            (int16_t*)data + i, le_buffer.get(), true, to_convert);
375
      } else if (element_size == 4) {
376
        torch::utils::THP_decodeInt32Buffer(
377
            (int32_t*)data + i, le_buffer.get(), true, to_convert);
378
      } else if (element_size == 8) {
379
        torch::utils::THP_decodeInt64Buffer(
380
            (int64_t*)data + i, le_buffer.get(), true, to_convert);
381
      }
382
    }
383
  }
384

385
  if (storage->device_type() != at::kCPU) {
386
    // Here we use a tensor.copy_() to impl H2D for all non-CPU device.
387
    auto cpu_tensor = at::from_blob(
388
        (void*)data,
389
        {static_cast<int64_t>(nbytes)},
390
        at::device(at::kCPU).dtype(c10::kByte));
391
    auto device_tensor = at::from_blob(
392
        storage->mutable_data(),
393
        {static_cast<int64_t>(nbytes)},
394
        {1},
395
        nullptr,
396
        at::device(storage->device()).dtype(c10::kByte),
397
        {storage->device()});
398
    device_tensor.copy_(cpu_tensor);
399
  }
400
  return storage;
401
}
402

403
template c10::intrusive_ptr<c10::StorageImpl> THPStorage_readFileRaw<int>(
404
    int fd,
405
    c10::intrusive_ptr<c10::StorageImpl> storage,
406
    uint64_t element_size);
407
template c10::intrusive_ptr<c10::StorageImpl> THPStorage_readFileRaw<PyObject*>(
408
    PyObject* fd,
409
    c10::intrusive_ptr<c10::StorageImpl> storage,
410
    uint64_t element_size);
411

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.