1
#include <torch/csrc/jit/ir/graph_utils.h>
2
#include <torch/csrc/jit/python/module_python.h>
3
#include <torch/csrc/jit/python/pybind_utils.h>
4
#include <torch/csrc/jit/python/python_dict.h>
5
#include <torch/csrc/jit/python/python_ivalue.h>
6
#include <torch/csrc/jit/python/python_list.h>
7
#include <torch/csrc/jit/python/utf8_decoding_ignore.h>
9
#include <ATen/ScalarOps.h>
11
#include <c10/core/QScheme.h>
12
#include <c10/util/irange.h>
13
#include <torch/csrc/utils/python_arg_parser.h>
19
static thread_local bool allow_numbers_as_tensors = false;
21
ToIValueAllowNumbersAsTensors::ToIValueAllowNumbersAsTensors(bool enable)
22
: old_(allow_numbers_as_tensors) {
23
allow_numbers_as_tensors = enable;
26
ToIValueAllowNumbersAsTensors::~ToIValueAllowNumbersAsTensors() {
27
allow_numbers_as_tensors = old_;
30
// This is a hack to remove instances deleted in C++ from the PyBind cache
31
// C++->Python. We need this because otherwise we may get the old Python object
32
// if C++ creates a new object at the memory location of the deleted object.
33
void clear_registered_instances(void* ptr) {
34
auto& registered_instances =
35
pybind11::detail::get_internals().registered_instances;
36
auto range = registered_instances.equal_range(ptr);
37
for (auto it = range.first; it != range.second; ++it) {
38
auto vh = it->second->get_value_and_holder();
39
vh.set_instance_registered(false);
41
registered_instances.erase(ptr);
44
// WARNING: Precondition for this function is that, e.g., you have tested if a
45
// SymIntList is in fact only ints, and if so, you called this with T=int64_t.
46
// This precondition is NOT checked at runtime.
48
IValue listToIValue(py::handle obj) {
50
for (auto it = obj.begin(); it != obj.end(); it++) {
52
rs.push_back(py::cast<T>(elm));
54
// Promises that we have decayed the list appropriately
55
return c10::impl::toList<T>(rs);
58
IValue toIValue(py::handle obj, const TypePtr& type, c10::optional<int32_t> N) {
59
switch (type->kind()) {
60
case TypeKind::TensorType: {
61
if (obj.ptr() == Py_None) {
62
// None gets converted to undefined Tensors
63
return autograd::Variable();
65
if (THPVariable_Check(obj.ptr())) {
66
auto var = py::cast<autograd::Variable>(obj);
67
guardAgainstNamedTensor<autograd::Variable>(var);
70
if (!allow_numbers_as_tensors) {
72
c10::str("Unable to cast ", py::str(obj), " to Tensor"));
74
bool save_symint = false;
76
if (PyBool_Check(obj.ptr())) {
77
scalar = at::Scalar(THPUtils_unpackBool(obj.ptr()));
78
} else if (THPUtils_checkLong(obj.ptr())) {
79
scalar = at::Scalar(THPUtils_unpackLong(obj.ptr()));
80
} else if (PyComplex_Check(obj.ptr())) {
81
scalar = at::Scalar(THPUtils_unpackComplexDouble(obj.ptr()));
82
} else if (THPUtils_checkDouble(obj.ptr())) {
83
scalar = at::Scalar(THPUtils_unpackDouble(obj.ptr()));
84
} else if (torch::is_symint(py::handle(obj))) {
86
scalar = at::Scalar(7777777);
87
} else if (torch::is_symfloat(py::handle(obj))) {
89
scalar = at::Scalar(std::numeric_limits<double>::quiet_NaN());
90
} else if (torch::is_symbool(py::handle(obj))) {
92
scalar = at::Scalar(true);
95
c10::str("Unable to cast ", py::str(obj), " to Tensor"));
97
at::Tensor tensor = at::scalar_to_tensor(scalar);
98
tensor.unsafeGetTensorImpl()->set_wrapped_number(true);
101
auto py_tensor = py::cast(tensor);
102
if (PyObject_SetAttrString(
103
py_tensor.ptr(), "_wrapped_number", obj.ptr()) < 0) {
104
throw python_error();
111
case TypeKind::StorageType:
112
return py::cast<at::Storage>(obj);
113
case TypeKind::FloatType:
114
if (torch::is_symfloat(py::handle(obj))) {
115
return py::cast<c10::SymFloat>(obj).guard_float(__FILE__, __LINE__);
117
return py::cast<double>(obj);
118
case TypeKind::ComplexType: {
119
auto c_obj = py::cast<std::complex<double>>(obj.ptr());
120
return static_cast<c10::complex<double>>(c_obj);
122
case TypeKind::IntType:
123
// TODO: Properly fake this type
124
if (THPQScheme_Check(obj.ptr())) {
125
auto qscheme = reinterpret_cast<THPQScheme*>(obj.ptr());
126
return static_cast<uint8_t>(qscheme->qscheme);
128
// For backwards compatibility
129
if (THPDtype_Check(obj.ptr())) {
130
auto dtype = reinterpret_cast<THPDtype*>(obj.ptr());
131
return static_cast<int64_t>(dtype->scalar_type);
133
if (THPQScheme_Check(obj.ptr())) {
134
auto qscheme = reinterpret_cast<THPQScheme*>(obj.ptr());
135
return static_cast<uint8_t>(qscheme->qscheme);
137
if (THPLayout_Check(obj.ptr())) {
138
auto layout = reinterpret_cast<THPLayout*>(obj.ptr());
139
return static_cast<int8_t>(layout->layout);
141
if (THPMemoryFormat_Check(obj.ptr())) {
142
auto memory_format = reinterpret_cast<THPMemoryFormat*>(obj.ptr());
143
return static_cast<int8_t>(memory_format->memory_format);
145
if (torch::is_symint(py::handle(obj))) {
146
return py::cast<c10::SymInt>(obj).guard_int(__FILE__, __LINE__);
148
return py::cast<int64_t>(obj);
149
case TypeKind::LayoutType: {
150
if (THPLayout_Check(obj.ptr())) {
151
auto layout = reinterpret_cast<THPLayout*>(obj.ptr());
152
return static_cast<int8_t>(layout->layout);
154
// For backwards compatibility
155
return py::cast<int64_t>(obj);
157
case TypeKind::ScalarTypeType: {
158
if (THPDtype_Check(obj.ptr())) {
159
auto dtype = reinterpret_cast<THPDtype*>(obj.ptr());
160
return static_cast<int64_t>(dtype->scalar_type);
162
// For backwards compatibility
163
return py::cast<int64_t>(obj);
165
case TypeKind::MemoryFormatType: {
166
if (THPMemoryFormat_Check(obj.ptr())) {
167
auto memory_format = reinterpret_cast<THPMemoryFormat*>(obj.ptr());
168
return static_cast<int8_t>(memory_format->memory_format);
170
// For backwards compatibility
171
return py::cast<int64_t>(obj);
173
case TypeKind::SymIntType:
174
if (torch::is_symint(obj.ptr())) {
175
return py::cast<c10::SymInt>(obj);
177
return py::cast<int64_t>(obj);
178
case TypeKind::SymFloatType:
179
if (torch::is_symfloat(obj.ptr())) {
180
return py::cast<c10::SymFloat>(obj);
182
return py::cast<double>(obj);
183
case TypeKind::SymBoolType:
184
if (torch::is_symbool(obj.ptr())) {
185
return py::cast<c10::SymBool>(obj);
187
return py::cast<bool>(obj);
188
case TypeKind::NoneType:
189
if (!obj.is_none()) {
190
throw py::cast_error(
191
c10::str("Cannot cast ", py::str(obj), " to None"));
194
case TypeKind::BoolType:
195
if (torch::is_symbool(obj.ptr())) {
196
return py::cast<c10::SymBool>(obj).guard_bool(__FILE__, __LINE__);
198
return py::cast<bool>(obj);
199
case TypeKind::TupleType: {
200
py::tuple tuple = py::cast<py::tuple>(obj);
201
size_t tuple_size = tuple.size();
202
auto tuple_type = type->cast<TupleType>();
203
const auto& elem_types = tuple_type->elements();
204
if (elem_types.size() != tuple_size) {
205
throw py::cast_error(c10::str(
208
" had a different number of elements than type ",
211
std::vector<IValue> values;
212
values.reserve(tuple_size);
213
for (const auto i : c10::irange(tuple_size)) {
214
values.push_back(toIValue(tuple[i], elem_types[i]));
216
return tuple_type->name()
217
? c10::ivalue::Tuple::createNamed(std::move(values), tuple_type)
218
: c10::ivalue::Tuple::create(std::move(values));
220
case TypeKind::UnionType: {
221
auto actual_type = toTypeInferredIValue(obj);
222
auto actual_type_ptr = actual_type.type();
223
auto union_type = type->expect<UnionType>();
224
if (!actual_type_ptr->isSubtypeOf(union_type)) {
225
throw py::cast_error(c10::str(
226
"Expected a member of ",
227
union_type->annotation_str(),
228
" but instead found type ",
229
actual_type.type()->annotation_str()));
233
case TypeKind::StringType:
234
return ConstantString::create(py::cast<std::string>(obj));
235
case TypeKind::DeviceObjType: {
236
if (THPDevice_Check(obj.ptr())) {
237
auto device = reinterpret_cast<THPDevice*>(obj.ptr());
238
return device->device;
240
return c10::Device(py::cast<std::string>(obj.ptr()));
242
case TypeKind::StreamObjType: {
243
auto thp_stream = reinterpret_cast<THPStream*>(obj.ptr());
244
auto stream = c10::Stream::unpack3(
245
thp_stream->stream_id,
246
thp_stream->device_index,
247
static_cast<c10::DeviceType>(thp_stream->device_type));
250
case TypeKind::ListType: {
251
// If the object is a ScriptList, retrieve the c10::List
252
// instance inside it.
253
if (py::isinstance<ScriptList>(obj)) {
254
return py::cast<ScriptList>(obj).list_;
257
// If not (i.e. it is a regular Python list), make a new
259
const auto& elem_type = type->expectRef<ListType>().getElementType();
260
switch (elem_type->kind()) {
261
// allows single int/float to be broadcasted to a fixed size list
262
case TypeKind::IntType:
263
if (!N || !py::isinstance<py::int_>(obj)) {
264
return IValue(py::cast<std::vector<int64_t>>(obj));
266
int64_t value = py::cast<int64_t>(obj);
267
c10::List<int64_t> repeated;
268
repeated.reserve(*N);
269
for (int i = 0; i < *N; ++i) {
270
repeated.push_back(value);
274
case TypeKind::SymIntType: {
275
bool is_symbolic = false;
276
for (auto it = obj.begin(); it != obj.end(); it++) {
278
if (torch::is_symint(elm)) {
284
return listToIValue<c10::SymInt>(obj);
286
return listToIValue<int64_t>(obj);
289
case TypeKind::SymFloatType: {
290
bool is_symbolic = false;
291
for (auto it = obj.begin(); it != obj.end(); it++) {
293
// TODO: what about SymInt conversion to SymFloat?
294
if (torch::is_symfloat(elm)) {
300
return listToIValue<c10::SymFloat>(obj);
302
return listToIValue<double>(obj);
305
case TypeKind::SymBoolType: {
306
bool is_symbolic = false;
307
for (auto it = obj.begin(); it != obj.end(); it++) {
309
if (torch::is_symbool(elm)) {
315
return listToIValue<c10::SymBool>(obj);
317
return listToIValue<bool>(obj);
320
case TypeKind::FloatType:
321
if (!N || !py::isinstance<py::float_>(obj)) {
322
return IValue(py::cast<std::vector<double>>(obj));
324
double value = py::cast<double>(obj);
325
c10::List<double> repeated;
326
repeated.reserve(*N);
327
for (int i = 0; i < *N; ++i) {
328
repeated.push_back(value);
332
case TypeKind::BoolType:
333
return IValue(py::cast<std::vector<bool>>(obj));
334
case TypeKind::TensorType:
335
return IValue(py::cast<std::vector<at::Tensor>>(obj));
337
return createGenericList(obj, elem_type);
340
case TypeKind::DictType: {
341
const auto& dict_type = type->expect<DictType>();
343
// If the object is a ScriptDict, retrieve the c10::Dict
344
// instance inside it.
346
auto script_dict = py::cast<ScriptDict>(obj);
347
return script_dict.dict_;
348
} catch (py::cast_error& e) {
351
// If not (i.e. it is a regular Python dictionary), make a new
353
return createGenericDict(
354
py::cast<py::dict>(obj),
355
dict_type->getKeyType(),
356
dict_type->getValueType());
358
case TypeKind::OptionalType: {
359
// check if it's a none obj since optional accepts NoneType
361
// check if it's a none obj since optional accepts NoneType
362
// return an IValue() to denote a NoneType
365
return toIValue(obj, type->expectRef<OptionalType>().getElementType(), N);
367
case TypeKind::ClassType: {
368
auto classType = type->expect<ClassType>();
369
auto object = py::cast<py::object>(obj);
370
if (auto mod = as_module(object)) {
371
// if obj is already a ScriptModule, just return its ivalue
372
return mod.value()._ivalue();
375
// Check if the obj is a ScriptObject.
376
if (auto script_obj = as_object(object)) {
377
return script_obj.value()._ivalue();
380
// otherwise is a normal class object, we create a fresh
381
// ivalue::Object to use from the py object.
382
// 1. create a bare ivalue
383
const size_t numAttrs = classType->numAttributes();
384
auto cu = classType->compilation_unit();
385
auto userObj = c10::ivalue::Object::create(
386
c10::StrongTypePtr(cu, classType), numAttrs);
388
// 2. copy all the contained types
389
for (const auto slot : c10::irange(numAttrs)) {
390
const auto& attrType = classType->getAttribute(slot);
391
const auto& attrName = classType->getAttributeName(slot);
393
if (!py::hasattr(obj, attrName.c_str())) {
394
throw py::cast_error(c10::str(
395
"Tried to cast object to type ",
398
" was missing attribute ",
403
const auto& contained = py::getattr(obj, attrName.c_str());
404
userObj->setSlot(slot, toIValue(contained, attrType));
405
} catch (std::exception& e) {
406
throw py::cast_error(c10::str(
407
"Could not cast attribute '",
410
attrType->repr_str(),
417
case TypeKind::InterfaceType: {
418
auto interfaceType = type->expect<InterfaceType>();
419
// When converting an pyobj to an interface, we check if rhs
420
// is module or normal torchscript class, get the type and ivalue
421
// from them correspondingly.
422
c10::ClassTypePtr classType = nullptr;
424
if (auto mod = as_module(py::cast<py::object>(obj))) {
425
classType = mod.value().type();
426
res = mod.value()._ivalue();
427
} else if (auto object = as_object(py::cast<py::object>(obj))) {
428
classType = object.value().type();
429
res = object.value()._ivalue();
431
// We inspect the value to found the compiled TorchScript class
432
// and then create a ivalue::Object from that class type.
433
py::str qualified_name = py::module::import("torch._jit_internal")
434
.attr("_qualified_name")(obj.get_type());
435
auto pyCu = get_python_cu();
436
classType = pyCu->get_class(c10::QualifiedName(qualified_name));
438
throw std::runtime_error(c10::str(
439
"Assigning the object ",
441
" to an interface fails because the value is not "
442
"a TorchScript compatible type, did you forget to",
443
"turn it into a user defined TorchScript class?"));
445
res = toIValue(obj, classType);
447
// check if the classType conform with the interface or not
448
std::stringstream why_not;
449
if (!classType->isSubtypeOfExt(*interfaceType, &why_not)) {
450
throw py::cast_error(c10::str(
452
classType->repr_str(),
453
" is not compatible with interface ",
454
interfaceType->repr_str(),
460
case TypeKind::NumberType: {
461
if (THPDtype_Check(obj.ptr())) {
462
auto dtype = reinterpret_cast<THPDtype*>(obj.ptr());
463
return static_cast<int64_t>(dtype->scalar_type);
465
if (THPQScheme_Check(obj.ptr())) {
466
auto qscheme = reinterpret_cast<THPQScheme*>(obj.ptr());
467
return static_cast<uint8_t>(qscheme->qscheme);
469
if (THPLayout_Check(obj.ptr())) {
470
auto layout = reinterpret_cast<THPLayout*>(obj.ptr());
471
return static_cast<int8_t>(layout->layout);
473
if (py::isinstance<py::bool_>(obj)) {
474
return py::cast<bool>(obj);
475
} else if (py::isinstance<py::int_>(obj)) {
476
return py::cast<int64_t>(obj);
477
} else if (py::isinstance<py::float_>(obj)) {
478
return py::cast<double>(obj);
479
} else if (PyComplex_CheckExact(obj.ptr())) {
480
auto c_obj = py::cast<std::complex<double>>(obj.ptr());
481
return static_cast<c10::complex<double>>(c_obj);
482
} else if (torch::is_symint(obj)) {
483
return py::cast<c10::SymInt>(obj);
484
} else if (torch::is_symfloat(obj)) {
485
return py::cast<c10::SymFloat>(obj);
486
} else if (torch::is_symbool(obj)) {
487
return py::cast<c10::SymBool>(obj);
489
throw py::cast_error(
490
c10::str("Cannot cast ", py::str(obj), " to ", type->repr_str()));
493
case TypeKind::RRefType: {
495
return obj.cast<torch::distributed::rpc::PyRRef>().toIValue();
497
AT_ERROR("RRef is only supported with the distributed package");
500
case TypeKind::PyObjectType: {
501
return c10::ivalue::ConcretePyObjectHolder::create(obj);
503
case TypeKind::CapsuleType: {
504
return IValue::make_capsule(py::cast<c10::Capsule>(obj).obj_ptr);
506
case TypeKind::FutureType: {
507
return obj.cast<std::shared_ptr<PythonFutureWrapper>>()->fut;
509
case TypeKind::AwaitType: {
510
return obj.cast<std::shared_ptr<PythonAwaitWrapper>>()->aw_;
512
case TypeKind::AnyType:
513
return toTypeInferredIValue(obj);
514
case TypeKind::QSchemeType: {
515
if (py::isinstance<py::int_>(obj)) {
516
return static_cast<at::QScheme>(py::cast<int64_t>(obj));
518
throw py::cast_error(
519
c10::str("Cannot cast ", py::str(obj), " to ", type->repr_str()));
521
case TypeKind::GeneratorType:
522
return py::cast<at::Generator>(obj);
523
case TypeKind::DynamicType:
524
case TypeKind::FunctionType:
525
case TypeKind::QuantizerType:
526
case TypeKind::VarType:
527
case TypeKind::AnyListType:
528
case TypeKind::AnyTupleType:
529
case TypeKind::AnyClassType:
530
case TypeKind::AnyEnumType:
532
case TypeKind::EnumType:
533
EnumTypePtr enum_type = type->expect<EnumType>();
534
py::object py_obj = py::reinterpret_borrow<py::object>(obj);
535
std::string name = py::cast<std::string>(obj.attr("name"));
536
IValue value = toIValue(obj.attr("value"), enum_type->getValueType(), {});
538
c10::make_intrusive<c10::ivalue::EnumHolder>(enum_type, name, value);
539
return IValue(enum_holder);
541
throw py::cast_error(c10::str(
542
"toIValue() cannot handle converting to type: ", type->repr_str()));
545
py::object toPyObject(IValue ivalue) {
546
if (ivalue.isNone()) {
548
} else if (ivalue.isTensor()) {
549
auto tensor = std::move(ivalue).toTensor();
550
if (tensor.unsafeGetTensorImpl()->is_wrapped_number()) {
551
TORCH_INTERNAL_ASSERT(tensor.device().is_cpu());
552
auto py_tensor = py::cast(tensor);
553
if (PyObject_HasAttrString(py_tensor.ptr(), "_wrapped_number")) {
554
return py_tensor.attr("_wrapped_number");
556
auto scalar_type = tensor.scalar_type();
557
switch (scalar_type) {
558
case at::ScalarType::Bool:
559
return py::cast(*tensor.const_data_ptr<bool>());
560
case at::ScalarType::Long:
561
return py::cast(*tensor.const_data_ptr<int64_t>());
562
case at::ScalarType::Double:
563
return py::cast(*tensor.const_data_ptr<double>());
564
case at::ScalarType::ComplexDouble:
565
// TODO: https://github.com/pytorch/pytorch/issues/77134
566
return py::cast(static_cast<std::complex<double>>(
567
*tensor.const_data_ptr<c10::complex<double>>()));
571
"Missing cases in 'toPyObject' wrapped number handling! Can't convert ",
573
" to a Python object");
576
guardAgainstNamedTensor<at::Tensor>(tensor);
577
return py::cast(autograd::Variable(std::move(tensor)));
579
} else if (ivalue.isStorage()) {
580
return py::cast(std::move(ivalue).toStorage());
581
} else if (ivalue.isGenerator()) {
582
return py::cast(std::move(ivalue).toGenerator());
583
} else if (ivalue.isDouble()) {
584
return py::cast(std::move(ivalue).toDouble());
585
} else if (ivalue.isComplexDouble()) {
587
static_cast<std::complex<double>>(std::move(ivalue).toComplexDouble()));
588
} else if (ivalue.isInt()) {
589
return py::cast(std::move(ivalue).toInt());
590
} else if (ivalue.isBool()) {
591
return py::cast(std::move(ivalue).toBool());
592
} else if (ivalue.isString()) {
593
if (getUTF8DecodingIgnore()) {
594
std::string s = std::move(ivalue).toStringRef();
595
PyObject* pyObj = PyUnicode_DecodeUTF8(s.data(), s.length(), "ignore");
596
return py::reinterpret_steal<py::object>(pyObj);
598
return py::cast(std::move(ivalue).toStringRef());
600
} else if (ivalue.isList()) {
601
auto list = std::move(ivalue).toList();
602
py::list t{list.size()};
603
for (const auto i : c10::irange(list.size())) {
604
t[i] = toPyObject(IValue{list.get(i)});
607
} else if (ivalue.isTuple()) {
608
auto tuple = std::move(ivalue).toTuple();
609
const auto& elements = tuple->elements();
611
py::tuple t{elements.size()};
612
for (const auto i : c10::irange(elements.size())) {
613
t[i] = toPyObject(IValue{elements.at(i)});
616
// If we have a NamedTuple
617
if (tuple->type() && tuple->type()->schema() &&
618
!tuple->type()->schema()->name().empty()) {
619
auto unqualName = tuple->type()->name()->name();
621
std::vector<Argument> tuple_args = tuple->type()->schema()->arguments();
623
std::vector<pybind11::object> defaults;
624
auto it = std::find_if(
625
tuple_args.begin(), tuple_args.end(), [](const Argument& arg) {
626
return arg.default_value().has_value();
631
std::back_inserter(defaults),
632
[](const Argument& arg) { return toPyObject(*arg.default_value()); });
634
std::vector<std::string> fieldNames =
635
fmap(tuple_args, [](const Argument& arg) { return arg.name(); });
637
return py::module::import("torch._jit_internal")
638
.attr("_create_named_tuple")(
639
t, unqualName, fieldNames, py::make_tuple(defaults));
643
} else if (ivalue.isDevice()) {
644
return py::cast<py::object>(THPDevice_New(std::move(ivalue).toDevice()));
645
} else if (ivalue.isStream()) {
646
return py::cast(std::move(ivalue).toStream());
647
} else if (ivalue.isGenericDict()) {
648
auto dict = std::move(ivalue).toGenericDict();
650
for (auto& pair : dict) {
651
py_dict[toPyObject(IValue{pair.key()})] =
652
toPyObject(IValue{pair.value()});
654
return std::move(py_dict);
655
} else if (ivalue.isRRef()) {
658
c10::dynamic_intrusive_pointer_cast<torch::distributed::rpc::RRef>(
659
std::move(ivalue).toRRef());
660
return py::cast(torch::distributed::rpc::PyRRef(RRefPtr));
662
AT_ERROR("RRef is only supported with the distributed package");
664
} else if (ivalue.isObject()) {
665
const auto obj = std::move(ivalue).toObject();
666
if (obj->type()->is_module()) {
667
return py::cast(Module(obj));
670
auto pyCu = get_python_cu();
671
if (obj->name().find("__torch__.torch.classes") == 0) {
672
return py::cast(Object(obj));
674
const auto classType = pyCu->get_class(c10::QualifiedName(obj->name()));
675
AT_ASSERT(classType);
676
auto pyClass = getScriptedClassOrError(obj->type());
677
auto pyObj = pyClass.attr("__new__")(pyClass);
679
const auto numAttrs = classType->numAttributes();
681
for (const auto slot : c10::irange(numAttrs)) {
682
const auto& attrName = classType->getAttributeName(slot);
683
IValue v = obj->getSlot(slot);
684
py::setattr(pyObj, attrName.c_str(), toPyObject(std::move(v)));
687
} else if (ivalue.isPyObject()) {
688
// return borrowed reference to ensure it correctly incref the underlying
690
return py::reinterpret_borrow<py::object>(ivalue.toPyObject());
691
} else if (ivalue.isCapsule()) {
692
return py::cast(c10::Capsule(ivalue.toCapsule()));
693
} else if (ivalue.isFuture()) {
694
return py::cast(std::make_shared<PythonFutureWrapper>(ivalue.toFuture()));
695
} else if (ivalue.isAwait()) {
696
return py::cast(std::make_shared<PythonAwaitWrapper>(ivalue.toAwait()));
697
} else if (ivalue.isEnum()) {
698
auto enum_holder = ivalue.toEnumHolder();
699
auto py_class = getScriptedClassOrError(enum_holder->type());
700
return py_class.attr(enum_holder->name().c_str());
701
} else if (ivalue.isRRef()) {
703
return py::cast(torch::distributed::rpc::PyRRef(
704
c10::static_intrusive_pointer_cast<distributed::rpc::RRef>(
707
TORCH_CHECK(false, "RRef is only supported with the distributed package");
709
} else if (ivalue.isSymInt()) {
710
return py::cast(std::move(ivalue).toSymInt());
711
} else if (ivalue.isSymFloat()) {
712
return py::cast(std::move(ivalue).toSymFloat());
713
} else if (ivalue.isSymBool()) {
714
return py::cast(std::move(ivalue).toSymBool());
717
"Missing cases in 'toPyObject'! Can't convert ",
719
" to a Python object");
723
std::pair<std::shared_ptr<Operator>, Stack> getOpWithStack(
724
const std::vector<std::shared_ptr<Operator>>& operations,
726
const py::kwargs& kwargs) {
728
if (operations.size() == 1) {
729
std::shared_ptr<Operator> op = operations.at(0);
730
// Create a stack full of the arguments and keyword arguments.
731
stack = createStackForSchema(
732
op->schema(), std::move(args), kwargs, c10::nullopt);
734
return std::make_pair(std::move(op), std::move(stack));
736
std::vector<schema_match_error> errors;
737
std::shared_ptr<Operator> found_op = nullptr;
738
for (const auto& op : operations) {
740
stack = createStackForSchema(op->schema(), args, kwargs, c10::nullopt);
743
} catch (schema_match_error& error) {
744
errors.push_back(std::move(error));
748
std::stringstream ss;
749
ss << "Overloaded torch operator invoked from Python failed to match any schema:\n";
750
for (const auto& err : errors) {
751
ss << err.what() << "\n\n";
753
throw std::runtime_error(ss.str());
756
return std::make_pair(std::move(found_op), std::move(stack));
760
py::object invokeOperatorFromPython(
761
const std::vector<std::shared_ptr<Operator>>& operations,
763
const py::kwargs& kwargs,
764
c10::optional<c10::DispatchKey> dk) {
765
auto [found_op, stack] = getOpWithStack(operations, args, kwargs);
767
pybind11::gil_scoped_release no_gil_guard;
769
found_op->getOperationForDispatchKey (*dk)(stack);
771
found_op->getOperation()(stack);
775
return createPyObjectForStack(std::move(stack));
778
py::object _get_operation_for_overload_or_packet(
779
const std::vector<std::shared_ptr<Operator>>& operations,
782
const py::kwargs& kwargs,
784
c10::optional<c10::DispatchKey> dk) {
785
std::vector<PyObject*> overloaded_args;
786
size_t total_arg_num = args.size() + kwargs.size();
787
for (const auto i : c10::irange(args.size())) {
788
is_tensor_and_append_overloaded(args[i].ptr(), &overloaded_args);
789
is_tensor_list_and_append_overloaded(
792
static_cast<int>(total_arg_num),
793
false /* throw_error */);
795
// NB: for kwargs, we cannot guarantee the order of appending
796
// is the same as the argument order in operator's schema.
797
// This is suboptimal, but should be fine. Later when we have
798
// better schema matching and argument parsing, we could
799
// match the operator in `operations` first, then the order will
801
for (auto item : kwargs) {
802
is_tensor_and_append_overloaded(item.second.ptr(), &overloaded_args);
803
is_tensor_list_and_append_overloaded(
807
false /* throw_error */);
809
if (!overloaded_args.empty() || at::impl::torch_function_mode_enabled()) {
811
std::string ns = symbol.ns().toUnqualString();
812
std::string method_name = symbol.toUnqualString();
813
auto self_func = py::module::import("torch")
816
.attr(method_name.c_str());
818
auto overload_name = operations[0]->schema().overload_name();
819
if (overload_name.empty()) {
820
self_func = self_func.attr("default");
822
self_func = self_func.attr(overload_name.c_str());
825
std::string module_name("torch.ops");
826
module_name.append(ns);
827
return pybind11::reinterpret_steal<py::object>(
828
handle_torch_function_no_python_arg_parser(
834
module_name.c_str()));
836
return invokeOperatorFromPython(operations, args, kwargs, dk);
839
} // namespace torch::jit