pytorch

Форк
0
/
Storage.cpp 
765 строк · 24.0 Кб
1
#include <torch/csrc/python_headers.h>
2
#ifdef _MSC_VER
3
#include <c10/util/win32-headers.h>
4
#endif
5
#include <structmember.h>
6

7
#include <ATen/mps/MPSDevice.h>
8
#include <c10/core/CPUAllocator.h>
9
#include <c10/core/RefcountedDeleter.h>
10
#include <libshm.h>
11
#include <torch/csrc/CudaIPCTypes.h>
12
#include <torch/csrc/Device.h>
13
#include <torch/csrc/DynamicTypes.h>
14
#include <torch/csrc/StorageMethods.h>
15
#include <torch/csrc/StorageSharing.h>
16
#include <torch/csrc/THP.h>
17
#include <torch/csrc/autograd/utils/wrap_outputs.h>
18
#include <torch/csrc/copy_utils.h>
19
#include <torch/csrc/utils/pyobject_preservation.h>
20
#include <torch/csrc/utils/python_arg_parser.h>
21

22
#include <c10/util/intrusive_ptr.h>
23
#include <fmt/format.h>
24

25
template <>
26
void THPPointer<c10::StorageImpl>::free() {
27
  if (ptr) {
28
    c10::raw::intrusive_ptr::decref(ptr);
29
  }
30
}
31

32
PyTypeObject* THPStorageClass = nullptr;
33

34
PyObject* THPStorage_NewWithStorage(
35
    PyTypeObject* type,
36
    c10::Storage _storage,
37
    c10::impl::PyInterpreterStatus status,
38
    bool allow_preexisting_pyobj) {
39
  TORCH_CHECK(
40
      PyType_IsSubtype(type, &THPStorageType),
41
      "Creating a Storage subclass from a class that does not inherit from ",
42
      "Storage is not possible. Make sure your class inherits from Storage.");
43

44
  auto maybe_pyobj = _storage.unsafeGetStorageImpl()->pyobj_slot()->check_pyobj(
45
      getPyInterpreter(), /*ignore_hermetic_tls=*/false);
46
  if (maybe_pyobj.has_value() && maybe_pyobj.value()) {
47
    TORCH_CHECK(
48
        allow_preexisting_pyobj,
49
        "Creating a new Storage subclass ",
50
        type->tp_name,
51
        " but the raw Storage object is already associated to a python object ",
52
        "of type ",
53
        maybe_pyobj.value()->ob_type->tp_name);
54
    PyObject* obj = *maybe_pyobj;
55
    PyTypeObject* obj_type = Py_TYPE(obj);
56
    TORCH_CHECK(
57
        obj_type == type || PyType_IsSubtype(obj_type, type),
58
        "Creating a new Storage subclass ",
59
        type->tp_name,
60
        " but the raw Storage object is already associated to a python object ",
61
        "of type ",
62
        maybe_pyobj.value()->ob_type->tp_name,
63
        " which is not a subclass of the "
64
        "requested type");
65
    return THPStorage_Wrap(std::move(_storage));
66
  }
67

68
  PyObject* obj = type->tp_alloc(type, 0);
69
  TORCH_CHECK(obj, "Failed to allocate a ", type->tp_name, " object");
70

71
  auto s = (THPStorage*)obj;
72

73
  new (&s->cdata) c10::MaybeOwned<c10::Storage>();
74

75
  s->cdata = c10::MaybeOwned<c10::Storage>::owned(std::move(_storage));
76

77
  if (!c10::impl::HermeticPyObjectTLS::get_state()) {
78
    s->is_hermetic = false;
79
    const auto& storage = THPStorage_Unpack(s);
80
    storage.unsafeGetStorageImpl()->pyobj_slot()->init_pyobj(
81
        getPyInterpreter(), obj, status);
82
  } else {
83
    s->is_hermetic = true;
84
  }
85

86
  return obj;
87
}
88

