pytorch

Форк
0
/
Module.cpp 
2137 строк · 68.1 Кб
1
#include <c10/util/Optional.h>
2
#include <fmt/core.h>
3
#include <sys/types.h>
4
#include <torch/csrc/python_headers.h>
5

6
#ifndef _MSC_VER
7
#include <sys/socket.h>
8
#endif
9

10
#include <ATen/ATen.h>
11
#include <ATen/DLConvertor.h>
12
#include <ATen/ExpandUtils.h>
13
#include <ATen/LegacyVmapMode.h>
14
#include <ATen/LinalgBackend.h>
15
#include <ATen/Parallel.h>
16
#include <ATen/Utils.h>
17
#include <ATen/core/Vitals.h>
18
#include <ATen/dlpack.h>
19
#include <ATen/native/ConvUtils.h>
20
#include <ATen/native/ForeachUtils.h>
21
#include <c10/core/DispatchKeySet.h>
22
#include <c10/util/AbortHandler.h>
23
#include <c10/util/Backtrace.h>
24
#include <c10/util/Logging.h>
25
#include <c10/util/irange.h>
26
#include <libshm.h>
27
#include <pybind11/pybind11.h>
28
#include <pybind11/stl.h>
29
#include <torch/csrc/THConcat.h>
30
#include <torch/csrc/utils/pybind.h>
31
#include <cstdlib>
32
#include <iostream>
33
#include <unordered_map>
34

35
#include <ATen/ThreadLocalPythonObjects.h>
36
#include <torch/csrc/DataLoader.h>
37
#include <torch/csrc/Device.h>
38
#include <torch/csrc/Dtype.h>
39
#include <torch/csrc/DynamicTypes.h>
40
#include <torch/csrc/Generator.h>
41
#include <torch/csrc/Layout.h>
42
#include <torch/csrc/MemoryFormat.h>
43
#include <torch/csrc/QScheme.h>
44
#include <torch/csrc/Stream.h>
45
#include <torch/csrc/THP.h>
46
#include <torch/csrc/TypeInfo.h>
47
#include <torch/csrc/api/include/torch/python/init.h>
48
#include <torch/csrc/autograd/generated/python_return_types.h>
49
#include <torch/csrc/autograd/python_cpp_function.h>
50
#include <torch/csrc/autograd/python_enum_tag.h>
51
#include <torch/csrc/autograd/python_fft_functions.h>
52
#include <torch/csrc/autograd/python_function.h>
53
#include <torch/csrc/autograd/python_legacy_variable.h>
54
#include <torch/csrc/autograd/python_linalg_functions.h>
55
#include <torch/csrc/autograd/python_nested_functions.h>
56
#include <torch/csrc/autograd/python_nn_functions.h>
57
#include <torch/csrc/autograd/python_sparse_functions.h>
58
#include <torch/csrc/autograd/python_special_functions.h>
59
#include <torch/csrc/autograd/python_variable.h>
60
#include <torch/csrc/cpu/Module.h>
61
#include <torch/csrc/dynamo/init.h>
62
#include <torch/csrc/functorch/init.h>
63
#include <torch/csrc/inductor/aoti_runner/pybind.h>
64
#include <torch/csrc/jit/python/init.h>
65
#include <torch/csrc/jit/python/python_ir.h>
66
#include <torch/csrc/jit/python/python_tracer.h>
67
#include <torch/csrc/jit/serialization/pickler.h>
68
#include <torch/csrc/lazy/python/init.h>
69
#include <torch/csrc/monitor/python_init.h>
70
#include <torch/csrc/mps/Module.h>
71
#include <torch/csrc/multiprocessing/init.h>
72
#include <torch/csrc/onnx/init.h>
73
#include <torch/csrc/profiler/python/init.h>
74
#include <torch/csrc/tensor/python_tensor.h>
75
#include <torch/csrc/utils/disable_torch_function.h>
76
#include <torch/csrc/utils/init.h>
77
#include <torch/csrc/utils/pycfunction_helpers.h>
78
#include <torch/csrc/utils/python_arg_parser.h>
79
#include <torch/csrc/utils/python_compat.h>
80
#include <torch/csrc/utils/python_dispatch.h>
81
#include <torch/csrc/utils/python_strings.h>
82
#include <torch/csrc/utils/tensor_dtypes.h>
83
#include <torch/csrc/utils/tensor_layouts.h>
84
#include <torch/csrc/utils/tensor_memoryformats.h>
85
#include <torch/csrc/utils/tensor_new.h>
86
#include <torch/csrc/utils/tensor_numpy.h>
87
#include <torch/csrc/utils/tensor_qschemes.h>
88
#include <torch/csrc/utils/verbose.h>
89

90
#include <ATen/native/transformers/sdp_utils_cpp.h>
91
#include <torch/csrc/profiler/combined_traceback.h>
92
#include <sstream>
93
#ifdef USE_CUDA
94
#include <ATen/native/transformers/cuda/sdp_utils.h>
95
#endif
96

97
#ifdef USE_DISTRIBUTED
98
#ifdef USE_C10D
99
#include <torch/csrc/distributed/autograd/python_autograd.h>
100
#include <torch/csrc/distributed/c10d/c10d.h>
101
#include <torch/csrc/distributed/rpc/rpc.h>
102
#include <torch/csrc/distributed/rpc/testing/testing.h>
103
#endif
104
#endif
105

106
#if defined(USE_VALGRIND)
107
#include <callgrind.h>
108
#endif
109

110
namespace py = pybind11;
111

112
PyObject* module;
113

114
THPGenerator* THPDefaultCPUGenerator = nullptr;
115

116
////////////////////////////////////////////////////////////////////////////////
117
////////////////////////////////////////////////////////////////////////////////
118

119
static PyObject* THPModule_initNames(PyObject* self, PyObject* arg) {
120
  HANDLE_TH_ERRORS
121
  static std::vector<std::string> names;
122

123
  THPObjectPtr types(PySequence_Fast(arg, "expected a sequence"));
124
  if (!types)
125
    return nullptr;
126

127
  // NOLINTNEXTLINE(bugprone-branch-clone)
128
  auto num_classes = PySequence_Fast_GET_SIZE(types.get());
129
  names.reserve(names.size() + num_classes);
130
  for (Py_ssize_t i = 0; i < num_classes; i++) {
131
    PyObject* obj = PySequence_Fast_GET_ITEM(types.get(), i);
132
    TORCH_CHECK(PyType_Check(obj), "expected a PyTypeObject");
133
    PyTypeObject* type = (PyTypeObject*)obj;
134

135
    THPObjectPtr module_name(PyObject_GetAttrString(obj, "__module__"));
136
    if (!module_name)
137
      return nullptr;
138
    TORCH_CHECK(
139
        THPUtils_checkString(module_name.get()),
140
        "expected __module__ to be a string");
141
    std::string name = THPUtils_unpackString(module_name.get());
142
    names.emplace_back(name + "." + type->tp_name);
143
    type->tp_name = names.back().c_str();
144
  }
145
  Py_RETURN_NONE;
146
  END_HANDLE_TH_ERRORS
147
}
148
//
149
// Callback for python part. Used for additional initialization of python
150
// classes
151
static PyObject* THPModule_initExtension(
152
    PyObject* _unused,
153
    PyObject* shm_manager_path) {
154
  HANDLE_TH_ERRORS
155
#if !defined(FBCODE_CAFFE2)
156
  if (torch::get_cpp_stacktraces_enabled() && !torch::get_disable_addr2line()) {
157
    c10::SetStackTraceFetcher([]() -> std::string {
158
      auto tb = torch::CapturedTraceback::gather(false, false, true);
159
      LOG(WARNING)
160
          << "symbolizing C++ stack trace for exception; if this hangs, rerun with TORCH_DISABLE_ADDR2LINE=1..."
161
          << std::endl;
162
      auto s_tbs = torch::symbolize({tb.get()});
163
      std::stringstream oss;
164
      oss << "C++ CapturedTraceback:" << std::endl;
165
      const auto& s_tb = s_tbs.tracebacks.at(0);
166
      for (auto idx : c10::irange(s_tb.size())) {
167
        // Skip the first few frames:
168
        //  #1 torch::CapturedTraceback::gather(bool, bool, bool)
169
        //  #2 THPModule_initExtension
170
        //  #3 THPModule_initExtension(_object*, _object*)::{lambda()#1}
171
        if (idx <= 3) {
172
          continue;
173
        }
174
        auto frame_id = s_tb[idx];
175
        const auto& frame = s_tbs.all_frames.at(frame_id);
176
        oss << "#" << idx << " " << frame.funcname << " from " << frame.filename
177
            << ":" << frame.lineno << std::endl;
178
      }
179
      return oss.str();
180
    });
181
  }
182
#endif
183
  if (!THPUtils_checkString(shm_manager_path)) {
184
    THPUtils_setError(
185
        "initialization error - expected bytes/string object as shm_manager_path!");
186
    return nullptr;
187
  }
188
  torch::utils::initializeLayouts();
189
  torch::utils::initializeMemoryFormats();
190
  torch::utils::initializeQSchemes();
191
  torch::utils::initializeDtypes();
192
  torch::tensors::initialize_python_bindings();
193
  std::string path = THPUtils_unpackString(shm_manager_path);
194
  libshm_init(path.c_str());
195

196
  auto module = THPObjectPtr(PyImport_ImportModule("torch"));
197
  if (!module)
198
    throw python_error();
199

200
  THPStorage_postInit(module);
201
  THPAutograd_initFunctions();
202
  Py_RETURN_NONE;
203
  END_HANDLE_TH_ERRORS
204
}
205

206
// The idea behind these two functions is to make it easy to test if we are
207
// built with ASAN: they're designed not to crash if ASAN is not enabled, but
208
// to trigger ASAN if it is enabled.  This lets us run a "canary" tests which
209
// checks if our build environment is misconfigured.
210

211
static PyObject* THPModule_crashIfCsrcASAN(PyObject* module, PyObject* arg) {
212
  HANDLE_TH_ERRORS
213
  TORCH_CHECK(
214
      THPUtils_checkLong(arg),
215
      "crash_if_csrc_asan expects an int, but got ",
216
      THPUtils_typename(arg));
217
  // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays, modernize-avoid-c-arrays)
218
  volatile char x[3];
219
  x[THPUtils_unpackInt(arg)] = 0;
220
  // NOLINTNEXTLINE(clang-analyzer-core.CallAndMessage)
221
  return THPUtils_packInt32(x[0]);
222
  END_HANDLE_TH_ERRORS
223
}
224

225
static PyObject* THPModule_crashIfCsrcUBSAN(PyObject* module, PyObject* arg) {
226
  HANDLE_TH_ERRORS
227
  TORCH_CHECK(
228
      THPUtils_checkLong(arg),
229
      "crash_if_csrc_ubsan expects an int, but got ",
230
      THPUtils_typename(arg));
231
  int32_t x = THPUtils_unpackInt(arg);
232
  double y = 1.0 / x;
233
  return THPUtils_packInt32((int)y);
234
  END_HANDLE_TH_ERRORS
235
}
236

237
static PyObject* THPModule_crashIfvptrUBSAN(PyObject* module, PyObject* noarg) {
238
  // This code should work perfectly fine, as vtables are identical for Foo and
239
  // Baz unless rtti and ubsan are enabled
240
  struct Foo {
241
    virtual int bar() = 0;
242
    virtual ~Foo() = default;
243
  };
244
  struct Baz {
245
    virtual int bar() {
246
      return 17;
247
    }
248
    virtual ~Baz() = default;
249
  };
250
  Baz x{};
251
  auto y = static_cast<Foo*>(static_cast<void*>(&x));
252
  auto rc = y->bar();
253
  return THPUtils_packInt32(rc);
254
}
255

256
static PyObject* THPModule_crashIfATenASAN(PyObject* module, PyObject* arg) {
257
  HANDLE_TH_ERRORS
258
  TORCH_CHECK(
259
      THPUtils_checkLong(arg),
260
      "crash_if_aten_asan expects an int, "
261
      "but got ",
262
      THPUtils_typename(arg));
263
  return THPUtils_packInt32(at::_crash_if_asan(THPUtils_unpackInt(arg)));
264
  END_HANDLE_TH_ERRORS
265
}
266

267
static PyObject* THPModule_abort(PyObject* module, PyObject* noargs) {
268
  std::terminate();
269
  Py_RETURN_NONE;
270
}
271

