pytorch

Форк
0
/
tensor_qschemes.cpp 
45 строк · 1.3 Кб
1
#include <torch/csrc/utils/tensor_qschemes.h>
2

3
#include <c10/core/QScheme.h>
4
#include <c10/util/irange.h>
5
#include <torch/csrc/DynamicTypes.h>
6
#include <torch/csrc/Exceptions.h>
7
#include <torch/csrc/QScheme.h>
8

9
#include <torch/csrc/python_headers.h>
10
#include <torch/csrc/utils/object_ptr.h>
11

12
namespace torch {
13
namespace utils {
14

15
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
16
static PyObject* thp_qscheme_array[at::COMPILE_TIME_NUM_QSCHEMES];
17

18
void initializeQSchemes() {
19
  auto torch_module = THPObjectPtr(PyImport_ImportModule("torch"));
20
  if (!torch_module) {
21
    throw python_error();
22
  }
23

24
  for (const auto i : c10::irange(at::COMPILE_TIME_NUM_QSCHEMES)) {
25
    auto qscheme = static_cast<at::QScheme>(i);
26
    PyObject* qscheme_obj = THPQScheme_New(qscheme, toString(qscheme));
27
    thp_qscheme_array[static_cast<int>(qscheme)] = qscheme_obj;
28
    Py_INCREF(qscheme_obj);
29
    if (PyModule_AddObject(
30
            torch_module, toString(qscheme).c_str(), qscheme_obj) != 0) {
31
      throw python_error();
32
    }
33
  }
34
}
35

36
PyObject* getTHPQScheme(at::QScheme qscheme) {
37
  auto qscheme_ = thp_qscheme_array[static_cast<int>(qscheme)];
38
  if (!qscheme_) {
39
    throw std::invalid_argument("unsupported QScheme");
40
  }
41
  return qscheme_;
42
}
43

44
} // namespace utils
45
} // namespace torch
46

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.