89
// Wraps the c10::Storage with a storage PyObject
90
PyObject* THPStorage_Wrap(c10::Storage storage) {
91
  c10::StorageImpl* storage_impl = storage.unsafeGetStorageImpl();
92
  if (c10::impl::HermeticPyObjectTLS::get_state()) {
93
    return THPStorage_NewWithStorage(
94
        THPStorageClass,
95
        std::move(storage),
96
        c10::impl::PyInterpreterStatus::DEFINITELY_UNINITIALIZED);
97
  }
98
  c10::impl::PyObjectSlot* pyobj_slot = storage_impl->pyobj_slot();
99

100
  // If the StorageImpl has a PyObject that is managed by a different
101
  // interpreter than the current one, create a new StorageImpl that points to
102
  // the same data and then create the Python storage from that.
103
  // NOTE: This is only supposed to happen in MultiPy
104
  if (pyobj_slot->has_pyobj_nonhermetic() &&
105
      !pyobj_slot->check_interpreter(getPyInterpreter())) {
106
    return THPStorage_NewWithStorage(
107
        THPStorageClass,
108
        c10::newStorageImplFromRefcountedDataPtr(storage),
109
        c10::impl::PyInterpreterStatus::DEFINITELY_UNINITIALIZED);
110
  }
111
  c10::optional<PyObject*> maybe_pyobj = pyobj_slot->check_pyobj(
112
      getPyInterpreter(), /*ignore_hermetic_tls=*/false);
113
  c10::impl::PyInterpreterStatus status =
114
      c10::impl::PyInterpreterStatus::TAGGED_BY_US;
115
  if (maybe_pyobj.has_value()) {
116
    auto obj = *maybe_pyobj;
117
    if (obj) {
118
      TORCH_CHECK(
119
          THPStorage_Check(obj),
120
          "Expected a storage type, but got ",
121
          Py_TYPE(obj)->tp_name);
122

123
      if (pyobj_slot->owns_pyobj()) {
124
        pyobj_slot->set_owns_pyobj(false);
125
        reinterpret_cast<THPStorage*>(obj)->cdata =
126
            c10::MaybeOwned<c10::Storage>::owned(std::move(storage));
127
        return obj;
128
      } else {
129
        Py_INCREF(obj);
130
        return obj;
131
      }
132
    }
133
    status = c10::impl::PyInterpreterStatus::TAGGED_BY_US;
134
  } else {
135
    if (storage.use_count() <= 1) {
136
      status = c10::impl::PyInterpreterStatus::DEFINITELY_UNINITIALIZED;
137
    } else {
138
      status = c10::impl::PyInterpreterStatus::MAYBE_UNINITIALIZED;
139
    }
140
  }
141
  return THPStorage_NewWithStorage(THPStorageClass, std::move(storage), status);
142
}
143

144
static bool THPStorage_isPreservable(THPStorage* self) {
145
  if (self->cdata.unsafeIsBorrowed()) {
146
    return false;
147
  }
148
  auto const& storage = THPStorage_Unpack(self);
149

150
  if (self->is_hermetic) {
151
    return false;
152
  }
153

154
  if (storage.unsafeGetStorageImpl()->pyobj_slot()->check_pyobj(
155
          getPyInterpreter(), /*ignore_hermetic_tls=*/true) !=
156
      c10::make_optional((PyObject*)self)) {
157
    return false;
158
  }
159
  if (storage.use_count() <= 1) {
160
    return false;
161
  }
162
  return true;
163
}
164