272
static PyObject* THPModule_crashIfDebugAssertsFail(
273
    PyObject* module,
274
    PyObject* arg) {
275
  HANDLE_TH_ERRORS
276
  TORCH_CHECK(
277
      THPUtils_checkLong(arg),
278
      "crash_if_debug_asserts_fail expects an int, but got ",
279
      THPUtils_typename(arg));
280
  TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
281
      THPUtils_unpackInt(arg) != 424242,
282
      "Expect anything but 424242 as an input for debug builds");
283
  return THPUtils_packInt32(0);
284
  END_HANDLE_TH_ERRORS
285
}
286

287
static PyObject* THPModule_getNumThreads(PyObject* module, PyObject* noargs) {
288
  return THPUtils_packInt32(at::get_num_threads());
289
}
290

291
static PyObject* THPModule_setNumThreads(PyObject* module, PyObject* arg) {
292
  HANDLE_TH_ERRORS
293
  TORCH_CHECK(
294
      THPUtils_checkLong(arg),
295
      "set_num_threads expects an int, but got ",
296
      THPUtils_typename(arg));
297
  int nthreads = (int)THPUtils_unpackLong(arg);
298
  TORCH_CHECK(nthreads > 0, "set_num_threads expects a positive integer");
299
  at::set_num_threads(nthreads);
300
  Py_RETURN_NONE;
301
  END_HANDLE_TH_ERRORS
302
}
303

304
static PyObject* THPModule_getNumInteropThreads(
305
    PyObject* module,
306
    PyObject* noargs) {
307
  return THPUtils_packInt32(at::get_num_interop_threads());
308
}
309

310
static PyObject* THPModule_setNumInteropThreads(
311
    PyObject* module,
312
    PyObject* arg) {
313
  HANDLE_TH_ERRORS
314
  TORCH_CHECK(
315
      THPUtils_checkLong(arg),
316
      "set_num_interop_threads expects an int, "
317
      "but got ",
318
      THPUtils_typename(arg));
319
  int nthreads = (int)THPUtils_unpackLong(arg);
320
  TORCH_CHECK(
321
      nthreads > 0, "set_num_interop_threads expects a positive integer");
322
  at::set_num_interop_threads(nthreads);
323
  Py_RETURN_NONE;
324
  END_HANDLE_TH_ERRORS
325
}
326

327
PyObject* THPModule_setDefaultTensorType(PyObject* _unused, PyObject* type) {
328
  HANDLE_TH_ERRORS
329
  torch::tensors::py_set_default_tensor_type(type);
330
  Py_RETURN_NONE;
331
  END_HANDLE_TH_ERRORS
332
}
333

334
PyObject* THPModule_setDefaultDtype(PyObject* _unused, PyObject* dtype) {
335
  HANDLE_TH_ERRORS
336
  torch::tensors::py_set_default_dtype(dtype);
337
  Py_RETURN_NONE;
338
  END_HANDLE_TH_ERRORS
339
}
340

341
PyObject* THPModule_swap_tensor_impl(PyObject* _unused, PyObject* args) {
342
  HANDLE_TH_ERRORS
343
  PyObject* a_ = nullptr;
344
  PyObject* b_ = nullptr;
345
  if (!PyArg_ParseTuple(args, "OO", &a_, &b_)) {
346
    return nullptr;
347
  }
348

349
  // Ensure we have Tensors
350
  TORCH_CHECK(THPVariable_Check(a_));
351
  TORCH_CHECK(THPVariable_Check(b_));
352

353
  THPVariable* a = reinterpret_cast<THPVariable*>(a_);
354
  THPVariable* b = reinterpret_cast<THPVariable*>(b_);
355

356
  TORCH_CHECK(
357
      a->cdata->use_count() == 1,
358
      "Expected single reference to a's Tensor object but got ",
359
      a->cdata->use_count());
360
  TORCH_CHECK(
361
      b->cdata->use_count() == 1,
362
      "Expected single reference to b's Tensor object but got ",
363
      b->cdata->use_count());
364
  // weak_use_count() adds 1 if use_count is non-zero
365
  TORCH_CHECK(
366
      a->cdata->weak_use_count() == 1,
367
      "Expected no weakrefs to a's Tensor object but got  ",
368
      a->cdata->weak_use_count() - 1);
369
  TORCH_CHECK(
370
      b->cdata->weak_use_count() == 1,
371
      "Expected no weakrefs to b's Tensor object but got  ",
372
      b->cdata->weak_use_count() - 1);
373

374
  // Swap the Tensor Impl
375
  c10::MaybeOwned<at::Tensor> tmp = a->cdata;
376

377
  // The TensorImpls contain PyObjectSlots that have a reference to the PyObject
378
  // associated with the TensorImpl. Swap this field as well.
379
  c10::optional<PyObject*> mb_obj_a =
380
      a->cdata->unsafeGetTensorImpl()->pyobj_slot()->check_pyobj(
381
          getPyInterpreter(), /*ignore_hermetic_tls=*/false);
382
  c10::optional<PyObject*> mb_obj_b =
383
      b->cdata->unsafeGetTensorImpl()->pyobj_slot()->check_pyobj(
384
          getPyInterpreter(), /*ignore_hermetic_tls=*/false);
385
  TORCH_INTERNAL_ASSERT(
386
      mb_obj_a.has_value() && mb_obj_b.has_value(),
387
      "Both tensors should have PyObjects tagged by the current python interpreter");
388
  TORCH_CHECK(mb_obj_a.value() == a_);
389
  TORCH_CHECK(mb_obj_b.value() == b_);
390

391
  a->cdata = b->cdata;
392
  b->cdata = tmp;
393

394
  a->cdata->unsafeGetTensorImpl()->pyobj_slot()->init_pyobj(
395
      getPyInterpreter(), a_, c10::impl::PyInterpreterStatus::TAGGED_BY_US);
396
  b->cdata->unsafeGetTensorImpl()->pyobj_slot()->init_pyobj(
397
      getPyInterpreter(), b_, c10::impl::PyInterpreterStatus::TAGGED_BY_US);
398

399
  Py_RETURN_NONE;
400
  END_HANDLE_TH_ERRORS
401
}
402

403
PyObject* THPModule_addDocStr(PyObject* _unused, PyObject* args) {
404
  // adds a __doc__ string to a function, similar to numpy's arr_add_docstring
405
  static std::vector<std::string> all_docs;
406
  PyObject* obj = nullptr;
407
  PyObject* doc_obj = nullptr;
408
  if (!PyArg_ParseTuple(args, "OO", &obj, &doc_obj)) {
409
    return nullptr;
410
  }
411

412
  const char* doc_str = "<invalid string>";
413
  if (THPUtils_checkString(doc_obj)) {
414
    all_docs.push_back(THPUtils_unpackString(doc_obj));
415
    doc_str = all_docs.back().c_str();
416
  }
417

418
  if (Py_TYPE(obj) == &PyCFunction_Type) {
419
    PyCFunctionObject* f = (PyCFunctionObject*)obj;
420
    if (f->m_ml->ml_doc) {
421
      return PyErr_Format(
422
          PyExc_RuntimeError,
423
          "function '%s' already has a docstring",
424
          f->m_ml->ml_name);
425
    }
426
    f->m_ml->ml_doc = doc_str;
427
  } else if (strcmp(Py_TYPE(obj)->tp_name, "method_descriptor") == 0) {
428
    PyMethodDescrObject* m = (PyMethodDescrObject*)obj;
429
    if (m->d_method->ml_doc) {
430
      return PyErr_Format(
431
          PyExc_RuntimeError,
432
          "method '%s' already has a docstring",
433
          m->d_method->ml_name);
434
    }
435
    m->d_method->ml_doc = doc_str;
436
  } else if (strcmp(Py_TYPE(obj)->tp_name, "getset_descriptor") == 0) {
437
    // NOLINTNEXTLINE(cppcoreguidelines-pro-type-cstyle-cast)
438
    PyGetSetDescrObject* m = (PyGetSetDescrObject*)obj;
439
    if (m->d_getset->doc) {
440
      return PyErr_Format(
441
          PyExc_RuntimeError,
442
          "attribute '%s' already has a docstring",
443
          m->d_getset->name);
444
    }
445
    m->d_getset->doc = doc_str;
446
  } else if (Py_TYPE(obj) == &PyType_Type) {
447
    PyTypeObject* t = (PyTypeObject*)obj;
448
    if (t->tp_doc) {
449
      return PyErr_Format(
450
          PyExc_RuntimeError, "Type '%s' already has a docstring", t->tp_name);
451
    }
452
    t->tp_doc = doc_str;
453
  } else {
454
    return PyErr_Format(
455
        PyExc_TypeError,
456
        "don't know how to add docstring to type '%s'",
457
        Py_TYPE(obj)->tp_name);
458
  }
459

460
  Py_INCREF(obj);
461
  return obj;
462
}
463

464
PyObject* THPModule_inferSize(PyObject* _unused, PyObject* args) {
465
  HANDLE_TH_ERRORS
466
  Py_ssize_t num_args = args ? (Py_ssize_t)PyTuple_Size(args) : 0;
467
  TORCH_CHECK(num_args == 2, "expected exactly 2 arguments");
468
  PyObject* arg1 = PyTuple_GET_ITEM(args, 0);
469
  TORCH_CHECK(THPSize_Check(arg1), "expected a torch.Size as argument 1");
470
  PyObject* arg2 = PyTuple_GET_ITEM(args, 1);
471
  TORCH_CHECK(THPSize_Check(arg2), "expected a torch.Size as argument 2");
472

473
  auto size1 = THPUtils_unpackLongs(arg1);
474
  auto size2 = THPUtils_unpackLongs(arg2);
475
  auto sizes = at::infer_size(size1, size2);
476
  return THPSize_NewFromSizes(static_cast<int64_t>(sizes.size()), sizes.data());
477
  END_HANDLE_TH_ERRORS
478
}
479

480
static PyObject* THPModule_setBackcompatBroadcastWarn(
481
    PyObject* module,
482
    PyObject* arg) {
483
  HANDLE_TH_ERRORS
484
  TORCH_CHECK(
485
      PyBool_Check(arg),
486
      "set_backcompat_broadcast_warn expects a bool, "
487
      "but got ",
488
      THPUtils_typename(arg));
489
  setBackCompatBroadcastWarn(arg == Py_True);
490
  Py_RETURN_NONE;
491
  END_HANDLE_TH_ERRORS
492
}
493

494
static PyObject* THPModule_getBackcompatBroadcastWarn(
495
    PyObject* module,
496
    PyObject* noargs) {
497
  if (getBackCompatBroadcastWarn())
498
    Py_RETURN_TRUE;
499
  else
500
    Py_RETURN_FALSE;
501
}
502

503
static PyObject* THPModule_setBackcompatKeepdimWarn(
504
    PyObject* module,
505
    PyObject* arg) {
506
  HANDLE_TH_ERRORS
507
  TORCH_CHECK(
508
      PyBool_Check(arg),
509
      "set_backcompat_keepdim_warn expects a bool, "
510
      "but got ",
511
      THPUtils_typename(arg));
512
  setBackCompatKeepdimWarn(arg == Py_True);
513
  Py_RETURN_NONE;
514
  END_HANDLE_TH_ERRORS
515
}
516

517
static PyObject* THPModule_getBackcompatKeepdimWarn(
518
    PyObject* module,
519
    PyObject* noargs) {
520
  if (getBackCompatKeepdimWarn())
521
    Py_RETURN_TRUE;
522
  else
523
    Py_RETURN_FALSE;
524
}
525

526
PyObject* THPModule_hasDistributed(PyObject* _unused, PyObject* noargs) {
527
#ifdef USE_DISTRIBUTED
528
  Py_RETURN_TRUE;
529
#else
530
  Py_RETURN_FALSE;
531
#endif
532
}
533

534
static PyObject* THPModule_showConfig(PyObject* module, PyObject* noargs) {
535
  HANDLE_TH_ERRORS
536
  return THPUtils_packString(at::show_config());
537
  END_HANDLE_TH_ERRORS
538
}
539

540
static PyObject* THPModule_cxxFlags(PyObject* module, PyObject* noargs) {
541
  HANDLE_TH_ERRORS
542
  return THPUtils_packString(at::get_cxx_flags());
543
  END_HANDLE_TH_ERRORS
544
}
545

