pytorch

Форк
0
/
python_tensor.cpp 
466 строк · 15.5 Кб
1
#include <torch/csrc/tensor/python_tensor.h>
2

3
#include <pybind11/pybind11.h>
4
#include <structmember.h>
5
#include <torch/csrc/utils/pybind.h>
6

7
#include <torch/csrc/Dtype.h>
8
#include <torch/csrc/DynamicTypes.h>
9
#include <torch/csrc/Exceptions.h>
10
#include <torch/csrc/Layout.h>
11
#include <torch/csrc/autograd/generated/VariableType.h>
12
#include <torch/csrc/autograd/python_variable.h>
13
#include <torch/csrc/autograd/utils/wrap_outputs.h>
14
#include <torch/csrc/autograd/variable.h>
15
#include <torch/csrc/utils/cuda_enabled.h>
16
#include <torch/csrc/utils/device_lazy_init.h>
17
#include <torch/csrc/utils/python_strings.h>
18
#include <torch/csrc/utils/tensor_new.h>
19
#include <torch/csrc/utils/tensor_types.h>
20

21
#include <ATen/ATen.h>
22

23
#include <sstream>
24
#include <string>
25
#include <type_traits>
26
#include <vector>
27

28
namespace torch::tensors {
29

30
using namespace at;
31
using namespace torch::autograd;
32

33
struct PyTensorType {
34
  PyTypeObject py_type;
35
  THPDtype* dtype;
36
  THPLayout* layout;
37
  bool is_cuda;
38
  bool is_xpu;
39
  // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers,modernize-avoid-c-arrays)
40
  char name[64];
41
  int backend;
42
  int scalar_type;
43

44
  Backend get_backend() const {
45
    return static_cast<Backend>(backend);
46
  }
47

48
  DispatchKey get_dispatch_key() const {
49
    return backendToDispatchKey(static_cast<Backend>(backend));
50
  }
51

52
  ScalarType get_scalar_type() const {
53
    return static_cast<ScalarType>(scalar_type);
54
  }
55
};
56

57
static_assert(
58
    std::is_standard_layout_v<PyTensorType>,
59
    "PyTensorType must be standard layout");
60

61
static Backend default_backend = Backend::CPU;
62

63
static void py_bind_tensor_types(
64
    const std::vector<PyTensorType*>& tensor_types);
65

66
static PyObject* Tensor_new(
67
    PyTypeObject* type,
68
    PyObject* args,
69
    PyObject* kwargs) {
70
  HANDLE_TH_ERRORS
71
  auto& tensor_type = *((PyTensorType*)type);
72
  TORCH_CHECK_TYPE(
73
      !tensor_type.is_cuda || torch::utils::cuda_enabled(),
74
      "type ",
75
      tensor_type.name,
76
      " not available. Torch not compiled with CUDA enabled.")
77
  if (tensor_type.is_cuda) {
78
    TORCH_WARN_ONCE(
79
        "The torch.cuda.*DtypeTensor constructors are no longer recommended. "
80
        "It's best to use methods such as torch.tensor(data, dtype=*, device='cuda') to create tensors.")
81
  }
82
  return THPVariable_Wrap(torch::utils::legacy_tensor_ctor(
83
      tensor_type.get_dispatch_key(),
84
      tensor_type.get_scalar_type(),
85
      args,
86
      kwargs));
87
  END_HANDLE_TH_ERRORS
88
}
89

90
// TODO: Deprecate this instancecheck entirely.  It's here to make
91
// instanceof(t, torch.FloatTensor) work, but we are not going to keep
92
// adding torch.QuantizedIntTensor classes for every new tensor type
93
// we add...
94
static PyObject* Tensor_instancecheck(PyObject* _self, PyObject* arg) {
95
  HANDLE_TH_ERRORS
96
  auto self = (PyTensorType*)_self;
97
  if (THPVariable_Check(arg)) {
98
    const auto& var = THPVariable_Unpack(arg);
99
    // NB: This is a little unfortunate, in that if I do an isinstance check
100
    // against torch.cuda.FloatTensor, this will immediately initialize CUDA.
101
    // I originally thought that it would not be possible for aten_type_ to
102
    // be nullptr if you had a tensor of some type, in which case you can
103
    // skip initializing aten_type(), but TestAutograd.test_type_conversions
104
    // seems to violate this property (for whatever reason.)
105
    //
106
    // TODO: Stop using legacyExtractDispatchKey here (probably need to build
107
    // in instanceof checking to Tensor class itself)
108
    if (legacyExtractDispatchKey(var.key_set()) == self->get_dispatch_key() &&
109
        var.scalar_type() == static_cast<ScalarType>(self->scalar_type)) {
110
      Py_RETURN_TRUE;
111
    }
112
  }
113
  Py_RETURN_FALSE;
114
  END_HANDLE_TH_ERRORS
115
}
116

117
static PyObject* Tensor_dtype(PyTensorType* self, void* unused) {
118
  return torch::autograd::utils::wrap(self->dtype);
119
}
120

121
static PyObject* Tensor_layout(PyTensorType* self, void* unused) {
122
  return torch::autograd::utils::wrap(self->layout);
123
}
124

125
static PyObject* Tensor_is_cuda(PyTensorType* self, void* unused) {
126
  if (self->is_cuda) {
127
    Py_RETURN_TRUE;
128
  } else {
129
    Py_RETURN_FALSE;
130
  }
131
}
132

133
static PyObject* Tensor_is_xpu(PyTensorType* self, void* unused) {
134
  if (self->is_xpu) {
135
    Py_RETURN_TRUE;
136
  } else {
137
    Py_RETURN_FALSE;
138
  }
139
}
140

141
static PyObject* Tensor_is_sparse(PyTensorType* self, void* unused) {
142
  if (self->layout->layout == at::Layout::Strided) {
143
    Py_RETURN_FALSE;
144
  } else {
145
    Py_RETURN_TRUE;
146
  }
147
}
148

149
static PyObject* Tensor_is_sparse_csr(PyTensorType* self, void* unused) {
150
  if (self->layout->layout == at::Layout::SparseCsr) {
151
    Py_RETURN_TRUE;
152
  } else {
153
    Py_RETURN_FALSE;
154
  }
155
}
156

157
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays)
158
static struct PyMethodDef metaclass_methods[] = {
159
    {"__instancecheck__", Tensor_instancecheck, METH_O, nullptr},
160
    {nullptr}};