165
static bool THPStorage_tryPreserve(THPStorage* self) {
166
  if (!THPStorage_isPreservable(self)) {
167
    return false;
168
  }
169

170
  const auto& storage = THPStorage_Unpack(self);
171
  c10::StorageImpl* storage_impl = storage.unsafeGetStorageImpl();
172

173
  auto maybe_pyobj = storage_impl->pyobj_slot()->check_pyobj(
174
      getPyInterpreter(),
175
      /*ignore_hermetic_tls=*/true);
176
  // NOTE: It is possible to just set the PyObjectSlot here, but the point is
177
  // that we should have already set PyObjectSlot when the storage PyObject was
178
  // created.
179
  TORCH_INTERNAL_ASSERT(
180
      maybe_pyobj.has_value(),
181
      "Trying to preserve a Python storage whose PyObjectSlot does not have a PyObject");
182

183
  PyObject* pyobj = *maybe_pyobj;
184

185
  TORCH_CHECK(
186
      THPStorage_Check(pyobj),
187
      "Expected a storage type, but got ",
188
      Py_TYPE(pyobj)->tp_name);
189

190
  TORCH_INTERNAL_ASSERT(
191
      (void*)pyobj == (void*)self,
192
      "Python storage and the PyObject in the internal PyObjectSlot are not at the same address");
193

194
  TORCH_INTERNAL_ASSERT(!storage_impl->pyobj_slot()->owns_pyobj());
195

196
  storage_impl->pyobj_slot()->set_owns_pyobj(true);
197
  Py_INCREF(self);
198

199
  self->cdata = c10::MaybeOwned<c10::Storage>::borrowed(storage);
200
  return true;
201
}
202

203
static void THPStorage_subclass_dealloc(PyObject* self) {
204
  THPStorage* _self = (THPStorage*)self;
205

206
  if (THPStorage_tryPreserve(_self)) {
207
    return;
208
  }
209

210
  // Some subclass of StorageBase could be GC-tracked objects even
211
  // though the base class is not
212
  auto* type = Py_TYPE(self);
213
  if (PyType_HasFeature(type, Py_TPFLAGS_HAVE_GC) != 0) {
214
    PyObject_GC_UnTrack(self);
215
  }
216

217
  bool has_finalizer = type->tp_finalize || type->tp_del;
218

219
  if (type->tp_finalize) {
220
    PyObject_GC_Track(self);
221
    if (PyObject_CallFinalizerFromDealloc(self) < 0) {
222
      // The finalizer has resurrected the PyObject and there is a new Python
223
      // reference to it, so we can just stop deallocating. Read about
224
      // resurrection from `__del__` here:
225
      // https://docs.python.org/3/reference/datamodel.html#object.__del__
226
      return;
227
    }
228
    PyObject_GC_UnTrack(self);
229
  }
230

231
  // base test is unnecessary as THPStorae does not set this
232
  if (type->tp_weaklistoffset) {
233
    PyObject_ClearWeakRefs(self);
234
  }
235

236
  if (type->tp_del) {
237
    PyObject_GC_Track(self);
238
    type->tp_del(self);
239
    if (self->ob_refcnt > 0) {
240
      // Resurrected (see above comment about resurrection from `__del__`)
241
      return;
242
    }
243
    PyObject_GC_UnTrack(self);
244
  }
245

246
  if (has_finalizer) {
247
    /* New weakrefs could be created during the finalizer call.
248
       If this occurs, clear them out without calling their
249
       finalizers since they might rely on part of the object
250
       being finalized that has already been destroyed. */
251
    if (type->tp_weaklistoffset) {
252
      /* Modeled after GET_WEAKREFS_LISTPTR() */
253
      PyWeakReference** list =
254
          (PyWeakReference**)PyObject_GET_WEAKREFS_LISTPTR(self);
255
      while (*list)
256
        _PyWeakref_ClearRef(*list);
257
    }
258
  }
259

260
  // Clear slots
261
  {
262
    PyTypeObject* base = type;
263
    while (base != &THPStorageType) {
264
      if (Py_SIZE(base)) {
265
        clear_slots(base, self);
266
      }
267
      base = base->tp_base;
268
      TORCH_INTERNAL_ASSERT(base);
269
    }
270
  }
271

272
  // Clear __dict__
273
  if (C10_LIKELY(type->tp_dictoffset)) {
274
    PyObject** dictptr = _PyObject_GetDictPtr(self);
275
    if (dictptr != nullptr) {
276
      PyObject* dict = *dictptr;
277
      if (dict != nullptr) {
278
        Py_DECREF(dict);
279
        *dictptr = nullptr;
280
      }
281
    }
282
  }
283

284
  TORCH_INTERNAL_ASSERT(Py_TYPE(self) == type);
285

286
  _self->cdata.~MaybeOwned<c10::Storage>();
287
  Py_TYPE(_self)->tp_free(self);
288

289
  TORCH_INTERNAL_ASSERT(type->tp_flags & Py_TPFLAGS_HEAPTYPE);
290
  Py_DECREF(type);
291
}
292