546
static PyObject* THPModule_parallelInfo(PyObject* module, PyObject* noargs) {
547
  HANDLE_TH_ERRORS
548
  return THPUtils_packString(at::get_parallel_info());
549
  END_HANDLE_TH_ERRORS
550
}
551

552
static PyObject* THPModule_getCpuCapability(
553
    PyObject* module,
554
    PyObject* noargs) {
555
  HANDLE_TH_ERRORS
556
  return THPUtils_packString(at::get_cpu_capability());
557
  END_HANDLE_TH_ERRORS
558
}
559

560
void DLPack_Capsule_Destructor(PyObject* data) {
561
  if (C10_LIKELY(!PyCapsule_IsValid(data, "dltensor"))) {
562
    // early out, see DLPack spec: if a consuming library sets the capsule
563
    // name to something else, they own it and we don't need to do anything
564
    return;
565
  }
566
  HANDLE_TH_ERRORS
567
  // Causes overheads for validity checks again, but this case is rare
568
  // since consuming libraries should rename the capsule according to spec.
569
  // Note that this cannot set a python error (we checked validity above),
570
  // so we don't need to handle python error state here.
571
  DLManagedTensor* dlMTensor =
572
      (DLManagedTensor*)PyCapsule_GetPointer(data, "dltensor");
573
  // the dlMTensor has not been consumed, call deleter ourselves.
574
  // DLPack spec mentions that deleter may be NULL, but deleter from
575
  // `at::toDLPack` is never NULL, so no need for an additional check here.
576
  dlMTensor->deleter(dlMTensor);
577
  END_HANDLE_TH_ERRORS_RET()
578
}
579

580
PyObject* THPModule_toDLPack(PyObject* _unused, PyObject* data) {
581
  HANDLE_TH_ERRORS
582
  TORCH_CHECK(THPVariable_Check(data), "data must be a Tensor");
583
  DLManagedTensor* dlMTensor = at::toDLPack(THPVariable_Unpack(data));
584
  return PyCapsule_New(dlMTensor, "dltensor", DLPack_Capsule_Destructor);
585
  END_HANDLE_TH_ERRORS
586
}
587

588
PyObject* THPModule_fromDLPack(PyObject* _unused, PyObject* data) {
589
  using namespace torch::autograd;
590
  HANDLE_TH_ERRORS
591
  auto tensor = torch::utils::tensor_fromDLPack(data);
592
  return THPVariable_Wrap(tensor);
593
  END_HANDLE_TH_ERRORS
594
}
595

596
PyObject* THModule_getCppBacktrace(PyObject* _unused, PyObject* args) {
597
  HANDLE_TH_ERRORS
598
  size_t frames_to_skip = 0;
599
  size_t maximum_number_of_frames = 0;
600
  if (!PyArg_ParseTuple(
601
          args, "LL", &frames_to_skip, &maximum_number_of_frames)) {
602
    return nullptr;
603
  }
604
  return THPUtils_packString(
605
      c10::get_backtrace(frames_to_skip, maximum_number_of_frames, true));
606
  END_HANDLE_TH_ERRORS
607
}
608

609
static PyObject* THModule_rename_privateuse1_backend(
610
    PyObject* _unused,
611
    PyObject* arg) {
612
  HANDLE_TH_ERRORS
613
  TORCH_CHECK(
614
      THPUtils_checkString(arg),
615
      "_rename_privateuse1_backend expects a str, but got ",
616
      THPUtils_typename(arg));
617
  const std::string backend_name = THPUtils_unpackString(arg);
618
  c10::register_privateuse1_backend(backend_name);
619
  Py_RETURN_NONE;
620
  END_HANDLE_TH_ERRORS
621
}
622

623
static PyObject* THModule_get_privateuse1_backend_name(
624
    PyObject* _unused,
625
    PyObject* arg) {
626
  HANDLE_TH_ERRORS
627
  return THPUtils_packString(c10::get_privateuse1_backend());
628
  END_HANDLE_TH_ERRORS
629
}
630

631
PyObject* THPModule_setAllowTF32CuDNN(PyObject* _unused, PyObject* arg) {
632
  HANDLE_TH_ERRORS
633
  TORCH_CHECK(
634
      PyBool_Check(arg),
635
      "set_allow_tf32_cublas expects a bool, "
636
      "but got ",
637
      THPUtils_typename(arg));
638
  at::globalContext().setAllowTF32CuDNN(arg == Py_True);
639
  Py_RETURN_NONE;
640
  END_HANDLE_TH_ERRORS
641
}
642

643
PyObject* THPModule_allowTF32CuDNN(PyObject* _unused, PyObject* noargs) {
644
  if (at::globalContext().allowTF32CuDNN())
645
    Py_RETURN_TRUE;
646
  else
647
    Py_RETURN_FALSE;
648
}
649

650
PyObject* THPModule_setFloat32MatmulPrecision(
651
    PyObject* _unused,
652
    PyObject* arg) {
653
  HANDLE_TH_ERRORS
654
  TORCH_CHECK(
655
      THPUtils_checkString(arg),
656
      "set_float32_matmul_precision expects a str, "
657
      "but got ",
658
      THPUtils_typename(arg));
659
  std::string s = THPUtils_unpackString(arg);
660
  at::globalContext().setFloat32MatmulPrecision(s);
661
  Py_RETURN_NONE;
662
  END_HANDLE_TH_ERRORS
663
}
664

665
PyObject* THPModule_float32MatmulPrecision(
666
    PyObject* _unused,
667
    PyObject* noargs) {
668
  std::string s = "highest";
669
  auto p = at::globalContext().float32MatmulPrecision();
670
  if (p == at::Float32MatmulPrecision::HIGH) {
671
    s = "high";
672
  } else if (p == at::Float32MatmulPrecision::MEDIUM) {
673
    s = "medium";
674
  }
675
  return THPUtils_packString(s);
676
}
677
PyObject* THPModule_setSDPUseFlash(PyObject* _unused, PyObject* arg) {
678
  HANDLE_TH_ERRORS
679
  TORCH_CHECK(
680
      PyBool_Check(arg),
681
      "set_sdp_use_math expects a bool, "
682
      "but got ",
683
      THPUtils_typename(arg));
684
  at::globalContext().setSDPUseFlash(arg == Py_True);
685
  Py_RETURN_NONE;
686
  END_HANDLE_TH_ERRORS
687
}
688
PyObject* THPModule_userEnabledFlashSDP(PyObject* _unused, PyObject* noargs) {
689
  if (at::globalContext().userEnabledFlashSDP())
690
    Py_RETURN_TRUE;
691
  else
692
    Py_RETURN_FALSE;
693
}
694
PyObject* THPModule_setSDPUseMemEfficient(PyObject* _unused, PyObject* arg) {
695
  HANDLE_TH_ERRORS
696
  TORCH_CHECK(
697
      PyBool_Check(arg),
698
      "set_sdp_use_math expects a bool, "
699
      "but got ",
700
      THPUtils_typename(arg));
701
  at::globalContext().setSDPUseMemEfficient(arg == Py_True);
702
  Py_RETURN_NONE;
703
  END_HANDLE_TH_ERRORS
704
}
705
PyObject* userEnabledMemEfficientSDP(PyObject* _unused, PyObject* noargs) {
706
  if (at::globalContext().userEnabledMemEfficientSDP())
707
    Py_RETURN_TRUE;
708
  else
709
    Py_RETURN_FALSE;
710
}
711
PyObject* THPModule_setSDPUseMath(PyObject* _unused, PyObject* arg) {
712
  HANDLE_TH_ERRORS
713
  TORCH_CHECK(
714
      PyBool_Check(arg),
715
      "set_sdp_use_math expects a bool, "
716
      "but got ",
717
      THPUtils_typename(arg));
718
  at::globalContext().setSDPUseMath(arg == Py_True);
719
  Py_RETURN_NONE;
720
  END_HANDLE_TH_ERRORS
721
}
722
PyObject* THPModule_userEnabledMathSDP(PyObject* _unused, PyObject* noargs) {
723
  if (at::globalContext().userEnabledMathSDP())
724
    Py_RETURN_TRUE;
725
  else
726
    Py_RETURN_FALSE;
727
}
728
PyObject* THPModule_setSDPUseCuDNN(PyObject* _unused, PyObject* arg) {
729
  HANDLE_TH_ERRORS
730
  TORCH_CHECK(
731
      PyBool_Check(arg),
732
      "set_sdp_use_cudnn expects a bool, "
733
      "but got %s",
734
      THPUtils_typename(arg));
735
  at::globalContext().setSDPUseCuDNN(arg == Py_True);
736
  Py_RETURN_NONE;
737
  END_HANDLE_TH_ERRORS
738
}
739
PyObject* THPModule_userEnabledCuDNNSDP(PyObject* _unused, PyObject* noargs) {
740
  if (at::globalContext().userEnabledCuDNNSDP())
741
    Py_RETURN_TRUE;
742
  else
743
    Py_RETURN_FALSE;
744
}
745

746
PyObject* THPModule_setUserEnabledCuDNN(PyObject* _unused, PyObject* arg) {
747
  HANDLE_TH_ERRORS
748
  TORCH_CHECK(
749
      PyBool_Check(arg),
750
      "set_enabled_cudnn expects a bool, "
751
      "but got ",
752
      THPUtils_typename(arg));
753
  at::globalContext().setUserEnabledCuDNN(arg == Py_True);
754
  Py_RETURN_NONE;
755
  END_HANDLE_TH_ERRORS
756
}
757

758
PyObject* THPModule_userEnabledCuDNN(PyObject* _unused, PyObject* noargs) {
759
  if (at::globalContext().userEnabledCuDNN())
760
    Py_RETURN_TRUE;
761
  else
762
    Py_RETURN_FALSE;
763
}
764

765
PyObject* THPModule_setUserEnabledMkldnn(PyObject* _unused, PyObject* arg) {
766
  HANDLE_TH_ERRORS
767
  TORCH_CHECK(
768
      PyBool_Check(arg),
769
      "set_enabled_mkldnn expects a bool, "
770
      "but got ",
771
      THPUtils_typename(arg));
772
  at::globalContext().setUserEnabledMkldnn(arg == Py_True);
773
  Py_RETURN_NONE;
774
  END_HANDLE_TH_ERRORS
775
}
776

777
PyObject* THPModule_userEnabledMkldnn(PyObject* _unused, PyObject* noargs) {
778
  if (at::globalContext().userEnabledMkldnn())
779
    Py_RETURN_TRUE;
780
  else
781
    Py_RETURN_FALSE;
782
}
783

784
PyObject* THPModule_setDeterministicCuDNN(PyObject* _unused, PyObject* arg) {
785
  HANDLE_TH_ERRORS
786
  TORCH_CHECK(
787
      PyBool_Check(arg),
788
      "set_deterministic_cudnn expects a bool, "
789
      "but got ",
790
      THPUtils_typename(arg));
791
  at::globalContext().setDeterministicCuDNN(arg == Py_True);
792
  Py_RETURN_NONE;
793
  END_HANDLE_TH_ERRORS
794
}
795

796
PyObject* THPModule_deterministicCuDNN(PyObject* _unused, PyObject* noargs) {
797
  if (at::globalContext().deterministicCuDNN())
798
    Py_RETURN_TRUE;
799
  else
800
    Py_RETURN_FALSE;
801
}
802

803
PyObject* THPModule_setDeterministicAlgorithms(
804
    PyObject* _unused,
805
    PyObject* args,
806
    PyObject* kwargs) {
807
  HANDLE_TH_ERRORS
808
  static torch::PythonArgParser parser(
809
      {"_set_deterministic_algorithms(bool mode, *, bool warn_only=False)"});
810
  torch::ParsedArgs<2> parsed_args{};
811
  auto r = parser.parse(args, kwargs, parsed_args);
812
  bool mode = r.toBool(0);
813
  bool warn_only = r.toBool(1);
814
  at::globalContext().setDeterministicAlgorithms(mode, warn_only);
815
  Py_RETURN_NONE;
816
  END_HANDLE_TH_ERRORS
817
}
818

819
PyObject* THPModule_deterministicAlgorithms(
820
    PyObject* _unused,
821
    PyObject* noargs) {
822
  if (at::globalContext().deterministicAlgorithms()) {
823
    Py_RETURN_TRUE;
824
  }
825
  Py_RETURN_FALSE;
826
}
827

