pytorch

Форк
0
/
pybind_utils.cpp 
839 строк · 30.2 Кб
1
#include <torch/csrc/jit/ir/graph_utils.h>
2
#include <torch/csrc/jit/python/module_python.h>
3
#include <torch/csrc/jit/python/pybind_utils.h>
4
#include <torch/csrc/jit/python/python_dict.h>
5
#include <torch/csrc/jit/python/python_ivalue.h>
6
#include <torch/csrc/jit/python/python_list.h>
7
#include <torch/csrc/jit/python/utf8_decoding_ignore.h>
8

9
#include <ATen/ScalarOps.h>
10

11
#include <c10/core/QScheme.h>
12
#include <c10/util/irange.h>
13
#include <torch/csrc/utils/python_arg_parser.h>
14

15
#include <limits>
16

17
namespace torch::jit {
18

19
static thread_local bool allow_numbers_as_tensors = false;
20

21
ToIValueAllowNumbersAsTensors::ToIValueAllowNumbersAsTensors(bool enable)
22
    : old_(allow_numbers_as_tensors) {
23
  allow_numbers_as_tensors = enable;
24
}
25

26
ToIValueAllowNumbersAsTensors::~ToIValueAllowNumbersAsTensors() {
27
  allow_numbers_as_tensors = old_;
28
}
29

30
// This is a hack to remove instances deleted in C++ from the PyBind cache
31
// C++->Python. We need this because otherwise we may get the old Python object
32
// if C++ creates a new object at the memory location of the deleted object.
33
void clear_registered_instances(void* ptr) {
34
  auto& registered_instances =
35
      pybind11::detail::get_internals().registered_instances;
36
  auto range = registered_instances.equal_range(ptr);
37
  for (auto it = range.first; it != range.second; ++it) {
38
    auto vh = it->second->get_value_and_holder();
39
    vh.set_instance_registered(false);
40
  }
41
  registered_instances.erase(ptr);
42
}
43

44
// WARNING: Precondition for this function is that, e.g., you have tested if a
45
// SymIntList is in fact only ints, and if so, you called this with T=int64_t.
46
// This precondition is NOT checked at runtime.
47
template <typename T>
48
IValue listToIValue(py::handle obj) {
49
  c10::List<T> rs;
50
  for (auto it = obj.begin(); it != obj.end(); it++) {
51
    auto elm = *it;
52
    rs.push_back(py::cast<T>(elm));
53
  }
54
  // Promises that we have decayed the list appropriately
55
  return c10::impl::toList<T>(rs);
56
}
57

58
IValue toIValue(py::handle obj, const TypePtr& type, c10::optional<int32_t> N) {
59
  switch (type->kind()) {
60
    case TypeKind::TensorType: {
61
      if (obj.ptr() == Py_None) {
62
        // None gets converted to undefined Tensors
63
        return autograd::Variable();
64
      }
65
      if (THPVariable_Check(obj.ptr())) {
66
        auto var = py::cast<autograd::Variable>(obj);
67
        guardAgainstNamedTensor<autograd::Variable>(var);
68
        return var;
69
      } else {
70
        if (!allow_numbers_as_tensors) {
71
          throw py::cast_error(
72
              c10::str("Unable to cast ", py::str(obj), " to Tensor"));
73
        }
74
        bool save_symint = false;
75
        at::Scalar scalar;
76
        if (PyBool_Check(obj.ptr())) {
77
          scalar = at::Scalar(THPUtils_unpackBool(obj.ptr()));
78
        } else if (THPUtils_checkLong(obj.ptr())) {
79
          scalar = at::Scalar(THPUtils_unpackLong(obj.ptr()));
80
        } else if (PyComplex_Check(obj.ptr())) {
81
          scalar = at::Scalar(THPUtils_unpackComplexDouble(obj.ptr()));
82
        } else if (THPUtils_checkDouble(obj.ptr())) {
83
          scalar = at::Scalar(THPUtils_unpackDouble(obj.ptr()));
84
        } else if (torch::is_symint(py::handle(obj))) {
85
          save_symint = true;
86
          scalar = at::Scalar(7777777);
87
        } else if (torch::is_symfloat(py::handle(obj))) {
88
          save_symint = true;
89
          scalar = at::Scalar(std::numeric_limits<double>::quiet_NaN());
90
        } else if (torch::is_symbool(py::handle(obj))) {
91
          save_symint = true;
92
          scalar = at::Scalar(true);
93
        } else {
94
          throw py::cast_error(
95
              c10::str("Unable to cast ", py::str(obj), " to Tensor"));
96
        }
97
        at::Tensor tensor = at::scalar_to_tensor(scalar);
98
        tensor.unsafeGetTensorImpl()->set_wrapped_number(true);
99

100
        if (save_symint) {
101
          auto py_tensor = py::cast(tensor);
102
          if (PyObject_SetAttrString(
103
                  py_tensor.ptr(), "_wrapped_number", obj.ptr()) < 0) {
104
            throw python_error();
105
          }
106
        }
107

108
        return tensor;
109
      }
110
    }
111
    case TypeKind::StorageType:
112
      return py::cast<at::Storage>(obj);
113
    case TypeKind::FloatType:
114
      if (torch::is_symfloat(py::handle(obj))) {
115
        return py::cast<c10::SymFloat>(obj).guard_float(__FILE__, __LINE__);
116
      }
117
      return py::cast<double>(obj);
118
    case TypeKind::ComplexType: {
119
      auto c_obj = py::cast<std::complex<double>>(obj.ptr());
120
      return static_cast<c10::complex<double>>(c_obj);
121
    }
122
    case TypeKind::IntType:
123
      // TODO: Properly fake this type
124
      if (THPQScheme_Check(obj.ptr())) {
125
        auto qscheme = reinterpret_cast<THPQScheme*>(obj.ptr());
126
        return static_cast<uint8_t>(qscheme->qscheme);
127
      }
128
      // For backwards compatibility
129
      if (THPDtype_Check(obj.ptr())) {
130
        auto dtype = reinterpret_cast<THPDtype*>(obj.ptr());
131
        return static_cast<int64_t>(dtype->scalar_type);
132
      }
133
      if (THPQScheme_Check(obj.ptr())) {
134
        auto qscheme = reinterpret_cast<THPQScheme*>(obj.ptr());
135
        return static_cast<uint8_t>(qscheme->qscheme);
136
      }
137
      if (THPLayout_Check(obj.ptr())) {
138
        auto layout = reinterpret_cast<THPLayout*>(obj.ptr());
139
        return static_cast<int8_t>(layout->layout);
140
      }
141
      if (THPMemoryFormat_Check(obj.ptr())) {
142
        auto memory_format = reinterpret_cast<THPMemoryFormat*>(obj.ptr());
143
        return static_cast<int8_t>(memory_format->memory_format);
144
      }
145
      if (torch::is_symint(py::handle(obj))) {
146
        return py::cast<c10::SymInt>(obj).guard_int(__FILE__, __LINE__);
147
      }
148
      return py::cast<int64_t>(obj);
149
    case TypeKind::LayoutType: {
150
      if (THPLayout_Check(obj.ptr())) {
151
        auto layout = reinterpret_cast<THPLayout*>(obj.ptr());
152
        return static_cast<int8_t>(layout->layout);
153
      }
154
      // For backwards compatibility
155
      return py::cast<int64_t>(obj);
156
    }
157
    case TypeKind::ScalarTypeType: {
158
      if (THPDtype_Check(obj.ptr())) {
159
        auto dtype = reinterpret_cast<THPDtype*>(obj.ptr());
160
        return static_cast<int64_t>(dtype->scalar_type);
161
      }
162
      // For backwards compatibility
163
      return py::cast<int64_t>(obj);
164
    }
165
    case TypeKind::MemoryFormatType: {
166
      if (THPMemoryFormat_Check(obj.ptr())) {
167
        auto memory_format = reinterpret_cast<THPMemoryFormat*>(obj.ptr());
168
        return static_cast<int8_t>(memory_format->memory_format);
169
      }
170
      // For backwards compatibility
171
      return py::cast<int64_t>(obj);
172
    }
173
    case TypeKind::SymIntType:
174
      if (torch::is_symint(obj.ptr())) {
175
        return py::cast<c10::SymInt>(obj);
176
      }
177
      return py::cast<int64_t>(obj);
178
    case TypeKind::SymFloatType:
179
      if (torch::is_symfloat(obj.ptr())) {
180
        return py::cast<c10::SymFloat>(obj);
181
      }
182
      return py::cast<double>(obj);
183
    case TypeKind::SymBoolType:
184
      if (torch::is_symbool(obj.ptr())) {
185
        return py::cast<c10::SymBool>(obj);
186
      }
187
      return py::cast<bool>(obj);
188
    case TypeKind::NoneType:
189
      if (!obj.is_none()) {
190
        throw py::cast_error(
191
            c10::str("Cannot cast ", py::str(obj), " to None"));
192
      }
193
      return {};
194
    case TypeKind::BoolType:
195
      if (torch::is_symbool(obj.ptr())) {
196
        return py::cast<c10::SymBool>(obj).guard_bool(__FILE__, __LINE__);
197
      }
198
      return py::cast<bool>(obj);
199
    case TypeKind::TupleType: {
200
      py::tuple tuple = py::cast<py::tuple>(obj);
201
      size_t tuple_size = tuple.size();
202
      auto tuple_type = type->cast<TupleType>();
203
      const auto& elem_types = tuple_type->elements();
204
      if (elem_types.size() != tuple_size) {
205
        throw py::cast_error(c10::str(
206
            "Object ",
207
            py::str(obj),
208
            " had a different number of elements than type ",
209
            type->repr_str()));
210
      }
211
      std::vector<IValue> values;
212
      values.reserve(tuple_size);
213
      for (const auto i : c10::irange(tuple_size)) {
214
        values.push_back(toIValue(tuple[i], elem_types[i]));
215
      }
216
      return tuple_type->name()
217
          ? c10::ivalue::Tuple::createNamed(std::move(values), tuple_type)
218
          : c10::ivalue::Tuple::create(std::move(values));
219
    }
220
    case TypeKind::UnionType: {
221
      auto actual_type = toTypeInferredIValue(obj);
222
      auto actual_type_ptr = actual_type.type();
223
      auto union_type = type->expect<UnionType>();
224
      if (!actual_type_ptr->isSubtypeOf(union_type)) {
225
        throw py::cast_error(c10::str(
226
            "Expected a member of ",
227
            union_type->annotation_str(),
228
            " but instead found type ",
229
            actual_type.type()->annotation_str()));
230
      }
231
      return actual_type;
232
    }
233
    case TypeKind::StringType:
234
      return ConstantString::create(py::cast<std::string>(obj));
235
    case TypeKind::DeviceObjType: {
236
      if (THPDevice_Check(obj.ptr())) {
237
        auto device = reinterpret_cast<THPDevice*>(obj.ptr());
238
        return device->device;
239
      }
240
      return c10::Device(py::cast<std::string>(obj.ptr()));
241
    }
242
    case TypeKind::StreamObjType: {
243
      auto thp_stream = reinterpret_cast<THPStream*>(obj.ptr());
244
      auto stream = c10::Stream::unpack3(
245
          thp_stream->stream_id,
246
          thp_stream->device_index,
247
          static_cast<c10::DeviceType>(thp_stream->device_type));
248
      return stream;
249
    }
250
    case TypeKind::ListType: {
251
      // If the object is a ScriptList, retrieve the c10::List
252
      // instance inside it.
253
      if (py::isinstance<ScriptList>(obj)) {
254
        return py::cast<ScriptList>(obj).list_;
255
      }
256

257
      // If not (i.e. it is a regular Python list), make a new
258
      // c10::List.
259
      const auto& elem_type = type->expectRef<ListType>().getElementType();
260
      switch (elem_type->kind()) {
261
        // allows single int/float to be broadcasted to a fixed size list
262
        case TypeKind::IntType:
263
          if (!N || !py::isinstance<py::int_>(obj)) {
264
            return IValue(py::cast<std::vector<int64_t>>(obj));
265
          } else {
266
            int64_t value = py::cast<int64_t>(obj);
267
            c10::List<int64_t> repeated;
268
            repeated.reserve(*N);
269
            for (int i = 0; i < *N; ++i) {
270
              repeated.push_back(value);
271
            }
272
            return repeated;
273
          }
274
        case TypeKind::SymIntType: {
275
          bool is_symbolic = false;
276
          for (auto it = obj.begin(); it != obj.end(); it++) {
277
            auto elm = *it;
278
            if (torch::is_symint(elm)) {
279
              is_symbolic = true;
280
              break;
281
            }
282
          }
283
          if (is_symbolic) {
284
            return listToIValue<c10::SymInt>(obj);
285
          } else {
286
            return listToIValue<int64_t>(obj);
287
          }
288
        }
289
        case TypeKind::SymFloatType: {
290
          bool is_symbolic = false;
291
          for (auto it = obj.begin(); it != obj.end(); it++) {
292
            auto elm = *it;
293
            // TODO: what about SymInt conversion to SymFloat?
294
            if (torch::is_symfloat(elm)) {
295
              is_symbolic = true;
296
              break;
297
            }
298
          }
299
          if (is_symbolic) {
300
            return listToIValue<c10::SymFloat>(obj);
301
          } else {
302
            return listToIValue<double>(obj);
303
          }
304
        }
305
        case TypeKind::SymBoolType: {
306
          bool is_symbolic = false;
307
          for (auto it = obj.begin(); it != obj.end(); it++) {
308
            auto elm = *it;
309
            if (torch::is_symbool(elm)) {
310
              is_symbolic = true;
311
              break;
312
            }
313
          }
314
          if (is_symbolic) {
315
            return listToIValue<c10::SymBool>(obj);
316
          } else {
317
            return listToIValue<bool>(obj);
318
          }
319
        }
320
        case TypeKind::FloatType:
321
          if (!N || !py::isinstance<py::float_>(obj)) {
322
            return IValue(py::cast<std::vector<double>>(obj));
323
          } else {
324
            double value = py::cast<double>(obj);
325
            c10::List<double> repeated;
326
            repeated.reserve(*N);
327
            for (int i = 0; i < *N; ++i) {
328
              repeated.push_back(value);
329
            }
330
            return repeated;
331
          }
332
        case TypeKind::BoolType:
333
          return IValue(py::cast<std::vector<bool>>(obj));
334
        case TypeKind::TensorType:
335
          return IValue(py::cast<std::vector<at::Tensor>>(obj));
336
        default:
337
          return createGenericList(obj, elem_type);
338
      }
339
    }
340
    case TypeKind::DictType: {
341
      const auto& dict_type = type->expect<DictType>();
342

343
      // If the object is a ScriptDict, retrieve the c10::Dict
344
      // instance inside it.
345
      try {
346
        auto script_dict = py::cast<ScriptDict>(obj);
347
        return script_dict.dict_;
348
      } catch (py::cast_error& e) {
349
      }
350

351
      // If not (i.e. it is a regular Python dictionary), make a new
352
      // c10::Dict.
353
      return createGenericDict(
354
          py::cast<py::dict>(obj),
355
          dict_type->getKeyType(),
356
          dict_type->getValueType());
357
    }
358
    case TypeKind::OptionalType: {
359
      // check if it's a none obj since optional accepts NoneType
360
      if (obj.is_none()) {
361
        // check if it's a none obj since optional accepts NoneType
362
        // return an IValue() to denote a NoneType
363
        return {};
364
      }
365
      return toIValue(obj, type->expectRef<OptionalType>().getElementType(), N);
366
    }
367
    case TypeKind::ClassType: {
368
      auto classType = type->expect<ClassType>();
369
      auto object = py::cast<py::object>(obj);
370
      if (auto mod = as_module(object)) {
371
        // if obj is already a ScriptModule, just return its ivalue
372
        return mod.value()._ivalue();
373
      }
374

375
      // Check if the obj is a ScriptObject.
376
      if (auto script_obj = as_object(object)) {
377
        return script_obj.value()._ivalue();
378
      }
379

380
      // otherwise is a normal class object, we create a fresh
381
      // ivalue::Object to use from the py object.
382
      // 1. create a bare ivalue
383
      const size_t numAttrs = classType->numAttributes();
384
      auto cu = classType->compilation_unit();
385
      auto userObj = c10::ivalue::Object::create(
386
          c10::StrongTypePtr(cu, classType), numAttrs);
387

388
      // 2. copy all the contained types
389
      for (const auto slot : c10::irange(numAttrs)) {
390
        const auto& attrType = classType->getAttribute(slot);
391
        const auto& attrName = classType->getAttributeName(slot);
392

393
        if (!py::hasattr(obj, attrName.c_str())) {
394
          throw py::cast_error(c10::str(
395
              "Tried to cast object to type ",
396
              type->repr_str(),
397
              " but object",
398
              " was missing attribute ",
399
              attrName));
400
        }
401

402
        try {
403
          const auto& contained = py::getattr(obj, attrName.c_str());
404
          userObj->setSlot(slot, toIValue(contained, attrType));
405
        } catch (std::exception& e) {
406
          throw py::cast_error(c10::str(
407
              "Could not cast attribute '",
408
              attrName,
409
              "' to type ",
410
              attrType->repr_str(),
411
              ": ",
412
              e.what()));
413
        }
414
      }
415
      return userObj;
416
    }
417
    case TypeKind::InterfaceType: {
418
      auto interfaceType = type->expect<InterfaceType>();
419
      // When converting an pyobj to an interface, we check if rhs
420
      // is module or normal torchscript class, get the type and ivalue
421
      // from them correspondingly.
422
      c10::ClassTypePtr classType = nullptr;
423
      IValue res;
424
      if (auto mod = as_module(py::cast<py::object>(obj))) {
425
        classType = mod.value().type();
426
        res = mod.value()._ivalue();
427
      } else if (auto object = as_object(py::cast<py::object>(obj))) {
428
        classType = object.value().type();
429
        res = object.value()._ivalue();
430
      } else {
431
        // We inspect the value to found the compiled TorchScript class
432
        // and then create a ivalue::Object from that class type.
433
        py::str qualified_name = py::module::import("torch._jit_internal")
434
                                     .attr("_qualified_name")(obj.get_type());
435
        auto pyCu = get_python_cu();
436
        classType = pyCu->get_class(c10::QualifiedName(qualified_name));
437
        if (!classType) {
438
          throw std::runtime_error(c10::str(
439
              "Assigning the object ",
440
              py::str(obj),
441
              " to an interface fails because the value is not "
442
              "a TorchScript compatible type, did you forget to",
443
              "turn it into a user defined TorchScript class?"));
444
        }
445
        res = toIValue(obj, classType);
446
      }
447
      // check if the classType conform with the interface or not
448
      std::stringstream why_not;
449
      if (!classType->isSubtypeOfExt(*interfaceType, &why_not)) {
450
        throw py::cast_error(c10::str(
451
            "Object of type ",
452
            classType->repr_str(),
453
            " is not compatible with interface ",
454
            interfaceType->repr_str(),
455
            "\n",
456
            why_not.str()));
457
      }
458
      return res;
459
    }
460
    case TypeKind::NumberType: {
461
      if (THPDtype_Check(obj.ptr())) {
462
        auto dtype = reinterpret_cast<THPDtype*>(obj.ptr());
463
        return static_cast<int64_t>(dtype->scalar_type);
464
      }
465
      if (THPQScheme_Check(obj.ptr())) {
466
        auto qscheme = reinterpret_cast<THPQScheme*>(obj.ptr());
467
        return static_cast<uint8_t>(qscheme->qscheme);
468
      }
469
      if (THPLayout_Check(obj.ptr())) {
470
        auto layout = reinterpret_cast<THPLayout*>(obj.ptr());
471
        return static_cast<int8_t>(layout->layout);
472
      }
473
      if (py::isinstance<py::bool_>(obj)) {
474
        return py::cast<bool>(obj);
475
      } else if (py::isinstance<py::int_>(obj)) {
476
        return py::cast<int64_t>(obj);
477
      } else if (py::isinstance<py::float_>(obj)) {
478
        return py::cast<double>(obj);
479
      } else if (PyComplex_CheckExact(obj.ptr())) {
480
        auto c_obj = py::cast<std::complex<double>>(obj.ptr());
481
        return static_cast<c10::complex<double>>(c_obj);
482
      } else if (torch::is_symint(obj)) {
483
        return py::cast<c10::SymInt>(obj);
484
      } else if (torch::is_symfloat(obj)) {
485
        return py::cast<c10::SymFloat>(obj);
486
      } else if (torch::is_symbool(obj)) {
487
        return py::cast<c10::SymBool>(obj);
488
      } else {
489
        throw py::cast_error(
490
            c10::str("Cannot cast ", py::str(obj), " to ", type->repr_str()));
491
      }
492
    }
493
    case TypeKind::RRefType: {
494
#ifdef USE_RPC
495
      return obj.cast<torch::distributed::rpc::PyRRef>().toIValue();
496
#else
497
      AT_ERROR("RRef is only supported with the distributed package");
498
#endif
499
    } break;
500
    case TypeKind::PyObjectType: {
501
      return c10::ivalue::ConcretePyObjectHolder::create(obj);
502
    }
503
    case TypeKind::CapsuleType: {
504
      return IValue::make_capsule(py::cast<c10::Capsule>(obj).obj_ptr);
505
    }
506
    case TypeKind::FutureType: {
507
      return obj.cast<std::shared_ptr<PythonFutureWrapper>>()->fut;
508
    }
509
    case TypeKind::AwaitType: {
510
      return obj.cast<std::shared_ptr<PythonAwaitWrapper>>()->aw_;
511
    }
512
    case TypeKind::AnyType:
513
      return toTypeInferredIValue(obj);
514
    case TypeKind::QSchemeType: {
515
      if (py::isinstance<py::int_>(obj)) {
516
        return static_cast<at::QScheme>(py::cast<int64_t>(obj));
517
      }
518
      throw py::cast_error(
519
          c10::str("Cannot cast ", py::str(obj), " to ", type->repr_str()));
520
    }
521
    case TypeKind::GeneratorType:
522
      return py::cast<at::Generator>(obj);
523
    case TypeKind::DynamicType:
524
    case TypeKind::FunctionType:
525
    case TypeKind::QuantizerType:
526
    case TypeKind::VarType:
527
    case TypeKind::AnyListType:
528
    case TypeKind::AnyTupleType:
529
    case TypeKind::AnyClassType:
530
    case TypeKind::AnyEnumType:
531
      break;
532
    case TypeKind::EnumType:
533
      EnumTypePtr enum_type = type->expect<EnumType>();
534
      py::object py_obj = py::reinterpret_borrow<py::object>(obj);
535
      std::string name = py::cast<std::string>(obj.attr("name"));
536
      IValue value = toIValue(obj.attr("value"), enum_type->getValueType(), {});
537
      auto enum_holder =
538
          c10::make_intrusive<c10::ivalue::EnumHolder>(enum_type, name, value);
539
      return IValue(enum_holder);
540
  }
541
  throw py::cast_error(c10::str(
542
      "toIValue() cannot handle converting to type: ", type->repr_str()));
543
}
544

545
py::object toPyObject(IValue ivalue) {
546
  if (ivalue.isNone()) {
547
    return py::none();
548
  } else if (ivalue.isTensor()) {
549
    auto tensor = std::move(ivalue).toTensor();
550
    if (tensor.unsafeGetTensorImpl()->is_wrapped_number()) {
551
      TORCH_INTERNAL_ASSERT(tensor.device().is_cpu());
552
      auto py_tensor = py::cast(tensor);
553
      if (PyObject_HasAttrString(py_tensor.ptr(), "_wrapped_number")) {
554
        return py_tensor.attr("_wrapped_number");
555
      }
556
      auto scalar_type = tensor.scalar_type();
557
      switch (scalar_type) {
558
        case at::ScalarType::Bool:
559
          return py::cast(*tensor.const_data_ptr<bool>());
560
        case at::ScalarType::Long:
561
          return py::cast(*tensor.const_data_ptr<int64_t>());
562
        case at::ScalarType::Double:
563
          return py::cast(*tensor.const_data_ptr<double>());
564
        case at::ScalarType::ComplexDouble:
565
          // TODO: https://github.com/pytorch/pytorch/issues/77134
566
          return py::cast(static_cast<std::complex<double>>(
567
              *tensor.const_data_ptr<c10::complex<double>>()));
568
        default:
569
          TORCH_CHECK(
570
              false,
571
              "Missing cases in 'toPyObject' wrapped number handling! Can't convert ",
572
              scalar_type,
573
              " to a Python object");
574
      }
575
    } else {
576
      guardAgainstNamedTensor<at::Tensor>(tensor);
577
      return py::cast(autograd::Variable(std::move(tensor)));
578
    }
579
  } else if (ivalue.isStorage()) {
580
    return py::cast(std::move(ivalue).toStorage());
581
  } else if (ivalue.isGenerator()) {
582
    return py::cast(std::move(ivalue).toGenerator());
583
  } else if (ivalue.isDouble()) {
584
    return py::cast(std::move(ivalue).toDouble());
585
  } else if (ivalue.isComplexDouble()) {
586
    return py::cast(
587
        static_cast<std::complex<double>>(std::move(ivalue).toComplexDouble()));
588
  } else if (ivalue.isInt()) {
589
    return py::cast(std::move(ivalue).toInt());
590
  } else if (ivalue.isBool()) {
591
    return py::cast(std::move(ivalue).toBool());
592
  } else if (ivalue.isString()) {
593
    if (getUTF8DecodingIgnore()) {
594
      std::string s = std::move(ivalue).toStringRef();
595
      PyObject* pyObj = PyUnicode_DecodeUTF8(s.data(), s.length(), "ignore");
596
      return py::reinterpret_steal<py::object>(pyObj);
597
    } else {
598
      return py::cast(std::move(ivalue).toStringRef());
599
    }
600
  } else if (ivalue.isList()) {
601
    auto list = std::move(ivalue).toList();
602
    py::list t{list.size()};
603
    for (const auto i : c10::irange(list.size())) {
604
      t[i] = toPyObject(IValue{list.get(i)});
605
    }
606
    return std::move(t);
607
  } else if (ivalue.isTuple()) {
608
    auto tuple = std::move(ivalue).toTuple();
609
    const auto& elements = tuple->elements();
610

611
    py::tuple t{elements.size()};
612
    for (const auto i : c10::irange(elements.size())) {
613
      t[i] = toPyObject(IValue{elements.at(i)});
614
    }
615

616
    // If we have a NamedTuple
617
    if (tuple->type() && tuple->type()->schema() &&
618
        !tuple->type()->schema()->name().empty()) {
619
      auto unqualName = tuple->type()->name()->name();
620

621
      std::vector<Argument> tuple_args = tuple->type()->schema()->arguments();
622

623
      std::vector<pybind11::object> defaults;
624
      auto it = std::find_if(
625
          tuple_args.begin(), tuple_args.end(), [](const Argument& arg) {
626
            return arg.default_value().has_value();
627
          });
628
      std::transform(
629
          it,
630
          tuple_args.end(),
631
          std::back_inserter(defaults),
632
          [](const Argument& arg) { return toPyObject(*arg.default_value()); });
633

634
      std::vector<std::string> fieldNames =
635
          fmap(tuple_args, [](const Argument& arg) { return arg.name(); });
636

637
      return py::module::import("torch._jit_internal")
638
          .attr("_create_named_tuple")(
639
              t, unqualName, fieldNames, py::make_tuple(defaults));
640
    } else {
641
      return std::move(t);
642
    }
643
  } else if (ivalue.isDevice()) {
644
    return py::cast<py::object>(THPDevice_New(std::move(ivalue).toDevice()));
645
  } else if (ivalue.isStream()) {
646
    return py::cast(std::move(ivalue).toStream());
647
  } else if (ivalue.isGenericDict()) {
648
    auto dict = std::move(ivalue).toGenericDict();
649
    py::dict py_dict;
650
    for (auto& pair : dict) {
651
      py_dict[toPyObject(IValue{pair.key()})] =
652
          toPyObject(IValue{pair.value()});
653
    }
654
    return std::move(py_dict);
655
  } else if (ivalue.isRRef()) {
656
#ifdef USE_RPC
657
    auto RRefPtr =
658
        c10::dynamic_intrusive_pointer_cast<torch::distributed::rpc::RRef>(
659
            std::move(ivalue).toRRef());
660
    return py::cast(torch::distributed::rpc::PyRRef(RRefPtr));
661
#else
662
    AT_ERROR("RRef is only supported with the distributed package");
663
#endif
664
  } else if (ivalue.isObject()) {
665
    const auto obj = std::move(ivalue).toObject();
666
    if (obj->type()->is_module()) {
667
      return py::cast(Module(obj));
668
    }
669

670
    auto pyCu = get_python_cu();
671
    if (obj->name().find("__torch__.torch.classes") == 0) {
672
      return py::cast(Object(obj));
673
    }
674
    const auto classType = pyCu->get_class(c10::QualifiedName(obj->name()));
675
    AT_ASSERT(classType);
676
    auto pyClass = getScriptedClassOrError(obj->type());
677
    auto pyObj = pyClass.attr("__new__")(pyClass);
678

679
    const auto numAttrs = classType->numAttributes();
680

681
    for (const auto slot : c10::irange(numAttrs)) {
682
      const auto& attrName = classType->getAttributeName(slot);
683
      IValue v = obj->getSlot(slot);
684
      py::setattr(pyObj, attrName.c_str(), toPyObject(std::move(v)));
685
    }
686
    return pyObj;
687
  } else if (ivalue.isPyObject()) {
688
    // return borrowed reference to ensure it correctly incref the underlying
689
    // PyObject
690
    return py::reinterpret_borrow<py::object>(ivalue.toPyObject());
691
  } else if (ivalue.isCapsule()) {
692
    return py::cast(c10::Capsule(ivalue.toCapsule()));
693
  } else if (ivalue.isFuture()) {
694
    return py::cast(std::make_shared<PythonFutureWrapper>(ivalue.toFuture()));
695
  } else if (ivalue.isAwait()) {
696
    return py::cast(std::make_shared<PythonAwaitWrapper>(ivalue.toAwait()));
697
  } else if (ivalue.isEnum()) {
698
    auto enum_holder = ivalue.toEnumHolder();
699
    auto py_class = getScriptedClassOrError(enum_holder->type());
700
    return py_class.attr(enum_holder->name().c_str());
701
  } else if (ivalue.isRRef()) {
702
#ifdef USE_RPC
703
    return py::cast(torch::distributed::rpc::PyRRef(
704
        c10::static_intrusive_pointer_cast<distributed::rpc::RRef>(
705
            ivalue.toRRef())));
706
#else
707
    TORCH_CHECK(false, "RRef is only supported with the distributed package");
708
#endif
709
  } else if (ivalue.isSymInt()) {
710
    return py::cast(std::move(ivalue).toSymInt());
711
  } else if (ivalue.isSymFloat()) {
712
    return py::cast(std::move(ivalue).toSymFloat());
713
  } else if (ivalue.isSymBool()) {
714
    return py::cast(std::move(ivalue).toSymBool());
715
  } else {
716
    AT_ERROR(
717
        "Missing cases in 'toPyObject'! Can't convert ",
718
        ivalue.tagKind(),
719
        " to a Python object");
720
  }
721
}
722

723
std::pair<std::shared_ptr<Operator>, Stack> getOpWithStack(
724
    const std::vector<std::shared_ptr<Operator>>& operations,
725
    py::args args,
726
    const py::kwargs& kwargs) {
727
  Stack stack;
728
  if (operations.size() == 1) {
729
    std::shared_ptr<Operator> op = operations.at(0);
730
    // Create a stack full of the arguments and keyword arguments.
731
    stack = createStackForSchema(
732
        op->schema(), std::move(args), kwargs, c10::nullopt);
733

734
    return std::make_pair(std::move(op), std::move(stack));
735
  } else {
736
    std::vector<schema_match_error> errors;
737
    std::shared_ptr<Operator> found_op = nullptr;
738
    for (const auto& op : operations) {
739
      try {
740
        stack = createStackForSchema(op->schema(), args, kwargs, c10::nullopt);
741
        found_op = op;
742
        break;
743
      } catch (schema_match_error& error) {
744
        errors.push_back(std::move(error));
745
      }
746
    }
747
    if (!found_op) {
748
      std::stringstream ss;
749
      ss << "Overloaded torch operator invoked from Python failed to match any schema:\n";
750
      for (const auto& err : errors) {
751
        ss << err.what() << "\n\n";
752
      }
753
      throw std::runtime_error(ss.str());
754
    }
755

756
    return std::make_pair(std::move(found_op), std::move(stack));
757
  }
758
}
759

760
py::object invokeOperatorFromPython(
761
    const std::vector<std::shared_ptr<Operator>>& operations,
762
    py::args args,
763
    const py::kwargs& kwargs,
764
    c10::optional<c10::DispatchKey> dk) {
765
  auto [found_op, stack] = getOpWithStack(operations, args, kwargs);
766
  {
767
    pybind11::gil_scoped_release no_gil_guard;
768
    if (dk) {
769
      found_op->getOperationForDispatchKey (*dk)(stack);
770
    } else {
771
      found_op->getOperation()(stack);
772
    }
773
  }
774

775
  return createPyObjectForStack(std::move(stack));
776
}
777

778
py::object _get_operation_for_overload_or_packet(
779
    const std::vector<std::shared_ptr<Operator>>& operations,
780
    Symbol symbol,
781
    py::args args,
782
    const py::kwargs& kwargs,
783
    bool is_overload,
784
    c10::optional<c10::DispatchKey> dk) {
785
  std::vector<PyObject*> overloaded_args;
786
  size_t total_arg_num = args.size() + kwargs.size();
787
  for (const auto i : c10::irange(args.size())) {
788
    is_tensor_and_append_overloaded(args[i].ptr(), &overloaded_args);
789
    is_tensor_list_and_append_overloaded(
790
        args[i].ptr(),
791
        &overloaded_args,
792
        static_cast<int>(total_arg_num),
793
        false /* throw_error */);
794
  }
795
  // NB: for kwargs, we cannot guarantee the order of appending
796
  // is the same as the argument order in operator's schema.
797
  // This is suboptimal, but should be fine. Later when we have
798
  // better schema matching and argument parsing, we could
799
  // match the operator in `operations` first, then the order will
800
  // be guaranteed.
801
  for (auto item : kwargs) {
802
    is_tensor_and_append_overloaded(item.second.ptr(), &overloaded_args);
803
    is_tensor_list_and_append_overloaded(
804
        item.second.ptr(),
805
        &overloaded_args,
806
        total_arg_num,
807
        false /* throw_error */);
808
  }
809
  if (!overloaded_args.empty() || at::impl::torch_function_mode_enabled()) {
810
    py::object ret;
811
    std::string ns = symbol.ns().toUnqualString();
812
    std::string method_name = symbol.toUnqualString();
813
    auto self_func = py::module::import("torch")
814
                         .attr("ops")
815
                         .attr(ns.c_str())
816
                         .attr(method_name.c_str());
817
    if (is_overload) {
818
      auto overload_name = operations[0]->schema().overload_name();
819
      if (overload_name.empty()) {
820
        self_func = self_func.attr("default");
821
      } else {
822
        self_func = self_func.attr(overload_name.c_str());
823
      }
824
    }
825
    std::string module_name("torch.ops");
826
    module_name.append(ns);
827
    return pybind11::reinterpret_steal<py::object>(
828
        handle_torch_function_no_python_arg_parser(
829
            overloaded_args,
830
            args.ptr(),
831
            kwargs.ptr(),
832
            method_name.c_str(),
833
            self_func.ptr(),
834
            module_name.c_str()));
835
  }
836
  return invokeOperatorFromPython(operations, args, kwargs, dk);
837
}
838

839
} // namespace torch::jit
840

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.