293
c10::intrusive_ptr<c10::StorageImpl> make_storage_impl(
294
    c10::StorageImpl::use_byte_size_t use_byte_size,
295
    c10::SymInt size_bytes,
296
    c10::Allocator* allocator,
297
    bool resizable,
298
    c10::optional<int64_t> allocator_opt,
299
    c10::optional<at::Device> device_opt) {
300
  at::OptionalDeviceGuard device_guard;
301
  // This will be non-nullptr only when there is a custom StorageImpl
302
  // constructor for the given device
303
  c10::StorageImplCreateHelper fptr = nullptr;
304
  // For directly passing allocator scenarios, only c10::StorageImpl objects can
305
  // be created. If you need to create a storageimpl object of a subclass, you
306
  // need to pass in the device information.
307
  if (allocator_opt.has_value()) {
308
    // NOLINTNEXTLINE(performance-no-int-to-ptr)
309
    allocator = reinterpret_cast<c10::Allocator*>(allocator_opt.value());
310
  } else if (device_opt.has_value()) {
311
    at::Device device = device_opt.value();
312
    // We only need to check this here as this is the only case where we can
313
    // have a device that is not CPU (and thus for which the StorageImpl
314
    // constructor can be overwritten).
315
    fptr = c10::GetStorageImplCreate(device.type());
316
    if (device.type() == at::kCPU) {
317
      allocator = c10::GetDefaultCPUAllocator();
318
#ifdef USE_CUDA
319
    } else if (device.type() == at::kCUDA) {
320
      at::globalContext().lazyInitCUDA();
321
      allocator = c10::cuda::CUDACachingAllocator::get();
322
#endif
323
#ifdef USE_MPS
324
    } else if (device.type() == at::kMPS) {
325
      allocator = at::mps::GetMPSAllocator();
326
#endif
327
      // NOLINTBEGIN(bugprone-branch-clone)
328
    } else if (device.type() == at::DeviceType::XPU) {
329
      allocator = c10::GetAllocator(device.type());
330
    } else if (device.type() == at::DeviceType::HPU) {
331
      allocator = c10::GetAllocator(device.type());
332
    } else if (device.type() == at::DeviceType::Meta) {
333
      allocator = c10::GetAllocator(device.type());
334
    } else if (device.type() == at::DeviceType::PrivateUse1) {
335
      at::globalContext().lazyInitPrivateUse1();
336
      allocator = c10::GetAllocator(device.type());
337
    } else {
338
      // NOLINTEND(bugprone-branch-clone)
339
      TORCH_CHECK(
340
          false,
341
          THPStorageStr,
342
          "(): Storage device not recognized: ",
343
          device.type());
344
    }
345
    device_guard.reset_device(device);
346
  } else {
347
    allocator = c10::GetDefaultCPUAllocator();
348
  }
349

350
  if (fptr != nullptr) {
351
    return fptr(use_byte_size, std::move(size_bytes), allocator, resizable);
352
  }
353

354
  // Create a c10::StorageImpl object.
355
  return c10::make_intrusive<c10::StorageImpl>(
356
      use_byte_size, std::move(size_bytes), allocator, resizable);
357
}
358