161

162
typedef PyObject* (*getter)(PyObject*, void*);
163

164
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays)
165
static struct PyGetSetDef metaclass_properties[] = {
166
    {"dtype", (getter)Tensor_dtype, nullptr, nullptr, nullptr},
167
    {"layout", (getter)Tensor_layout, nullptr, nullptr, nullptr},
168
    {"is_cuda", (getter)Tensor_is_cuda, nullptr, nullptr, nullptr},
169
    {"is_xpu", (getter)Tensor_is_xpu, nullptr, nullptr, nullptr},
170
    {"is_sparse", (getter)Tensor_is_sparse, nullptr, nullptr, nullptr},
171
    {"is_sparse_csr", (getter)Tensor_is_sparse_csr, nullptr, nullptr, nullptr},
172
    {nullptr}};
173

174
static PyTypeObject metaclass = {
175
    PyVarObject_HEAD_INIT(nullptr, 0) "torch.tensortype", /* tp_name */
176
    sizeof(PyTypeObject) /* tp_basicsize */
177
};
178

179
static void py_initialize_metaclass(PyTypeObject& metaclass) {
180
  // NOLINTNEXTLINE(misc-redundant-expression)
181
  metaclass.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE;
182
  metaclass.tp_methods = metaclass_methods;
183
  metaclass.tp_getset = metaclass_properties;
184
  metaclass.tp_base = &PyType_Type;
185
  if (PyType_Ready(&metaclass) < 0) {
186
    throw python_error();
187
  }
188
}
189