828
PyObject* THPModule_deterministicAlgorithmsWarnOnly(
829
    PyObject* _unused,
830
    PyObject* noargs) {
831
  if (at::globalContext().deterministicAlgorithmsWarnOnly()) {
832
    Py_RETURN_TRUE;
833
  }
834
  Py_RETURN_FALSE;
835
}
836

837
PyObject* THPModule_setDeterministicFillUninitializedMemory(
838
    PyObject* _unused,
839
    PyObject* arg) {
840
  HANDLE_TH_ERRORS
841
  TORCH_CHECK(
842
      PyBool_Check(arg), "expected a bool, but got ", THPUtils_typename(arg));
843
  at::globalContext().setDeterministicFillUninitializedMemory(arg == Py_True);
844
  Py_RETURN_NONE;
845
  END_HANDLE_TH_ERRORS
846
}
847

848
PyObject* THPModule_deterministicFillUninitializedMemory(
849
    PyObject* _unused,
850
    PyObject* noargs) {
851
  if (at::globalContext().deterministicFillUninitializedMemory())
852
    Py_RETURN_TRUE;
853
  else
854
    Py_RETURN_FALSE;
855
}
856

857
PyObject* THPModule_setUserEnabledNNPACK(PyObject* _unused, PyObject* arg) {
858
  HANDLE_TH_ERRORS
859
  TORCH_CHECK(
860
      PyBool_Check(arg),
861
      "set_enabled_NNPACK expects a bool, "
862
      "but got ",
863
      THPUtils_typename(arg));
864
  at::globalContext().setUserEnabledNNPACK(arg == Py_True);
865
  Py_RETURN_NONE;
866
  END_HANDLE_TH_ERRORS
867
}
868

869
PyObject* THPModule_userEnabledNNPACK(PyObject* _unused, PyObject* noargs) {
870
  if (at::globalContext().userEnabledNNPACK())
871
    Py_RETURN_TRUE;
872
  else
873
    Py_RETURN_FALSE;
874
}
875

876
PyObject* THPModule_setWarnAlways(PyObject* _unused, PyObject* arg) {
877
  HANDLE_TH_ERRORS
878
  TORCH_CHECK(
879
      PyBool_Check(arg),
880
      "setWarnOnlyOnce expects a bool, "
881
      "but got ",
882
      THPUtils_typename(arg));
883
  c10::WarningUtils::set_warnAlways(arg == Py_True);
884
  Py_RETURN_NONE;
885
  END_HANDLE_TH_ERRORS
886
}
887

888
PyObject* THPModule_warnAlways(PyObject* _unused, PyObject* noargs) {
889
  if (c10::WarningUtils::get_warnAlways()) {
890
    Py_RETURN_TRUE;
891
  }
892
  Py_RETURN_FALSE;
893
}
894

895
// Used only for testing C++ to Python warning translations.
896
PyObject* THPModule_warn(PyObject* _unused, PyObject* noargs) {
897
  HANDLE_TH_ERRORS
898
  TORCH_WARN("Test message for TORCH_WARN");
899
  Py_RETURN_NONE;
900
  END_HANDLE_TH_ERRORS
901
}
902

903
// Used only for testing C++ to Python warning translations.
904
PyObject* THPModule_warnDeprecation(PyObject* _unused, PyObject* noargs) {
905
  HANDLE_TH_ERRORS
906
  TORCH_WARN_DEPRECATION("Test message for TORCH_WARN_DEPRECATION");
907
  Py_RETURN_NONE;
908
  END_HANDLE_TH_ERRORS
909
}
910

911
PyObject* THPModule_setBenchmarkCuDNN(PyObject* _unused, PyObject* arg) {
912
  HANDLE_TH_ERRORS
913
  TORCH_CHECK(
914
      PyBool_Check(arg),
915
      "set_benchmark_cudnn expects a bool, "
916
      "but got ",
917
      THPUtils_typename(arg));
918
  at::globalContext().setBenchmarkCuDNN(arg == Py_True);
919
  Py_RETURN_NONE;
920
  END_HANDLE_TH_ERRORS
921
}
922

923
PyObject* THPModule_benchmarkCuDNN(PyObject* _unused, PyObject* noargs) {
924
  if (at::globalContext().benchmarkCuDNN()) {
925
    Py_RETURN_TRUE;
926
  }
927
  Py_RETURN_FALSE;
928
}
929

930
PyObject* THPModule_setAllowTF32CuBLAS(PyObject* _unused, PyObject* arg) {
931
  HANDLE_TH_ERRORS
932
  TORCH_CHECK(
933
      PyBool_Check(arg),
934
      "set_allow_tf32_cublas expects a bool, "
935
      "but got ",
936
      THPUtils_typename(arg));
937
  at::globalContext().setAllowTF32CuBLAS(arg == Py_True);
938
  Py_RETURN_NONE;
939
  END_HANDLE_TH_ERRORS
940
}
941

942
PyObject* THPModule_allowTF32CuBLAS(PyObject* _unused, PyObject* noargs) {
943
  if (at::globalContext().allowTF32CuBLAS()) {
944
    Py_RETURN_TRUE;
945
  }
946
  Py_RETURN_FALSE;
947
}
948

949
PyObject* THPModule_setAllowFP16ReductionCuBLAS(
950
    PyObject* _unused,
951
    PyObject* arg) {
952
  HANDLE_TH_ERRORS
953
  TORCH_CHECK(
954
      PyBool_Check(arg),
955
      "set_allow_fp16_reduction_cublas expects a bool, "
956
      "but got ",
957
      THPUtils_typename(arg));
958
  at::globalContext().setAllowFP16ReductionCuBLAS(arg == Py_True);
959
  Py_RETURN_NONE;
960
  END_HANDLE_TH_ERRORS
961
}
962

963
PyObject* THPModule_allowFP16ReductionCuBLAS(
964
    PyObject* _unused,
965
    PyObject* noargs) {
966
  if (at::globalContext().allowFP16ReductionCuBLAS()) {
967
    Py_RETURN_TRUE;
968
  }
969
  Py_RETURN_FALSE;
970
}
971

972
PyObject* THPModule_setAllowBF16ReductionCuBLAS(
973
    PyObject* _unused,
974
    PyObject* arg) {
975
  HANDLE_TH_ERRORS
976
  TORCH_CHECK(
977
      PyBool_Check(arg),
978
      "set_allow_bf16_reduction_cublas expects a bool, "
979
      "but got ",
980
      THPUtils_typename(arg));
981
  at::globalContext().setAllowBF16ReductionCuBLAS(arg == Py_True);
982
  Py_RETURN_NONE;
983
  END_HANDLE_TH_ERRORS
984
}
985

986
PyObject* THPModule_allowBF16ReductionCuBLAS(
987
    PyObject* _unused,
988
    PyObject* noargs) {
989
  if (at::globalContext().allowBF16ReductionCuBLAS()) {
990
    Py_RETURN_TRUE;
991
  }
992
  Py_RETURN_FALSE;
993
}
994

995
PyObject* THPModule_setFlushDenormal(PyObject* _unused, PyObject* arg) {
996
  HANDLE_TH_ERRORS
997
  TORCH_CHECK(
998
      PyBool_Check(arg),
999
      "flush_denormal expects a bool, "
1000
      "but got ",
1001
      THPUtils_typename(arg));
1002
  if (!at::globalContext().setFlushDenormal(arg == Py_True)) {
1003
    Py_RETURN_FALSE;
1004
  };
1005
  Py_RETURN_TRUE;
1006
  END_HANDLE_TH_ERRORS
1007
}
1008

1009
PyObject* THPModule_getDefaultDtype(PyObject* _unused, PyObject* arg) {
1010
  HANDLE_TH_ERRORS
1011
  auto scalar_type = torch::tensors::get_default_scalar_type();
1012
  auto dtype = (PyObject*)torch::getTHPDtype(scalar_type);
1013
  Py_INCREF(dtype);
1014
  return dtype;
1015
  END_HANDLE_TH_ERRORS
1016
}
1017

1018
PyObject* THPModule_getDefaultDevice(PyObject* _unused, PyObject* arg) {
1019
  HANDLE_TH_ERRORS
1020
  return THPUtils_packString(c10::DeviceTypeName(
1021
      dispatchKeyToDeviceType(torch::tensors::get_default_dispatch_key()),
1022
      /*lower_case=*/true));
1023
  END_HANDLE_TH_ERRORS
1024
}
1025

1026
PyObject* THPModule_setQEngine(PyObject* /* unused */, PyObject* arg) {
1027
  HANDLE_TH_ERRORS
1028
  TORCH_CHECK(
1029
      THPUtils_checkLong(arg),
1030
      "set_qengine expects an int, "
1031
      "but got ",
1032
      THPUtils_typename(arg));
1033
  auto qengine = THPUtils_unpackLong(arg);
1034
  at::globalContext().setQEngine(static_cast<at::QEngine>(qengine));
1035
  Py_RETURN_NONE;
1036
  END_HANDLE_TH_ERRORS
1037
}
1038

1039
PyObject* THPModule_qEngine(PyObject* _unused, PyObject* noargs) {
1040
  return THPUtils_packInt64(
1041
      static_cast<int64_t>(at::globalContext().qEngine()));
1042
}
1043

1044
PyObject* THPModule_supportedQEngines(PyObject* _unused, PyObject* noargs) {
1045
  auto qengines = at::globalContext().supportedQEngines();
1046
  auto list =
1047
      THPObjectPtr(PyList_New(static_cast<Py_ssize_t>(qengines.size())));
1048
  if (!list)
1049
    return nullptr;
1050
  for (const auto i : c10::irange(qengines.size())) {
1051
    PyObject* i64 = THPUtils_packInt64(static_cast<int64_t>(qengines[i]));
1052
    if (!i64)
1053
      return nullptr;
1054
    PyList_SET_ITEM(list.get(), i, i64);
1055
  }
1056
  return list.release();
1057
}
1058

1059
PyObject* THPModule_isEnabledXNNPACK(PyObject* _unused, PyObject* noargs) {
1060
  if (at::globalContext().isXNNPACKAvailable())
1061
    Py_RETURN_TRUE;
1062
  else
1063
    Py_RETURN_FALSE;
1064
}
1065

1066
PyObject* THPModule_setCheckSparseTensorInvariants(
1067
    PyObject* _unused,
1068
    PyObject* arg) {
1069
  HANDLE_TH_ERRORS
1070
  TORCH_CHECK(
1071
      PyBool_Check(arg),
1072
      "set_check_sparse_tensor_invariants expects a bool, "
1073
      "but got ",
1074
      THPUtils_typename(arg));
1075
  at::globalContext().setCheckSparseTensorInvariants(arg == Py_True);
1076
  Py_RETURN_NONE;
1077
  END_HANDLE_TH_ERRORS
1078
}
1079

1080
PyObject* THPModule_checkSparseTensorInvariants(
1081
    PyObject* _unused,
1082
    PyObject* noargs) {
1083
  if (at::globalContext().checkSparseTensorInvariants())
1084
    Py_RETURN_TRUE;
1085
  else
1086
    Py_RETURN_FALSE;
1087
}
1088

1089
PyObject* THPModule_willEngineExecuteNode(PyObject* _unused, PyObject* arg) {
1090
  HANDLE_TH_ERRORS
1091
  bool isTHPFunction = THPFunction_Check(arg);
1092
  bool isTHPCppFunction = torch::autograd::THPCppFunction_Check(arg);
1093
  TORCH_CHECK(
1094
      isTHPFunction || isTHPCppFunction,
1095
      "_will_engine_execute_node expects an grad_fn, "
1096
      "but got ",
1097
      THPUtils_typename(arg));
1098
  const auto exec_info = torch::autograd::get_current_graph_task_exec_info();
1099
  TORCH_CHECK(
1100
      exec_info,
1101
      "_get_should_execute_nodes should only be called during the backward pass");
1102
  torch::autograd::Node* node = nullptr;
1103
  std::shared_ptr<torch::autograd::Node> node_sp;
1104
  if (isTHPFunction) {
1105
    node_sp = ((THPFunction*)arg)->cdata.lock();
1106
    node = node_sp.get();
1107
  } else {
1108
    node = ((torch::autograd::THPCppFunction*)arg)->cdata.get();
1109
  }
1110
  const auto nodes_in_graph =
1111
      torch::autograd::get_current_graph_task_nodes_in_graph();
1112
  bool ret = nodes_in_graph->find(node) != nodes_in_graph->end();
1113
  if (ret && !exec_info->empty()) {
1114
    auto it = exec_info->find(node);
1115
    if (it == exec_info->end() || !it->second.should_execute()) {
1116
      ret = false;
1117
    } else {
1118
      TORCH_CHECK(
1119
          !(node->topological_nr() == 0 && it->second.captures_),
1120
          "A leaf node was passed to _will_engine_execute_node but we are "
1121
          "currently running autograd.grad(). This is currently not supported.");
1122
    }
1123
  }
1124
  if (ret) {
1125
    Py_RETURN_TRUE;
1126
  } else {
1127
    Py_RETURN_FALSE;
1128
  }
1129
  END_HANDLE_TH_ERRORS
1130
}
1131