359
static PyObject* THPStorage_pynew(
360
    PyTypeObject* type,
361
    PyObject* args,
362
    PyObject* kwargs) {
363
  HANDLE_TH_ERRORS
364
  TORCH_CHECK(
365
      type != &THPStorageType,
366
      "Cannot directly construct StorageBase; subclass it and then construct that");
367
  static torch::PythonArgParser parser({
368
      THPStorageStr "(*, int64_t allocator=None, Device device=None)",
369
      THPStorageStr
370
      "(int64_t size, *, int64_t allocator=None, Device device=None)",
371
      THPStorageStr
372
      "(PyObject* sequence, *, int64_t allocator=None, Device device=None)",
373
  });
374
  torch::ParsedArgs<3> parsed_args;
375
  auto r = parser.parse(args, kwargs, parsed_args);
376

377
  int allocator_arg_idx = 0;
378
  int device_arg_idx = 1;
379

380
  if (r.idx > 0) {
381
    allocator_arg_idx = 1;
382
    device_arg_idx = 2;
383
  }
384

385
  c10::optional<int64_t> allocator_opt = r.toInt64Optional(allocator_arg_idx);
386
  c10::optional<at::Device> device_opt = r.deviceOptional(device_arg_idx);
387

388
  TORCH_CHECK(
389
      !allocator_opt.has_value() || !device_opt.has_value(),
390
      THPStorageStr,
391
      "(): only one or neither of 'allocator' or 'device' can ",
392
      "be given, but not both");
393

394
  PyObject* self = nullptr;
395
  c10::Allocator* allocator = nullptr;
396

397
  // torch.Storage(*, ...)
398
  if (r.idx == 0) {
399
    self = THPStorage_NewWithStorage(
400
        type,
401
        make_storage_impl(
402
            c10::StorageImpl::use_byte_size_t(),
403
            0,
404
            allocator,
405
            /*resizable=*/true,
406
            allocator_opt,
407
            device_opt),
408
        c10::impl::PyInterpreterStatus::DEFINITELY_UNINITIALIZED);
409

410
    // torch.Storage(size, *, ...)
411
  } else if (r.idx == 1) {
412
    int64_t size = r.toInt64(0);
413
    self = THPStorage_NewWithStorage(
414
        type,
415
        make_storage_impl(
416
            c10::StorageImpl::use_byte_size_t(),
417
            size,
418
            allocator,
419
            /*resizable=*/true,
420
            allocator_opt,
421
            device_opt),
422
        c10::impl::PyInterpreterStatus::DEFINITELY_UNINITIALIZED);
423

424
    // torch.Storage(sequence, *, ...)
425
  } else if (r.idx == 2) {
426
    PyObject* sequence = r.pyobject(0);
427
    Py_ssize_t length = PySequence_Length(sequence);
428
    TORCH_CHECK(
429
        PySequence_Check(sequence),
430
        THPStorageStr,
431
        "(): Expected a sequence type, but got ",
432
        THPUtils_typename(sequence));
433
    TORCH_CHECK(
434
        length >= 0,
435
        THPStorageStr,
436
        "(): Could not obtain the length of sequence of type ",
437
        THPUtils_typename(sequence));
438
    self = THPStorage_NewWithStorage(
439
        type,
440
        make_storage_impl(
441
            c10::StorageImpl::use_byte_size_t(),
442
            length,
443
            allocator,
444
            /*resizable=*/true,
445
            allocator_opt,
446
            device_opt),
447
        c10::impl::PyInterpreterStatus::DEFINITELY_UNINITIALIZED);
448
    THPObjectPtr item;
449
    try {
450
      const auto& storage = THPStorage_Unpack(self);
451
      for (Py_ssize_t i = 0; i < length; i++) {
452
        item = PySequence_GetItem(sequence, i);
453
        uint8_t value = THPByteUtils_unpackReal(item.get());
454
        if (allocator == c10::GetDefaultCPUAllocator()) {
455
          static_cast<uint8_t*>(storage.mutable_data())[i] = value;
456
        } else {
457
          // TODO: this might be slow - consider batched updates?
458
          storage_set(storage, i, value);
459
        }
460
      }
461
    } catch (const std::exception& e) {
462
      TORCH_CHECK(
463
          THPStorageStr "(): tried to construct a storage from a sequence (",
464
          THPUtils_typename(sequence),
465
          "), ",
466
          "but one of the items was of type ",
467
          THPUtils_typename(item.get()),
468
          " instead of int");
469
      return nullptr;
470
    }
471
  }
472
  return self;
473
  Py_RETURN_NONE;
474
  END_HANDLE_TH_ERRORS
475
}
476

477
static Py_ssize_t THPStorage_length(THPStorage* self) {
478
  HANDLE_TH_ERRORS
479
  THPStorage_assertNotNull(self);
480
  return static_cast<Py_ssize_t>(THPStorage_Unpack(self).nbytes());
481
  END_HANDLE_TH_ERRORS_RET(-1)
482
}
483