190
static PyTypeObject tensor_type_prototype = {
191
    PyVarObject_HEAD_INIT(&metaclass, 0) nullptr, /* tp_name */
192
    sizeof(PyTensorType) /* tp_basicsize */
193
};
194

195
static void py_initialize_tensor_type(
196
    PyTypeObject& type,
197
    const char* name,
198
    PyObject* tp_dict) {
199
  // NOTE: we don't use the typical static declaration of PyTypeObject because
200
  // we need to initialize as many types as there are VariableType instances.
201
  // We copy the basic object fields from a prototype definition and initialize
202
  // the remaining fields below.
203
  memcpy(&type, &tensor_type_prototype, sizeof(PyTypeObject));
204
  // Subclassing from torch.<ScalarType>Tensor isn't supported.
205
  // (Py_TPFLAGS_BASETYPE omitted). Subclassing torch.Tensor still allowed.
206
  type.tp_flags = Py_TPFLAGS_DEFAULT;
207
  type.tp_name = name;
208
  type.tp_new = Tensor_new;
209
  if (PyType_Ready(&type) < 0) {
210
    throw python_error();
211
  }
212
  if (PyDict_Merge(type.tp_dict, tp_dict, 0) < 0) {
213
    throw python_error();
214
  }
215
}
216

217
static std::string get_name(Backend backend, ScalarType scalarType) {
218
  std::ostringstream ss;
219
  ss << torch::utils::backend_to_string(backend) << "." << toString(scalarType)
220
     << "Tensor";
221
  return ss.str();
222
}
223

224
static THPObjectPtr get_storage_obj(Backend backend, ScalarType dtype) {
225
  auto module_name = torch::utils::backend_to_string(backend);
226
  auto module_obj = THPObjectPtr(PyImport_ImportModule(module_name));
227
  if (!module_obj)
228
    throw python_error();
229

230
  auto storage_name = std::string(toString(dtype)) + "Storage";
231
  THPObjectPtr storage(
232
      PyObject_GetAttrString(module_obj.get(), storage_name.c_str()));
233
  TORCH_CHECK_TYPE(
234
      storage.get(), "couldn't find storage object ", storage_name);
235
  return storage;
236
}
237

238
static void set_type(
239
    PyTensorType& type_obj,
240
    Backend backend,
241
    ScalarType scalarType) {
242
  // This field is lazily initialized from backend and scalar_type
243
  type_obj.backend = static_cast<int>(backend);
244
  type_obj.scalar_type = static_cast<int>(scalarType);
245
  type_obj.layout = torch::getTHPLayout(layout_from_backend(backend));
246
  type_obj.dtype = torch::getTHPDtype(scalarType);
247
  type_obj.is_cuda =
248
      (backend == at::Backend::CUDA || backend == at::Backend::SparseCUDA);
249
  type_obj.is_xpu =
250
      (backend == at::Backend::XPU || backend == at::Backend::SparseXPU);
251
}
252

253
static void set_name(PyTensorType& type_obj, const std::string& name) {
254
  size_t n = sizeof(type_obj.name);
255
  strncpy(type_obj.name, name.c_str(), n);
256
  type_obj.name[n - 1] = '\0';
257
}
258

259
static THPObjectPtr get_tensor_dict() {
260
  auto torch = THPObjectPtr(PyImport_ImportModule("torch"));
261
  if (!torch)
262
    throw python_error();
263

264
  auto tensor_class = THPObjectPtr(PyObject_GetAttrString(torch, "Tensor"));
265
  if (!tensor_class)
266
    throw python_error();
267

268
  auto tensor_type = (PyTypeObject*)tensor_class.get();
269
  TORCH_CHECK(tensor_type->tp_base, "missing base type for Tensor");
270

271
  auto res = THPObjectPtr(PyDict_New());
272
  if (!res)
273
    throw python_error();
274

275
  if (PyDict_Merge(res.get(), tensor_type->tp_dict, 0) < 0) {
276
    throw python_error();
277
  }
278
  if (PyDict_Merge(res.get(), tensor_type->tp_base->tp_dict, 0) < 0) {
279
    throw python_error();
280
  }
281

282
  return res;
283
}
284