1132
PyObject* THPModule_getCurrentGraphTaskExecutionOrder(
1133
    PyObject* _unused,
1134
    PyObject* noargs) {
1135
  HANDLE_TH_ERRORS
1136
  std::vector<torch::autograd::Node*> nodes =
1137
      torch::autograd::get_current_graph_task_execution_order();
1138
  TORCH_CHECK(
1139
      !nodes.empty(),
1140
      "_current_graph_task_execution_order should only be called during the backward pass");
1141
  auto list = THPObjectPtr(PyList_New(static_cast<Py_ssize_t>(nodes.size())));
1142
  if (!list)
1143
    return nullptr;
1144
  for (const auto i : c10::irange(nodes.size())) {
1145
    // This node is guaranteed to be alive since the backward is still running
1146
    PyObject* pyobj_node =
1147
        torch::autograd::functionToPyObject(nodes[i]->getptr());
1148
    PyList_SET_ITEM(list.get(), i, pyobj_node);
1149
  }
1150
  return list.release();
1151
  END_HANDLE_TH_ERRORS
1152
}
1153

1154
PyObject* THPModule_getCurrentGraphTaskId(PyObject* _unused, PyObject* noargs) {
1155
  HANDLE_TH_ERRORS
1156
  return THPUtils_packInt64(torch::autograd::get_current_graph_task_id());
1157
  END_HANDLE_TH_ERRORS
1158
}
1159

1160
PyObject* THPModule_getCurrentNode(PyObject* _unused, PyObject* noargs) {
1161
  HANDLE_TH_ERRORS
1162
  return torch::autograd::functionToPyObject(
1163
      torch::autograd::get_current_node());
1164
  END_HANDLE_TH_ERRORS
1165
}
1166

1167
PyObject* THPModule_setDefaultMobileCPUAllocator(
1168
    PyObject* _unused,
1169
    PyObject* noargs) {
1170
  HANDLE_TH_ERRORS
1171
  at::globalContext().setDefaultMobileCPUAllocator();
1172
  Py_RETURN_NONE;
1173
  END_HANDLE_TH_ERRORS
1174
}
1175

1176
PyObject* THPModule_unsetDefaultMobileCPUAllocator(
1177
    PyObject* _unused,
1178
    PyObject* noargs) {
1179
  HANDLE_TH_ERRORS
1180
  at::globalContext().unsetDefaultMobileCPUAllocator();
1181
  Py_RETURN_NONE;
1182
  END_HANDLE_TH_ERRORS
1183
}
1184

1185
static PyObject* THPModule_vmapmode_increment_nesting(
1186
    PyObject* _unused,
1187
    PyObject* arg) {
1188
  HANDLE_TH_ERRORS
1189
  return THPUtils_packInt64(at::impl::VmapMode::increment_nesting());
1190
  END_HANDLE_TH_ERRORS
1191
}
1192

1193
static PyObject* THPModule_vmapmode_decrement_nesting(
1194
    PyObject* _unused,
1195
    PyObject* arg) {
1196
  HANDLE_TH_ERRORS
1197
  return THPUtils_packInt64(at::impl::VmapMode::decrement_nesting());
1198
  END_HANDLE_TH_ERRORS
1199
}
1200

1201
static PyObject* THPModule_set_display_vmap_fallback_warnings_mode(
1202
    PyObject* _unused,
1203
    PyObject* arg) {
1204
  HANDLE_TH_ERRORS
1205
  TORCH_CHECK(
1206
      PyBool_Check(arg),
1207
      "enabled must be a bool, "
1208
      "but got ",
1209
      THPUtils_typename(arg));
1210
  at::globalContext().setDisplayVmapFallbackWarnings(arg == Py_True);
1211
  Py_RETURN_NONE;
1212
  END_HANDLE_TH_ERRORS
1213
}
1214

1215
static PyObject* THPModule_are_vmap_fallback_warnings_enabled(
1216
    PyObject* _unused,
1217
    PyObject* arg) {
1218
  HANDLE_TH_ERRORS
1219
  if (at::globalContext().areVmapFallbackWarningsEnabled()) {
1220
    Py_RETURN_TRUE;
1221
  } else {
1222
    Py_RETURN_FALSE;
1223
  }
1224
  END_HANDLE_TH_ERRORS
1225
}
1226

1227
static PyMethodDef TorchMethods[] = { // NOLINT
1228
    {"_initExtension", THPModule_initExtension, METH_O, nullptr},
1229
    {"_autograd_init", THPAutograd_initExtension, METH_NOARGS, nullptr},
1230
    {"_add_docstr", THPModule_addDocStr, METH_VARARGS, nullptr},
1231
    {"_swap_tensor_impl", THPModule_swap_tensor_impl, METH_VARARGS, nullptr},
1232
    {"_init_names", THPModule_initNames, METH_O, nullptr},
1233
    {"_has_distributed", THPModule_hasDistributed, METH_NOARGS, nullptr},
1234
    {"_set_default_tensor_type",
1235
     THPModule_setDefaultTensorType,
1236
     METH_O,
1237
     nullptr},
1238
    {"_set_default_dtype", THPModule_setDefaultDtype, METH_O, nullptr},
1239
    {"_infer_size", THPModule_inferSize, METH_VARARGS, nullptr},
1240
    {"_abort", THPModule_abort, METH_NOARGS, nullptr},
1241
    {"_crash_if_csrc_asan", THPModule_crashIfCsrcASAN, METH_O, nullptr},
1242
    {"_crash_if_csrc_ubsan", THPModule_crashIfCsrcUBSAN, METH_O, nullptr},
1243
    {"_crash_if_vptr_ubsan", THPModule_crashIfvptrUBSAN, METH_NOARGS, nullptr},
1244
    {"_crash_if_aten_asan", THPModule_crashIfATenASAN, METH_O, nullptr},
1245
    {"_crash_if_debug_asserts_fail",
1246
     THPModule_crashIfDebugAssertsFail,
1247
     METH_O,
1248
     nullptr},
1249
    {"_show_config", THPModule_showConfig, METH_NOARGS, nullptr},
1250
    {"_cxx_flags", THPModule_cxxFlags, METH_NOARGS, nullptr},
1251
    {"_parallel_info", THPModule_parallelInfo, METH_NOARGS, nullptr},
1252
    {"_get_cpu_capability", THPModule_getCpuCapability, METH_NOARGS, nullptr},
1253
    {"_set_backcompat_broadcast_warn",
1254
     THPModule_setBackcompatBroadcastWarn,
1255
     METH_O,
1256
     nullptr},
1257
    {"_get_backcompat_broadcast_warn",
1258
     THPModule_getBackcompatBroadcastWarn,
1259
     METH_NOARGS,
1260
     nullptr},
1261
    {"_set_backcompat_keepdim_warn",
1262
     THPModule_setBackcompatKeepdimWarn,
1263
     METH_O,
1264
     nullptr},
1265
    {"_get_backcompat_keepdim_warn",
1266
     THPModule_getBackcompatKeepdimWarn,
1267
     METH_NOARGS,
1268
     nullptr},
1269
    {"get_num_threads", THPModule_getNumThreads, METH_NOARGS, nullptr},
1270
    {"set_num_threads", THPModule_setNumThreads, METH_O, nullptr},
1271
    {"get_num_interop_threads",
1272
     THPModule_getNumInteropThreads,
1273
     METH_NOARGS,
1274
     nullptr},
1275
    {"set_num_interop_threads",
1276
     THPModule_setNumInteropThreads,
1277
     METH_O,
1278
     nullptr},
1279
    {"_get_flash_sdp_enabled",
1280
     THPModule_userEnabledFlashSDP,
1281
     METH_NOARGS,
1282
     nullptr},
1283
    {"_set_sdp_use_flash", THPModule_setSDPUseFlash, METH_O, nullptr},
1284
    {"_get_mem_efficient_sdp_enabled",
1285
     userEnabledMemEfficientSDP,
1286
     METH_NOARGS,
1287
     nullptr},
1288
    {"_set_sdp_use_mem_efficient",
1289
     THPModule_setSDPUseMemEfficient,
1290
     METH_O,
1291
     nullptr},
1292
    {"_get_math_sdp_enabled",
1293
     THPModule_userEnabledMathSDP,
1294
     METH_NOARGS,
1295
     nullptr},
1296
    {"_set_sdp_use_math", THPModule_setSDPUseMath, METH_O, nullptr},
1297
    {"_get_cudnn_sdp_enabled",
1298
     THPModule_userEnabledCuDNNSDP,
1299
     METH_NOARGS,
1300
     nullptr},
1301
    {"_set_sdp_use_cudnn", THPModule_setSDPUseCuDNN, METH_O, nullptr},
1302
    {"_get_cudnn_enabled", THPModule_userEnabledCuDNN, METH_NOARGS, nullptr},
1303
    {"_set_cudnn_enabled", THPModule_setUserEnabledCuDNN, METH_O, nullptr},
1304
    {"_get_mkldnn_enabled", THPModule_userEnabledMkldnn, METH_NOARGS, nullptr},
1305
    {"_set_mkldnn_enabled", THPModule_setUserEnabledMkldnn, METH_O, nullptr},
1306
    {"_get_cudnn_allow_tf32", THPModule_allowTF32CuDNN, METH_NOARGS, nullptr},
1307
    {"_set_cudnn_allow_tf32", THPModule_setAllowTF32CuDNN, METH_O, nullptr},
1308
    {"_get_cudnn_benchmark", THPModule_benchmarkCuDNN, METH_NOARGS, nullptr},
1309
    {"_set_cudnn_benchmark", THPModule_setBenchmarkCuDNN, METH_O, nullptr},
1310
    {"_get_cudnn_deterministic",
1311
     THPModule_deterministicCuDNN,
1312
     METH_NOARGS,
1313
     nullptr},
1314
    {"_set_cudnn_deterministic",
1315
     THPModule_setDeterministicCuDNN,
1316
     METH_O,
1317
     nullptr},
1318
    {"_get_deterministic_algorithms",
1319
     THPModule_deterministicAlgorithms,
1320
     METH_NOARGS,
1321
     nullptr},
1322
    {"_get_deterministic_algorithms_warn_only",
1323
     THPModule_deterministicAlgorithmsWarnOnly,
1324
     METH_NOARGS,
1325
     nullptr},
1326
    {"_set_deterministic_algorithms",
1327
     castPyCFunctionWithKeywords(THPModule_setDeterministicAlgorithms),
1328
     METH_VARARGS | METH_KEYWORDS,
1329
     nullptr},
1330
    {"_get_deterministic_fill_uninitialized_memory",
1331
     THPModule_deterministicFillUninitializedMemory,
1332
     METH_NOARGS,
1333
     nullptr},
1334
    {"_set_deterministic_fill_uninitialized_memory",
1335
     THPModule_setDeterministicFillUninitializedMemory,
1336
     METH_O,
1337
     nullptr},
1338
    {"_get_nnpack_enabled", THPModule_userEnabledNNPACK, METH_NOARGS, nullptr},
1339
    {"_set_nnpack_enabled", THPModule_setUserEnabledNNPACK, METH_O, nullptr},
1340
    {"_get_warnAlways", THPModule_warnAlways, METH_NOARGS, nullptr},
1341
    {"_set_warnAlways", THPModule_setWarnAlways, METH_O, nullptr},
1342
    {"_warn", THPModule_warn, METH_NOARGS, nullptr},
1343
    {"_warn_deprecation", THPModule_warnDeprecation, METH_NOARGS, nullptr},
1344
    {"_get_cublas_allow_tf32", THPModule_allowTF32CuBLAS, METH_NOARGS, nullptr},
1345
    {"_set_cublas_allow_tf32", THPModule_setAllowTF32CuBLAS, METH_O, nullptr},
1346
    {"_get_float32_matmul_precision",
1347
     THPModule_float32MatmulPrecision,
1348
     METH_NOARGS,
1349
     nullptr},
1350
    {"_set_float32_matmul_precision",
1351
     THPModule_setFloat32MatmulPrecision,
1352
     METH_O,
1353
     nullptr},
1354
    {"_get_cublas_allow_fp16_reduced_precision_reduction",
1355
     THPModule_allowFP16ReductionCuBLAS,
1356
     METH_NOARGS,
1357
     nullptr},
1358
    {"_set_cublas_allow_fp16_reduced_precision_reduction",
1359
     THPModule_setAllowFP16ReductionCuBLAS,
1360
     METH_O,
1361
     nullptr},
1362
    {"_get_cublas_allow_bf16_reduced_precision_reduction",
1363
     THPModule_allowBF16ReductionCuBLAS,
1364
     METH_NOARGS,
1365
     nullptr},
1366
    {"_set_cublas_allow_bf16_reduced_precision_reduction",
1367
     THPModule_setAllowBF16ReductionCuBLAS,
1368
     METH_O,
1369
     nullptr},
1370
    {"_vmapmode_increment_nesting",
1371
     THPModule_vmapmode_increment_nesting,
1372
     METH_NOARGS,
1373
     nullptr},
1374
    {"_vmapmode_decrement_nesting",
1375
     THPModule_vmapmode_decrement_nesting,
1376
     METH_NOARGS,
1377
     nullptr},
1378
    {"_debug_only_display_vmap_fallback_warnings",
1379
     THPModule_set_display_vmap_fallback_warnings_mode,
1380
     METH_O,
1381
     nullptr},
1382
    {"_debug_only_are_vmap_fallback_warnings_enabled",
1383
     THPModule_are_vmap_fallback_warnings_enabled,
1384
     METH_NOARGS,
1385
     nullptr},
1386
    {"_to_dlpack", THPModule_toDLPack, METH_O, nullptr},
1387
    {"_from_dlpack", THPModule_fromDLPack, METH_O, nullptr},
1388
    {"_get_cpp_backtrace", THModule_getCppBacktrace, METH_VARARGS, nullptr},
1389
    {"_rename_privateuse1_backend",
1390
     THModule_rename_privateuse1_backend,
1391
     METH_O,
1392
     nullptr},
1393
    {"_get_privateuse1_backend_name",
1394
     THModule_get_privateuse1_backend_name,
1395
     METH_NOARGS,
1396
     nullptr},
1397
    {"set_flush_denormal", THPModule_setFlushDenormal, METH_O, nullptr},
1398
    {"get_default_dtype", THPModule_getDefaultDtype, METH_NOARGS, nullptr},
1399
    {"_get_default_device", THPModule_getDefaultDevice, METH_NOARGS, nullptr},
1400
    {"_get_qengine", THPModule_qEngine, METH_NOARGS, nullptr},
1401
    {"_set_qengine", THPModule_setQEngine, METH_O, nullptr},
1402
    {"_supported_qengines", THPModule_supportedQEngines, METH_NOARGS, nullptr},
1403
    {"_is_xnnpack_enabled", THPModule_isEnabledXNNPACK, METH_NOARGS, nullptr},
1404
    {"_set_check_sparse_tensor_invariants",
1405
     THPModule_setCheckSparseTensorInvariants,
1406
     METH_O,
1407
     nullptr},
1408
    {"_check_sparse_tensor_invariants",
1409
     THPModule_checkSparseTensorInvariants,
1410
     METH_NOARGS,
1411
     nullptr},
1412
    {"_will_engine_execute_node",
1413
     THPModule_willEngineExecuteNode,
1414
     METH_O,
1415
     nullptr},
1416
    {"_current_graph_task_execution_order",
1417
     THPModule_getCurrentGraphTaskExecutionOrder,
1418
     METH_NOARGS,
1419
     nullptr},
1420
    {"_current_graph_task_id",
1421
     THPModule_getCurrentGraphTaskId,
1422
     METH_NOARGS,
1423
     nullptr},
1424
    {"_current_autograd_node", THPModule_getCurrentNode, METH_NOARGS, nullptr},
1425
    {"_set_default_mobile_cpu_allocator",
1426
     THPModule_setDefaultMobileCPUAllocator,
1427
     METH_NOARGS,
1428
     nullptr},
1429
    {"_unset_default_mobile_cpu_allocator",
1430
     THPModule_unsetDefaultMobileCPUAllocator,
1431
     METH_NOARGS,
1432
     nullptr},
1433
    {"_is_torch_function_enabled",
1434
     THPModule_isEnabledTorchFunction,
1435
     METH_NOARGS,
1436
     nullptr},
1437
    {"_disabled_torch_function_impl",
1438
     THPModule_disable_torch_function,
1439
     METH_VARARGS,
1440
     nullptr},
1441
    {"_disabled_torch_dispatch_impl",
1442
     THPModule_disable_torch_dispatch,
1443
     METH_VARARGS,
1444
     nullptr},
1445
    {"_has_torch_function", THPModule_has_torch_function, METH_O, nullptr},
1446
    {"_has_torch_function_unary",
1447
     THPModule_has_torch_function_unary,
1448
     METH_O,
1449
     nullptr},
1450
    {"_has_torch_function_variadic",
1451
     (PyCFunction)(void (*)())THPModule_has_torch_function_variadic,
1452
     METH_FASTCALL,
1453
     nullptr},
1454
    {nullptr, nullptr, 0, nullptr}};
