1
#include <torch/csrc/utils/pybind.h>
2
#include <torch/csrc/utils/python_arg_parser.h>
3
#include <torch/csrc/utils/python_symnode.h>
8
bool type_caster<c10::SymInt>::load(py::handle src, bool) {
9
if (torch::is_symint(src)) {
10
auto node = src.attr("node");
11
if (py::isinstance<c10::SymNodeImpl>(node)) {
12
value = c10::SymInt(py::cast<c10::SymNode>(node));
16
value = c10::SymInt(static_cast<c10::SymNode>(
17
c10::make_intrusive<torch::impl::PythonSymNodeImpl>(node)));
21
auto raw_obj = src.ptr();
23
if (THPVariable_Check(raw_obj)) {
24
auto& var = THPVariable_Unpack(raw_obj);
25
if (var.numel() == 1 &&
26
at::isIntegralType(var.dtype().toScalarType(), /*include_bool*/ true)) {
27
auto scalar = var.item();
28
TORCH_INTERNAL_ASSERT(scalar.isIntegral(/*include bool*/ false));
29
value = scalar.toSymInt();
34
if (THPUtils_checkIndex(raw_obj)) {
35
value = c10::SymInt{THPUtils_unpackIndex(raw_obj)};
41
py::handle type_caster<c10::SymInt>::cast(
42
const c10::SymInt& si,
43
return_value_policy /* policy */,
44
handle /* parent */) {
45
if (si.is_symbolic()) {
46
auto* py_node = dynamic_cast<torch::impl::PythonSymNodeImpl*>(
47
si.toSymNodeImplUnowned());
49
// Return the Python directly (unwrap)
50
return torch::get_symint_class()(py_node->getPyObj()).release();
52
// Wrap the C++ into Python
53
auto inner = py::cast(si.toSymNode());
57
return torch::get_symint_class()(inner).release();
60
auto m = si.maybe_as_int();
61
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
62
return py::cast(*m).release();
66
bool type_caster<c10::SymFloat>::load(py::handle src, bool) {
67
if (torch::is_symfloat(src)) {
68
value = c10::SymFloat(static_cast<c10::SymNode>(
69
c10::make_intrusive<torch::impl::PythonSymNodeImpl>(src.attr("node"))));
73
auto raw_obj = src.ptr();
74
if (THPUtils_checkDouble(raw_obj)) {
75
value = c10::SymFloat{THPUtils_unpackDouble(raw_obj)};
81
py::handle type_caster<c10::SymFloat>::cast(
82
const c10::SymFloat& si,
83
return_value_policy /* policy */,
84
handle /* parent */) {
85
if (si.is_symbolic()) {
86
// TODO: generalize this to work with C++ backed class
88
dynamic_cast<torch::impl::PythonSymNodeImpl*>(si.toSymNodeImpl().get());
89
TORCH_INTERNAL_ASSERT(py_node);
90
return torch::get_symfloat_class()(py_node->getPyObj()).release();
92
return py::cast(si.as_float_unchecked()).release();
96
bool type_caster<c10::SymBool>::load(py::handle src, bool) {
97
if (torch::is_symbool(src)) {
98
value = c10::SymBool(static_cast<c10::SymNode>(
99
c10::make_intrusive<torch::impl::PythonSymNodeImpl>(src.attr("node"))));
103
auto raw_obj = src.ptr();
104
if (THPUtils_checkBool(raw_obj)) {
105
value = c10::SymBool{THPUtils_unpackBool(raw_obj)};
111
py::handle type_caster<c10::SymBool>::cast(
112
const c10::SymBool& si,
113
return_value_policy /* policy */,
114
handle /* parent */) {
115
if (auto m = si.maybe_as_bool()) {
116
return py::cast(*m).release();
118
// TODO: generalize this to work with C++ backed class
120
dynamic_cast<torch::impl::PythonSymNodeImpl*>(si.toSymNodeImpl().get());
121
TORCH_INTERNAL_ASSERT(py_node);
122
return torch::get_symbool_class()(py_node->getPyObj()).release();
126
bool type_caster<c10::Scalar>::load(py::handle src, bool) {
127
TORCH_INTERNAL_ASSERT(
128
0, "pybind11 loading for c10::Scalar NYI (file a bug if you need it)");
131
py::handle type_caster<c10::Scalar>::cast(
132
const c10::Scalar& scalar,
133
return_value_policy /* policy */,
134
handle /* parent */) {
135
if (scalar.isIntegral(/*includeBool*/ false)) {
136
// We have to be careful here; we cannot unconditionally route through
137
// SymInt because integer data from Tensors can easily be MIN_INT or
138
// very negative, which conflicts with the allocated range.
139
if (scalar.isSymbolic()) {
140
return py::cast(scalar.toSymInt()).release();
142
if (scalar.type() == at::ScalarType::UInt64) {
143
return py::cast(scalar.toUInt64()).release();
145
return py::cast(scalar.toLong()).release();
148
} else if (scalar.isFloatingPoint()) {
149
// This isn't strictly necessary but we add it for symmetry
150
if (scalar.isSymbolic()) {
151
return py::cast(scalar.toSymFloat()).release();
153
return py::cast(scalar.toDouble()).release();
155
} else if (scalar.isBoolean()) {
156
if (scalar.isSymbolic()) {
157
return py::cast(scalar.toSymBool()).release();
159
return py::cast(scalar.toBool()).release();
160
} else if (scalar.isComplex()) {
161
return py::cast(scalar.toComplexDouble()).release();
163
TORCH_INTERNAL_ASSERT(0, "unrecognized scalar type ", scalar.type());
168
} // namespace pybind11