285
// A note about the lifetime of the various PyTensorType: normally
286
// PyTypeObject instances are statically allocated, but we want to create them
287
// dynamically at init time, because their exact number depends on
288
// torch::utils::all_declared_types(). The memory for each PyTensorType is
289
// allocated by initialize_aten_types() and never freed: technically it's a
290
// leak, but it's not a problem since we want them to be alive for the whole
291
// time of the process anyway.
292
//
293
// An alternative is to use a std::vector<PyTensorType> instead, and let
294
// std::vector to manage the lifetime of its items. This is problematic
295
// though, because it means that the memory of PyTensorType is deallocated at
296
// some point during the exit: if by chance we have another global destructor
297
// and/or atexit() function which tries to access the PyTensorTypes, we risk
298
// an use-after-free error. This happens for example if we embed CPython and
299
// call Py_Finalize inside an atexit() function which was registered before
300
// importing torch.
301
static std::vector<PyTensorType*> tensor_types;
302

303
static void set_default_storage_type(Backend backend, ScalarType dtype) {
304
  THPObjectPtr storage = get_storage_obj(backend, dtype);
305

306
  auto torch_module = THPObjectPtr(PyImport_ImportModule("torch"));
307
  if (!torch_module)
308
    throw python_error();
309

310
  if (PyObject_SetAttrString(torch_module.get(), "Storage", storage) != 0) {
311
    throw python_error();
312
  }
313
}
314

315
static void set_default_tensor_type(
316
    c10::optional<Backend> backend,
317
    c10::optional<ScalarType> dtype) {
318
  if (backend.has_value()) {
319
    TORCH_CHECK_TYPE(
320
        *backend != Backend::Undefined, "default type cannot be undefined");
321
    TORCH_CHECK_TYPE(
322
        !isSparse(*backend),
323
        "only dense types are supported as the default type");
324
  }
325
  if (dtype.has_value()) {
326
    TORCH_CHECK_TYPE(
327
        at::isFloatingType(*dtype),
328
        "only floating-point types are supported as the default type");
329
  }
330

331
  // Try setting default storage in python first as it's the only operation that
332
  // can fail
333
  set_default_storage_type(
334
      backend.value_or(default_backend),
335
      dtype.value_or(at::get_default_dtype_as_scalartype()));
336

337
  if (dtype.has_value()) {
338
    at::set_default_dtype(scalarTypeToTypeMeta(*dtype));
339
  }
340
  if (backend.has_value()) {
341
    default_backend = *backend;
342
  }
343
}
344

345
static void initialize_aten_types(std::vector<PyTensorType*>& tensor_types) {
346
  // includes CUDA types even when PyTorch is not built with CUDA
347
  auto declared_types = torch::utils::all_declared_types();
348
  tensor_types.resize(declared_types.size());
349

350
  for (size_t i = 0, end = declared_types.size(); i != end; i++) {
351
    tensor_types[i] = new PyTensorType();
352
    auto& tensor_type = *tensor_types[i];
353
    Backend backend = declared_types[i].first;
354
    ScalarType scalar_type = declared_types[i].second;
355
    set_type(tensor_type, backend, scalar_type);
356
    set_name(tensor_type, get_name(backend, scalar_type));
357
  }
358

359
  set_default_tensor_type(Backend::CPU, ScalarType::Float);
360
}
361