484
static PyObject* THPStorage_get(THPStorage* self, PyObject* index) {
485
  HANDLE_TH_ERRORS
486
  THPStorage_assertNotNull(self);
487
  const auto& storage = THPStorage_Unpack(self);
488
  int64_t len = static_cast<int64_t>(storage.nbytes());
489
  /* Integer index */
490
  if (THPUtils_checkLong(index)) {
491
    int64_t nindex = THPUtils_unpackLong(index);
492
    if (nindex < 0)
493
      nindex += len;
494
    if (nindex < 0 || nindex >= len) {
495
      PyErr_SetString(
496
          PyExc_IndexError,
497
          fmt::format(
498
              "index {} out of range for storage of size {}", nindex, len));
499
      return nullptr;
500
    }
501
    uint8_t value = storage_get(storage, nindex);
502
    return THPByteUtils_newReal(value);
503
    /* Slice index */
504
  } else if (PySlice_Check(index)) {
505
    // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
506
    Py_ssize_t start, stop, slicelength, step;
507
    if (PySlice_Unpack(index, &start, &stop, &step) < 0) {
508
      return nullptr;
509
    }
510
    slicelength = PySlice_AdjustIndices(len, &start, &stop, step);
511
    if (step != 1) {
512
      TORCH_CHECK(
513
          "Trying to slice with a step of ",
514
          step,
515
          ", but only a step of "
516
          "1 is supported");
517
      return nullptr;
518
    }
519

520
    const auto& storage = THPStorage_Unpack(self);
521
    auto data = static_cast<uint8_t*>(storage.mutable_data());
522

523
    at::StorageImpl* old_storage_impl = storage.unsafeGetStorageImpl();
524
    c10::raw::intrusive_ptr::incref(old_storage_impl);
525
    auto new_storage_impl = c10::make_intrusive<at::StorageImpl>(
526
        c10::StorageImpl::use_byte_size_t(),
527
#ifdef THQUANTIZED
528
        slicelength * sizeof(quantized_t),
529
#else
530
        slicelength,
531
#endif
532
        at::DataPtr(
533
            static_cast<void*>(data + start),
534
            old_storage_impl,
535
            [](void* s) {
536
              c10::raw::intrusive_ptr::decref(static_cast<at::StorageImpl*>(s));
537
            },
538
            old_storage_impl->device()),
539
        old_storage_impl->allocator(),
540
        /* resizable */ false);
541

542
    PyObject* _ret = THPStorage_NewWithStorage(
543
        Py_TYPE(self),
544
        std::move(new_storage_impl),
545
        c10::impl::PyInterpreterStatus::DEFINITELY_UNINITIALIZED);
546

547
    return _ret;
548
  }
549
  PyErr_Format(
550
      PyExc_TypeError,
551
      "can't index a " THPStorageStr " with %s",
552
      THPUtils_typename(index));
553
  return nullptr;
554
  END_HANDLE_TH_ERRORS
555
}
556

557
static int THPStorage_set(THPStorage* self, PyObject* index, PyObject* value) {
558
  HANDLE_TH_ERRORS
559
  THPStorage_assertNotNull(self);
560
  if (!THPByteUtils_checkReal(value)) {
561
    TORCH_CHECK(
562
        "can only set storage content with a int types, but got ",
563
        THPUtils_typename(value),
564
        " instead");
565
    return -1;
566
  }
567

568
  uint8_t rvalue = THPByteUtils_unpackReal(value);
569
  const auto& storage = THPStorage_Unpack(self);
570
  if (THPUtils_checkLong(index)) {
571
    int64_t nindex = THPUtils_unpackLong(index);
572
    storage_set(storage, nindex, rvalue);
573
    return 0;
574
  } else if (PySlice_Check(index)) {
575
    // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
576
    Py_ssize_t start, stop, step;
577
    Py_ssize_t len = static_cast<Py_ssize_t>(storage.nbytes());
578
    if (PySlice_Unpack(index, &start, &stop, &step) < 0) {
579
      return -1;
580
    }
581
    PySlice_AdjustIndices(len, &start, &stop, step);
582
    if (step != 1) {
583
      TORCH_CHECK(
584
          "Trying to slice with a step of ",
585
          step,
586
          ", but only a step of "
587
          "1 is supported");
588
      return 0;
589
    }
590
    // TODO: check the bounds only once
591
    // TODO: fill?
592
    for (; start < stop; start++)
593
      storage_set(storage, start, rvalue);
594
    return 0;
595
  }
596
  TORCH_CHECK(
597
      "can't index a " THPStorageStr " with ", THPUtils_typename(index));
598
  return -1;
599
  END_HANDLE_TH_ERRORS_RET(-1)
600
}
601