1455

1456
void THCPStream_init(PyObject* module);
1457
void THCPEvent_init(PyObject* module);
1458
void THCPGraph_init(PyObject* module);
1459

1460
#ifdef USE_CUDA
1461
PyMethodDef* THCPModule_methods();
1462
namespace torch::cuda {
1463
void initModule(PyObject* module);
1464
} // namespace torch::cuda
1465
#endif
1466

1467
#ifdef USE_XPU
1468
PyMethodDef* THXPModule_methods();
1469
void THXPStream_init(PyObject* module);
1470
void THXPEvent_init(PyObject* module);
1471
namespace torch::xpu {
1472
void initModule(PyObject* module);
1473
} // namespace torch::xpu
1474
#endif
1475

1476
#ifdef USE_ITT
1477
namespace torch::profiler {
1478
void initIttBindings(PyObject* module);
1479
} // namespace torch::profiler
1480
#endif
1481

1482
static std::vector<PyMethodDef> methods;
1483

1484
// In Python we can't use the trick of C10_LOG_API_USAGE_ONCE
1485
// Guaranteed to be invoked from Python under GIL, no locking on map needed
1486
static void LogAPIUsageOnceFromPython(const std::string& event) {
1487
  static std::unordered_set<std::string> seen;
1488
  if (!seen.count(event)) {
1489
    seen.insert(event);
1490
    c10::LogAPIUsage(event);
1491
  }
1492
}
1493

1494
static void LogAPIUsageMetadataFromPython(
1495
    const std::string& event,
1496
    const std::map<std::string, std::string>& metadata_map) {
1497
  c10::LogAPIUsageMetadata(event, metadata_map);
1498
}
1499

1500
// Weak reference to tensor, used to test a tensor isn't leaked
1501
class WeakTensorRef {
1502
  c10::weak_intrusive_ptr<c10::TensorImpl> weakref_;
1503

1504
 public:
1505
  WeakTensorRef(const at::Tensor& t) : weakref_(t.getIntrusivePtr()) {}
1506

1507
  bool expired() {
1508
    return weakref_.expired();
1509
  }
1510
};
1511