362
void initialize_python_bindings() {
363
  // Initialize the at::Type* pointers, name, and properties of the PyTensorType
364
  // vector. After this call, the vector must not be resized.
365
  initialize_aten_types(tensor_types);
366

367
  // Initialize the Python metaclass for the torch.FloatTensor, etc. types.
368
  // The metaclass handles __instancecheck__ checks and binds the dtype property
369
  // on the type objects.
370
  py_initialize_metaclass(metaclass);
371

372
  // Get the tp_dict of the Variable class. We copy function definitions
373
  // onto each Tensor type object so that they can be accessed via e.g.
374
  // `torch.FloatTensor.add`.
375
  auto tensor_dict = get_tensor_dict();
376

377
  // Initialize each Python type object torch.FloatTensor, torch.DoubleTensor,
378
  // etc.
379
  for (auto& tensor_type : tensor_types) {
380
    py_initialize_tensor_type(
381
        tensor_type->py_type, tensor_type->name, tensor_dict.get());
382
  }
383

384
  // Add the type objects to their corresponding modules. e.g. torch.FloatTensor
385
  // is added to the `torch` module as `FloatTensor`. Also add all the type
386
  // objects to the set torch._tensor_classes.
387
  py_bind_tensor_types(tensor_types);
388
}
389

390
static void py_bind_tensor_types(
391
    const std::vector<PyTensorType*>& tensor_types) {
392
  auto torch_module = THPObjectPtr(PyImport_ImportModule("torch"));
393
  if (!torch_module)
394
    throw python_error();
395

396
  auto tensor_classes = THPObjectPtr(
397
      PyObject_GetAttrString(torch_module.get(), "_tensor_classes"));
398
  if (!tensor_classes)
399
    throw python_error();
400

401
  for (auto& tensor_type : tensor_types) {
402
    auto name = std::string(tensor_type->name);
403
    auto idx = name.rfind('.');
404
    auto type_name = name.substr(idx + 1);
405
    auto module_name = name.substr(0, idx);
406

407
    auto module_obj = THPObjectPtr(PyImport_ImportModule(module_name.c_str()));
408
    if (!module_obj)
409
      throw python_error();
410

411
    PyObject* type_obj = (PyObject*)tensor_type;
412
    Py_INCREF(type_obj);
413
    if (PyModule_AddObject(module_obj.get(), type_name.c_str(), type_obj) < 0) {
414
      throw python_error();
415
    }
416
    if (PySet_Add(tensor_classes.get(), type_obj) < 0) {
417
      throw python_error();
418
    }
419
  }
420
}
421

422
static bool PyTensorType_Check(PyObject* obj) {
423
  auto it = std::find_if(
424
      tensor_types.begin(), tensor_types.end(), [obj](PyTensorType* x) {
425
        return (PyObject*)x == obj;
426
      });
427
  return it != tensor_types.end();
428
}
429

430
void py_set_default_tensor_type(PyObject* obj) {
431
  TORCH_WARN_ONCE(
432
      "torch.set_default_tensor_type() is deprecated as of PyTorch 2.1, "
433
      "please use torch.set_default_dtype() and torch.set_default_device() as alternatives.")
434
  TORCH_CHECK_TYPE(
435
      PyTensorType_Check(obj),
436
      "invalid type object: only floating-point types are supported as the default type");
437
  PyTensorType* type = (PyTensorType*)obj;
438
  TORCH_CHECK_TYPE(
439
      !type->is_cuda || torch::utils::cuda_enabled(),
440
      "type ",
441
      type->name,
442
      " not available. Torch not compiled with CUDA enabled.")
443
  set_default_tensor_type(type->get_backend(), type->get_scalar_type());
444
}
445

446
void py_set_default_dtype(PyObject* obj) {
447
  TORCH_CHECK_TYPE(
448
      THPDtype_Check(obj),
449
      "invalid dtype object: only floating-point types are supported as the default type");
450
  auto scalar_type = ((THPDtype*)obj)->scalar_type;
451
  set_default_tensor_type(/*backend=*/c10::nullopt, scalar_type);
452
}
453

454
c10::DispatchKey get_default_dispatch_key() {
455
  return backendToDispatchKey(default_backend);
456
}
457

458
at::Device get_default_device() {
459
  return at::Device(c10::backendToDeviceType(default_backend));
460
}
461

462
ScalarType get_default_scalar_type() {
463
  return get_default_dtype_as_scalartype();
464
}
465

466
} // namespace torch::tensors
467

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.