pytorch

Форк
0
/
pybind.cpp 
168 строк · 5.1 Кб
1
#include <torch/csrc/utils/pybind.h>
2
#include <torch/csrc/utils/python_arg_parser.h>
3
#include <torch/csrc/utils/python_symnode.h>
4

5
namespace pybind11 {
6
namespace detail {
7

8
bool type_caster<c10::SymInt>::load(py::handle src, bool) {
9
  if (torch::is_symint(src)) {
10
    auto node = src.attr("node");
11
    if (py::isinstance<c10::SymNodeImpl>(node)) {
12
      value = c10::SymInt(py::cast<c10::SymNode>(node));
13
      return true;
14
    }
15

16
    value = c10::SymInt(static_cast<c10::SymNode>(
17
        c10::make_intrusive<torch::impl::PythonSymNodeImpl>(node)));
18
    return true;
19
  }
20

21
  auto raw_obj = src.ptr();
22

23
  if (THPVariable_Check(raw_obj)) {
24
    auto& var = THPVariable_Unpack(raw_obj);
25
    if (var.numel() == 1 &&
26
        at::isIntegralType(var.dtype().toScalarType(), /*include_bool*/ true)) {
27
      auto scalar = var.item();
28
      TORCH_INTERNAL_ASSERT(scalar.isIntegral(/*include bool*/ false));
29
      value = scalar.toSymInt();
30
      return true;
31
    }
32
  }
33

34
  if (THPUtils_checkIndex(raw_obj)) {
35
    value = c10::SymInt{THPUtils_unpackIndex(raw_obj)};
36
    return true;
37
  }
38
  return false;
39
}
40

41
py::handle type_caster<c10::SymInt>::cast(
42
    const c10::SymInt& si,
43
    return_value_policy /* policy */,
44
    handle /* parent */) {
45
  if (si.is_symbolic()) {
46
    auto* py_node = dynamic_cast<torch::impl::PythonSymNodeImpl*>(
47
        si.toSymNodeImplUnowned());
48
    if (py_node) {
49
      // Return the Python directly (unwrap)
50
      return torch::get_symint_class()(py_node->getPyObj()).release();
51
    } else {
52
      // Wrap the C++ into Python
53
      auto inner = py::cast(si.toSymNode());
54
      if (!inner) {
55
        throw python_error();
56
      }
57
      return torch::get_symint_class()(inner).release();
58
    }
59
  } else {
60
    auto m = si.maybe_as_int();
61
    // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
62
    return py::cast(*m).release();
63
  }
64
}
65

66
bool type_caster<c10::SymFloat>::load(py::handle src, bool) {
67
  if (torch::is_symfloat(src)) {
68
    value = c10::SymFloat(static_cast<c10::SymNode>(
69
        c10::make_intrusive<torch::impl::PythonSymNodeImpl>(src.attr("node"))));
70
    return true;
71
  }
72

73
  auto raw_obj = src.ptr();
74
  if (THPUtils_checkDouble(raw_obj)) {
75
    value = c10::SymFloat{THPUtils_unpackDouble(raw_obj)};
76
    return true;
77
  }
78
  return false;
79
}
80

81
py::handle type_caster<c10::SymFloat>::cast(
82
    const c10::SymFloat& si,
83
    return_value_policy /* policy */,
84
    handle /* parent */) {
85
  if (si.is_symbolic()) {
86
    // TODO: generalize this to work with C++ backed class
87
    auto* py_node =
88
        dynamic_cast<torch::impl::PythonSymNodeImpl*>(si.toSymNodeImpl().get());
89
    TORCH_INTERNAL_ASSERT(py_node);
90
    return torch::get_symfloat_class()(py_node->getPyObj()).release();
91
  } else {
92
    return py::cast(si.as_float_unchecked()).release();
93
  }
94
}
95

96
bool type_caster<c10::SymBool>::load(py::handle src, bool) {
97
  if (torch::is_symbool(src)) {
98
    value = c10::SymBool(static_cast<c10::SymNode>(
99
        c10::make_intrusive<torch::impl::PythonSymNodeImpl>(src.attr("node"))));
100
    return true;
101
  }
102

103
  auto raw_obj = src.ptr();
104
  if (THPUtils_checkBool(raw_obj)) {
105
    value = c10::SymBool{THPUtils_unpackBool(raw_obj)};
106
    return true;
107
  }
108
  return false;
109
}
110

111
py::handle type_caster<c10::SymBool>::cast(
112
    const c10::SymBool& si,
113
    return_value_policy /* policy */,
114
    handle /* parent */) {
115
  if (auto m = si.maybe_as_bool()) {
116
    return py::cast(*m).release();
117
  } else {
118
    // TODO: generalize this to work with C++ backed class
119
    auto* py_node =
120
        dynamic_cast<torch::impl::PythonSymNodeImpl*>(si.toSymNodeImpl().get());
121
    TORCH_INTERNAL_ASSERT(py_node);
122
    return torch::get_symbool_class()(py_node->getPyObj()).release();
123
  }
124
}
125

126
bool type_caster<c10::Scalar>::load(py::handle src, bool) {
127
  TORCH_INTERNAL_ASSERT(
128
      0, "pybind11 loading for c10::Scalar NYI (file a bug if you need it)");
129
}
130

131
py::handle type_caster<c10::Scalar>::cast(
132
    const c10::Scalar& scalar,
133
    return_value_policy /* policy */,
134
    handle /* parent */) {
135
  if (scalar.isIntegral(/*includeBool*/ false)) {
136
    // We have to be careful here; we cannot unconditionally route through
137
    // SymInt because integer data from Tensors can easily be MIN_INT or
138
    // very negative, which conflicts with the allocated range.
139
    if (scalar.isSymbolic()) {
140
      return py::cast(scalar.toSymInt()).release();
141
    } else {
142
      if (scalar.type() == at::ScalarType::UInt64) {
143
        return py::cast(scalar.toUInt64()).release();
144
      } else {
145
        return py::cast(scalar.toLong()).release();
146
      }
147
    }
148
  } else if (scalar.isFloatingPoint()) {
149
    // This isn't strictly necessary but we add it for symmetry
150
    if (scalar.isSymbolic()) {
151
      return py::cast(scalar.toSymFloat()).release();
152
    } else {
153
      return py::cast(scalar.toDouble()).release();
154
    }
155
  } else if (scalar.isBoolean()) {
156
    if (scalar.isSymbolic()) {
157
      return py::cast(scalar.toSymBool()).release();
158
    }
159
    return py::cast(scalar.toBool()).release();
160
  } else if (scalar.isComplex()) {
161
    return py::cast(scalar.toComplexDouble()).release();
162
  } else {
163
    TORCH_INTERNAL_ASSERT(0, "unrecognized scalar type ", scalar.type());
164
  }
165
}
166

167
} // namespace detail
168
} // namespace pybind11
169

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.