602
static PyMappingMethods THPStorage_mappingmethods = {
603
    (lenfunc)THPStorage_length,
604
    (binaryfunc)THPStorage_get,
605
    (objobjargproc)THPStorage_set};
606

607
struct THPStorageMeta {
608
  PyHeapTypeObject base;
609
};
610

611
int THPStorageMetaType_init(PyObject* cls, PyObject* args, PyObject* kwargs);
612

613
PyTypeObject THPStorageMetaType = {
614
    PyVarObject_HEAD_INIT(
615
        DEFERRED_ADDRESS(&PyType_Type),
616
        0) "torch._C._StorageMeta", /* tp_name */
617
    sizeof(THPStorageMeta), /* tp_basicsize */
618
    0, /* tp_itemsize */
619
    nullptr, /* tp_dealloc */
620
    0, /* tp_vectorcall_offset */
621
    nullptr, /* tp_getattr */
622
    nullptr, /* tp_setattr */
623
    nullptr, /* tp_reserved */
624
    nullptr, /* tp_repr */
625
    nullptr, /* tp_as_number */
626
    nullptr, /* tp_as_sequence */
627
    nullptr, /* tp_as_mapping */
628
    nullptr, /* tp_hash  */
629
    nullptr, /* tp_call */
630
    nullptr, /* tp_str */
631
    nullptr, /* tp_getattro */
632
    nullptr, /* tp_setattro */
633
    nullptr, /* tp_as_buffer */
634
    // NOLINTNEXTLINE(misc-redundant-expression)
635
    Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /* tp_flags */
636
    nullptr, /* tp_doc */
637
    nullptr, /* tp_traverse */
638
    nullptr, /* tp_clear */
639
    nullptr, /* tp_richcompare */
640
    0, /* tp_weaklistoffset */
641
    nullptr, /* tp_iter */
642
    nullptr, /* tp_iternext */
643
    nullptr, /* tp_methods */
644
    nullptr, /* tp_members */
645
    nullptr, /* tp_getset */
646
    DEFERRED_ADDRESS(&PyType_Type), /* tp_base */
647
    nullptr, /* tp_dict */
648
    nullptr, /* tp_descr_get */
649
    nullptr, /* tp_descr_set */
650
    0, /* tp_dictoffset */
651
    THPStorageMetaType_init, /* tp_init */
652
    nullptr, /* tp_alloc */
653
    nullptr, /* tp_new */
654
};
655

