pytorch

Форк
0
/
python_interpreter.cpp 
86 строк · 2.6 Кб
1
#include <torch/csrc/jit/runtime/interpreter.h>
2
#include <torch/csrc/python_headers.h>
3

4
#include <torch/csrc/autograd/edge.h>
5
#include <torch/csrc/autograd/function.h>
6
#include <torch/csrc/autograd/profiler.h>
7
#include <torch/csrc/autograd/variable.h>
8
#include <torch/csrc/jit/ir/ir.h>
9
#include <torch/csrc/jit/python/pybind_utils.h>
10
#include <torch/csrc/jit/python/python_ir.h>
11
#include <torch/csrc/jit/runtime/custom_operator.h>
12
#include <torch/csrc/jit/runtime/graph_executor.h>
13
#include <torch/csrc/jit/runtime/operator.h>
14

15
#include <typeinfo>
16

17
#include <pybind11/pybind11.h>
18
#include <torch/csrc/Exceptions.h>
19
#include <torch/csrc/autograd/python_engine.h>
20
#include <torch/csrc/autograd/python_variable.h>
21
#include <torch/csrc/jit/python/pybind.h>
22
#include <torch/csrc/utils/pybind.h>
23

24
namespace py = pybind11;
25

26
namespace torch::jit {
27

28
namespace {
29

30
// Note: const_cast is used twice below to acquire a handle to a pyobject.
31
Operation createPythonOperation(const Node* op_) {
32
  pybind11::gil_scoped_acquire gil;
33
  const ConcretePythonOp* op = static_cast<const ConcretePythonOp*>(op_);
34
  const py::function func = py::reinterpret_borrow<const py::function>(
35
      // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
36
      py::handle(const_cast<ConcretePythonOp*>(op)->pyobj.get()));
37

38
  size_t num_inputs = 0;
39
  for (auto arg_type : op->cconv) {
40
    if (arg_type == 'd')
41
      num_inputs++;
42
  }
43

44
  AT_ASSERT(op->outputs().size() == 1);
45

46
  return [=](Stack& stack) {
47
    pybind11::gil_scoped_acquire gil;
48
    py::tuple py_inputs(op->cconv.size());
49
    size_t i = 0;
50
    size_t next_scalar = 0;
51
    size_t next_tensor = 0;
52
    for (auto arg_type : op->cconv) {
53
      if (arg_type == 'c') {
54
        py_inputs[i] = py::reinterpret_borrow<const py::object>(
55
            // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
56
            const_cast<ConcretePythonOp*>(op)
57
                ->scalar_args[next_scalar++]
58
                .get());
59
      } else if (arg_type == 'd') {
60
        py_inputs[i] =
61
            toPyObject(std::move(peek(stack, next_tensor, num_inputs)));
62
        next_tensor++;
63
      }
64
      i++;
65
    }
66
    drop(stack, num_inputs);
67
    try {
68
      py::object py_output(func(*py_inputs));
69
      stack.push_back(returnToIValue(op->output()->type(), py_output));
70
    } catch (py::error_already_set& e) {
71
      throw std::runtime_error(e.what());
72
    }
73
  };
74
}
75

76
c10::AliasAnalysisKind aliasAnalysisIsSpecialCase() {
77
  return AliasAnalysisKind::INTERNAL_SPECIAL_CASE;
78
}
79

80
RegisterOperators reg({Operator(
81
    prim::PythonOp,
82
    createPythonOperation,
83
    aliasAnalysisIsSpecialCase())});
84

85
} // namespace
86
} // namespace torch::jit
87

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.