1512
extern "C" C10_EXPORT PyObject* initModule();
1513
// separate decl and defn for msvc error C2491
1514
PyObject* initModule() {
1515
  HANDLE_TH_ERRORS
1516

1517
  c10::initLogging();
1518
  c10::set_terminate_handler();
1519
  at::internal::lazy_init_num_threads();
1520

1521
  C10_LOG_API_USAGE_ONCE("torch.python.import");
1522

1523
#define ASSERT_TRUE(cmd) \
1524
  if (!(cmd))            \
1525
  return nullptr
1526

1527
  THPUtils_addPyMethodDefs(methods, TorchMethods);
1528
  THPUtils_addPyMethodDefs(methods, DataLoaderMethods);
1529
  THPUtils_addPyMethodDefs(methods, torch::autograd::python_functions());
1530
  THPUtils_addPyMethodDefs(methods, torch::multiprocessing::python_functions());
1531
  THPUtils_addPyMethodDefs(methods, torch::mps::python_functions());
1532
#ifdef USE_CUDA
1533
  THPUtils_addPyMethodDefs(methods, THCPModule_methods());
1534
#endif
1535
#ifdef USE_XPU
1536
  THPUtils_addPyMethodDefs(methods, THXPModule_methods());
1537
#endif
1538
#if defined(USE_DISTRIBUTED) && defined(USE_C10D)
1539
  THPUtils_addPyMethodDefs(
1540
      methods, torch::distributed::c10d::python_functions());
1541
#ifndef _WIN32
1542
  THPUtils_addPyMethodDefs(
1543
      methods, torch::distributed::rpc::python_functions());
1544
  THPUtils_addPyMethodDefs(
1545
      methods, torch::distributed::autograd::python_functions());
1546
  THPUtils_addPyMethodDefs(
1547
      methods, torch::distributed::rpc::testing::python_functions());
1548
#endif
1549
#endif
1550

1551
  static struct PyModuleDef torchmodule = {
1552
      PyModuleDef_HEAD_INIT, "torch._C", nullptr, -1, methods.data()};
1553
  module = PyModule_Create(&torchmodule);
1554
  ASSERT_TRUE(module);
1555
  ASSERT_TRUE(THPGenerator_init(module));
1556
  ASSERT_TRUE(THPException_init(module));
1557
  THPSize_init(module);
1558
  THPDtype_init(module);
1559
  THPDTypeInfo_init(module);
1560
  THPLayout_init(module);
1561
  THPMemoryFormat_init(module);
1562
  THPQScheme_init(module);
1563
  THPDevice_init(module);
1564
  THPStream_init(module);
1565
  ASSERT_TRUE(THPVariable_initModule(module));
1566
  ASSERT_TRUE(THPFunction_initModule(module));
1567
  ASSERT_TRUE(THPEngine_initModule(module));
1568
  // NOTE: We need to be able to access OperatorExportTypes from ONNX for use in
1569
  // the export side of JIT, so this ONNX init needs to appear before the JIT
1570
  // init.
1571
  torch::onnx::initONNXBindings(module);
1572
  torch::autograd::initEnumTag(module);
1573
  torch::jit::initJITBindings(module);
1574
  torch::monitor::initMonitorBindings(module);
1575
  torch::impl::dispatch::initDispatchBindings(module);
1576
  torch::dynamo::initDynamoBindings(module);
1577
  torch::functorch::impl::initFuncTorchBindings(module);
1578
  torch::throughput_benchmark::initThroughputBenchmarkBindings(module);
1579
  torch::autograd::initReturnTypes(module);
1580
  torch::autograd::initNNFunctions(module);
1581
  torch::autograd::initFFTFunctions(module);
1582
  torch::autograd::initLinalgFunctions(module);
1583
  torch::autograd::initNestedFunctions(module);
1584
  torch::autograd::initSparseFunctions(module);
1585
  torch::autograd::initSpecialFunctions(module);
1586
  torch::autograd::init_legacy_variable(module);
1587
  torch::profiler::initPythonBindings(module);
1588
  torch::python::init_bindings(module);
1589
  torch::lazy::initLazyBindings(module);
1590
  torch::inductor::initAOTIRunnerBindings(module);
1591
#ifdef USE_ITT
1592
  torch::profiler::initIttBindings(module);
1593
#endif
1594
#ifdef USE_CUDA
1595
  torch::cuda::initModule(module);
1596
#endif
1597
#ifdef USE_XPU
1598
  torch::xpu::initModule(module);
1599
#endif
1600
  torch::cpu::initModule(module);
1601
  torch::initVerboseBindings(module);
1602
  ASSERT_TRUE(THPStorage_init(module));
1603

1604
#ifdef USE_CUDA
1605
  // This will only initialise base classes and attach them to library namespace
1606
  // They won't be ready for real usage until importing cuda module, that will
1607
  // complete the process (but it defines Python classes before calling back
1608
  // into C, so these lines have to execute first)..
1609
  THCPStream_init(module);
1610
  THCPEvent_init(module);
1611
  THCPGraph_init(module);
1612
#endif
1613

1614
#ifdef USE_XPU
1615
  THXPStream_init(module);
1616
  THXPEvent_init(module);
1617
#endif
1618

1619
  auto set_module_attr =
1620
      [&](const char* name, PyObject* v, bool incref = true) {
1621
        // PyModule_AddObject steals reference
1622
        if (incref) {
1623
          Py_INCREF(v);
1624
        }
1625

1626
        int ret = PyModule_AddObject(module, name, v);
1627
        if (ret != 0) {
1628
          Py_DECREF(v);
1629
        }
1630

1631
        return ret == 0;
1632
      };
1633

1634
#if defined(USE_CUDNN) || defined(USE_ROCM)
1635
  PyObject* has_cudnn = Py_True;
1636
#else
1637
  PyObject* has_cudnn = Py_False;
1638
#endif
1639
  ASSERT_TRUE(set_module_attr("_has_cudnn", has_cudnn));
1640

1641
#if AT_MKL_ENABLED() || AT_POCKETFFT_ENABLED()
1642
  PyObject* has_spectral = Py_True;
1643
#else
1644
  PyObject* has_spectral = Py_False;
1645
#endif
1646
  ASSERT_TRUE(set_module_attr("has_spectral", has_spectral));
1647

1648
  // force ATen to initialize because it handles
1649
  // setting up TH Errors so that they throw C++ exceptions
1650
  at::init();
1651

1652
  // Automatically translate errors thrown from pybind11 functions
1653
  py::register_exception_translator([](std::exception_ptr e) { // NOLINT
1654
    try {
1655
      if (e) {
1656
        std::rethrow_exception(e);
1657
      }
1658
    }
1659
    CATCH_TH_ERRORS()
1660
  });
1661

1662
  auto py_module = py::reinterpret_borrow<py::module>(module);
1663
  py_module.def("_demangle", &c10::demangle);
1664
  py_module.def("_log_api_usage_once", &LogAPIUsageOnceFromPython);
1665
  py_module.def("_log_api_usage_metadata", &LogAPIUsageMetadataFromPython);
1666

1667
  py_module.def("vitals_enabled", &at::vitals::torchVitalEnabled);
1668
  py_module.def(
1669
      "set_vital",
1670
      [](const std::string& vital,
1671
         const std::string& attr,
1672
         const std::string& value) {
1673
        return at::vitals::VitalsAPI.setVital(vital, attr, value);
1674
      });
1675
  py_module.def(
1676
      "read_vitals", []() { return at::vitals::VitalsAPI.readVitals(); });
1677

1678
  py_module.def(
1679
      "init_num_threads",
1680
      torch::wrap_pybind_function(at::init_num_threads),
1681
      R"(
1682
init_num_threads()
1683

1684
Initializes the number of parallel threads used on the current thread.
1685

1686
Call this whenever a new thread is created in order to propagate values from
1687
:func:`torch.set_num_threads` onto the new thread.
1688
)");
1689

1690
  ASSERT_TRUE(
1691
      set_module_attr("has_openmp", at::hasOpenMP() ? Py_True : Py_False));
1692
  ASSERT_TRUE(set_module_attr("has_mkl", at::hasMKL() ? Py_True : Py_False));
1693
  ASSERT_TRUE(
1694
      set_module_attr("has_lapack", at::hasLAPACK() ? Py_True : Py_False));
1695

1696
  py_module.def("_valgrind_supported_platform", []() {
1697
#if defined(USE_VALGRIND)
1698
    return true;
1699
#else
1700
      return false;
1701
#endif
1702
  });
1703

1704
  py_module.def("_valgrind_toggle", []() {
1705
#if defined(USE_VALGRIND)
1706
    CALLGRIND_TOGGLE_COLLECT;
1707
#else
1708
      TORCH_CHECK(false, "Valgrind is not supported.");
1709
#endif
1710
  });
1711

1712
  py_module.def("_valgrind_toggle_and_dump_stats", []() {
1713
#if defined(USE_VALGRIND)
1714
    // NB: If we don't toggle collect around dump stats, callgrind_annotate
1715
    //     won't process the results correctly. Specifically,
1716
    //     `callgrind_annotate --inclusive=no` will be almost completely empty.
1717
    CALLGRIND_TOGGLE_COLLECT;
1718
    CALLGRIND_DUMP_STATS;
1719
#else
1720
      TORCH_CHECK(false, "Valgrind is not supported.");
1721
#endif
1722
  });
1723

1724
  py::class_<WeakTensorRef>(py_module, "_WeakTensorRef")
1725
      .def(py::init([](py::object tensor) {
1726
        return WeakTensorRef(THPVariable_Unpack(tensor.ptr()));
1727
      }))
1728
      .def("expired", &WeakTensorRef::expired);
1729

1730
  py::enum_<at::native::ConvBackend>(py_module, "_ConvBackend")
1731
      .value("CudaDepthwise2d", at::native::ConvBackend::CudaDepthwise2d)
1732
      .value("CudaDepthwise3d", at::native::ConvBackend::CudaDepthwise3d)
1733
      .value("Cudnn", at::native::ConvBackend::Cudnn)
1734
      .value("CudnnTranspose", at::native::ConvBackend::CudnnTranspose)
1735
      .value("Empty", at::native::ConvBackend::Empty)
1736
      .value("Miopen", at::native::ConvBackend::Miopen)
1737
      .value("MiopenDepthwise", at::native::ConvBackend::MiopenDepthwise)
1738
      .value("MiopenTranspose", at::native::ConvBackend::MiopenTranspose)
1739
      .value("Mkldnn", at::native::ConvBackend::Mkldnn)
1740
      .value("MkldnnEmpty", at::native::ConvBackend::MkldnnEmpty)
1741
      .value("NnpackSpatial", at::native::ConvBackend::NnpackSpatial)
1742
      .value("Overrideable", at::native::ConvBackend::Overrideable)
1743
      .value("Slow2d", at::native::ConvBackend::Slow2d)
1744
      .value("Slow3d", at::native::ConvBackend::Slow3d)
1745
      .value("SlowDilated2d", at::native::ConvBackend::SlowDilated2d)
1746
      .value("SlowDilated3d", at::native::ConvBackend::SlowDilated3d)
1747
      .value("SlowTranspose2d", at::native::ConvBackend::SlowTranspose2d)
1748
      .value("SlowTranspose3d", at::native::ConvBackend::SlowTranspose3d)
1749
      .value(
1750
          "Winograd3x3Depthwise", at::native::ConvBackend::Winograd3x3Depthwise)
1751
      .value("Xnnpack2d", at::native::ConvBackend::Xnnpack2d)
1752
      .value("Mps", at::native::ConvBackend::Mps)
1753
      .value("MpsTranspose,", at::native::ConvBackend::MpsTranspose);
1754

1755
  py_module.def(
1756
      "_select_conv_backend",
1757
      [](const at::Tensor& input,
1758
         const at::Tensor& weight,
1759
         const c10::optional<at::Tensor>& bias_opt,
1760
         at::SymIntArrayRef stride_,
1761
         at::SymIntArrayRef padding_,
1762
         at::SymIntArrayRef dilation_,
1763
         bool transposed_,
1764
         at::SymIntArrayRef output_padding_,
1765
         c10::SymInt groups_) {
1766
        return at::native::select_conv_backend(
1767
            input,
1768
            weight,
1769
            bias_opt,
1770
            stride_,
1771
            padding_,
1772
            dilation_,
1773
            transposed_,
1774
            output_padding_,
1775
            std::move(groups_),
1776
            c10::nullopt);
1777
      },
1778
      py::arg("input"),
1779
      py::arg("weight"),
1780
      py::arg("bias"),
1781
      py::arg("stride"),
1782
      py::arg("padding"),
1783
      py::arg("dilation"),
1784
      py::arg("transposed"),
1785
      py::arg("output_padding"),
1786
      py::arg("groups"));
1787

1788
  // overload for bias_sizes_opt/backward TODO: figure out default value
1789
  py_module.def(
1790
      "_select_conv_backend",
1791
      [](const at::Tensor& input,
1792
         const at::Tensor& weight,
1793
         const c10::optional<at::Tensor>& bias,
1794
         at::SymIntArrayRef stride_,
1795
         at::SymIntArrayRef padding_,
1796
         at::SymIntArrayRef dilation_,
1797
         bool transposed_,
1798
         at::SymIntArrayRef output_padding_,
1799
         c10::SymInt groups_,
1800
         c10::optional<std::vector<c10::SymInt>> bias_sizes_opt) {
1801
        c10::OptionalArrayRef<c10::SymInt> ref = c10::nullopt;
1802
        if (bias_sizes_opt) {
1803
          ref = (*bias_sizes_opt);
1804
        }
1805
        return at::native::select_conv_backend(
1806
            input,
1807
            weight,
1808
            bias,
1809
            stride_,
1810
            padding_,
1811
            dilation_,
1812
            transposed_,
1813
            output_padding_,
1814
            std::move(groups_),
1815
            ref);
1816
      },
1817
      py::arg("input"),
1818
      py::arg("weight"),
1819
      py::arg("bias"),
1820
      py::arg("stride"),
1821
      py::arg("padding"),
1822
      py::arg("dilation"),
1823
      py::arg("transposed"),
1824
      py::arg("output_padding"),
1825
      py::arg("groups"),
1826
      py::arg("bias_sizes"));
1827

1828
  py_module.def(
1829
      "_conv_determine_backend_memory_format",
1830
      at::native::_determine_backend_memory_format);
1831

1832
  ////////////////////////////////////////////////////////////////////////////////
1833
  // Scaled Dot Product Attention utilities
1834
  ////////////////////////////////////////////////////////////////////////////////
1835
  py::class_<sdp::sdp_params>(py_module, "_SDPAParams")
1836
      .def(py::init([](at::Tensor const& query,
1837
                       at::Tensor const& key,
1838
                       at::Tensor const& value,
1839
                       c10::optional<at::Tensor> attn_mask,
1840
                       double dropout,
1841
                       bool is_causal) {
1842
        return sdp::sdp_params{
1843
            query, key, value, std::move(attn_mask), dropout, is_causal};
1844
      }))
1845
      .def_readonly("query", &sdp::sdp_params::query)
1846
      .def_readonly("key", &sdp::sdp_params::key)
1847
      .def_readonly("value", &sdp::sdp_params::value)
1848
      .def_readonly("attn_mask", &sdp::sdp_params::attn_mask)
1849
      .def_readonly("dropout", &sdp::sdp_params::dropout)
1850
      .def_readonly("is_causal", &sdp::sdp_params::is_causal);
1851

1852
  py::enum_<sdp::SDPBackend>(
1853
      py_module,
1854
      "_SDPBackend",
1855
      "An enum-like class that contains the different backends for scaled dot product attention.\n\n... warning:: This class is in beta and subject to change.\n\n"
1856
      "This backend class is designed to be used with the sdpa_kernel context manager."
1857
      "See :func: torch.nn.attention.sdpa_kernel for more details.")
1858
      .value("ERROR", sdp::SDPBackend::error)
1859
      .value("MATH", sdp::SDPBackend::math)
1860
      .value("FLASH_ATTENTION", sdp::SDPBackend::flash_attention)
1861
      .value("EFFICIENT_ATTENTION", sdp::SDPBackend::efficient_attention)
1862
      .value("CUDNN_ATTENTION", sdp::SDPBackend::cudnn_attention);
1863

1864
  py_module.def(
1865
      "_can_use_flash_attention",
1866
      [](const sdp::sdp_params& params, bool debug) {
1867
#ifdef USE_CUDA
1868
        return sdp::can_use_flash_attention(params, debug);
1869
#else
1870
        return false;
1871
#endif
1872
      });
1873
  py_module.def(
1874
      "_can_use_mem_efficient_attention",
1875
      [](const sdp::sdp_params& params, bool debug) {
1876
#ifdef USE_CUDA
1877
        return sdp::can_use_mem_efficient_attention(params, debug);
1878
#else
1879
        return false;
1880
#endif
1881
      });
1882

1883
  py::enum_<at::LinalgBackend>(py_module, "_LinalgBackend")
1884
      .value("Default", at::LinalgBackend::Default)
1885
      .value("Cusolver", at::LinalgBackend::Cusolver)
1886
      .value("Magma", at::LinalgBackend::Magma);
1887

1888
  py_module.def("_set_linalg_preferred_backend", [](at::LinalgBackend b) {
1889
    at::globalContext().setLinalgPreferredBackend(b);
1890
  });
1891
  py_module.def("_get_linalg_preferred_backend", []() {
1892
    return at::globalContext().linalgPreferredBackend();
1893
  });
1894

1895
  py_module.def(
1896
      "_construct_storage_from_data_pointer",
1897
      [](int64_t data_ptr, c10::Device device, size_t size_bytes) {
1898
        return c10::Storage(
1899
            c10::Storage::use_byte_size_t(),
1900
            size_bytes,
1901
            // NOLINTNEXTLINE(performance-no-int-to-ptr)
1902
            at::DataPtr(reinterpret_cast<void*>(data_ptr), device));
1903
      });
1904

1905
  py_module.def(
1906
      "_stash_obj_in_tls", [](const std::string& key, py::handle arg) {
1907
        at::impl::ThreadLocalPythonObjects::get_state().set(
1908
            key,
1909
            std::make_shared<c10::SafePyObject>(arg.ptr(), getPyInterpreter()));
1910
      });
1911

1912
  py_module.def("_get_obj_in_tls", [](const std::string& key) -> py::handle {
1913
    auto safe_pyobject =
1914
        at::impl::ThreadLocalPythonObjects::get_state().get(key);
1915
    auto obj = safe_pyobject->ptr(getPyInterpreter());
1916
    return py::handle(obj);
1917
  });
1918

1919
  py_module.def("_is_key_in_tls", [](const std::string& key) -> bool {
1920
    return at::impl::ThreadLocalPythonObjects::get_state().contains(key);
1921
  });
1922

1923
#ifdef USE_CUDA
1924
  PyObject* has_cuda = Py_True;
1925
#else
1926
  PyObject* has_cuda = Py_False;
1927
#endif
1928

1929
#ifdef USE_MPS
1930
  PyObject* has_mps = Py_True;
1931
#else
1932
  PyObject* has_mps = Py_False;
1933
#endif
1934

1935
#ifdef USE_XPU
1936
  PyObject* has_xpu = Py_True;
1937
#else
1938
  PyObject* has_xpu = Py_False;
1939
#endif
1940

1941
  ASSERT_TRUE(set_module_attr("_has_cuda", has_cuda));
1942
  ASSERT_TRUE(
1943
      set_module_attr("_has_magma", at::hasMAGMA() ? Py_True : Py_False));
1944
  ASSERT_TRUE(set_module_attr("_has_mps", has_mps));
1945
  ASSERT_TRUE(set_module_attr("_has_xpu", has_xpu));
1946
  ASSERT_TRUE(
1947
      set_module_attr("_has_mkldnn", at::hasMKLDNN() ? Py_True : Py_False));
1948

1949
#ifdef _GLIBCXX_USE_CXX11_ABI
1950
  ASSERT_TRUE(set_module_attr(
1951
      "_GLIBCXX_USE_CXX11_ABI", _GLIBCXX_USE_CXX11_ABI ? Py_True : Py_False));
1952
#else
1953
  ASSERT_TRUE(set_module_attr("_GLIBCXX_USE_CXX11_ABI", Py_False));
1954
#endif
1955

1956
// See note [Pybind11 ABI constants]
1957
#define SET_STR_DEFINE(name) \
1958
  ASSERT_TRUE(set_module_attr("_" #name, THPUtils_packString(name)))
1959

1960
#ifdef PYBIND11_COMPILER_TYPE
1961
  SET_STR_DEFINE(PYBIND11_COMPILER_TYPE);
1962
#else
1963
  ASSERT_TRUE(
1964
      set_module_attr("_" C10_STRINGIZE(PYBIND11_COMPILER_TYPE), Py_None));
1965
#endif
1966

1967
#ifdef PYBIND11_STDLIB
1968
  SET_STR_DEFINE(PYBIND11_STDLIB);
1969
#else
1970
  ASSERT_TRUE(set_module_attr("_" C10_STRINGIZE(PYBIND11_STDLIB), Py_None));
1971
#endif
1972

1973
#ifdef PYBIND11_BUILD_ABI
1974
  SET_STR_DEFINE(PYBIND11_BUILD_ABI);
1975
#else
1976
  ASSERT_TRUE(set_module_attr("_" C10_STRINGIZE(PYBIND11_BUILD_ABI), Py_None));
1977
#endif
1978
#undef SET_STR_DEFINE
1979

1980
  py_module.def(
1981
      "_set_conj", [](const at::Tensor& x, bool conj) { x._set_conj(conj); });
1982
  py_module.def(
1983
      "_set_neg", [](const at::Tensor& x, bool neg) { x._set_neg(neg); });
1984
  py_module.def("_get_tensor_metadata", &torch::jit::getTensorMetadata);
1985
  py_module.def(
1986
      "_set_tensor_metadata",
1987
      static_cast<void (*)(
1988
          const at::Tensor&, std::unordered_map<std::string, bool>)>(
1989
          torch::jit::setTensorMetadata));
1990
  py_module.def("_dispatch_key_set", [](const at::Tensor& x) {
1991
    return toString(x.key_set());
1992
  });
1993
  py_module.def(
1994
      "_has_storage", [](const at::Tensor& x) { return x.has_storage(); });
1995

1996
  py_module.def("_set_meta_in_tls_dispatch_include", [](bool meta_in_tls) {
1997
    auto local_keyset = c10::impl::tls_local_dispatch_key_set();
1998
    c10::DispatchKeySet key_set({at::DispatchKey::Meta});
1999
    if (meta_in_tls) {
2000
      local_keyset.included_ = local_keyset.included_ | key_set;
2001
    } else {
2002
      local_keyset.included_ =
2003
          local_keyset.included_.remove_backend(c10::BackendComponent::MetaBit);
2004
    }
2005
    c10::impl::_force_tls_local_dispatch_key_set(local_keyset);
2006
  });
2007

2008
  py_module.def("_meta_in_tls_dispatch_include", []() {
2009
    auto local_keyset = c10::impl::tls_local_dispatch_key_set();
2010
    return local_keyset.included_.has_backend(c10::BackendComponent::MetaBit);
2011
  });
2012

2013
  py_module.def("_dump_local_tls_set", []() {
2014
    auto local_keyset = c10::impl::tls_local_dispatch_key_set();
2015
    std::cout << "Included: " << toString(local_keyset.included_) << "\n";
2016
    std::cout << "Excluded: " << toString(local_keyset.excluded_) << "\n";
2017
  });
2018

2019
  py_module.def(
2020
      "_should_allow_numbers_as_tensors", [](const std::string& name) {
2021
        return torch::should_allow_numbers_as_tensors(name);
2022
      });
2023

2024
  // FIXME(crcrpar): Better to have `at::ScalarType` get mapped to `torch.dtype`
2025
  // Currently I see the second item of the key is displayed as
2026
  // e.g. `torch._C._te.ScalarType at 0x7fcf318adab0`
2027
  // I thought adding an appropriate type_caster of `at::ScalarType` to
2028
  // torch/csrc/pybind.h` would solve this but it caused segmentation fault in
2029
  // my environment.
2030
  using _DeviceDtypeKey = std::pair<at::Device, std::string>;
2031
  // Custom hasher is necessary to make unordered_map compilable for Windows
2032
  // debug targets. As `at::native::ParamsHash` only works on structs with
2033
  // standard layout, but std::string isn't one in Visual C++ debug builds,
2034
  // which one can easily verify by running something like:
2035
  //   #define _DEBUG
2036
  //   #include <type_traits>
2037
  //   #include <string>
2038
  //   static_assert(std::is_standard_layout_v<std::string>, "Oh noes");
2039
  // If above condition is not met, VC++ raises a very cryptic compilation
2040
  // error. See
2041
  // https://github.com/pytorch/pytorch/pull/100007#discussion_r1227116292 for
2042
  // more detail
2043
  struct _DeviceDtypeHasher {
2044
    std::size_t operator()(const _DeviceDtypeKey& k) const noexcept {
2045
      static at::native::ParamsHash<at::Device> device_hasher;
2046
      static std::hash<std::string> string_hasher;
2047
      return device_hasher(k.first) ^ string_hasher(k.second);
2048
    }
2049
  };
2050
  using _FlatMap = std::unordered_map<
2051
      _DeviceDtypeKey,
2052
      at::native::TensorsAndIndicesT,
2053
      _DeviceDtypeHasher>;
2054
  py_module.def(
2055
      "_group_tensors_by_device_and_dtype",
2056
      [](const std::vector<std::vector<c10::optional<at::Tensor>>>&
2057
             nested_tensorlist,
2058
         const bool with_indices) {
2059
        _FlatMap map;
2060
        for (const auto& iter :
2061
             at::native::_group_tensors_by_first_tensors_device_and_dtype(
2062
                 nested_tensorlist, with_indices)) {
2063
          const auto scalar_type_name =
2064
              torch::utils::getDtypeNames(iter.first.second).first;
2065
          map.insert({{iter.first.first, scalar_type_name}, iter.second});
2066
        }
2067
        return map;
2068
      });
2069

2070
  py_module.def(
2071
      "_storage_address",
2072
      [](const at::Tensor& tensor) {
2073
        return reinterpret_cast<std::intptr_t>(
2074
            tensor.storage().unsafeGetStorageImpl());
2075
      },
2076
      "Gets the memory address of the Tensor's StorageImpl.");
2077

2078
  py_module.def(
2079
      "_data_address",
2080
      [](const at::Tensor& tensor) {
2081
        return reinterpret_cast<std::intptr_t>(tensor.storage().data());
2082
      },
2083
      "Gets the memory address of the Tensor's data pointer.");
2084

2085
  py_module.def(
2086
      "_is_cow_tensor",
2087
      [](const at::Tensor& tensor) {
2088
        return c10::impl::cow::is_cow_data_ptr(tensor.storage().data_ptr());
2089
      },
2090
      "Checks if a tensor's data pointer is COW");
2091

2092
  const auto& defaultGenerator = at::detail::getDefaultCPUGenerator();
2093
  THPDefaultCPUGenerator =
2094
      (THPGenerator*)THPGenerator_initDefaultGenerator(defaultGenerator);
2095
  // This reference is meant to be given away, so no need to incref here.
2096
  ASSERT_TRUE(set_module_attr(
2097
      "default_generator",
2098
      (PyObject*)THPDefaultCPUGenerator,
2099
      /* incref= */ false));
2100
  ASSERT_TRUE(set_module_attr(
2101
      "DisableTorchFunctionSubclass",
2102
      (PyObject*)THPModule_DisableTorchFunctionSubclassType(),
2103
      /* incref= */ false));
2104
  ASSERT_TRUE(set_module_attr(
2105
      "DisableTorchFunction",
2106
      (PyObject*)THPModule_DisableTorchFunctionType(),
2107
      /* incref= */ false));
2108
  torch::set_disabled_torch_function_impl(
2109
      PyObject_GetAttrString(module, "_disabled_torch_function_impl"));
2110
  ASSERT_TRUE(torch::disabled_torch_function_impl() != nullptr);
2111
  torch::set_disabled_torch_dispatch_impl(
2112
      PyObject_GetAttrString(module, "_disabled_torch_dispatch_impl"));
2113
  ASSERT_TRUE(torch::disabled_torch_dispatch_impl() != nullptr);
2114
  return module;
2115
  END_HANDLE_TH_ERRORS
2116
}
2117

2118
// Checks that the _C shared library isn't initialized multiple times. This
2119
// can happen if the same csrc files are compiled into multiple shared
2120
// libraries.
2121
inline void pytorch_duplicate_guard() {
2122
  static int initialized = 0;
2123
  if (initialized) {
2124
    fmt::print(stderr, "pytorch: _C shared library re-initialized\n");
2125
    abort();
2126
  }
2127
  initialized = 1;
2128
  ;
2129
}
2130

2131
struct call_duplicate_guard {
2132
  call_duplicate_guard() {
2133
    pytorch_duplicate_guard();
2134
  }
2135
};
2136

2137
static call_duplicate_guard _call_duplicate_guard;
2138

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.