pytorch
1#include <torch/csrc/utils/tensor_qschemes.h>
2
3#include <c10/core/QScheme.h>
4#include <c10/util/irange.h>
5#include <torch/csrc/DynamicTypes.h>
6#include <torch/csrc/Exceptions.h>
7#include <torch/csrc/QScheme.h>
8
9#include <torch/csrc/python_headers.h>
10#include <torch/csrc/utils/object_ptr.h>
11
12namespace torch {
13namespace utils {
14
15// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
16static PyObject* thp_qscheme_array[at::COMPILE_TIME_NUM_QSCHEMES];
17
18void initializeQSchemes() {
19auto torch_module = THPObjectPtr(PyImport_ImportModule("torch"));
20if (!torch_module) {
21throw python_error();
22}
23
24for (const auto i : c10::irange(at::COMPILE_TIME_NUM_QSCHEMES)) {
25auto qscheme = static_cast<at::QScheme>(i);
26PyObject* qscheme_obj = THPQScheme_New(qscheme, toString(qscheme));
27thp_qscheme_array[static_cast<int>(qscheme)] = qscheme_obj;
28Py_INCREF(qscheme_obj);
29if (PyModule_AddObject(
30torch_module, toString(qscheme).c_str(), qscheme_obj) != 0) {
31throw python_error();
32}
33}
34}
35
36PyObject* getTHPQScheme(at::QScheme qscheme) {
37auto qscheme_ = thp_qscheme_array[static_cast<int>(qscheme)];
38if (!qscheme_) {
39throw std::invalid_argument("unsupported QScheme");
40}
41return qscheme_;
42}
43
44} // namespace utils
45} // namespace torch
46