1
#include <pybind11/detail/common.h>
2
#include <pybind11/pytypes.h>
3
#include <torch/csrc/jit/api/object.h>
4
#include <torch/csrc/jit/python/script_init.h>
5
#include <torch/csrc/utils/pybind.h>
7
#include <caffe2/serialize/versions.h>
8
#include <torch/csrc/Device.h>
9
#include <torch/csrc/DynamicTypes.h>
10
#include <torch/csrc/jit/api/module.h>
11
#include <torch/csrc/jit/frontend/ir_emitter.h>
12
#include <torch/csrc/jit/frontend/sugared_value.h>
13
#include <torch/csrc/jit/mobile/code.h>
14
#include <torch/csrc/jit/mobile/compatibility/backport.h>
15
#include <torch/csrc/jit/mobile/compatibility/model_compatibility.h>
16
#include <torch/csrc/jit/mobile/file_format.h>
17
#include <torch/csrc/jit/mobile/flatbuffer_loader.h>
18
#include <torch/csrc/jit/mobile/import.h>
19
#include <torch/csrc/jit/mobile/module.h>
20
#include <torch/csrc/jit/mobile/quantization.h>
21
#include <torch/csrc/jit/operator_upgraders/upgraders.h>
22
#include <torch/csrc/jit/operator_upgraders/upgraders_entry.h>
23
#include <torch/csrc/jit/operator_upgraders/utils.h>
24
#include <torch/csrc/jit/operator_upgraders/version_map.h>
25
#include <torch/csrc/jit/python/module_python.h>
26
#include <torch/csrc/jit/python/python_ivalue.h>
27
#include <torch/csrc/jit/python/python_sugared_value.h>
28
#include <torch/csrc/jit/serialization/export_bytecode.h>
29
#include <torch/csrc/jit/serialization/flatbuffer_serializer.h>
30
#include <torch/csrc/jit/serialization/import.h>
31
#include <torch/csrc/jit/testing/file_check.h>
33
#include <c10/util/Exception.h>
34
#include <c10/util/intrusive_ptr.h>
35
#include <c10/util/irange.h>
36
#include <torch/csrc/jit/frontend/parser.h>
37
#include <torch/csrc/jit/frontend/tracer.h>
38
#include <torch/csrc/jit/ir/constants.h>
39
#include <torch/csrc/jit/ir/graph_utils.h>
40
#include <torch/csrc/jit/ir/irparser.h>
41
#include <torch/csrc/jit/passes/inliner.h>
42
#include <torch/csrc/jit/passes/shape_analysis.h>
43
#include <torch/csrc/jit/python/pybind_utils.h>
44
#include <torch/csrc/jit/python/python_dict.h>
45
#include <torch/csrc/jit/python/python_list.h>
46
#include <torch/csrc/jit/python/python_tracer.h>
47
#include <torch/csrc/jit/runtime/graph_executor.h>
48
#include <torch/csrc/jit/runtime/instruction.h>
49
#include <torch/csrc/jit/runtime/interpreter.h>
50
#include <torch/csrc/jit/runtime/logging.h>
51
#include <torch/csrc/jit/serialization/export_bytecode.h>
52
#include <torch/csrc/jit/serialization/import_source.h>
53
#include <torch/csrc/jit/serialization/pickle.h>
54
#include <torch/csrc/jit/serialization/python_print.h>
55
#include <torch/csrc/jit/testing/hooks_for_testing.h>
57
#include <torch/csrc/api/include/torch/ordered_dict.h>
60
#include <ATen/core/function_schema.h>
61
#include <ATen/core/ivalue.h>
62
#include <ATen/core/qualified_name.h>
64
#include <pybind11/functional.h>
65
#include <pybind11/pybind11.h>
66
#include <pybind11/stl.h>
67
#include <pybind11/stl_bind.h>
68
#include <torch/csrc/jit/mobile/train/export_data.h>
81
using ::c10::FunctionSchema;
83
using FunctionDefaults = std::unordered_map<std::string, py::object>;
84
using ClassMethodDefaults = std::unordered_map<std::string, FunctionDefaults>;
88
// A resolver that will inspect the outer Python scope to find `name`.
89
struct PythonResolver : public Resolver {
90
explicit PythonResolver(ResolutionCallback rcb) : rcb_(std::move(rcb)) {}
93
* While compiling classes, the class type we're compiling will not be
94
* available in Python, since we haven't fowner_ defining the class yet. So
95
* in order to make the class type available to its own methods, we need to
96
* explicitly resolve it.
98
* @param rcb Python function to resolve a name to its Python object in the
100
* @param classname The unqualified classname of the class currently being
102
* @param classType The class's type.
104
explicit PythonResolver(
105
ResolutionCallback rcb,
106
std::string classname,
107
ClassTypePtr classType)
108
: rcb_(std::move(rcb)),
109
classname_(std::move(classname)),
110
classType_(std::move(classType)) {}
112
std::shared_ptr<SugaredValue> resolveValue(
113
const std::string& name,
115
const SourceRange& loc) override {
116
pybind11::gil_scoped_acquire ag;
117
py::object obj = rcb_(name);
121
return toSugaredValue(obj, m, loc);
124
static bool isNamedTupleClass(py::object obj) {
125
auto tuple_type = reinterpret_cast<PyObject*>(&PyTuple_Type);
126
return PyObject_IsSubclass(obj.ptr(), tuple_type) &&
127
py::hasattr(obj, "_fields");
130
TypePtr resolveTypeFromObject(const py::object& obj, const SourceRange& loc) {
131
if (py::isinstance<ScriptClass>(obj)) {
132
auto script_class = py::cast<ScriptClass>(obj);
133
return script_class.class_type_.type_;
136
py::bool_ isClass = py::module::import("inspect").attr("isclass")(obj);
137
if (!py::cast<bool>(isClass)) {
141
if (isNamedTupleClass(obj)) {
142
return registerNamedTuple(obj, loc, rcb_);
145
auto qualifiedName = c10::QualifiedName(
146
py::cast<std::string>(py::module::import("torch._jit_internal")
147
.attr("_qualified_name")(obj)));
149
return get_python_cu()->get_type(qualifiedName);
152
TypePtr resolveType(const std::string& name, const SourceRange& loc)
154
if (classType_ && name == classname_) {
157
pybind11::gil_scoped_acquire ag;
158
py::object obj = rcb_(name);
163
auto annotation_type =
164
py::module::import("torch.jit.annotations")
165
.attr("try_ann_to_type")(obj, loc, py::cpp_function(rcb_));
166
if (!annotation_type.is_none()) {
167
return py::cast<TypePtr>(annotation_type);
169
return resolveTypeFromObject(obj, loc);
173
ResolutionCallback rcb_;
174
std::string classname_;
175
ClassTypePtr classType_;
178
std::shared_ptr<PythonResolver> pythonResolver(const ResolutionCallback& rcb) {
179
return std::make_shared<PythonResolver>(rcb);
181
std::shared_ptr<PythonResolver> pythonResolver(
182
const ResolutionCallback& rcb,
183
std::string classname,
184
ClassTypePtr classType) {
185
return std::make_shared<PythonResolver>(
186
rcb, std::move(classname), std::move(classType));
189
void checkOverloadDecl(const Decl& new_decl, const Decl& old_decl) {
190
const auto& new_params = new_decl.params();
191
const auto& old_params = old_decl.params();
193
// TODO. same number of parameters not strictly necessary.
194
TORCH_INTERNAL_ASSERT(
195
new_params.size() == old_params.size(),
196
"Overload must have same number of parameters\n",
199
for (const auto i : c10::irange(new_decl.params().size())) {
200
TORCH_INTERNAL_ASSERT(
201
new_params[i].ident().name() == old_params[i].ident().name(),
202
"Overload parameters must have the same names\n",
203
new_params[i].ident(),
204
old_params[i].ident());
208
c10::optional<IValue> tryCalculateDefaultParam(
210
const py::object& def_value) {
212
auto list_type = arg.type()->cast<ListType>();
214
if (n && *n > 0 && list_type) {
215
// BroadcastingList, allow default values T for arg types List[T]
216
return toIValue(def_value, list_type->getElementType());
218
return toIValue(def_value, arg.type());
225
// An overloaded function may have a default that does not subtype all overloads
229
FunctionDefaults calcOverloadedFunctionDefaults(
230
const FunctionSchema& schema,
231
const FunctionDefaults& defaults) {
232
FunctionDefaults updated_defaults;
233
for (const auto& arg : schema.arguments()) {
234
const std::string& arg_name = arg.name();
235
auto value = defaults.find(arg_name);
236
if (value == defaults.end()) {
239
auto maybe_ivalue = tryCalculateDefaultParam(arg, value->second);
241
updated_defaults[arg_name] = value->second;
244
return updated_defaults;
249
bool checkMutableFunctionDefault(const py::object& def_arg) {
250
if (py::isinstance<py::list>(def_arg) || py::isinstance<py::dict>(def_arg)) {
253
if (py::isinstance<py::tuple>(def_arg)) {
254
auto pytuple = def_arg.cast<py::tuple>();
255
for (py::handle t : pytuple) {
256
py::object obj = py::reinterpret_borrow<py::object>(t);
257
if (checkMutableFunctionDefault(obj)) {
265
void checkMutableFunctionDefault(
266
const SourceRange& range,
268
const py::object& def_arg) {
269
if (checkMutableFunctionDefault(def_arg) || arg.type()->cast<ClassType>()) {
270
throw ErrorReport(range)
271
<< "Mutable default parameters are not supported because Python binds them to the function"
272
<< " and they persist across function calls.\n As a workaround, make the default None and instantiate"
273
<< " the default parameter within the body of the function. Found "
274
<< def_arg.get_type() << " on parameter " << arg.name();
278
FunctionSchema getSchemaWithNameAndDefaults(
279
const SourceRange& range,
280
const FunctionSchema& schema,
281
const at::optional<std::string>& new_name,
282
const FunctionDefaults& default_args) {
283
std::vector<Argument> new_args;
284
for (auto& arg : schema.arguments()) {
285
auto it = default_args.find(arg.name());
286
if (it != default_args.end()) {
287
checkMutableFunctionDefault(range, arg, it->second);
288
c10::optional<IValue> value = tryCalculateDefaultParam(arg, it->second);
290
ErrorReport error(range);
291
error << "Expected a default value of type " << arg.type()->repr_str()
292
<< " on parameter \"" << arg.name() << "\".";
293
if (arg.is_inferred_type()) {
294
error << "Because \"" << arg.name()
295
<< "\" was not annotated with an explicit type "
296
<< "it is assumed to be type 'Tensor'.";
300
new_args.emplace_back(
301
arg.name(), arg.type(), arg.N(), *value, arg.kwarg_only());
303
new_args.push_back(arg);
306
return FunctionSchema(
307
new_name.value_or(schema.name()),
308
schema.overload_name(),
315
static Decl mergeDefaultsAndExtraParametersToOverloadDecl(
316
const Decl& overload_decl,
317
const Decl& impl_decl,
318
const FunctionDefaults& defaults) {
319
std::vector<Param> adjusted_params;
320
const auto& overload_params = overload_decl.params();
321
const auto& impl_params = impl_decl.params();
323
// following PEP specification that the following should work:
325
// def mouse_event(x1: int, y1: int) -> ClickEvent: ...
327
// def mouse_event(x1: int, y1: int, x2: Optional[int] = None, y2:
328
// Optional[int] = None)
330
overload_params.size() <= impl_params.size(),
331
"Overload should not have more parameters than implementation function",
332
overload_decl.range(),
335
for (const auto i : c10::irange(overload_params.size())) {
336
auto overload_name = overload_params[i].ident().name();
337
auto impl_name = impl_params[i].ident().name();
338
if (overload_name != impl_name) {
339
throw ErrorReport(overload_decl.range())
340
<< "Overload parameters must have the same names. "
341
<< "Found " << overload_name << " and " << impl_name
342
<< " on argument " << i;
344
adjusted_params.push_back(overload_params[i]);
346
for (size_t i = overload_params.size(); i < impl_params.size(); ++i) {
347
if (!defaults.count(impl_params[i].ident().name())) {
348
throw ErrorReport(impl_decl.range())
349
<< "Expected to find default parameter on argument"
350
<< impl_params[i].ident().name()
351
<< " because it is not defined on the overloaded declaration";
353
if (!impl_params[i].type().present()) {
354
throw ErrorReport(impl_decl.range())
355
<< "Parameters not specified on the overloaded declaration must have a type annotation in the implementation function."
356
<< " Did not find type for param " << impl_params[i].ident().name();
358
adjusted_params.push_back(impl_params[i]);
361
overload_decl.range(),
362
List<Param>::create(overload_decl.range(), adjusted_params),
363
overload_decl.return_type());
366
static StrongFunctionPtr script_compile_overloaded_function(
367
const c10::QualifiedName& name,
368
const Decl& overload_decl,
369
const Def& implementation_def,
370
const ResolutionCallback& rcb,
371
const FunctionDefaults& implementation_defaults,
372
const py::object& signature) {
373
if (signature.is_none()) {
374
throw ErrorReport(overload_decl.range())
375
<< "Must explicitly add type annotations to overloaded functions";
378
auto adjusted_decl = mergeDefaultsAndExtraParametersToOverloadDecl(
379
overload_decl, implementation_def.decl(), implementation_defaults);
380
auto new_def = implementation_def.withDecl(adjusted_decl);
381
auto cu = get_python_cu();
382
auto defined_functions = cu->define(
383
QualifiedName(name.prefix()),
385
/*propResolvers=*/{},
387
{pythonResolver(rcb)},
390
TORCH_INTERNAL_ASSERT(defined_functions.size() == 1);
391
auto& defined = defined_functions[0];
392
FunctionDefaults updated_defaults = calcOverloadedFunctionDefaults(
393
defined->getSchema(), implementation_defaults);
394
defined->setSchema(getSchemaWithNameAndDefaults(
396
defined->getSchema(),
397
new_def.name().name(),
399
StrongFunctionPtr ret(std::move(cu), defined);
400
didFinishEmitFunction(ret);
404
static StrongFunctionPtr script_compile_function(
405
const c10::QualifiedName& name,
407
const FunctionDefaults& defaults,
408
const ResolutionCallback& rcb) {
409
auto cu = get_python_cu();
410
auto defined_functions = cu->define(
411
QualifiedName(name.prefix()),
413
/*propResolvers=*/{},
415
{pythonResolver(rcb)},
418
TORCH_INTERNAL_ASSERT(defined_functions.size() == 1);
419
auto& defined = defined_functions[0];
420
defined->setSchema(getSchemaWithNameAndDefaults(
421
def.range(), defined->getSchema(), def.name().name(), defaults));
422
StrongFunctionPtr ret(std::move(cu), defined);
423
didFinishEmitFunction(ret);
427
struct VISIBILITY_HIDDEN ModuleSelf : public Self {
428
ModuleSelf(std::shared_ptr<ConcreteModuleType> concreteType)
429
: Self(), concreteType_(std::move(concreteType)) {}
431
std::shared_ptr<SugaredValue> makeSugared(Value* v) const override {
432
v->setType(getClassType());
433
return std::make_shared<ModuleValue>(v, concreteType_);
436
ClassTypePtr getClassType() const override {
437
return concreteType_->getJitType()->expect<ClassType>();
441
std::shared_ptr<ConcreteModuleType> concreteType_;
444
static std::shared_ptr<Graph> _propagate_shapes(
446
std::vector<at::Tensor> inputs,
447
bool with_grad = false) {
448
Stack stack(inputs.begin(), inputs.end());
449
auto retval = graph.copy();
450
setInputTensorTypes(*retval, stack, /*complete=*/false);
451
PropagateInputShapes(retval);
455
static std::shared_ptr<Graph> _propagate_and_assign_input_shapes(
457
const std::vector<at::Tensor>& inputs,
458
const std::vector<int>& param_count_list,
459
bool with_grad = false,
460
bool propagate = true) {
461
auto retval = graph.copy();
463
*retval, fmap<IValue>(inputs), /*complete=*/true, param_count_list);
465
PropagateInputShapes(retval);
470
void addFunctionToModule(Module& module, const StrongFunctionPtr& func) {
471
// Make a graph with a fake self argument
472
auto graph = toGraphFunction(*func.function_).graph()->copy();
473
auto v = graph->insertInput(0, "self");
474
v->setType(module._ivalue()->type());
475
const auto name = QualifiedName(*module.type()->name(), "forward");
477
module._ivalue()->compilation_unit()->create_function(name, graph);
478
module.type()->addMethod(method);
481
// this is used in our test suite to check that we correctly preserved type tags
482
bool ivalue_tags_match(const Module& lhs, const Module& rhs) {
487
std::unordered_set<const void*> visited;
488
std::vector<Work> work = {{lhs._ivalue(), rhs._ivalue()}};
489
while (!work.empty()) {
490
Work item = work.back();
492
if (item.a.isPtrType()) {
493
// uncomment to debug type matching errors
494
// std::cout << "MATCHING " << /*item.a <<*/ "(" << *item.a.type() << ") "
495
// << item.a.internalToPointer() << " " << /*item.b <<*/ " ("
496
// << *item.b.type() << ") " << item.b.internalToPointer() <<
499
if (visited.count(item.a.internalToPointer())) {
502
visited.emplace(item.a.internalToPointer());
504
if (!unshapedType(item.b.type())
505
->isSubtypeOf(unshapedType(item.b.type()))) {
506
// Since named types are saved and loaded in the test suite, we cannot
507
// expect them to be equal. We should still check their slots however.
508
if (!item.a.type()->cast<c10::NamedType>()) {
512
// check tags for objects that contain subobjects
513
if (item.a.isObject()) {
514
auto ao = item.a.toObject();
515
auto bo = item.b.toObject();
516
for (size_t i = 0; i < ao->slots().size(); ++i) {
517
work.emplace_back(Work{ao->slots().at(i), bo->slots().at(i)});
519
} else if (item.a.isTuple()) {
520
auto at = item.a.toTuple();
521
auto bt = item.b.toTuple();
522
for (size_t i = 0; i < at->elements().size(); ++i) {
523
work.emplace_back(Work{at->elements().at(i), bt->elements().at(i)});
525
} else if (item.a.isList()) {
526
auto al = item.a.toList();
527
auto bl = item.b.toList();
528
for (const auto i : c10::irange(al.size())) {
529
work.emplace_back(Work{al.get(i), bl.get(i)});
531
} else if (item.a.isGenericDict()) {
532
auto ad = item.a.toGenericDict();
533
auto bd = item.b.toGenericDict();
534
for (auto& item : ad) {
535
// Dictionaory keys cannot contain List/Dicts that require tags
536
// so we do not have to check them.
537
// Furthermore without ordered dicts it is expensive to find the
539
work.emplace_back(Work{item.value(), bd.at(item.key())});
541
} else if (item.a.isFuture()) {
542
auto af = item.a.toFuture();
543
auto bf = item.b.toFuture();
546
work.emplace_back(Work{af->value(), bf->value()});
553
// helper used to implement ._parameters, ._buffers, ._modules dicts
554
// inside of script nn.Module
555
template <typename Policy>
556
struct slot_dict_impl {
557
slot_dict_impl(ModulePtr module) : module_(std::move(module)) {}
558
bool contains(const std::string& name) const {
559
if (auto slot = module_->type()->findAttributeSlot(name)) {
560
if (Policy::valid(module_->type(), *slot, module_->getSlot(*slot))) {
567
std::vector<std::pair<std::string, py::object>> items() const {
568
std::vector<std::pair<std::string, py::object>> result;
569
for (size_t i = 0, N = module_->type()->numAttributes(); i < N; ++i) {
570
if (Policy::valid(module_->type(), i, module_->getSlot(i))) {
572
module_->type()->getAttributeName(i),
573
toPyObject(module_->getSlot(i)));
579
void setattr(const std::string& name, py::object value) {
580
const TypePtr& type = module_->type()->getAttribute(name);
581
Module(module_).setattr(name, toIValue(std::move(value), type));
584
py::object getattr(const std::string& name) {
585
return toPyObject(Module(module_).attr(name));
588
static void bind(const py::module& m, const char* name) {
589
py::class_<slot_dict_impl<Policy>>(m, name)
591
[](Module& m) { return slot_dict_impl<Policy>(m._ivalue()); }))
592
.def("contains", &slot_dict_impl<Policy>::contains)
593
.def("items", &slot_dict_impl<Policy>::items)
594
.def("setattr", &slot_dict_impl<Policy>::setattr)
595
.def("getattr", &slot_dict_impl<Policy>::getattr);
603
py::list debugMakeList(const T& list) {
605
for (const auto& elem : list) {
606
result.append(py::cast(elem));
611
py::list debugMakeNamedList(const T& list) {
613
for (auto elem : list) {
614
result.append(py::cast(std::make_pair(elem.name, elem.value)));
619
py::set debugMakeSet(const T& list) {
621
for (const auto& elem : list) {
622
result.add(py::cast(elem));
627
static py::dict _jit_debug_module_iterators(Module& module) {
629
result["children"] = debugMakeList(module.children());
630
result["named_children"] = debugMakeNamedList(module.named_children());
631
result["modules"] = debugMakeList(module.modules());
632
result["named_modules"] = debugMakeNamedList(module.named_modules());
634
result["parameters"] = debugMakeList(module.parameters(false));
635
result["named_parameters"] =
636
debugMakeNamedList(module.named_parameters(false));
637
result["parameters_r"] = debugMakeList(module.parameters(true));
638
result["named_parameters_r"] =
639
debugMakeNamedList(module.named_parameters(true));
641
result["buffers"] = debugMakeList(module.buffers(false));
642
result["named_buffers"] = debugMakeNamedList(module.named_buffers(false));
643
result["buffers_r"] = debugMakeList(module.buffers(true));
644
result["named_buffers_r"] = debugMakeNamedList(module.named_buffers(true));
646
result["named_attributes"] =
647
debugMakeNamedList(module.named_attributes(false));
648
result["named_attributes_r"] =
649
debugMakeNamedList(module.named_attributes(true));
653
static constexpr std::array<const char*, 48> magic_method_names = {
654
"__lt__", "__le__", "__eq__", "__ne__",
655
"__ge__", "__gt__", "__not__", "__abs__",
656
"__add__", "__and__", "__floordiv__", "__index__",
657
"__inv__", "__invert__", "__lshift__", "__mod__",
658
"__mul__", "__matmul__", "__neg__", "__or__",
659
"__pos__", "__pow__", "__rshift__", "__sub__",
660
"__truediv__", "__xor__", "__concat__", "__contains__",
661
"__delitem__", "__getitem__", "__setitem__", "__iadd__",
662
"__iand__", "__iconcat__", "__ifloordiv__", "__ilshift__",
663
"__imod__", "__imul__", "__imatmul__", "__ior__",
664
"__ipow__", "__irshift__", "__isub__", "__itruediv__",
665
"__ixor__", "__str__", "__len__", "__repr__",
668
struct DeepCopyMemoTable {
669
std::shared_ptr<IValue::HashAliasedIValueMap> map;
672
IValue pyIValueDeepcopy(const IValue& ivalue, const py::dict& memo) {
673
if (!memo.contains(py::str("__torch_script_memo_table"))) {
674
memo["__torch_script_memo_table"] =
675
DeepCopyMemoTable{std::make_shared<IValue::HashAliasedIValueMap>()};
678
*py::cast<DeepCopyMemoTable>(memo["__torch_script_memo_table"]).map;
679
return ivalue.deepcopy(ivalue_memo);
682
ExtraFilesMap extra_files_from_python(const py::dict& pydict) {
684
for (const auto& it : pydict) {
685
r[py::cast<std::string>(it.first)] = "";
690
void extra_files_to_python(const ExtraFilesMap& m, const py::dict& pydict) {
691
// py::dict is pointer-like type so it gets modified despite const&
692
for (const auto& it : m) {
693
pydict[py::str(it.first)] = py::bytes(it.second);
697
void pyCompilationUnitDefine(
699
const std::string& src,
700
const ResolutionCallback* rcb,
701
const uint32_t _frames_up) {
703
cu.define(c10::nullopt, src, pythonResolver(*rcb), nullptr);
705
py::object py_default_rcb =
706
py::module::import("torch._jit_internal")
707
.attr("createResolutionCallbackFromFrame")(_frames_up);
708
auto default_rcb = py_default_rcb.cast<ResolutionCallback>();
709
cu.define(c10::nullopt, src, pythonResolver(default_rcb), nullptr);
713
// This function will copy bytes into a shared_ptr of chars aligned
714
// at kFlatbufferDataAlignmentBytes boundary (currently 16).
715
// This is required because tensors need to be aligned at 16 bytes boundary.
716
static std::shared_ptr<char> copyStr(const std::string& bytes) {
717
size_t size = (bytes.size() / kFlatbufferDataAlignmentBytes + 1) *
718
kFlatbufferDataAlignmentBytes;
720
std::shared_ptr<char> bytes_copy(
721
static_cast<char*>(_aligned_malloc(size, kFlatbufferDataAlignmentBytes)),
723
#elif defined(__APPLE__)
725
::posix_memalign(&p, kFlatbufferDataAlignmentBytes, size);
726
TORCH_INTERNAL_ASSERT(p, "Could not allocate memory for flatbuffer");
727
std::shared_ptr<char> bytes_copy(static_cast<char*>(p), free);
729
std::shared_ptr<char> bytes_copy(
730
static_cast<char*>(aligned_alloc(kFlatbufferDataAlignmentBytes, size)),
733
memcpy(bytes_copy.get(), bytes.data(), bytes.size());
737
void initJitScriptBindings(PyObject* module) {
738
auto m = py::handle(module).cast<py::module>();
740
// NOLINTNEXTLINE(bugprone-unused-raii)
741
py::class_<c10::Capsule>(m, "Capsule");
744
py::class_<Object>(m, "ScriptObject")
745
.def("_type", [](Object& o) { return o.type(); })
748
[](Object& self, const std::string& name) -> Method {
749
return self.get_method(name);
751
py::keep_alive<0, 1>())
754
[](Object& self, const std::string& name, py::object value) {
755
if (self.type()->hasConstant(name)) {
758
"Can't set constant '",
760
"' which has value:",
761
self.type()->getConstant(name));
763
TypePtr type = self.type()->getAttribute(name);
765
auto ivalue = toIValue(std::move(value), type);
766
self.setattr(name, ivalue);
767
} catch (std::exception& e) {
768
throw py::cast_error(c10::str(
769
"Could not cast attribute '",
779
[](Object& self, const std::string& name) {
781
return toPyObject(self.attr(name));
782
} catch (const ObjectAttributeError& err) {
783
throw AttributeError("%s", err.what());
788
[](Object& self, const std::string& name) -> py::object {
790
if (name == "__qualname__") {
791
return py::cast(self.type()->name()->name());
793
if (auto method = self.find_method(name)) {
794
return py::cast(*method);
796
if (self.has_property(name)) {
797
auto prop = self.get_property(name);
798
// wrap the Method into callable PyObject
799
auto getter_func = py::cast(prop.getter_func);
800
return getter_func();
802
return toPyObject(self.attr(name));
803
} catch (const ObjectAttributeError& err) {
804
throw AttributeError("%s", err.what());
809
[](Object& self, const std::string& name, py::object value) {
811
if (self.has_property(name)) {
812
auto prop = self.get_property(name);
813
if (!prop.setter_func.has_value()) {
814
TORCH_CHECK(false, "can't set attribute");
816
// wrap the Method into callable PyObject
817
auto setter_func = py::cast(prop.setter_func);
822
if (self.type()->hasConstant(name)) {
825
"Can't set constant '",
827
"' which has value:",
828
self.type()->getConstant(name));
830
TypePtr type = self.type()->getAttribute(name);
831
auto ivalue = toIValue(std::move(value), type);
832
self.setattr(name, ivalue);
833
} catch (const ObjectAttributeError& err) {
834
throw AttributeError("%s", err.what());
839
[](Object& self, const std::string& name) {
840
return self.hasattr(name);
844
[](Object& self, const std::string& name) {
845
return bool(self.find_method(name));
850
return fmap(self.get_methods(), [](const Method& method) {
851
return method.name();
855
"_properties", [](Object& self) { return self.get_properties(); })
856
.def("__copy__", &Object::copy)
859
[](const Object& self) {
860
// Similar to Tensor's `__hash__`, which is `id()`.
861
return std::hash<c10::ivalue::Object*>{}(self._ivalue().get());
864
[](const Object& self)
865
-> std::tuple<py::object, std::string> { // __getstate__
866
if (auto getstate_method = self.find_method("__getstate__")) {
867
auto object_state = toPyObject((*getstate_method)(Stack{}));
868
TORCH_INTERNAL_ASSERT(self.type()->name());
869
return std::make_tuple(
870
object_state, self.type()->name()->qualifiedName());
872
std::stringstream err;
873
err << "Tried to serialize object ";
874
if (auto qualname = self.type()->name()) {
875
err << qualname->qualifiedName() << " ";
877
err << "which does not have a __getstate__ method defined!";
878
throw std::runtime_error(err.str());
880
[](const std::tuple<py::object, std::string>& state_tup)
882
auto [state, qualname] = state_tup;
883
auto class_type = getCustomClass(qualname);
886
"Tried to deserialize class ",
888
" which is not known to the runtime. "
889
"If this is a custom C++ class, make "
890
"sure the appropriate code is linked.");
892
auto self = Object(c10::ivalue::Object::create(
894
std::shared_ptr<torch::jit::CompilationUnit>(),
897
if (auto setstate_method = self.find_method("__setstate__")) {
898
auto setstate_schema =
899
setstate_method->function().getSchema();
900
TORCH_INTERNAL_ASSERT(
901
setstate_schema.arguments().size() == 2,
902
"__setstate__ method for class ",
903
class_type->repr_str(),
904
" must have exactly 2 arguments!");
905
auto state_type = setstate_schema.arguments().at(1).type();
906
(*setstate_method)(Stack{toIValue(state, state_type)});
909
std::stringstream err;
910
err << "Tried to deserialize object ";
911
if (auto qualname = class_type->name()) {
912
err << qualname->qualifiedName() << " ";
914
err << "which does not have a __setstate__ method defined!";
915
throw std::runtime_error(err.str());
918
py::class_<Object::Property>(m, "ScriptObjectProperty")
919
.def_property_readonly(
920
"name", [](const Object::Property& self) { return self.name; })
921
.def_property_readonly(
923
[](const Object::Property& self) { return self.getter_func; })
924
.def_property_readonly("setter", [](const Object::Property& self) {
925
return self.setter_func;
928
// Special case __str__ and __repr__ to make sure we can print Objects/Modules
929
// regardless of if the user defined __str__/__repr__
930
using MagicMethodImplType = std::function<py::object(
931
const Object& self, py::args args, py::kwargs kwargs)>;
933
std::unordered_map<std::string, MagicMethodImplType> special_magic_methods;
934
special_magic_methods.emplace(
936
[](const Object& self, py::args args, py::kwargs kwargs) -> py::object {
937
auto method = self.find_method("__str__");
939
return py::str("ScriptObject <" + self.type()->str() + ">");
941
return invokeScriptMethodFromPython(
943
// NOLINTNEXTLINE(performance-move-const-arg)
945
// NOLINTNEXTLINE(performance-move-const-arg)
949
special_magic_methods.emplace(
951
[](const Object& self, py::args args, py::kwargs kwargs) -> py::object {
952
auto method = self.find_method("__repr__");
954
std::stringstream ss;
955
ss << std::hex << static_cast<const void*>(&self);
956
return py::str("<torch.ScriptObject object at " + ss.str() + ">");
958
return invokeScriptMethodFromPython(
960
// NOLINTNEXTLINE(performance-move-const-arg)
962
// NOLINTNEXTLINE(performance-move-const-arg)
966
for (const char* mm_name : magic_method_names) {
967
if (special_magic_methods.count(mm_name)) {
968
object_class.def(mm_name, special_magic_methods[mm_name]);
972
[mm_name](const Object& self, py::args args, py::kwargs kwargs) {
973
auto method = self.find_method(mm_name);
975
throw c10::NotImplementedError(
976
"'%s' is not implemented for %s",
978
self.type()->str().c_str());
980
return invokeScriptMethodFromPython(
982
// NOLINTNEXTLINE(performance-move-const-arg)
984
// NOLINTNEXTLINE(performance-move-const-arg)
990
// NOLINTNEXTLINE(bugprone-unused-raii)
991
py::class_<DeepCopyMemoTable>(m, "DeepCopyMemoTable");
993
py::class_<UpgraderEntry>(m, "_UpgraderEntry")
994
.def(py::init<int, std::string, std::string>())
995
.def_property_readonly(
997
[](const UpgraderEntry& self) { return self.bumped_at_version; })
998
.def_property_readonly(
1000
[](const UpgraderEntry& self) { return self.upgrader_name; })
1001
.def_property_readonly("old_schema", [](const UpgraderEntry& self) {
1002
return self.old_schema;
1005
py::class_<UpgraderRange>(m, "_UpgraderRange")
1006
.def(py::init<int, int>())
1007
.def_property_readonly(
1009
[](const UpgraderRange& self) { return self.min_version; })
1010
.def_property_readonly("max_version", [](const UpgraderRange& self) {
1011
return self.max_version;
1015
"__deepcopy__", [](const Object& self, const py::dict& memo) {
1017
pyIValueDeepcopy(IValue(self._ivalue()), memo).toObject());
1020
// Used by torch.package to save ScriptModule objects in unified format.
1021
py::class_<ScriptModuleSerializer>(m, "ScriptModuleSerializer")
1022
.def(py::init<caffe2::serialize::PyTorchStreamWriter&>())
1023
.def("serialize", &ScriptModuleSerializer::serialize_unified_format)
1026
&ScriptModuleSerializer::writeFiles,
1027
py::arg("code_dir") = ".data/ts_code/code/")
1030
&ScriptModuleSerializer::storage_context,
1031
pybind11::return_value_policy::reference_internal);
1033
// Used by torch.package to coordinate sharing of storages between eager
1034
// and ScriptModules.
1036
SerializationStorageContext,
1037
std::shared_ptr<SerializationStorageContext>>(
1038
m, "SerializationStorageContext")
1039
.def("has_storage", &SerializationStorageContext::hasStorage)
1040
.def("get_or_add_storage", &SerializationStorageContext::getOrAddStorage);
1042
// torch.jit.ScriptModule is a subclass of this C++ object.
1043
// Methods here are prefixed with _ since they should not be
1045
py::class_<Module, Object>(m, "ScriptModule")
1046
.def(py::init<std::string, std::shared_ptr<CompilationUnit>, bool>())
1050
const std::string& filename,
1051
const ExtraFilesMap& _extra_files = ExtraFilesMap()) {
1052
m.save(filename, _extra_files);
1054
py::arg("filename"),
1055
py::arg("_extra_files") = ExtraFilesMap())
1058
[](Module& m, const ExtraFilesMap& _extra_files = ExtraFilesMap()) {
1059
std::ostringstream buf;
1060
m.save(buf, _extra_files);
1061
return py::bytes(buf.str());
1063
py::arg("_extra_files") = ExtraFilesMap())
1067
const std::string& filename,
1068
const ExtraFilesMap& _extra_files = ExtraFilesMap(),
1069
bool _save_mobile_debug_info = false,
1070
bool _use_flatbuffer = false) {
1074
_save_mobile_debug_info,
1077
py::arg("filename"),
1078
py::arg("_extra_files") = ExtraFilesMap(),
1079
py::arg("_save_mobile_debug_info") = false,
1080
py::arg("_use_flatbuffer") = false)
1082
"_save_to_buffer_for_mobile",
1084
const ExtraFilesMap& _extra_files = ExtraFilesMap(),
1085
bool _save_mobile_debug_info = false,
1086
bool _use_flatbuffer = false) {
1087
std::ostringstream buf;
1089
buf, _extra_files, _save_mobile_debug_info, _use_flatbuffer);
1090
return py::bytes(buf.str());
1092
py::arg("_extra_files") = ExtraFilesMap(),
1093
py::arg("_save_mobile_debug_info") = false,
1094
py::arg("_use_flatbuffer") = false)
1095
.def("_set_optimized", &Module::set_optimized)
1099
py::arg("code") = true,
1100
py::arg("attrs") = true,
1101
py::arg("params") = true)
1104
&Module::dump_to_str,
1105
py::arg("code") = true,
1106
py::arg("attrs") = true,
1107
py::arg("params") = true)
1109
"_replicate_for_data_parallel",
1110
[](Module& module) {
1111
const ModulePtr& obj = module._ivalue();
1112
auto copy = c10::ivalue::Object::create(
1113
c10::StrongTypePtr(obj->compilation_unit(), obj->type()),
1114
obj->slots().size());
1115
for (size_t i = 0; i < obj->slots().size(); ++i) {
1116
copy->setSlot(i, obj->getSlot(i));
1118
return Module(std::move(copy));
1123
if (auto m = self.find_method("forward")) {
1124
return m->get_executor().getDebugState();
1126
throw std::runtime_error(
1127
"Attempted to call get_debug_state on a Module without a compiled forward()");
1132
std::shared_ptr<ConcreteModuleType> concreteType,
1133
const std::string& script,
1134
const ResolutionCallback& rcb) {
1135
const auto self = ModuleSelf(std::move(concreteType));
1136
m._ivalue()->compilation_unit()->define(
1137
*m.type()->name(), script, pythonResolver(rcb), &self);
1138
didFinishEmitModule(m);
1141
"_register_attribute",
1143
const std::string& name,
1144
const TypePtr& type,
1146
m.register_attribute(name, type, toIValue(value, type));
1149
"_create_method_from_trace",
1151
const std::string& name,
1152
const py::function& func,
1153
const py::tuple& input_tuple,
1154
const py::function& var_name_lookup_fn,
1156
bool force_outplace,
1157
const std::vector<std::string>& argument_names,
1158
bool store_inputs) {
1159
// prereq: Module's buffers and parameters are unique
1160
// this was ensured in python before calling this function
1161
auto typed_inputs = toTraceableStack(input_tuple);
1163
std::shared_ptr<Graph> graph =
1164
std::get<0>(tracer::createGraphByTracing(
1172
const auto method_name = QualifiedName(*self.type()->name(), name);
1173
auto fn = self._ivalue()->compilation_unit()->create_function(
1174
method_name, graph);
1175
self.type()->addMethod(fn);
1177
self.store_traced_inputs(name, typed_inputs);
1179
didFinishEmitModule(self);
1183
py::arg("input_tuple"),
1184
py::arg("var_name_lookup_fn"),
1186
py::arg("force_outplace"),
1187
py::arg("argument_names") = std::vector<std::string>(),
1188
py::arg("store_inputs"))
1190
"_create_method_from_trace_with_dict",
1192
const std::string& name,
1193
const py::function& func,
1194
const py::dict& input_dict,
1195
const py::function& var_name_lookup_fn,
1197
bool force_outplace,
1198
const std::vector<std::string>& argument_names,
1199
bool store_inputs) {
1200
// prereq: Module's buffers and parameters are unique
1201
// this was ensured in python before calling this function
1202
auto typed_inputs = toTraceableStack(input_dict);
1204
std::shared_ptr<Graph> graph =
1205
std::get<0>(tracer::createGraphByTracingWithDict(
1214
const auto method_name = QualifiedName(*self.type()->name(), name);
1215
auto fn = self._ivalue()->compilation_unit()->create_function(
1216
method_name, graph);
1218
self.store_traced_inputs(name, typed_inputs);
1220
self.type()->addMethod(fn);
1221
didFinishEmitModule(self);
1225
py::arg("input_dict"),
1226
py::arg("var_name_lookup_fn"),
1228
py::arg("force_outplace"),
1229
py::arg("argument_names") = std::vector<std::string>(),
1230
py::arg("store_inputs"))
1232
"_get_forward_hooks",
1233
[](const Module& m) {
1234
std::vector<StrongFunctionPtr> funcs;
1235
for (auto& hook : m.type()->getForwardHooks()) {
1236
funcs.emplace_back(m.type()->compilation_unit(), hook);
1241
"_get_forward_pre_hooks",
1242
[](const Module& m) {
1243
std::vector<StrongFunctionPtr> funcs;
1244
for (auto& pre_hook : m.type()->getForwardPreHooks()) {
1245
funcs.emplace_back(m.type()->compilation_unit(), pre_hook);
1250
"_retrieve_traced_inputs",
1251
[](const Module& m) {
1252
return ScriptDict(m.retrieve_traced_inputs());
1254
.def_property_readonly(
1257
std::vector<at::IValue> constants;
1258
PrintDepsTable deps;
1259
PythonPrint pp(constants, deps);
1260
pp.printNamedType(self.type());
1263
.def_property_readonly(
1264
"code_with_constants",
1266
std::vector<at::IValue> constants;
1267
PrintDepsTable deps;
1268
PythonPrint pp(constants, deps);
1269
pp.printNamedType(self.type());
1270
std::map<std::string, at::IValue> consts;
1272
for (auto const& constant : constants) {
1273
consts["c" + std::to_string(i)] = constant;
1276
return std::make_tuple(pp.str(), consts);
1278
.def("apply", &Module::apply)
1279
.def("__copy__", &Module::copy)
1282
[](const Module& self) {
1283
// Similar to Tensor's `__hash__`, which is `id()`.
1284
return std::hash<c10::ivalue::Object*>{}(self._ivalue().get());
1288
[](const Module& self, const py::object& other) {
1289
// TODO: call UDF if it exists
1290
if (!py::isinstance<Module>(other)) {
1293
return self._ivalue().get() ==
1294
py::cast<Module>(other)._ivalue().get();
1298
[](const Module& self, const py::dict& memo) {
1300
pyIValueDeepcopy(IValue(self._ivalue()), memo).toObject());
1302
.def("children", &Module::children)
1303
.def_property_readonly("qualified_name", [](const Module& self) {
1304
return self.type()->name()->qualifiedName();
1307
py::class_<mobile::Module>(m, "LiteScriptModule")
1309
c10::intrusive_ptr<c10::ivalue::Object>,
1310
std::shared_ptr<mobile::CompilationUnit>>())
1313
[](mobile::Module& m, const std::string& method_name) {
1314
auto method = m.find_method(method_name);
1315
return method != c10::nullopt;
1317
py::arg("method_name"))
1320
[](mobile::Module& m,
1321
const std::string& method_name,
1322
const py::tuple& input_tuple) {
1324
for (auto& input : input_tuple) {
1325
stack.push_back(toTypeInferredIValue(input));
1327
return m.get_method(method_name)(stack);
1329
py::arg("method_name"),
1330
py::arg("input_tuple"))
1333
[](mobile::Module& m, const py::tuple& input_tuple) {
1335
for (auto& input : input_tuple) {
1336
stack.push_back(toTypeInferredIValue(input));
1338
return m.get_method("forward")(stack);
1340
py::arg("input_tuple"));
1342
slot_dict_impl<detail::ParameterPolicy>::bind(m, "ParameterDict");
1343
slot_dict_impl<detail::BufferPolicy>::bind(m, "BufferDict");
1344
slot_dict_impl<detail::ModulePolicy>::bind(m, "ModuleDict");
1346
py::class_<ErrorReport, std::shared_ptr<ErrorReport>>(m, "ErrorReport")
1347
.def(py::init<SourceRange>())
1348
.def("what", &ErrorReport::what)
1349
.def_static("call_stack", ErrorReport::current_call_stack);
1351
py::class_<CompilationUnit, std::shared_ptr<CompilationUnit>>(
1352
m, "CompilationUnit")
1354
py::init([](const std::string& lang, const uint32_t _frames_up) {
1355
auto cu = std::make_shared<CompilationUnit>();
1356
if (!lang.empty()) {
1357
pyCompilationUnitDefine(*cu, lang, nullptr, _frames_up);
1361
py::arg("lang") = "",
1362
py::arg("_frames_up") = 0)
1366
[](std::shared_ptr<CompilationUnit> self, const std::string& name) {
1367
auto fn = self->find_function(QualifiedName(name));
1369
return c10::optional<StrongFunctionPtr>(
1370
StrongFunctionPtr(std::move(self), fn));
1372
return c10::optional<StrongFunctionPtr>(c10::nullopt);
1377
[](std::shared_ptr<CompilationUnit> self, const std::string& name) {
1378
auto fn = self->find_function(QualifiedName(name));
1380
return StrongFunctionPtr(std::move(self), fn);
1382
throw AttributeError(
1383
"'CompilationUnit' has no attribute '%s'", name.c_str());
1388
[](const std::shared_ptr<CompilationUnit>& self) {
1389
auto raw_functions = self->get_functions();
1390
std::vector<StrongFunctionPtr> functions;
1391
functions.reserve(raw_functions.size());
1392
for (auto fn : raw_functions) {
1394
functions.emplace_back(self, fn);
1399
.def("set_optimized", &CompilationUnit::set_optimized)
1402
pyCompilationUnitDefine,
1404
py::arg("rcb") = nullptr,
1405
py::arg("_frames_up") = 0)
1408
[](std::shared_ptr<CompilationUnit>& self,
1409
const std::string& qualified_name,
1410
std::shared_ptr<Graph> graph,
1411
bool should_mangle) {
1412
Function* fn = self->create_function(
1413
qualified_name, std::move(graph), should_mangle);
1414
return StrongFunctionPtr(std::move(self), fn);
1416
py::arg("qualified_name"),
1418
py::arg("should_mangle") = false)
1421
[](const std::shared_ptr<CompilationUnit>& self,
1422
const std::string& name) { return self->get_interface(name); })
1425
[](const std::shared_ptr<CompilationUnit>& self,
1426
const std::string& name) { return self->get_class(name); })
1428
"drop_all_functions",
1429
[](const std::shared_ptr<CompilationUnit>& self) {
1430
self->drop_all_functions();
1433
py::class_<StrongFunctionPtr>(m, "ScriptFunction", py::dynamic_attr())
1436
[](py::args args, py::kwargs kwargs) {
1438
// see: [pybind11 varargs]
1439
auto strongPtr = py::cast<StrongFunctionPtr>(args[0]);
1440
Function& callee = *strongPtr.function_;
1441
py::object result = invokeScriptFunctionFromPython(
1443
// NOLINTNEXTLINE(performance-move-const-arg)
1444
tuple_slice(std::move(args), 1),
1445
// NOLINTNEXTLINE(performance-move-const-arg)
1448
END_HANDLE_TH_ERRORS_PYBIND
1452
[](const StrongFunctionPtr& self,
1453
const std::string& filename,
1454
const ExtraFilesMap& _extra_files = ExtraFilesMap()) {
1455
Module module("__torch__.PlaceholderModule");
1457
// Modules have 'training' attributes by default, but due to
1458
// https://github.com/pytorch/pytorch/issues/27343, functions end
1459
// up having a training attribute when they are loaded. This adds
1460
// a fake 'training' attribute that shouldn't be used, but prevents
1461
// jitter on saving and loading. Once that issue is fixed this can
1463
module.register_attribute("training", BoolType::get(), true);
1464
addFunctionToModule(module, self);
1465
module.save(filename, _extra_files);
1467
py::arg("filename"),
1468
py::arg("_extra_files") = ExtraFilesMap())
1471
[](const StrongFunctionPtr& self,
1472
const ExtraFilesMap& _extra_files = ExtraFilesMap()) {
1473
std::ostringstream buf;
1474
Module module("__torch__.PlaceholderModule");
1475
// see [issue 27343]
1476
module.register_attribute("training", BoolType::get(), true);
1477
addFunctionToModule(module, self);
1478
module.save(buf, _extra_files);
1479
return py::bytes(buf.str());
1481
py::arg("_extra_files") = ExtraFilesMap())
1482
.def_property_readonly(
1484
[](const StrongFunctionPtr& self) {
1485
return toGraphFunction(*self.function_).graph();
1487
.def_property_readonly(
1489
[](const StrongFunctionPtr& self) {
1490
auto g = toGraphFunction(*self.function_).graph()->copy();
1494
.def_property_readonly(
1496
[](const StrongFunctionPtr& self) {
1497
return self.function_->getSchema();
1499
.def_property_readonly(
1501
[](const StrongFunctionPtr& self) {
1502
std::vector<at::IValue> constants;
1503
PrintDepsTable deps;
1505
PythonPrint pp(constants, deps);
1506
pp.printFunction(*self.function_);
1511
[](const StrongFunctionPtr& self) {
1512
return toGraphFunction(*self.function_)
1517
"_debug_flush_compilation_cache",
1518
[](const StrongFunctionPtr& self) {
1519
toGraphFunction(*self.function_)
1521
.debugFlushCompilationCache();
1523
.def_property_readonly(
1525
[](const StrongFunctionPtr& self) { return self.function_->name(); })
1528
[](StrongFunctionPtr& self, bool ignore) {
1529
auto fn = self.function_;
1530
TORCH_INTERNAL_ASSERT(fn->isGraphFunction());
1531
GraphFunction& g_fn = toGraphFunction(*fn);
1532
g_fn._set_ignore_amp(ignore);
1534
.def_property_readonly(
1536
[](const StrongFunctionPtr& self) {
1537
return self.function_->qualname().qualifiedName();
1539
.def_property_readonly("__doc__", [](const StrongFunctionPtr& self) {
1540
return self.function_->doc_string();
1543
py::class_<Method>(m, "ScriptMethod", py::dynamic_attr())
1546
[](py::args args, py::kwargs kwargs) {
1547
// see: [pybind11 varargs]
1549
Method& method = py::cast<Method&>(args[0]);
1551
return invokeScriptMethodFromPython(
1553
// NOLINTNEXTLINE(performance-move-const-arg)
1554
tuple_slice(std::move(args), 1),
1555
// NOLINTNEXTLINE(performance-move-const-arg)
1557
END_HANDLE_TH_ERRORS_PYBIND
1559
.def_property_readonly("graph", &Method::graph)
1560
.def_property_readonly(
1562
[](const Method& self) {
1563
auto g = toGraphFunction(self.function()).graph()->copy();
1567
.def_property_readonly(
1568
"schema", [](Method& m) { return m.function().getSchema(); })
1569
.def_property_readonly("name", &Method::name)
1570
.def_property_readonly(
1573
std::vector<at::IValue> constants;
1574
PrintDepsTable deps;
1575
PythonPrint pp(constants, deps);
1576
pp.printMethod(self.function());
1580
"_debug_flush_compilation_cache",
1582
return self.get_executor().debugFlushCompilationCache();
1584
.def_property_readonly(
1585
"code_with_constants",
1587
std::vector<at::IValue> constants;
1588
PrintDepsTable deps;
1589
PythonPrint pp(constants, deps);
1590
pp.printMethod(self.function());
1591
std::map<std::string, at::IValue> consts;
1593
for (auto const& constant : constants) {
1594
consts["c" + std::to_string(i)] = constant;
1597
return std::make_tuple(pp.str(), consts);
1599
.def_property_readonly("owner", &Method::owner)
1600
.def_property_readonly("raw_owner", [](const Method& self) {
1601
return Object(self.raw_owner());
1603
m.def("_generate_upgraders_graph", &generate_upgraders_graph);
1605
"_calculate_package_version_based_on_upgraders",
1606
&calculate_package_version_based_on_upgraders);
1607
m.def("_get_version_calculator_flag", &get_version_calculator_flag);
1609
"_compile_graph_to_code_table",
1610
[](const std::string& name, const std::shared_ptr<Graph>& graph) {
1611
CompilationOptions options;
1612
GraphFunction jitFunc(name, graph, nullptr);
1613
auto mobileFunc = convertJitFunctionToMobileFunction(jitFunc, options);
1614
return convertMobileFunctionToCodeTable(*mobileFunc, options);
1617
"_jit_script_compile",
1618
[](const std::string& qualname,
1620
const ResolutionCallback& rcb,
1621
const FunctionDefaults& defaults) {
1622
C10_LOG_API_USAGE_ONCE("torch.script.compile");
1623
const auto name = c10::QualifiedName(qualname);
1624
TORCH_INTERNAL_ASSERT(name.name() == def.name().name());
1625
return script_compile_function(name, def, defaults, rcb);
1628
"_jit_script_compile_overload",
1629
[](const std::string& qualname,
1630
const Decl& overload_decl,
1631
const Def& implementation_def,
1632
const ResolutionCallback& rcb,
1633
const FunctionDefaults& implementation_defaults,
1634
const py::object& signature) {
1635
const auto name = c10::QualifiedName(qualname);
1636
return script_compile_overloaded_function(
1641
implementation_defaults,
1645
"_replace_overloaded_method_decl",
1646
[](const Decl& overload_decl,
1647
const Def& implementation_def,
1648
const std::string& new_name) {
1649
checkOverloadDecl(overload_decl, implementation_def.decl());
1650
return implementation_def.withDecl(overload_decl).withName(new_name);
1653
"_create_function_from_trace",
1654
[](const std::string& qualname,
1655
const py::function& func,
1656
const py::tuple& input_tuple,
1657
const py::function& var_name_lookup_fn,
1659
bool force_outplace,
1660
const std::vector<std::string>& argument_names) {
1661
auto typed_inputs = toTraceableStack(input_tuple);
1662
std::shared_ptr<Graph> graph = std::get<0>(tracer::createGraphByTracing(
1671
auto cu = get_python_cu();
1672
auto name = c10::QualifiedName(qualname);
1673
auto result = cu->create_function(
1674
std::move(name), std::move(graph), /*shouldMangle=*/true);
1675
StrongFunctionPtr ret(std::move(cu), result);
1676
didFinishEmitFunction(ret);
1681
py::arg("input_tuple"),
1682
py::arg("var_name_lookup_fn"),
1684
py::arg("force_outplace"),
1685
py::arg("argument_names") = std::vector<std::string>());
1688
"_create_function_from_trace_with_dict",
1689
[](const std::string& qualname,
1690
const py::function& func,
1691
const py::dict& input_dict,
1692
const py::function& var_name_lookup_fn,
1694
bool force_outplace,
1695
const std::vector<std::string>& argument_names) {
1696
auto typed_inputs = toTraceableStack(input_dict);
1697
std::shared_ptr<Graph> graph =
1698
std::get<0>(tracer::createGraphByTracingWithDict(
1708
auto cu = get_python_cu();
1709
auto name = c10::QualifiedName(qualname);
1710
auto result = cu->create_function(
1711
std::move(name), std::move(graph), /*shouldMangle=*/true);
1712
StrongFunctionPtr ret(std::move(cu), result);
1713
didFinishEmitFunction(ret);
1718
py::arg("input_dict"),
1719
py::arg("var_name_lookup_fn"),
1721
py::arg("force_outplace"),
1722
py::arg("argument_names") = std::vector<std::string>());
1725
"_jit_script_class_compile",
1726
[](const std::string& qualifiedName,
1727
const ClassDef& classDef,
1728
const ClassMethodDefaults& defaults,
1729
const ResolutionCallback& rcb) {
1730
C10_LOG_API_USAGE_ONCE("torch.script.class");
1731
if (classDef.superclass().present()) {
1732
throw ErrorReport(classDef.range())
1733
<< "Torchscript does not support class inheritance.";
1735
auto cu = get_python_cu();
1736
auto classname = c10::QualifiedName(qualifiedName);
1737
if (cu->get_type(classname) != nullptr) {
1738
classname = cu->mangle(classname);
1741
auto classType = ClassType::create(
1744
/* is_module = */ false,
1745
/* doc_string = */ "",
1746
getUnresolvedClassAttributes(classDef));
1747
cu->register_type(classType);
1748
std::vector<ResolverPtr> methodRcbs, propRcbs;
1749
std::vector<Def> methodDefs;
1750
std::vector<Property> props;
1752
for (const auto& def : classDef.body()) {
1753
if (def.kind() != TK_DEF) {
1754
throw ErrorReport(def.range())
1755
<< "Currently class bodies can only contain method "
1756
"definitions. File an issue on GitHub if you want "
1759
methodDefs.emplace_back(def);
1760
methodRcbs.push_back(
1761
pythonResolver(rcb, classDef.name().name(), classType));
1764
// Gather definitions for property getters and setters as well as
1765
// corresponding resolution callbacks.
1766
if (classDef.properties().present()) {
1767
for (const auto& prop : classDef.properties().get()) {
1768
props.emplace_back(prop);
1770
pythonResolver(rcb, classDef.name().name(), classType));
1774
const auto self = SimpleSelf(classType);
1775
cu->define(classname, props, propRcbs, methodDefs, methodRcbs, &self);
1777
// Stitch in default arguments for methods. Properties don't need to be
1778
// considered since there is no way to invoke setters without passing in
1780
auto defs_it = methodDefs.begin();
1781
while (defs_it != methodDefs.end()) {
1782
auto def_name = (*defs_it).name().name();
1783
// If the method is not in the defaults map, assume there are
1784
// no default arguments for it.
1785
auto default_it = defaults.find(def_name);
1786
if (default_it == defaults.end()) {
1790
const auto method_name =
1791
QualifiedName(classname, (*defs_it).name().name());
1792
auto& method = cu->get_function(method_name);
1793
method.setSchema(getSchemaWithNameAndDefaults(
1797
default_it->second));
1803
"_jit_script_interface_compile",
1804
[](const std::string& qualifiedName,
1805
const ClassDef& classDef,
1806
const ResolutionCallback& rcb,
1808
auto cu = get_python_cu();
1809
auto className = c10::QualifiedName(qualifiedName);
1810
if (cu->get_type(className) != nullptr) {
1811
className = cu->mangle(className);
1814
get_python_cu()->define_interface(
1815
className, classDef, pythonResolver(rcb), is_module);
1816
return className.qualifiedName();
1819
py::class_<torch::jit::ErrorReport::CallStack>(
1820
m, "CallStack", py::dynamic_attr())
1821
.def(py::init<const std::string&, const SourceRange&>());
1823
m.def("_parse_source_def", [](const std::string& src) {
1824
Parser p(std::make_shared<Source>(src));
1825
return Def(p.parseFunction(/*is_method=*/true));
1827
m.def("parse_type_comment", [](const std::string& comment) {
1828
Parser p(std::make_shared<Source>(comment));
1829
return Decl(p.parseTypeComment());
1832
m.def("_get_upgraders_map_size", &get_upgraders_map_size);
1833
m.def("_dump_upgraders_map", &dump_upgraders_map);
1835
m.def("_test_only_populate_upgraders", &test_only_populate_upgraders);
1836
m.def("_test_only_remove_upgraders", &test_only_remove_upgraders);
1838
m.def("merge_type_from_type_comment", &mergeTypesFromTypeComment);
1839
m.def("_get_max_operator_version", &getMaxOperatorVersion);
1840
m.def("_get_operator_version_map", &get_operator_version_map);
1841
m.def("_get_upgraders_entry_map", &get_upgraders_entry_map);
1842
m.def("_get_upgrader_ranges", &getUpgradersRangeForOp);
1843
m.def("_test_only_add_entry_to_op_version_map", &test_only_add_entry);
1844
m.def("_test_only_remove_entry_to_op_version_map", &test_only_remove_entry);
1847
[](std::shared_ptr<CompilationUnit> cu,
1848
const std::string& filename,
1849
py::object map_location,
1850
const py::dict& extra_files,
1851
bool restore_shapes = false) {
1852
c10::optional<at::Device> optional_device;
1853
if (!map_location.is_none()) {
1854
AT_ASSERT(THPDevice_Check(map_location.ptr()));
1856
reinterpret_cast<THPDevice*>(map_location.ptr())->device;
1858
ExtraFilesMap extra_files_map = extra_files_from_python(extra_files);
1859
auto ret = import_ir_module(
1864
/*load_debug_files*/ true,
1866
extra_files_to_python(extra_files_map, extra_files);
1870
"_import_ir_module_from_package",
1871
[](std::shared_ptr<CompilationUnit> cu,
1872
std::shared_ptr<caffe2::serialize::PyTorchStreamReader> reader,
1873
std::shared_ptr<torch::jit::DeserializationStorageContext>
1875
py::object map_location,
1876
std::string ts_id) {
1877
c10::optional<at::Device> optional_device;
1878
if (!map_location.is_none()) {
1879
AT_ASSERT(THPDevice_Check(map_location.ptr()));
1881
reinterpret_cast<THPDevice*>(map_location.ptr())->device;
1883
return import_ir_module(
1886
std::move(storage_context),
1891
"import_ir_module_from_buffer",
1892
[](std::shared_ptr<CompilationUnit> cu,
1893
const std::string& buffer,
1894
py::object map_location,
1895
const py::dict& extra_files,
1896
bool restore_shapes = false) {
1897
std::istringstream in(buffer);
1898
c10::optional<at::Device> optional_device;
1899
if (!map_location.is_none()) {
1900
AT_ASSERT(THPDevice_Check(map_location.ptr()));
1902
reinterpret_cast<THPDevice*>(map_location.ptr())->device;
1904
ExtraFilesMap extra_files_map = extra_files_from_python(extra_files);
1905
auto ret = import_ir_module(
1910
/*load_debug_files*/ true,
1912
extra_files_to_python(extra_files_map, extra_files);
1916
"_load_for_lite_interpreter",
1917
[](const std::string& filename, py::object map_location) {
1918
c10::optional<at::Device> optional_device;
1919
if (!map_location.is_none()) {
1920
AT_ASSERT(THPDevice_Check(map_location.ptr()));
1922
reinterpret_cast<THPDevice*>(map_location.ptr())->device;
1924
return _load_for_mobile(filename, optional_device);
1927
"_load_for_lite_interpreter_from_buffer",
1928
[](const std::string& buffer, py::object map_location) {
1929
std::istringstream in(buffer);
1930
c10::optional<at::Device> optional_device;
1931
if (!map_location.is_none()) {
1932
AT_ASSERT(THPDevice_Check(map_location.ptr()));
1934
reinterpret_cast<THPDevice*>(map_location.ptr())->device;
1936
return _load_for_mobile(in, optional_device);
1939
"_backport_for_mobile",
1940
[](const std::string& filename_input,
1941
const std::string& filename_output,
1942
const int64_t version) {
1943
return _backport_for_mobile(filename_input, filename_output, version);
1946
"_backport_for_mobile_from_buffer",
1947
[](const std::string& buffer_input,
1948
const std::string& filename_output,
1949
const int64_t version) {
1950
std::istringstream in(buffer_input);
1951
return _backport_for_mobile(in, filename_output, version);
1954
"_backport_for_mobile_to_buffer",
1955
[](const std::string& filename_input, const int64_t version) {
1956
std::ostringstream buffer_output;
1958
_backport_for_mobile(filename_input, buffer_output, version);
1959
return success ? py::bytes(buffer_output.str()) : py::bytes("");
1962
"_backport_for_mobile_from_buffer_to_buffer",
1963
[](const std::string& buffer_input, const int64_t version) {
1964
std::istringstream in(buffer_input);
1965
std::ostringstream buffer_output;
1966
bool success = _backport_for_mobile(in, buffer_output, version);
1967
return success ? py::bytes(buffer_output.str()) : py::bytes("");
1969
m.def("_get_model_bytecode_version", [](const std::string& filename) {
1970
return _get_model_bytecode_version(filename);
1973
"_get_model_extra_files",
1974
[](const std::string& filename, const py::dict& py_extra_files) {
1975
c10::optional<at::Device> optional_device;
1976
ExtraFilesMap cpp_extra_files = ExtraFilesMap();
1977
_load_for_mobile(filename, optional_device, cpp_extra_files);
1978
extra_files_to_python(cpp_extra_files, py_extra_files);
1980
return py_extra_files;
1983
"_get_model_bytecode_version_from_buffer", [](const std::string& buffer) {
1984
std::istringstream in(buffer);
1985
return _get_model_bytecode_version(in);
1988
"_get_model_extra_files_from_buffer",
1989
[](const std::string& buffer, const py::dict& py_extra_files) {
1990
c10::optional<at::Device> optional_device;
1991
ExtraFilesMap cpp_extra_files = ExtraFilesMap();
1992
std::istringstream in(buffer);
1993
_load_for_mobile(in, optional_device, cpp_extra_files);
1994
extra_files_to_python(cpp_extra_files, py_extra_files);
1996
return py_extra_files;
1998
m.def("_get_mobile_model_contained_types", [](const std::string& filename) {
1999
return _get_mobile_model_contained_types(filename);
2002
"_get_mobile_model_contained_types_from_buffer",
2003
[](const std::string& buffer) {
2004
std::istringstream in(buffer);
2005
return _get_mobile_model_contained_types(in);
2007
m.def("_nn_module_to_mobile", [](const Module& module) {
2008
CompilationOptions options;
2009
return jitModuleToMobile(module, options);
2011
py::class_<OperatorInfo>(m, "OperatorInfo")
2012
.def_readonly("num_schema_args", &OperatorInfo::num_schema_args);
2013
m.def("_get_model_ops_and_info", [](const std::string& filename) {
2014
return _get_model_ops_and_info(filename);
2016
m.def("_get_model_ops_and_info_from_buffer", [](const std::string& buffer) {
2017
std::istringstream in(buffer);
2018
return _get_model_ops_and_info(in);
2020
m.def("_export_operator_list", [](torch::jit::mobile::Module& sm) {
2021
return debugMakeSet(torch::jit::mobile::_export_operator_list(sm));
2024
"_quantize_ondevice_ptq_dynamic",
2025
[](mobile::Module& m, const std::string& method_name) {
2026
mobile::quantization::PTQQuanizationHelper ptq_helper;
2027
ptq_helper.quantize_dynamic(m, method_name);
2030
m.def("_jit_set_emit_hooks", setEmitHooks);
2031
m.def("_jit_get_emit_hooks", getEmitHooks);
2032
m.def("_jit_clear_class_registry", []() {
2033
get_python_cu()->_clear_python_cu();
2036
"_debug_set_autodiff_subgraph_inlining",
2037
debugSetAutodiffSubgraphInlining);
2038
m.def("_debug_set_fusion_group_inlining", debugSetFusionGroupInlining);
2039
m.def("_debug_get_fusion_group_inlining", getFusionGroupInlining);
2040
m.def("_propagate_shapes", _propagate_shapes);
2042
"_propagate_and_assign_input_shapes", _propagate_and_assign_input_shapes);
2044
"_last_executed_optimized_graph",
2045
[]() { return lastExecutedOptimizedGraph(); },
2046
"Retrieve the optimized graph that was run the last time the graph executor ran on this thread");
2048
"_create_function_from_graph",
2049
[](const std::string& qualname, std::shared_ptr<Graph> graph) {
2050
// TODO this should go in the global Python CU
2051
auto cu = std::make_shared<CompilationUnit>();
2052
c10::QualifiedName name(qualname);
2053
auto fn = cu->create_function(std::move(name), std::move(graph));
2054
return StrongFunctionPtr(std::move(cu), fn);
2056
m.def("_ivalue_tags_match", ivalue_tags_match);
2057
m.def("_ivalue_debug_python_object", [](py::object py_obj) {
2058
// convert to IValue first, IValue will incref via py::object
2059
IValue pyobj_ivalue = toIValue(std::move(py_obj), PyObjectType::get());
2060
// convert back to PyObject by borrowing the reference, which also
2061
// incref, after the return of this function, IValue is out of scope
2062
// which decref, so the return value is original refcount + 1
2063
py::object ret = toPyObject(pyobj_ivalue);
2066
m.def("_jit_debug_module_iterators", _jit_debug_module_iterators);
2068
py::class_<testing::FileCheck>(m, "FileCheck")
2070
.def("check", &testing::FileCheck::check)
2071
.def("check_not", &testing::FileCheck::check_not)
2072
.def("check_same", &testing::FileCheck::check_same)
2073
.def("check_next", &testing::FileCheck::check_next)
2074
.def("check_count", &testing::FileCheck::check_count)
2075
.def("check_dag", &testing::FileCheck::check_dag)
2077
"check_source_highlighted",
2078
&testing::FileCheck::check_source_highlighted)
2079
.def("check_regex", &testing::FileCheck::check_regex)
2082
[](testing::FileCheck& f,
2083
const std::string& str,
2085
bool exactly) { return f.check_count(str, count, exactly); },
2089
py::arg("exactly") = false)
2092
[](testing::FileCheck& f, const std::string& str) {
2096
"run", [](testing::FileCheck& f, const Graph& g) { return f.run(g); })
2099
[](testing::FileCheck& f,
2100
const std::string& input,
2101
const std::string& output) { return f.run(input, output); },
2103
py::arg("checks_file"),
2104
py::arg("test_file"))
2107
[](testing::FileCheck& f, const std::string& input, const Graph& g) {
2108
return f.run(input, g);
2111
py::arg("checks_file"),
2115
"_logging_set_logger",
2116
[](logging::LoggerBase* logger) { return logging::setLogger(logger); },
2117
py::return_value_policy::reference);
2118
m.def("_set_graph_executor_optimize", [](bool optimize) {
2119
setGraphExecutorOptimize(optimize);
2123
"_get_graph_executor_optimize",
2124
[](c10::optional<bool> new_setting = c10::nullopt) {
2125
bool old_value = getGraphExecutorOptimize();
2127
setGraphExecutorOptimize(*new_setting);
2131
py::arg("new_settings") = nullptr);
2134
"_enable_mobile_interface_call_export",
2135
&torch::jit::enableMobileInterfaceCallExport);
2137
m.def("_create_module_with_type", [](const ClassTypePtr& type) {
2138
return Module(get_python_cu(), type);
2139
}).def("_create_object_with_type", [](const ClassTypePtr& type) {
2140
return Object(get_python_cu(), type);
2143
m.def("_export_opnames", [](Module& sm) {
2144
return debugMakeList(torch::jit::export_opnames(sm));
2148
ConcreteModuleTypeBuilder,
2149
std::shared_ptr<ConcreteModuleTypeBuilder>>(
2150
m, "ConcreteModuleTypeBuilder")
2151
.def(py::init<py::object>())
2154
[](ConcreteModuleTypeBuilder& self,
2157
self.addConstant(std::move(name), std::move(value));
2159
.def("add_attribute", &ConcreteModuleTypeBuilder::addAttribute)
2161
"add_function_attribute",
2162
&ConcreteModuleTypeBuilder::addFunctionAttribute)
2164
"add_builtin_function",
2165
&ConcreteModuleTypeBuilder::addBuiltinFunction)
2166
.def("add_forward_hook", &ConcreteModuleTypeBuilder::addForwardHook)
2168
"add_forward_pre_hook", &ConcreteModuleTypeBuilder::addForwardPreHook)
2169
.def("add_module", &ConcreteModuleTypeBuilder::addModule)
2170
.def("add_overload", &ConcreteModuleTypeBuilder::addOverload)
2171
.def("set_poisoned", &ConcreteModuleTypeBuilder::setPoisoned)
2173
"add_failed_attribute",
2174
&ConcreteModuleTypeBuilder::addFailedAttribute)
2176
"add_ignored_attribute",
2177
&ConcreteModuleTypeBuilder::addIgnoredAttribute)
2179
"add_ignored_attributes",
2180
[](ConcreteModuleTypeBuilder& self,
2181
const std::vector<std::string>& names) {
2182
for (auto& name : names) {
2183
self.addIgnoredAttribute(name);
2188
[](ConcreteModuleTypeBuilder& self) {
2189
self.setIterableModuleKind(IterableModuleKind::DICT);
2191
.def("build", &ConcreteModuleTypeBuilder::build)
2194
[](const ConcreteModuleTypeBuilder& self,
2195
const ConcreteModuleTypeBuilder& other) {
2196
return self.equals(other);
2200
[](ConcreteModuleTypeBuilder& self) {
2201
self.setIterableModuleKind(IterableModuleKind::LIST);
2204
"set_parameter_list",
2205
[](ConcreteModuleTypeBuilder& self) {
2206
self.setIterableModuleKind(IterableModuleKind::PARAMLIST);
2208
.def("set_parameter_dict", [](ConcreteModuleTypeBuilder& self) {
2209
self.setIterableModuleKind(IterableModuleKind::PARAMDICT);
2212
py::class_<ConcreteModuleType, std::shared_ptr<ConcreteModuleType>>(
2213
m, "ConcreteModuleType")
2214
.def_property_readonly("py_class", &ConcreteModuleType::getPyClass)
2215
.def_property_readonly("jit_type", &ConcreteModuleType::getJitType)
2216
.def_static("from_jit_type", &ConcreteModuleType::fromJitType)
2217
.def("get_constants", &ConcreteModuleType::getConstantsPy)
2218
.def("get_attributes", &ConcreteModuleType::getAttributesPy)
2219
.def("get_modules", &ConcreteModuleType::getModulesPy)
2220
.def("dump", &ConcreteModuleType::dump)
2221
.def("is_ignored_attribute", &ConcreteModuleType::isIgnoredAttribute)
2224
[](const ConcreteModuleType& self, const ConcreteModuleType& other) {
2225
return self.equals(other);
2229
[](const ConcreteModuleType& self,
2230
const ConcreteModuleTypeBuilder& other) {
2231
return self.equals(other);
2234
"_create_methods_and_properties",
2235
[](std::shared_ptr<ConcreteModuleType> concreteType,
2236
const std::vector<Property>& properties,
2237
const std::vector<ResolutionCallback>& propertyRcbs,
2238
const std::vector<Def>& methodDefs,
2239
const std::vector<ResolutionCallback>& methodRcbs,
2240
const std::vector<FunctionDefaults>& defaults) {
2241
TORCH_INTERNAL_ASSERT(methodDefs.size() == methodRcbs.size());
2242
TORCH_INTERNAL_ASSERT(properties.size() == propertyRcbs.size());
2244
std::vector<ResolverPtr> methodResolvers, propertyResolvers;
2245
methodResolvers.reserve(methodRcbs.size());
2246
for (auto& callback : methodRcbs) {
2247
methodResolvers.push_back(pythonResolver(callback));
2250
propertyResolvers.reserve(propertyRcbs.size());
2251
for (auto& callback : propertyRcbs) {
2252
propertyResolvers.push_back(pythonResolver(callback));
2255
const auto& selfType =
2256
concreteType->getJitType()->expect<ClassType>();
2257
const auto& prefix = selfType->name().value();
2258
const auto self = ModuleSelf(std::move(concreteType));
2259
auto cu = selfType->compilation_unit();
2267
// Stitch in default arguments for each Def if provided
2268
auto defaults_it = defaults.begin();
2269
auto defs_it = methodDefs.begin();
2270
while (defs_it != methodDefs.end()) {
2271
const auto method_name =
2272
QualifiedName(prefix, (*defs_it).name().name());
2273
auto& method = cu->get_function(method_name);
2274
method.setSchema(getSchemaWithNameAndDefaults(
2285
[](std::shared_ptr<ConcreteModuleType> concreteType,
2286
const std::vector<Def>& hookDefs,
2287
const std::vector<ResolutionCallback>& hookRcbs,
2288
const std::vector<Def>& preHookDefs,
2289
const std::vector<ResolutionCallback>& preHookRcbs) {
2290
TORCH_INTERNAL_ASSERT(hookDefs.size() == hookRcbs.size());
2291
TORCH_INTERNAL_ASSERT(preHookDefs.size() == preHookRcbs.size());
2293
std::vector<ResolverPtr> hookResolvers, preHookResolvers;
2295
hookResolvers.reserve(hookRcbs.size());
2296
for (auto& callback : hookRcbs) {
2297
hookResolvers.push_back(pythonResolver(callback));
2300
preHookResolvers.reserve(preHookRcbs.size());
2301
for (auto& callback : preHookRcbs) {
2302
preHookResolvers.push_back(pythonResolver(callback));
2305
const auto& selfType =
2306
concreteType->getJitType()->expect<ClassType>();
2307
const auto& prefix = selfType->name().value();
2308
const auto self = ModuleSelf(std::move(concreteType));
2309
auto cu = selfType->compilation_unit();
2321
[](const std::string& name,
2322
const SourceRange& range,
2323
const ResolutionCallback& rcb) {
2324
return pythonResolver(rcb)->resolveType(name, range);
2327
"_resolve_type_from_object",
2328
[](const py::object& obj,
2329
const SourceRange& range,
2330
const ResolutionCallback& rcb) {
2331
return pythonResolver(rcb)->resolveTypeFromObject(obj, range);
2335
"_run_emit_module_hook", [](const Module& m) { didFinishEmitModule(m); });
2338
"_set_should_use_format_with_string_table",
2339
setShouldUseFormatWithStringTable);
2341
// NOLINTNEXTLINE(bugprone-unused-raii)
2342
py::class_<logging::LoggerBase, std::shared_ptr<logging::LoggerBase>>(
2344
py::enum_<logging::LockingLogger::AggregationType>(m, "AggregationType")
2345
.value("SUM", logging::LockingLogger::AggregationType::SUM)
2346
.value("AVG", logging::LockingLogger::AggregationType::AVG)
2349
logging::LockingLogger,
2350
logging::LoggerBase,
2351
std::shared_ptr<logging::LockingLogger>>(m, "LockingLogger")
2353
.def("set_aggregation_type", &logging::LockingLogger::setAggregationType)
2354
.def("get_counter_val", &logging::LockingLogger::getCounterValue);
2356
logging::NoopLogger,
2357
logging::LoggerBase,
2358
std::shared_ptr<logging::NoopLogger>>(m, "NoopLogger")
2360
m.def("_jit_is_script_object", [](const py::object& obj) {
2361
return py::isinstance<Object>(obj);
2364
m.def("_get_file_format", [](const std::string& path) {
2365
switch (getFileFormat(path)) {
2366
case FileFormat::FlatbufferFileFormat:
2367
return "flatbuffer";
2368
case FileFormat::ZipFileFormat:
2377
[](const std::map<std::string, at::Tensor>& map,
2378
const std::string& filename,
2379
bool use_flatbuffer = false) {
2380
_save_parameters(map, filename, use_flatbuffer);
2383
m.def("_load_mobile_module_from_file", [](const std::string& filename) {
2384
return torch::jit::load_mobile_module_from_file(filename);
2386
m.def("_load_mobile_module_from_bytes", [](const std::string& bytes) {
2387
auto bytes_copy = copyStr(bytes);
2388
return torch::jit::parse_and_initialize_mobile_module(
2389
bytes_copy, bytes.size());
2391
m.def("_load_jit_module_from_file", [](const std::string& filename) {
2392
ExtraFilesMap extra_files = ExtraFilesMap();
2393
return torch::jit::load_jit_module_from_file(filename, extra_files);
2395
m.def("_load_jit_module_from_bytes", [](const std::string& bytes) {
2396
auto bytes_copy = copyStr(bytes);
2397
ExtraFilesMap extra_files = ExtraFilesMap();
2398
return torch::jit::parse_and_initialize_jit_module(
2399
bytes_copy, bytes.size(), extra_files);
2402
"_save_mobile_module",
2403
[](const torch::jit::mobile::Module& module,
2404
const std::string& filename,
2405
const ExtraFilesMap& _extra_files = ExtraFilesMap()) {
2406
return torch::jit::save_mobile_module(module, filename, _extra_files);
2410
[](const torch::jit::Module& module,
2411
const std::string& filename,
2412
const ExtraFilesMap& _extra_files = ExtraFilesMap()) {
2413
return torch::jit::save_jit_module(module, filename, _extra_files);
2416
"_save_mobile_module_to_bytes",
2417
[](const torch::jit::mobile::Module& module,
2418
const ExtraFilesMap& _extra_files = ExtraFilesMap()) {
2419
auto detached_buffer =
2420
torch::jit::save_mobile_module_to_bytes(module, _extra_files);
2422
reinterpret_cast<char*>(detached_buffer->data()),
2423
detached_buffer->size());
2426
"_save_jit_module_to_bytes",
2427
[](const torch::jit::Module& module,
2428
const ExtraFilesMap& _extra_files = ExtraFilesMap()) {
2429
auto detached_buffer =
2430
torch::jit::save_jit_module_to_bytes(module, _extra_files);
2432
reinterpret_cast<char*>(detached_buffer->data()),
2433
detached_buffer->size());
2435
m.def("_get_module_info_from_flatbuffer", [](std::string flatbuffer_content) {
2436
py::gil_scoped_acquire acquire;
2438
mobile::ModuleInfo minfo =
2439
torch::jit::get_module_info_from_flatbuffer(&flatbuffer_content[0]);
2440
result["bytecode_version"] = minfo.bytecode_version;
2441
result["operator_version"] = minfo.operator_version;
2442
result["function_names"] = minfo.function_names;
2443
result["type_names"] = minfo.type_names;
2444
result["opname_to_num_args"] = minfo.opname_to_num_args;
2448
m.def("_pickle_save", [](IValue v) {
2449
auto bytes = torch::jit::pickle_save(std::move(v));
2450
return py::bytes(bytes.data(), bytes.size());
2453
initScriptDictBindings(module);
2454
initScriptListBindings(module);
2457
} // namespace torch::jit