656
// TODO: implement equality
657
PyTypeObject THPStorageType = {
658
    PyVarObject_HEAD_INIT(
659
        &THPStorageMetaType,
660
        0) "torch._C.StorageBase", /* tp_name */
661
    sizeof(THPStorage), /* tp_basicsize */
662
    0, /* tp_itemsize */
663
    nullptr, /* tp_dealloc */
664
    0, /* tp_vectorcall_offset */
665
    nullptr, /* tp_getattr */
666
    nullptr, /* tp_setattr */
667
    nullptr, /* tp_reserved */
668
    nullptr, /* tp_repr */
669
    nullptr, /* tp_as_number */
670
    nullptr, /* tp_as_sequence */
671
    &THPStorage_mappingmethods, /* tp_as_mapping */
672
    nullptr, /* tp_hash  */
673
    nullptr, /* tp_call */
674
    nullptr, /* tp_str */
675
    nullptr, /* tp_getattro */
676
    nullptr, /* tp_setattro */
677
    nullptr, /* tp_as_buffer */
678
    // NOLINTNEXTLINE(misc-redundant-expression)
679
    Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /* tp_flags */
680
    nullptr, /* tp_doc */
681
    nullptr, /* tp_traverse */
682
    nullptr, /* tp_clear */
683
    nullptr, /* tp_richcompare */
684
    0, /* tp_weaklistoffset */
685
    nullptr, /* tp_iter */
686
    nullptr, /* tp_iternext */
687
    nullptr,
688
    /* will be assigned in init */ /* tp_methods */
689
    nullptr,
690
    /* will be assigned in init */ /* tp_members */
691
    nullptr, /* tp_getset */
692
    nullptr, /* tp_base */
693
    nullptr, /* tp_dict */
694
    nullptr, /* tp_descr_get */
695
    nullptr, /* tp_descr_set */
696
    0, /* tp_dictoffset */
697
    nullptr, /* tp_init */
698
    nullptr, /* tp_alloc */
699
    THPStorage_pynew, /* tp_new */
700
};
701

702
int THPStorageMetaType_init(PyObject* cls, PyObject* args, PyObject* kwargs) {
703
  if (PyType_Type.tp_init(cls, args, kwargs) < 0) {
704
    return -1;
705
  }
706
  ((PyTypeObject*)cls)->tp_dealloc = (destructor)THPStorage_subclass_dealloc;
707
  return 0;
708
}
709

710
static PyObject* THPStorage_device(THPStorage* self, void* unused) {
711
  HANDLE_TH_ERRORS
712
  THPStorage_assertNotNull(self);
713
  return THPDevice_New(THPStorage_Unpack(self).device());
714
  END_HANDLE_TH_ERRORS
715
}
716

717
PyObject* THPStorage_get_cdata(THPStorage* self, void* unused) {
718
  HANDLE_TH_ERRORS
719
  return PyLong_FromVoidPtr(THPStorage_Unpack(self).unsafeGetStorageImpl());
720
  END_HANDLE_TH_ERRORS
721
}
722

723
typedef PyObject* (*getter)(PyObject*, void*);
724

725
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables)
726
static struct PyGetSetDef THPStorage_properties[] = {
727
    {"device", (getter)THPStorage_device, nullptr, nullptr, nullptr},
728
    {"_cdata", (getter)THPStorage_get_cdata, nullptr, nullptr, nullptr},
729
    {nullptr}};
730

731
bool THPStorage_init(PyObject* module) {
732
  static std::vector<PyMethodDef> methods;
733
  THPUtils_addPyMethodDefs(methods, THPStorage_getMethods());
734
  THPUtils_addPyMethodDefs(methods, THPStorage_getSharingMethods());
735

736
  THPStorageMetaType.tp_base = &PyType_Type;
737
  if (PyType_Ready(&THPStorageMetaType) < 0)
738
    return false;
739
  Py_INCREF(&THPStorageMetaType);
740
  PyModule_AddObject(module, "_StorageMeta", (PyObject*)&THPStorageMetaType);
741

742
  THPStorageType.tp_methods = methods.data();
743
  THPStorageType.tp_getset = THPStorage_properties;
744
  if (PyType_Ready(&THPStorageType) < 0)
745
    return false;
746
  Py_INCREF(&THPStorageType);
747
  PyModule_AddObject(module, "StorageBase", (PyObject*)&THPStorageType);
748
  return true;
749
}
750

751
void THPStorage_postInit(PyObject* module) {
752
  THPStorageClass =
753
      (PyTypeObject*)PyObject_GetAttrString(module, "UntypedStorage");
754
  if (!THPStorageClass)
755
    throw python_error();
756
}
757

758
void THPStorage_assertNotNull(THPStorage* storage) {
759
  TORCH_CHECK(
760
      THPStorage_Unpack(storage).unsafeGetStorageImpl(), "Got a null Storage");
761
}
762

763
void THPStorage_assertNotNull(PyObject* obj) {
764
  THPStorage_assertNotNull((THPStorage*)obj);
765
}
766

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.