pytorch

Форк
0
/
script_init.cpp 
2457 строк · 90.0 Кб
1
#include <pybind11/detail/common.h>
2
#include <pybind11/pytypes.h>
3
#include <torch/csrc/jit/api/object.h>
4
#include <torch/csrc/jit/python/script_init.h>
5
#include <torch/csrc/utils/pybind.h>
6

7
#include <caffe2/serialize/versions.h>
8
#include <torch/csrc/Device.h>
9
#include <torch/csrc/DynamicTypes.h>
10
#include <torch/csrc/jit/api/module.h>
11
#include <torch/csrc/jit/frontend/ir_emitter.h>
12
#include <torch/csrc/jit/frontend/sugared_value.h>
13
#include <torch/csrc/jit/mobile/code.h>
14
#include <torch/csrc/jit/mobile/compatibility/backport.h>
15
#include <torch/csrc/jit/mobile/compatibility/model_compatibility.h>
16
#include <torch/csrc/jit/mobile/file_format.h>
17
#include <torch/csrc/jit/mobile/flatbuffer_loader.h>
18
#include <torch/csrc/jit/mobile/import.h>
19
#include <torch/csrc/jit/mobile/module.h>
20
#include <torch/csrc/jit/mobile/quantization.h>
21
#include <torch/csrc/jit/operator_upgraders/upgraders.h>
22
#include <torch/csrc/jit/operator_upgraders/upgraders_entry.h>
23
#include <torch/csrc/jit/operator_upgraders/utils.h>
24
#include <torch/csrc/jit/operator_upgraders/version_map.h>
25
#include <torch/csrc/jit/python/module_python.h>
26
#include <torch/csrc/jit/python/python_ivalue.h>
27
#include <torch/csrc/jit/python/python_sugared_value.h>
28
#include <torch/csrc/jit/serialization/export_bytecode.h>
29
#include <torch/csrc/jit/serialization/flatbuffer_serializer.h>
30
#include <torch/csrc/jit/serialization/import.h>
31
#include <torch/csrc/jit/testing/file_check.h>
32

33
#include <c10/util/Exception.h>
34
#include <c10/util/intrusive_ptr.h>
35
#include <c10/util/irange.h>
36
#include <torch/csrc/jit/frontend/parser.h>
37
#include <torch/csrc/jit/frontend/tracer.h>
38
#include <torch/csrc/jit/ir/constants.h>
39
#include <torch/csrc/jit/ir/graph_utils.h>
40
#include <torch/csrc/jit/ir/irparser.h>
41
#include <torch/csrc/jit/passes/inliner.h>
42
#include <torch/csrc/jit/passes/shape_analysis.h>
43
#include <torch/csrc/jit/python/pybind_utils.h>
44
#include <torch/csrc/jit/python/python_dict.h>
45
#include <torch/csrc/jit/python/python_list.h>
46
#include <torch/csrc/jit/python/python_tracer.h>
47
#include <torch/csrc/jit/runtime/graph_executor.h>
48
#include <torch/csrc/jit/runtime/instruction.h>
49
#include <torch/csrc/jit/runtime/interpreter.h>
50
#include <torch/csrc/jit/runtime/logging.h>
51
#include <torch/csrc/jit/serialization/export_bytecode.h>
52
#include <torch/csrc/jit/serialization/import_source.h>
53
#include <torch/csrc/jit/serialization/pickle.h>
54
#include <torch/csrc/jit/serialization/python_print.h>
55
#include <torch/csrc/jit/testing/hooks_for_testing.h>
56

57
#include <torch/csrc/api/include/torch/ordered_dict.h>
58

59
#include <ATen/ATen.h>
60
#include <ATen/core/function_schema.h>
61
#include <ATen/core/ivalue.h>
62
#include <ATen/core/qualified_name.h>
63

64
#include <pybind11/functional.h>
65
#include <pybind11/pybind11.h>
66
#include <pybind11/stl.h>
67
#include <pybind11/stl_bind.h>
68
#include <torch/csrc/jit/mobile/train/export_data.h>
69
#include <chrono>
70
#include <cstddef>
71
#include <memory>
72
#include <sstream>
73
#include <string>
74
#include <tuple>
75
#include <utility>
76
#include <vector>
77

78
namespace torch::jit {
79

80
using ::c10::Argument;
81
using ::c10::FunctionSchema;
82

83
using FunctionDefaults = std::unordered_map<std::string, py::object>;
84
using ClassMethodDefaults = std::unordered_map<std::string, FunctionDefaults>;
85

86
namespace {
87

88
// A resolver that will inspect the outer Python scope to find `name`.
89
struct PythonResolver : public Resolver {
90
  explicit PythonResolver(ResolutionCallback rcb) : rcb_(std::move(rcb)) {}
91

92
  /**
93
   * While compiling classes, the class type we're compiling will not be
94
   * available in Python, since we haven't fowner_ defining the class yet. So
95
   * in order to make the class type available to its own methods, we need to
96
   * explicitly resolve it.
97
   *
98
   * @param rcb Python function to resolve a name to its Python object in the
99
   *            enclosing scope
100
   * @param classname The unqualified classname of the class currently being
101
   *                  compiled.
102
   * @param classType The class's type.
103
   */
104
  explicit PythonResolver(
105
      ResolutionCallback rcb,
106
      std::string classname,
107
      ClassTypePtr classType)
108
      : rcb_(std::move(rcb)),
109
        classname_(std::move(classname)),
110
        classType_(std::move(classType)) {}
111

112
  std::shared_ptr<SugaredValue> resolveValue(
113
      const std::string& name,
114
      GraphFunction& m,
115
      const SourceRange& loc) override {
116
    pybind11::gil_scoped_acquire ag;
117
    py::object obj = rcb_(name);
118
    if (obj.is_none()) {
119
      return nullptr;
120
    }
121
    return toSugaredValue(obj, m, loc);
122
  }
123

124
  static bool isNamedTupleClass(py::object obj) {
125
    auto tuple_type = reinterpret_cast<PyObject*>(&PyTuple_Type);
126
    return PyObject_IsSubclass(obj.ptr(), tuple_type) &&
127
        py::hasattr(obj, "_fields");
128
  }
129

130
  TypePtr resolveTypeFromObject(const py::object& obj, const SourceRange& loc) {
131
    if (py::isinstance<ScriptClass>(obj)) {
132
      auto script_class = py::cast<ScriptClass>(obj);
133
      return script_class.class_type_.type_;
134
    }
135

136
    py::bool_ isClass = py::module::import("inspect").attr("isclass")(obj);
137
    if (!py::cast<bool>(isClass)) {
138
      return nullptr;
139
    }
140

141
    if (isNamedTupleClass(obj)) {
142
      return registerNamedTuple(obj, loc, rcb_);
143
    }
144

145
    auto qualifiedName = c10::QualifiedName(
146
        py::cast<std::string>(py::module::import("torch._jit_internal")
147
                                  .attr("_qualified_name")(obj)));
148

149
    return get_python_cu()->get_type(qualifiedName);
150
  }
151

152
  TypePtr resolveType(const std::string& name, const SourceRange& loc)
153
      override {
154
    if (classType_ && name == classname_) {
155
      return classType_;
156
    }
157
    pybind11::gil_scoped_acquire ag;
158
    py::object obj = rcb_(name);
159
    if (obj.is_none()) {
160
      return nullptr;
161
    }
162

163
    auto annotation_type =
164
        py::module::import("torch.jit.annotations")
165
            .attr("try_ann_to_type")(obj, loc, py::cpp_function(rcb_));
166
    if (!annotation_type.is_none()) {
167
      return py::cast<TypePtr>(annotation_type);
168
    }
169
    return resolveTypeFromObject(obj, loc);
170
  }
171

172
 private:
173
  ResolutionCallback rcb_;
174
  std::string classname_;
175
  ClassTypePtr classType_;
176
};
177

178
std::shared_ptr<PythonResolver> pythonResolver(const ResolutionCallback& rcb) {
179
  return std::make_shared<PythonResolver>(rcb);
180
}
181
std::shared_ptr<PythonResolver> pythonResolver(
182
    const ResolutionCallback& rcb,
183
    std::string classname,
184
    ClassTypePtr classType) {
185
  return std::make_shared<PythonResolver>(
186
      rcb, std::move(classname), std::move(classType));
187
}
188

189
void checkOverloadDecl(const Decl& new_decl, const Decl& old_decl) {
190
  const auto& new_params = new_decl.params();
191
  const auto& old_params = old_decl.params();
192

193
  // TODO. same number of parameters not strictly necessary.
194
  TORCH_INTERNAL_ASSERT(
195
      new_params.size() == old_params.size(),
196
      "Overload must have same number of parameters\n",
197
      new_decl.range(),
198
      old_decl.range());
199
  for (const auto i : c10::irange(new_decl.params().size())) {
200
    TORCH_INTERNAL_ASSERT(
201
        new_params[i].ident().name() == old_params[i].ident().name(),
202
        "Overload parameters must have the same names\n",
203
        new_params[i].ident(),
204
        old_params[i].ident());
205
  }
206
}
207

208
c10::optional<IValue> tryCalculateDefaultParam(
209
    const Argument& arg,
210
    const py::object& def_value) {
211
  auto n = arg.N();
212
  auto list_type = arg.type()->cast<ListType>();
213
  try {
214
    if (n && *n > 0 && list_type) {
215
      // BroadcastingList, allow default values T for arg types List[T]
216
      return toIValue(def_value, list_type->getElementType());
217
    } else {
218
      return toIValue(def_value, arg.type());
219
    }
220
  } catch (...) {
221
    return c10::nullopt;
222
  }
223
}
224

225
// An overloaded function may have a default that does not subtype all overloads
226
// @overload
227
// def foo(x: str)
228
// def foo(x=1)
229
FunctionDefaults calcOverloadedFunctionDefaults(
230
    const FunctionSchema& schema,
231
    const FunctionDefaults& defaults) {
232
  FunctionDefaults updated_defaults;
233
  for (const auto& arg : schema.arguments()) {
234
    const std::string& arg_name = arg.name();
235
    auto value = defaults.find(arg_name);
236
    if (value == defaults.end()) {
237
      continue;
238
    }
239
    auto maybe_ivalue = tryCalculateDefaultParam(arg, value->second);
240
    if (maybe_ivalue) {
241
      updated_defaults[arg_name] = value->second;
242
    }
243
  }
244
  return updated_defaults;
245
}
246

247
} // namespace
248

249
bool checkMutableFunctionDefault(const py::object& def_arg) {
250
  if (py::isinstance<py::list>(def_arg) || py::isinstance<py::dict>(def_arg)) {
251
    return true;
252
  }
253
  if (py::isinstance<py::tuple>(def_arg)) {
254
    auto pytuple = def_arg.cast<py::tuple>();
255
    for (py::handle t : pytuple) {
256
      py::object obj = py::reinterpret_borrow<py::object>(t);
257
      if (checkMutableFunctionDefault(obj)) {
258
        return true;
259
      }
260
    }
261
  }
262
  return false;
263
}
264

265
void checkMutableFunctionDefault(
266
    const SourceRange& range,
267
    const Argument& arg,
268
    const py::object& def_arg) {
269
  if (checkMutableFunctionDefault(def_arg) || arg.type()->cast<ClassType>()) {
270
    throw ErrorReport(range)
271
        << "Mutable default parameters are not supported because Python binds them to the function"
272
        << " and they persist across function calls.\n As a workaround, make the default None and instantiate"
273
        << " the default parameter within the body of the function. Found "
274
        << def_arg.get_type() << " on parameter " << arg.name();
275
  }
276
}
277

278
FunctionSchema getSchemaWithNameAndDefaults(
279
    const SourceRange& range,
280
    const FunctionSchema& schema,
281
    const at::optional<std::string>& new_name,
282
    const FunctionDefaults& default_args) {
283
  std::vector<Argument> new_args;
284
  for (auto& arg : schema.arguments()) {
285
    auto it = default_args.find(arg.name());
286
    if (it != default_args.end()) {
287
      checkMutableFunctionDefault(range, arg, it->second);
288
      c10::optional<IValue> value = tryCalculateDefaultParam(arg, it->second);
289
      if (!value) {
290
        ErrorReport error(range);
291
        error << "Expected a default value of type " << arg.type()->repr_str()
292
              << " on parameter \"" << arg.name() << "\".";
293
        if (arg.is_inferred_type()) {
294
          error << "Because \"" << arg.name()
295
                << "\" was not annotated with an explicit type "
296
                << "it is assumed to be type 'Tensor'.";
297
        }
298
        throw error;
299
      }
300
      new_args.emplace_back(
301
          arg.name(), arg.type(), arg.N(), *value, arg.kwarg_only());
302
    } else {
303
      new_args.push_back(arg);
304
    }
305
  }
306
  return FunctionSchema(
307
      new_name.value_or(schema.name()),
308
      schema.overload_name(),
309
      new_args,
310
      schema.returns(),
311
      schema.is_vararg(),
312
      schema.is_varret());
313
}
314

315
static Decl mergeDefaultsAndExtraParametersToOverloadDecl(
316
    const Decl& overload_decl,
317
    const Decl& impl_decl,
318
    const FunctionDefaults& defaults) {
319
  std::vector<Param> adjusted_params;
320
  const auto& overload_params = overload_decl.params();
321
  const auto& impl_params = impl_decl.params();
322

323
  // following PEP specification that the following should work:
324
  // @overload
325
  // def mouse_event(x1: int, y1: int) -> ClickEvent: ...
326
  // ...
327
  // def mouse_event(x1: int, y1: int, x2: Optional[int] = None, y2:
328
  // Optional[int] = None)
329
  TORCH_CHECK(
330
      overload_params.size() <= impl_params.size(),
331
      "Overload should not have more parameters than implementation function",
332
      overload_decl.range(),
333
      impl_decl.range());
334

335
  for (const auto i : c10::irange(overload_params.size())) {
336
    auto overload_name = overload_params[i].ident().name();
337
    auto impl_name = impl_params[i].ident().name();
338
    if (overload_name != impl_name) {
339
      throw ErrorReport(overload_decl.range())
340
          << "Overload parameters must have the same names. "
341
          << "Found " << overload_name << " and " << impl_name
342
          << " on argument " << i;
343
    }
344
    adjusted_params.push_back(overload_params[i]);
345
  }
346
  for (size_t i = overload_params.size(); i < impl_params.size(); ++i) {
347
    if (!defaults.count(impl_params[i].ident().name())) {
348
      throw ErrorReport(impl_decl.range())
349
          << "Expected to find default parameter on argument"
350
          << impl_params[i].ident().name()
351
          << " because it is not defined on the overloaded declaration";
352
    }
353
    if (!impl_params[i].type().present()) {
354
      throw ErrorReport(impl_decl.range())
355
          << "Parameters not specified on the overloaded declaration must have a type annotation in the implementation function."
356
          << " Did not find type for param " << impl_params[i].ident().name();
357
    }
358
    adjusted_params.push_back(impl_params[i]);
359
  }
360
  return Decl::create(
361
      overload_decl.range(),
362
      List<Param>::create(overload_decl.range(), adjusted_params),
363
      overload_decl.return_type());
364
}
365

366
static StrongFunctionPtr script_compile_overloaded_function(
367
    const c10::QualifiedName& name,
368
    const Decl& overload_decl,
369
    const Def& implementation_def,
370
    const ResolutionCallback& rcb,
371
    const FunctionDefaults& implementation_defaults,
372
    const py::object& signature) {
373
  if (signature.is_none()) {
374
    throw ErrorReport(overload_decl.range())
375
        << "Must explicitly add type annotations to overloaded functions";
376
  }
377

378
  auto adjusted_decl = mergeDefaultsAndExtraParametersToOverloadDecl(
379
      overload_decl, implementation_def.decl(), implementation_defaults);
380
  auto new_def = implementation_def.withDecl(adjusted_decl);
381
  auto cu = get_python_cu();
382
  auto defined_functions = cu->define(
383
      QualifiedName(name.prefix()),
384
      /*properties=*/{},
385
      /*propResolvers=*/{},
386
      {new_def},
387
      {pythonResolver(rcb)},
388
      nullptr,
389
      true);
390
  TORCH_INTERNAL_ASSERT(defined_functions.size() == 1);
391
  auto& defined = defined_functions[0];
392
  FunctionDefaults updated_defaults = calcOverloadedFunctionDefaults(
393
      defined->getSchema(), implementation_defaults);
394
  defined->setSchema(getSchemaWithNameAndDefaults(
395
      new_def.range(),
396
      defined->getSchema(),
397
      new_def.name().name(),
398
      updated_defaults));
399
  StrongFunctionPtr ret(std::move(cu), defined);
400
  didFinishEmitFunction(ret);
401
  return ret;
402
}
403

404
static StrongFunctionPtr script_compile_function(
405
    const c10::QualifiedName& name,
406
    const Def& def,
407
    const FunctionDefaults& defaults,
408
    const ResolutionCallback& rcb) {
409
  auto cu = get_python_cu();
410
  auto defined_functions = cu->define(
411
      QualifiedName(name.prefix()),
412
      /*properties=*/{},
413
      /*propResolvers=*/{},
414
      {def},
415
      {pythonResolver(rcb)},
416
      nullptr,
417
      true);
418
  TORCH_INTERNAL_ASSERT(defined_functions.size() == 1);
419
  auto& defined = defined_functions[0];
420
  defined->setSchema(getSchemaWithNameAndDefaults(
421
      def.range(), defined->getSchema(), def.name().name(), defaults));
422
  StrongFunctionPtr ret(std::move(cu), defined);
423
  didFinishEmitFunction(ret);
424
  return ret;
425
}
426

427
struct VISIBILITY_HIDDEN ModuleSelf : public Self {
428
  ModuleSelf(std::shared_ptr<ConcreteModuleType> concreteType)
429
      : Self(), concreteType_(std::move(concreteType)) {}
430

431
  std::shared_ptr<SugaredValue> makeSugared(Value* v) const override {
432
    v->setType(getClassType());
433
    return std::make_shared<ModuleValue>(v, concreteType_);
434
  }
435

436
  ClassTypePtr getClassType() const override {
437
    return concreteType_->getJitType()->expect<ClassType>();
438
  }
439

440
 private:
441
  std::shared_ptr<ConcreteModuleType> concreteType_;
442
};
443

444
static std::shared_ptr<Graph> _propagate_shapes(
445
    Graph& graph,
446
    std::vector<at::Tensor> inputs,
447
    bool with_grad = false) {
448
  Stack stack(inputs.begin(), inputs.end());
449
  auto retval = graph.copy();
450
  setInputTensorTypes(*retval, stack, /*complete=*/false);
451
  PropagateInputShapes(retval);
452
  return retval;
453
}
454

455
static std::shared_ptr<Graph> _propagate_and_assign_input_shapes(
456
    Graph& graph,
457
    const std::vector<at::Tensor>& inputs,
458
    const std::vector<int>& param_count_list,
459
    bool with_grad = false,
460
    bool propagate = true) {
461
  auto retval = graph.copy();
462
  setInputTensorTypes(
463
      *retval, fmap<IValue>(inputs), /*complete=*/true, param_count_list);
464
  if (propagate) {
465
    PropagateInputShapes(retval);
466
  }
467
  return retval;
468
}
469

470
void addFunctionToModule(Module& module, const StrongFunctionPtr& func) {
471
  // Make a graph with a fake self argument
472
  auto graph = toGraphFunction(*func.function_).graph()->copy();
473
  auto v = graph->insertInput(0, "self");
474
  v->setType(module._ivalue()->type());
475
  const auto name = QualifiedName(*module.type()->name(), "forward");
476
  auto method =
477
      module._ivalue()->compilation_unit()->create_function(name, graph);
478
  module.type()->addMethod(method);
479
}
480

481
// this is used in our test suite to check that we correctly preserved type tags
482
bool ivalue_tags_match(const Module& lhs, const Module& rhs) {
483
  struct Work {
484
    IValue a;
485
    IValue b;
486
  };
487
  std::unordered_set<const void*> visited;
488
  std::vector<Work> work = {{lhs._ivalue(), rhs._ivalue()}};
489
  while (!work.empty()) {
490
    Work item = work.back();
491
    work.pop_back();
492
    if (item.a.isPtrType()) {
493
      // uncomment to debug type matching errors
494
      // std::cout << "MATCHING " << /*item.a <<*/ "(" << *item.a.type() << ") "
495
      //          << item.a.internalToPointer() << " " << /*item.b <<*/ " ("
496
      //          << *item.b.type() << ") " << item.b.internalToPointer() <<
497
      //          "\n";
498

499
      if (visited.count(item.a.internalToPointer())) {
500
        continue;
501
      }
502
      visited.emplace(item.a.internalToPointer());
503
    }
504
    if (!unshapedType(item.b.type())
505
             ->isSubtypeOf(unshapedType(item.b.type()))) {
506
      // Since named types are saved and loaded in the test suite, we cannot
507
      // expect them to be equal. We should still check their slots however.
508
      if (!item.a.type()->cast<c10::NamedType>()) {
509
        return false;
510
      }
511
    }
512
    // check tags for objects that contain subobjects
513
    if (item.a.isObject()) {
514
      auto ao = item.a.toObject();
515
      auto bo = item.b.toObject();
516
      for (size_t i = 0; i < ao->slots().size(); ++i) {
517
        work.emplace_back(Work{ao->slots().at(i), bo->slots().at(i)});
518
      }
519
    } else if (item.a.isTuple()) {
520
      auto at = item.a.toTuple();
521
      auto bt = item.b.toTuple();
522
      for (size_t i = 0; i < at->elements().size(); ++i) {
523
        work.emplace_back(Work{at->elements().at(i), bt->elements().at(i)});
524
      }
525
    } else if (item.a.isList()) {
526
      auto al = item.a.toList();
527
      auto bl = item.b.toList();
528
      for (const auto i : c10::irange(al.size())) {
529
        work.emplace_back(Work{al.get(i), bl.get(i)});
530
      }
531
    } else if (item.a.isGenericDict()) {
532
      auto ad = item.a.toGenericDict();
533
      auto bd = item.b.toGenericDict();
534
      for (auto& item : ad) {
535
        // Dictionaory keys cannot contain List/Dicts that require tags
536
        // so we do not have to check them.
537
        // Furthermore without ordered dicts it is expensive to find the
538
        // equivalent key
539
        work.emplace_back(Work{item.value(), bd.at(item.key())});
540
      }
541
    } else if (item.a.isFuture()) {
542
      auto af = item.a.toFuture();
543
      auto bf = item.b.toFuture();
544
      af->wait();
545
      bf->wait();
546
      work.emplace_back(Work{af->value(), bf->value()});
547
    }
548
  }
549

550
  return true;
551
}
552

553
// helper used to implement ._parameters, ._buffers, ._modules dicts
554
// inside of script nn.Module
555
template <typename Policy>
556
struct slot_dict_impl {
557
  slot_dict_impl(ModulePtr module) : module_(std::move(module)) {}
558
  bool contains(const std::string& name) const {
559
    if (auto slot = module_->type()->findAttributeSlot(name)) {
560
      if (Policy::valid(module_->type(), *slot, module_->getSlot(*slot))) {
561
        return true;
562
      }
563
    }
564
    return false;
565
  }
566

567
  std::vector<std::pair<std::string, py::object>> items() const {
568
    std::vector<std::pair<std::string, py::object>> result;
569
    for (size_t i = 0, N = module_->type()->numAttributes(); i < N; ++i) {
570
      if (Policy::valid(module_->type(), i, module_->getSlot(i))) {
571
        result.emplace_back(
572
            module_->type()->getAttributeName(i),
573
            toPyObject(module_->getSlot(i)));
574
      }
575
    }
576
    return result;
577
  }
578

579
  void setattr(const std::string& name, py::object value) {
580
    const TypePtr& type = module_->type()->getAttribute(name);
581
    Module(module_).setattr(name, toIValue(std::move(value), type));
582
  }
583

584
  py::object getattr(const std::string& name) {
585
    return toPyObject(Module(module_).attr(name));
586
  }
587

588
  static void bind(const py::module& m, const char* name) {
589
    py::class_<slot_dict_impl<Policy>>(m, name)
590
        .def(py::init(
591
            [](Module& m) { return slot_dict_impl<Policy>(m._ivalue()); }))
592
        .def("contains", &slot_dict_impl<Policy>::contains)
593
        .def("items", &slot_dict_impl<Policy>::items)
594
        .def("setattr", &slot_dict_impl<Policy>::setattr)
595
        .def("getattr", &slot_dict_impl<Policy>::getattr);
596
  }
597

598
 private:
599
  ModulePtr module_;
600
};
601

602
template <typename T>
603
py::list debugMakeList(const T& list) {
604
  py::list result;
605
  for (const auto& elem : list) {
606
    result.append(py::cast(elem));
607
  }
608
  return result;
609
}
610
template <typename T>
611
py::list debugMakeNamedList(const T& list) {
612
  py::list result;
613
  for (auto elem : list) {
614
    result.append(py::cast(std::make_pair(elem.name, elem.value)));
615
  }
616
  return result;
617
}
618
template <typename T>
619
py::set debugMakeSet(const T& list) {
620
  py::set result;
621
  for (const auto& elem : list) {
622
    result.add(py::cast(elem));
623
  }
624
  return result;
625
}
626

627
static py::dict _jit_debug_module_iterators(Module& module) {
628
  py::dict result;
629
  result["children"] = debugMakeList(module.children());
630
  result["named_children"] = debugMakeNamedList(module.named_children());
631
  result["modules"] = debugMakeList(module.modules());
632
  result["named_modules"] = debugMakeNamedList(module.named_modules());
633

634
  result["parameters"] = debugMakeList(module.parameters(false));
635
  result["named_parameters"] =
636
      debugMakeNamedList(module.named_parameters(false));
637
  result["parameters_r"] = debugMakeList(module.parameters(true));
638
  result["named_parameters_r"] =
639
      debugMakeNamedList(module.named_parameters(true));
640

641
  result["buffers"] = debugMakeList(module.buffers(false));
642
  result["named_buffers"] = debugMakeNamedList(module.named_buffers(false));
643
  result["buffers_r"] = debugMakeList(module.buffers(true));
644
  result["named_buffers_r"] = debugMakeNamedList(module.named_buffers(true));
645

646
  result["named_attributes"] =
647
      debugMakeNamedList(module.named_attributes(false));
648
  result["named_attributes_r"] =
649
      debugMakeNamedList(module.named_attributes(true));
650
  return result;
651
}
652

653
static constexpr std::array<const char*, 48> magic_method_names = {
654
    "__lt__",      "__le__",      "__eq__",        "__ne__",
655
    "__ge__",      "__gt__",      "__not__",       "__abs__",
656
    "__add__",     "__and__",     "__floordiv__",  "__index__",
657
    "__inv__",     "__invert__",  "__lshift__",    "__mod__",
658
    "__mul__",     "__matmul__",  "__neg__",       "__or__",
659
    "__pos__",     "__pow__",     "__rshift__",    "__sub__",
660
    "__truediv__", "__xor__",     "__concat__",    "__contains__",
661
    "__delitem__", "__getitem__", "__setitem__",   "__iadd__",
662
    "__iand__",    "__iconcat__", "__ifloordiv__", "__ilshift__",
663
    "__imod__",    "__imul__",    "__imatmul__",   "__ior__",
664
    "__ipow__",    "__irshift__", "__isub__",      "__itruediv__",
665
    "__ixor__",    "__str__",     "__len__",       "__repr__",
666
};
667

668
struct DeepCopyMemoTable {
669
  std::shared_ptr<IValue::HashAliasedIValueMap> map;
670
};
671

672
IValue pyIValueDeepcopy(const IValue& ivalue, const py::dict& memo) {
673
  if (!memo.contains(py::str("__torch_script_memo_table"))) {
674
    memo["__torch_script_memo_table"] =
675
        DeepCopyMemoTable{std::make_shared<IValue::HashAliasedIValueMap>()};
676
  }
677
  auto& ivalue_memo =
678
      *py::cast<DeepCopyMemoTable>(memo["__torch_script_memo_table"]).map;
679
  return ivalue.deepcopy(ivalue_memo);
680
}
681

682
ExtraFilesMap extra_files_from_python(const py::dict& pydict) {
683
  ExtraFilesMap r;
684
  for (const auto& it : pydict) {
685
    r[py::cast<std::string>(it.first)] = "";
686
  }
687
  return r;
688
}
689

690
void extra_files_to_python(const ExtraFilesMap& m, const py::dict& pydict) {
691
  // py::dict is pointer-like type so it gets modified despite const&
692
  for (const auto& it : m) {
693
    pydict[py::str(it.first)] = py::bytes(it.second);
694
  }
695
}
696

697
void pyCompilationUnitDefine(
698
    CompilationUnit& cu,
699
    const std::string& src,
700
    const ResolutionCallback* rcb,
701
    const uint32_t _frames_up) {
702
  if (rcb && *rcb) {
703
    cu.define(c10::nullopt, src, pythonResolver(*rcb), nullptr);
704
  } else {
705
    py::object py_default_rcb =
706
        py::module::import("torch._jit_internal")
707
            .attr("createResolutionCallbackFromFrame")(_frames_up);
708
    auto default_rcb = py_default_rcb.cast<ResolutionCallback>();
709
    cu.define(c10::nullopt, src, pythonResolver(default_rcb), nullptr);
710
  }
711
}
712

713
// This function will copy bytes into a shared_ptr of chars aligned
714
// at kFlatbufferDataAlignmentBytes boundary (currently 16).
715
// This is required because tensors need to be aligned at 16 bytes boundary.
716
static std::shared_ptr<char> copyStr(const std::string& bytes) {
717
  size_t size = (bytes.size() / kFlatbufferDataAlignmentBytes + 1) *
718
      kFlatbufferDataAlignmentBytes;
719
#ifdef _WIN32
720
  std::shared_ptr<char> bytes_copy(
721
      static_cast<char*>(_aligned_malloc(size, kFlatbufferDataAlignmentBytes)),
722
      _aligned_free);
723
#elif defined(__APPLE__)
724
  void* p;
725
  ::posix_memalign(&p, kFlatbufferDataAlignmentBytes, size);
726
  TORCH_INTERNAL_ASSERT(p, "Could not allocate memory for flatbuffer");
727
  std::shared_ptr<char> bytes_copy(static_cast<char*>(p), free);
728
#else
729
  std::shared_ptr<char> bytes_copy(
730
      static_cast<char*>(aligned_alloc(kFlatbufferDataAlignmentBytes, size)),
731
      free);
732
#endif
733
  memcpy(bytes_copy.get(), bytes.data(), bytes.size());
734
  return bytes_copy;
735
}
736

737
void initJitScriptBindings(PyObject* module) {
738
  auto m = py::handle(module).cast<py::module>();
739

740
  // NOLINTNEXTLINE(bugprone-unused-raii)
741
  py::class_<c10::Capsule>(m, "Capsule");
742

743
  auto object_class =
744
      py::class_<Object>(m, "ScriptObject")
745
          .def("_type", [](Object& o) { return o.type(); })
746
          .def(
747
              "_get_method",
748
              [](Object& self, const std::string& name) -> Method {
749
                return self.get_method(name);
750
              },
751
              py::keep_alive<0, 1>())
752
          .def(
753
              "setattr",
754
              [](Object& self, const std::string& name, py::object value) {
755
                if (self.type()->hasConstant(name)) {
756
                  TORCH_CHECK(
757
                      false,
758
                      "Can't set constant '",
759
                      name,
760
                      "' which has value:",
761
                      self.type()->getConstant(name));
762
                }
763
                TypePtr type = self.type()->getAttribute(name);
764
                try {
765
                  auto ivalue = toIValue(std::move(value), type);
766
                  self.setattr(name, ivalue);
767
                } catch (std::exception& e) {
768
                  throw py::cast_error(c10::str(
769
                      "Could not cast attribute '",
770
                      name,
771
                      "' to type ",
772
                      type->repr_str(),
773
                      ": ",
774
                      e.what()));
775
                }
776
              })
777
          .def(
778
              "getattr",
779
              [](Object& self, const std::string& name) {
780
                try {
781
                  return toPyObject(self.attr(name));
782
                } catch (const ObjectAttributeError& err) {
783
                  throw AttributeError("%s", err.what());
784
                }
785
              })
786
          .def(
787
              "__getattr__",
788
              [](Object& self, const std::string& name) -> py::object {
789
                try {
790
                  if (name == "__qualname__") {
791
                    return py::cast(self.type()->name()->name());
792
                  }
793
                  if (auto method = self.find_method(name)) {
794
                    return py::cast(*method);
795
                  }
796
                  if (self.has_property(name)) {
797
                    auto prop = self.get_property(name);
798
                    // wrap the Method into callable PyObject
799
                    auto getter_func = py::cast(prop.getter_func);
800
                    return getter_func();
801
                  }
802
                  return toPyObject(self.attr(name));
803
                } catch (const ObjectAttributeError& err) {
804
                  throw AttributeError("%s", err.what());
805
                }
806
              })
807
          .def(
808
              "__setattr__",
809
              [](Object& self, const std::string& name, py::object value) {
810
                try {
811
                  if (self.has_property(name)) {
812
                    auto prop = self.get_property(name);
813
                    if (!prop.setter_func.has_value()) {
814
                      TORCH_CHECK(false, "can't set attribute");
815
                    }
816
                    // wrap the Method into callable PyObject
817
                    auto setter_func = py::cast(prop.setter_func);
818
                    setter_func(value);
819
                    return;
820
                  }
821

822
                  if (self.type()->hasConstant(name)) {
823
                    TORCH_CHECK(
824
                        false,
825
                        "Can't set constant '",
826
                        name,
827
                        "' which has value:",
828
                        self.type()->getConstant(name));
829
                  }
830
                  TypePtr type = self.type()->getAttribute(name);
831
                  auto ivalue = toIValue(std::move(value), type);
832
                  self.setattr(name, ivalue);
833
                } catch (const ObjectAttributeError& err) {
834
                  throw AttributeError("%s", err.what());
835
                }
836
              })
837
          .def(
838
              "hasattr",
839
              [](Object& self, const std::string& name) {
840
                return self.hasattr(name);
841
              })
842
          .def(
843
              "_has_method",
844
              [](Object& self, const std::string& name) {
845
                return bool(self.find_method(name));
846
              })
847
          .def(
848
              "_method_names",
849
              [](Object& self) {
850
                return fmap(self.get_methods(), [](const Method& method) {
851
                  return method.name();
852
                });
853
              })
854
          .def(
855
              "_properties", [](Object& self) { return self.get_properties(); })
856
          .def("__copy__", &Object::copy)
857
          .def(
858
              "__hash__",
859
              [](const Object& self) {
860
                // Similar to Tensor's `__hash__`, which is `id()`.
861
                return std::hash<c10::ivalue::Object*>{}(self._ivalue().get());
862
              })
863
          .def(py::pickle(
864
              [](const Object& self)
865
                  -> std::tuple<py::object, std::string> { // __getstate__
866
                if (auto getstate_method = self.find_method("__getstate__")) {
867
                  auto object_state = toPyObject((*getstate_method)(Stack{}));
868
                  TORCH_INTERNAL_ASSERT(self.type()->name());
869
                  return std::make_tuple(
870
                      object_state, self.type()->name()->qualifiedName());
871
                }
872
                std::stringstream err;
873
                err << "Tried to serialize object ";
874
                if (auto qualname = self.type()->name()) {
875
                  err << qualname->qualifiedName() << " ";
876
                }
877
                err << "which does not have a __getstate__ method defined!";
878
                throw std::runtime_error(err.str());
879
              },
880
              [](const std::tuple<py::object, std::string>& state_tup)
881
                  -> Object {
882
                auto [state, qualname] = state_tup;
883
                auto class_type = getCustomClass(qualname);
884
                TORCH_CHECK(
885
                    class_type,
886
                    "Tried to deserialize class ",
887
                    qualname,
888
                    " which is not known to the runtime. "
889
                    "If this is a custom C++ class, make "
890
                    "sure the appropriate code is linked.");
891

892
                auto self = Object(c10::ivalue::Object::create(
893
                    c10::StrongTypePtr(
894
                        std::shared_ptr<torch::jit::CompilationUnit>(),
895
                        class_type),
896
                    1));
897
                if (auto setstate_method = self.find_method("__setstate__")) {
898
                  auto setstate_schema =
899
                      setstate_method->function().getSchema();
900
                  TORCH_INTERNAL_ASSERT(
901
                      setstate_schema.arguments().size() == 2,
902
                      "__setstate__ method for class ",
903
                      class_type->repr_str(),
904
                      " must have exactly 2 arguments!");
905
                  auto state_type = setstate_schema.arguments().at(1).type();
906
                  (*setstate_method)(Stack{toIValue(state, state_type)});
907
                  return self;
908
                }
909
                std::stringstream err;
910
                err << "Tried to deserialize object ";
911
                if (auto qualname = class_type->name()) {
912
                  err << qualname->qualifiedName() << " ";
913
                }
914
                err << "which does not have a __setstate__ method defined!";
915
                throw std::runtime_error(err.str());
916
              }));
917

918
  py::class_<Object::Property>(m, "ScriptObjectProperty")
919
      .def_property_readonly(
920
          "name", [](const Object::Property& self) { return self.name; })
921
      .def_property_readonly(
922
          "getter",
923
          [](const Object::Property& self) { return self.getter_func; })
924
      .def_property_readonly("setter", [](const Object::Property& self) {
925
        return self.setter_func;
926
      });
927

928
  // Special case __str__ and __repr__ to make sure we can print Objects/Modules
929
  // regardless of if the user defined __str__/__repr__
930
  using MagicMethodImplType = std::function<py::object(
931
      const Object& self, py::args args, py::kwargs kwargs)>;
932

933
  std::unordered_map<std::string, MagicMethodImplType> special_magic_methods;
934
  special_magic_methods.emplace(
935
      "__str__",
936
      [](const Object& self, py::args args, py::kwargs kwargs) -> py::object {
937
        auto method = self.find_method("__str__");
938
        if (!method) {
939
          return py::str("ScriptObject <" + self.type()->str() + ">");
940
        }
941
        return invokeScriptMethodFromPython(
942
            *method,
943
            // NOLINTNEXTLINE(performance-move-const-arg)
944
            std::move(args),
945
            // NOLINTNEXTLINE(performance-move-const-arg)
946
            std::move(kwargs));
947
      });
948

949
  special_magic_methods.emplace(
950
      "__repr__",
951
      [](const Object& self, py::args args, py::kwargs kwargs) -> py::object {
952
        auto method = self.find_method("__repr__");
953
        if (!method) {
954
          std::stringstream ss;
955
          ss << std::hex << static_cast<const void*>(&self);
956
          return py::str("<torch.ScriptObject object at " + ss.str() + ">");
957
        }
958
        return invokeScriptMethodFromPython(
959
            *method,
960
            // NOLINTNEXTLINE(performance-move-const-arg)
961
            std::move(args),
962
            // NOLINTNEXTLINE(performance-move-const-arg)
963
            std::move(kwargs));
964
      });
965

966
  for (const char* mm_name : magic_method_names) {
967
    if (special_magic_methods.count(mm_name)) {
968
      object_class.def(mm_name, special_magic_methods[mm_name]);
969
    } else {
970
      object_class.def(
971
          mm_name,
972
          [mm_name](const Object& self, py::args args, py::kwargs kwargs) {
973
            auto method = self.find_method(mm_name);
974
            if (!method) {
975
              throw c10::NotImplementedError(
976
                  "'%s' is not implemented for %s",
977
                  mm_name,
978
                  self.type()->str().c_str());
979
            }
980
            return invokeScriptMethodFromPython(
981
                *method,
982
                // NOLINTNEXTLINE(performance-move-const-arg)
983
                std::move(args),
984
                // NOLINTNEXTLINE(performance-move-const-arg)
985
                std::move(kwargs));
986
          });
987
    }
988
  }
989

990
  // NOLINTNEXTLINE(bugprone-unused-raii)
991
  py::class_<DeepCopyMemoTable>(m, "DeepCopyMemoTable");
992

993
  py::class_<UpgraderEntry>(m, "_UpgraderEntry")
994
      .def(py::init<int, std::string, std::string>())
995
      .def_property_readonly(
996
          "bumped_at_version",
997
          [](const UpgraderEntry& self) { return self.bumped_at_version; })
998
      .def_property_readonly(
999
          "upgrader_name",
1000
          [](const UpgraderEntry& self) { return self.upgrader_name; })
1001
      .def_property_readonly("old_schema", [](const UpgraderEntry& self) {
1002
        return self.old_schema;
1003
      });
1004

1005
  py::class_<UpgraderRange>(m, "_UpgraderRange")
1006
      .def(py::init<int, int>())
1007
      .def_property_readonly(
1008
          "min_version",
1009
          [](const UpgraderRange& self) { return self.min_version; })
1010
      .def_property_readonly("max_version", [](const UpgraderRange& self) {
1011
        return self.max_version;
1012
      });
1013

1014
  object_class.def(
1015
      "__deepcopy__", [](const Object& self, const py::dict& memo) {
1016
        return Object(
1017
            pyIValueDeepcopy(IValue(self._ivalue()), memo).toObject());
1018
      });
1019

1020
  // Used by torch.package to save ScriptModule objects in unified format.
1021
  py::class_<ScriptModuleSerializer>(m, "ScriptModuleSerializer")
1022
      .def(py::init<caffe2::serialize::PyTorchStreamWriter&>())
1023
      .def("serialize", &ScriptModuleSerializer::serialize_unified_format)
1024
      .def(
1025
          "write_files",
1026
          &ScriptModuleSerializer::writeFiles,
1027
          py::arg("code_dir") = ".data/ts_code/code/")
1028
      .def(
1029
          "storage_context",
1030
          &ScriptModuleSerializer::storage_context,
1031
          pybind11::return_value_policy::reference_internal);
1032

1033
  // Used by torch.package to coordinate sharing of storages between eager
1034
  // and ScriptModules.
1035
  py::class_<
1036
      SerializationStorageContext,
1037
      std::shared_ptr<SerializationStorageContext>>(
1038
      m, "SerializationStorageContext")
1039
      .def("has_storage", &SerializationStorageContext::hasStorage)
1040
      .def("get_or_add_storage", &SerializationStorageContext::getOrAddStorage);
1041

1042
  // torch.jit.ScriptModule is a subclass of this C++ object.
1043
  // Methods here are prefixed with _ since they should not be
1044
  // public.
1045
  py::class_<Module, Object>(m, "ScriptModule")
1046
      .def(py::init<std::string, std::shared_ptr<CompilationUnit>, bool>())
1047
      .def(
1048
          "save",
1049
          [](Module& m,
1050
             const std::string& filename,
1051
             const ExtraFilesMap& _extra_files = ExtraFilesMap()) {
1052
            m.save(filename, _extra_files);
1053
          },
1054
          py::arg("filename"),
1055
          py::arg("_extra_files") = ExtraFilesMap())
1056
      .def(
1057
          "save_to_buffer",
1058
          [](Module& m, const ExtraFilesMap& _extra_files = ExtraFilesMap()) {
1059
            std::ostringstream buf;
1060
            m.save(buf, _extra_files);
1061
            return py::bytes(buf.str());
1062
          },
1063
          py::arg("_extra_files") = ExtraFilesMap())
1064
      .def(
1065
          "_save_for_mobile",
1066
          [](Module& m,
1067
             const std::string& filename,
1068
             const ExtraFilesMap& _extra_files = ExtraFilesMap(),
1069
             bool _save_mobile_debug_info = false,
1070
             bool _use_flatbuffer = false) {
1071
            m._save_for_mobile(
1072
                filename,
1073
                _extra_files,
1074
                _save_mobile_debug_info,
1075
                _use_flatbuffer);
1076
          },
1077
          py::arg("filename"),
1078
          py::arg("_extra_files") = ExtraFilesMap(),
1079
          py::arg("_save_mobile_debug_info") = false,
1080
          py::arg("_use_flatbuffer") = false)
1081
      .def(
1082
          "_save_to_buffer_for_mobile",
1083
          [](Module& m,
1084
             const ExtraFilesMap& _extra_files = ExtraFilesMap(),
1085
             bool _save_mobile_debug_info = false,
1086
             bool _use_flatbuffer = false) {
1087
            std::ostringstream buf;
1088
            m._save_for_mobile(
1089
                buf, _extra_files, _save_mobile_debug_info, _use_flatbuffer);
1090
            return py::bytes(buf.str());
1091
          },
1092
          py::arg("_extra_files") = ExtraFilesMap(),
1093
          py::arg("_save_mobile_debug_info") = false,
1094
          py::arg("_use_flatbuffer") = false)
1095
      .def("_set_optimized", &Module::set_optimized)
1096
      .def(
1097
          "dump",
1098
          &Module::dump,
1099
          py::arg("code") = true,
1100
          py::arg("attrs") = true,
1101
          py::arg("params") = true)
1102
      .def(
1103
          "dump_to_str",
1104
          &Module::dump_to_str,
1105
          py::arg("code") = true,
1106
          py::arg("attrs") = true,
1107
          py::arg("params") = true)
1108
      .def(
1109
          "_replicate_for_data_parallel",
1110
          [](Module& module) {
1111
            const ModulePtr& obj = module._ivalue();
1112
            auto copy = c10::ivalue::Object::create(
1113
                c10::StrongTypePtr(obj->compilation_unit(), obj->type()),
1114
                obj->slots().size());
1115
            for (size_t i = 0; i < obj->slots().size(); ++i) {
1116
              copy->setSlot(i, obj->getSlot(i));
1117
            }
1118
            return Module(std::move(copy));
1119
          })
1120
      .def(
1121
          "get_debug_state",
1122
          [](Module& self) {
1123
            if (auto m = self.find_method("forward")) {
1124
              return m->get_executor().getDebugState();
1125
            }
1126
            throw std::runtime_error(
1127
                "Attempted to call get_debug_state on a Module without a compiled forward()");
1128
          })
1129
      .def(
1130
          "_define",
1131
          [](Module& m,
1132
             std::shared_ptr<ConcreteModuleType> concreteType,
1133
             const std::string& script,
1134
             const ResolutionCallback& rcb) {
1135
            const auto self = ModuleSelf(std::move(concreteType));
1136
            m._ivalue()->compilation_unit()->define(
1137
                *m.type()->name(), script, pythonResolver(rcb), &self);
1138
            didFinishEmitModule(m);
1139
          })
1140
      .def(
1141
          "_register_attribute",
1142
          [](Module& m,
1143
             const std::string& name,
1144
             const TypePtr& type,
1145
             py::handle value) {
1146
            m.register_attribute(name, type, toIValue(value, type));
1147
          })
1148
      .def(
1149
          "_create_method_from_trace",
1150
          [](Module& self,
1151
             const std::string& name,
1152
             const py::function& func,
1153
             const py::tuple& input_tuple,
1154
             const py::function& var_name_lookup_fn,
1155
             bool strict,
1156
             bool force_outplace,
1157
             const std::vector<std::string>& argument_names,
1158
             bool store_inputs) {
1159
            // prereq: Module's buffers and parameters are unique
1160
            // this was ensured in python before calling this function
1161
            auto typed_inputs = toTraceableStack(input_tuple);
1162

1163
            std::shared_ptr<Graph> graph =
1164
                std::get<0>(tracer::createGraphByTracing(
1165
                    func,
1166
                    typed_inputs,
1167
                    var_name_lookup_fn,
1168
                    strict,
1169
                    force_outplace,
1170
                    &self,
1171
                    argument_names));
1172
            const auto method_name = QualifiedName(*self.type()->name(), name);
1173
            auto fn = self._ivalue()->compilation_unit()->create_function(
1174
                method_name, graph);
1175
            self.type()->addMethod(fn);
1176
            if (store_inputs) {
1177
              self.store_traced_inputs(name, typed_inputs);
1178
            }
1179
            didFinishEmitModule(self);
1180
          },
1181
          py::arg("name"),
1182
          py::arg("func"),
1183
          py::arg("input_tuple"),
1184
          py::arg("var_name_lookup_fn"),
1185
          py::arg("strict"),
1186
          py::arg("force_outplace"),
1187
          py::arg("argument_names") = std::vector<std::string>(),
1188
          py::arg("store_inputs"))
1189
      .def(
1190
          "_create_method_from_trace_with_dict",
1191
          [](Module& self,
1192
             const std::string& name,
1193
             const py::function& func,
1194
             const py::dict& input_dict,
1195
             const py::function& var_name_lookup_fn,
1196
             bool strict,
1197
             bool force_outplace,
1198
             const std::vector<std::string>& argument_names,
1199
             bool store_inputs) {
1200
            // prereq: Module's buffers and parameters are unique
1201
            // this was ensured in python before calling this function
1202
            auto typed_inputs = toTraceableStack(input_dict);
1203

1204
            std::shared_ptr<Graph> graph =
1205
                std::get<0>(tracer::createGraphByTracingWithDict(
1206
                    func,
1207
                    input_dict,
1208
                    typed_inputs,
1209
                    var_name_lookup_fn,
1210
                    strict,
1211
                    force_outplace,
1212
                    &self,
1213
                    argument_names));
1214
            const auto method_name = QualifiedName(*self.type()->name(), name);
1215
            auto fn = self._ivalue()->compilation_unit()->create_function(
1216
                method_name, graph);
1217
            if (store_inputs) {
1218
              self.store_traced_inputs(name, typed_inputs);
1219
            }
1220
            self.type()->addMethod(fn);
1221
            didFinishEmitModule(self);
1222
          },
1223
          py::arg("name"),
1224
          py::arg("func"),
1225
          py::arg("input_dict"),
1226
          py::arg("var_name_lookup_fn"),
1227
          py::arg("strict"),
1228
          py::arg("force_outplace"),
1229
          py::arg("argument_names") = std::vector<std::string>(),
1230
          py::arg("store_inputs"))
1231
      .def(
1232
          "_get_forward_hooks",
1233
          [](const Module& m) {
1234
            std::vector<StrongFunctionPtr> funcs;
1235
            for (auto& hook : m.type()->getForwardHooks()) {
1236
              funcs.emplace_back(m.type()->compilation_unit(), hook);
1237
            }
1238
            return funcs;
1239
          })
1240
      .def(
1241
          "_get_forward_pre_hooks",
1242
          [](const Module& m) {
1243
            std::vector<StrongFunctionPtr> funcs;
1244
            for (auto& pre_hook : m.type()->getForwardPreHooks()) {
1245
              funcs.emplace_back(m.type()->compilation_unit(), pre_hook);
1246
            }
1247
            return funcs;
1248
          })
1249
      .def(
1250
          "_retrieve_traced_inputs",
1251
          [](const Module& m) {
1252
            return ScriptDict(m.retrieve_traced_inputs());
1253
          })
1254
      .def_property_readonly(
1255
          "code",
1256
          [](Module& self) {
1257
            std::vector<at::IValue> constants;
1258
            PrintDepsTable deps;
1259
            PythonPrint pp(constants, deps);
1260
            pp.printNamedType(self.type());
1261
            return pp.str();
1262
          })
1263
      .def_property_readonly(
1264
          "code_with_constants",
1265
          [](Module& self) {
1266
            std::vector<at::IValue> constants;
1267
            PrintDepsTable deps;
1268
            PythonPrint pp(constants, deps);
1269
            pp.printNamedType(self.type());
1270
            std::map<std::string, at::IValue> consts;
1271
            int i = 0;
1272
            for (auto const& constant : constants) {
1273
              consts["c" + std::to_string(i)] = constant;
1274
              i += 1;
1275
            }
1276
            return std::make_tuple(pp.str(), consts);
1277
          })
1278
      .def("apply", &Module::apply)
1279
      .def("__copy__", &Module::copy)
1280
      .def(
1281
          "__hash__",
1282
          [](const Module& self) {
1283
            // Similar to Tensor's `__hash__`, which is `id()`.
1284
            return std::hash<c10::ivalue::Object*>{}(self._ivalue().get());
1285
          })
1286
      .def(
1287
          "__eq__",
1288
          [](const Module& self, const py::object& other) {
1289
            // TODO: call UDF if it exists
1290
            if (!py::isinstance<Module>(other)) {
1291
              return false;
1292
            }
1293
            return self._ivalue().get() ==
1294
                py::cast<Module>(other)._ivalue().get();
1295
          })
1296
      .def(
1297
          "__deepcopy__",
1298
          [](const Module& self, const py::dict& memo) {
1299
            return Module(
1300
                pyIValueDeepcopy(IValue(self._ivalue()), memo).toObject());
1301
          })
1302
      .def("children", &Module::children)
1303
      .def_property_readonly("qualified_name", [](const Module& self) {
1304
        return self.type()->name()->qualifiedName();
1305
      });
1306

1307
  py::class_<mobile::Module>(m, "LiteScriptModule")
1308
      .def(py::init<
1309
           c10::intrusive_ptr<c10::ivalue::Object>,
1310
           std::shared_ptr<mobile::CompilationUnit>>())
1311
      .def(
1312
          "find_method",
1313
          [](mobile::Module& m, const std::string& method_name) {
1314
            auto method = m.find_method(method_name);
1315
            return method != c10::nullopt;
1316
          },
1317
          py::arg("method_name"))
1318
      .def(
1319
          "run_method",
1320
          [](mobile::Module& m,
1321
             const std::string& method_name,
1322
             const py::tuple& input_tuple) {
1323
            Stack stack;
1324
            for (auto& input : input_tuple) {
1325
              stack.push_back(toTypeInferredIValue(input));
1326
            }
1327
            return m.get_method(method_name)(stack);
1328
          },
1329
          py::arg("method_name"),
1330
          py::arg("input_tuple"))
1331
      .def(
1332
          "forward",
1333
          [](mobile::Module& m, const py::tuple& input_tuple) {
1334
            Stack stack;
1335
            for (auto& input : input_tuple) {
1336
              stack.push_back(toTypeInferredIValue(input));
1337
            }
1338
            return m.get_method("forward")(stack);
1339
          },
1340
          py::arg("input_tuple"));
1341

1342
  slot_dict_impl<detail::ParameterPolicy>::bind(m, "ParameterDict");
1343
  slot_dict_impl<detail::BufferPolicy>::bind(m, "BufferDict");
1344
  slot_dict_impl<detail::ModulePolicy>::bind(m, "ModuleDict");
1345

1346
  py::class_<ErrorReport, std::shared_ptr<ErrorReport>>(m, "ErrorReport")
1347
      .def(py::init<SourceRange>())
1348
      .def("what", &ErrorReport::what)
1349
      .def_static("call_stack", ErrorReport::current_call_stack);
1350

1351
  py::class_<CompilationUnit, std::shared_ptr<CompilationUnit>>(
1352
      m, "CompilationUnit")
1353
      .def(
1354
          py::init([](const std::string& lang, const uint32_t _frames_up) {
1355
            auto cu = std::make_shared<CompilationUnit>();
1356
            if (!lang.empty()) {
1357
              pyCompilationUnitDefine(*cu, lang, nullptr, _frames_up);
1358
            }
1359
            return cu;
1360
          }),
1361
          py::arg("lang") = "",
1362
          py::arg("_frames_up") = 0)
1363

1364
      .def(
1365
          "find_function",
1366
          [](std::shared_ptr<CompilationUnit> self, const std::string& name) {
1367
            auto fn = self->find_function(QualifiedName(name));
1368
            if (fn) {
1369
              return c10::optional<StrongFunctionPtr>(
1370
                  StrongFunctionPtr(std::move(self), fn));
1371
            } else {
1372
              return c10::optional<StrongFunctionPtr>(c10::nullopt);
1373
            }
1374
          })
1375
      .def(
1376
          "__getattr__",
1377
          [](std::shared_ptr<CompilationUnit> self, const std::string& name) {
1378
            auto fn = self->find_function(QualifiedName(name));
1379
            if (fn) {
1380
              return StrongFunctionPtr(std::move(self), fn);
1381
            } else {
1382
              throw AttributeError(
1383
                  "'CompilationUnit' has no attribute '%s'", name.c_str());
1384
            }
1385
          })
1386
      .def(
1387
          "get_functions",
1388
          [](const std::shared_ptr<CompilationUnit>& self) {
1389
            auto raw_functions = self->get_functions();
1390
            std::vector<StrongFunctionPtr> functions;
1391
            functions.reserve(raw_functions.size());
1392
            for (auto fn : raw_functions) {
1393
              if (fn) {
1394
                functions.emplace_back(self, fn);
1395
              }
1396
            }
1397
            return functions;
1398
          })
1399
      .def("set_optimized", &CompilationUnit::set_optimized)
1400
      .def(
1401
          "define",
1402
          pyCompilationUnitDefine,
1403
          py::arg("src"),
1404
          py::arg("rcb") = nullptr,
1405
          py::arg("_frames_up") = 0)
1406
      .def(
1407
          "create_function",
1408
          [](std::shared_ptr<CompilationUnit>& self,
1409
             const std::string& qualified_name,
1410
             std::shared_ptr<Graph> graph,
1411
             bool should_mangle) {
1412
            Function* fn = self->create_function(
1413
                qualified_name, std::move(graph), should_mangle);
1414
            return StrongFunctionPtr(std::move(self), fn);
1415
          },
1416
          py::arg("qualified_name"),
1417
          py::arg("graph"),
1418
          py::arg("should_mangle") = false)
1419
      .def(
1420
          "get_interface",
1421
          [](const std::shared_ptr<CompilationUnit>& self,
1422
             const std::string& name) { return self->get_interface(name); })
1423
      .def(
1424
          "get_class",
1425
          [](const std::shared_ptr<CompilationUnit>& self,
1426
             const std::string& name) { return self->get_class(name); })
1427
      .def(
1428
          "drop_all_functions",
1429
          [](const std::shared_ptr<CompilationUnit>& self) {
1430
            self->drop_all_functions();
1431
          });
1432

1433
  py::class_<StrongFunctionPtr>(m, "ScriptFunction", py::dynamic_attr())
1434
      .def(
1435
          "__call__",
1436
          [](py::args args, py::kwargs kwargs) {
1437
            HANDLE_TH_ERRORS
1438
            // see: [pybind11 varargs]
1439
            auto strongPtr = py::cast<StrongFunctionPtr>(args[0]);
1440
            Function& callee = *strongPtr.function_;
1441
            py::object result = invokeScriptFunctionFromPython(
1442
                callee,
1443
                // NOLINTNEXTLINE(performance-move-const-arg)
1444
                tuple_slice(std::move(args), 1),
1445
                // NOLINTNEXTLINE(performance-move-const-arg)
1446
                std::move(kwargs));
1447
            return result;
1448
            END_HANDLE_TH_ERRORS_PYBIND
1449
          })
1450
      .def(
1451
          "save",
1452
          [](const StrongFunctionPtr& self,
1453
             const std::string& filename,
1454
             const ExtraFilesMap& _extra_files = ExtraFilesMap()) {
1455
            Module module("__torch__.PlaceholderModule");
1456
            // [issue 27343]
1457
            // Modules have 'training' attributes by default, but due to
1458
            // https://github.com/pytorch/pytorch/issues/27343, functions end
1459
            // up having a training attribute when they are loaded. This adds
1460
            // a fake 'training' attribute that shouldn't be used, but prevents
1461
            // jitter on saving and loading. Once that issue is fixed this can
1462
            // be deleted.
1463
            module.register_attribute("training", BoolType::get(), true);
1464
            addFunctionToModule(module, self);
1465
            module.save(filename, _extra_files);
1466
          },
1467
          py::arg("filename"),
1468
          py::arg("_extra_files") = ExtraFilesMap())
1469
      .def(
1470
          "save_to_buffer",
1471
          [](const StrongFunctionPtr& self,
1472
             const ExtraFilesMap& _extra_files = ExtraFilesMap()) {
1473
            std::ostringstream buf;
1474
            Module module("__torch__.PlaceholderModule");
1475
            // see [issue 27343]
1476
            module.register_attribute("training", BoolType::get(), true);
1477
            addFunctionToModule(module, self);
1478
            module.save(buf, _extra_files);
1479
            return py::bytes(buf.str());
1480
          },
1481
          py::arg("_extra_files") = ExtraFilesMap())
1482
      .def_property_readonly(
1483
          "graph",
1484
          [](const StrongFunctionPtr& self) {
1485
            return toGraphFunction(*self.function_).graph();
1486
          })
1487
      .def_property_readonly(
1488
          "inlined_graph",
1489
          [](const StrongFunctionPtr& self) {
1490
            auto g = toGraphFunction(*self.function_).graph()->copy();
1491
            Inline(*g);
1492
            return g;
1493
          })
1494
      .def_property_readonly(
1495
          "schema",
1496
          [](const StrongFunctionPtr& self) {
1497
            return self.function_->getSchema();
1498
          })
1499
      .def_property_readonly(
1500
          "code",
1501
          [](const StrongFunctionPtr& self) {
1502
            std::vector<at::IValue> constants;
1503
            PrintDepsTable deps;
1504

1505
            PythonPrint pp(constants, deps);
1506
            pp.printFunction(*self.function_);
1507
            return pp.str();
1508
          })
1509
      .def(
1510
          "get_debug_state",
1511
          [](const StrongFunctionPtr& self) {
1512
            return toGraphFunction(*self.function_)
1513
                .get_executor()
1514
                .getDebugState();
1515
          })
1516
      .def(
1517
          "_debug_flush_compilation_cache",
1518
          [](const StrongFunctionPtr& self) {
1519
            toGraphFunction(*self.function_)
1520
                .get_executor()
1521
                .debugFlushCompilationCache();
1522
          })
1523
      .def_property_readonly(
1524
          "name",
1525
          [](const StrongFunctionPtr& self) { return self.function_->name(); })
1526
      .def(
1527
          "_set_ignore_amp",
1528
          [](StrongFunctionPtr& self, bool ignore) {
1529
            auto fn = self.function_;
1530
            TORCH_INTERNAL_ASSERT(fn->isGraphFunction());
1531
            GraphFunction& g_fn = toGraphFunction(*fn);
1532
            g_fn._set_ignore_amp(ignore);
1533
          })
1534
      .def_property_readonly(
1535
          "qualified_name",
1536
          [](const StrongFunctionPtr& self) {
1537
            return self.function_->qualname().qualifiedName();
1538
          })
1539
      .def_property_readonly("__doc__", [](const StrongFunctionPtr& self) {
1540
        return self.function_->doc_string();
1541
      });
1542

1543
  py::class_<Method>(m, "ScriptMethod", py::dynamic_attr())
1544
      .def(
1545
          "__call__",
1546
          [](py::args args, py::kwargs kwargs) {
1547
            // see: [pybind11 varargs]
1548
            HANDLE_TH_ERRORS
1549
            Method& method = py::cast<Method&>(args[0]);
1550

1551
            return invokeScriptMethodFromPython(
1552
                method,
1553
                // NOLINTNEXTLINE(performance-move-const-arg)
1554
                tuple_slice(std::move(args), 1),
1555
                // NOLINTNEXTLINE(performance-move-const-arg)
1556
                std::move(kwargs));
1557
            END_HANDLE_TH_ERRORS_PYBIND
1558
          })
1559
      .def_property_readonly("graph", &Method::graph)
1560
      .def_property_readonly(
1561
          "inlined_graph",
1562
          [](const Method& self) {
1563
            auto g = toGraphFunction(self.function()).graph()->copy();
1564
            Inline(*g);
1565
            return g;
1566
          })
1567
      .def_property_readonly(
1568
          "schema", [](Method& m) { return m.function().getSchema(); })
1569
      .def_property_readonly("name", &Method::name)
1570
      .def_property_readonly(
1571
          "code",
1572
          [](Method& self) {
1573
            std::vector<at::IValue> constants;
1574
            PrintDepsTable deps;
1575
            PythonPrint pp(constants, deps);
1576
            pp.printMethod(self.function());
1577
            return pp.str();
1578
          })
1579
      .def(
1580
          "_debug_flush_compilation_cache",
1581
          [](Method& self) {
1582
            return self.get_executor().debugFlushCompilationCache();
1583
          })
1584
      .def_property_readonly(
1585
          "code_with_constants",
1586
          [](Method& self) {
1587
            std::vector<at::IValue> constants;
1588
            PrintDepsTable deps;
1589
            PythonPrint pp(constants, deps);
1590
            pp.printMethod(self.function());
1591
            std::map<std::string, at::IValue> consts;
1592
            int i = 0;
1593
            for (auto const& constant : constants) {
1594
              consts["c" + std::to_string(i)] = constant;
1595
              i += 1;
1596
            }
1597
            return std::make_tuple(pp.str(), consts);
1598
          })
1599
      .def_property_readonly("owner", &Method::owner)
1600
      .def_property_readonly("raw_owner", [](const Method& self) {
1601
        return Object(self.raw_owner());
1602
      });
1603
  m.def("_generate_upgraders_graph", &generate_upgraders_graph);
1604
  m.def(
1605
      "_calculate_package_version_based_on_upgraders",
1606
      &calculate_package_version_based_on_upgraders);
1607
  m.def("_get_version_calculator_flag", &get_version_calculator_flag);
1608
  m.def(
1609
      "_compile_graph_to_code_table",
1610
      [](const std::string& name, const std::shared_ptr<Graph>& graph) {
1611
        CompilationOptions options;
1612
        GraphFunction jitFunc(name, graph, nullptr);
1613
        auto mobileFunc = convertJitFunctionToMobileFunction(jitFunc, options);
1614
        return convertMobileFunctionToCodeTable(*mobileFunc, options);
1615
      });
1616
  m.def(
1617
      "_jit_script_compile",
1618
      [](const std::string& qualname,
1619
         const Def& def,
1620
         const ResolutionCallback& rcb,
1621
         const FunctionDefaults& defaults) {
1622
        C10_LOG_API_USAGE_ONCE("torch.script.compile");
1623
        const auto name = c10::QualifiedName(qualname);
1624
        TORCH_INTERNAL_ASSERT(name.name() == def.name().name());
1625
        return script_compile_function(name, def, defaults, rcb);
1626
      });
1627
  m.def(
1628
      "_jit_script_compile_overload",
1629
      [](const std::string& qualname,
1630
         const Decl& overload_decl,
1631
         const Def& implementation_def,
1632
         const ResolutionCallback& rcb,
1633
         const FunctionDefaults& implementation_defaults,
1634
         const py::object& signature) {
1635
        const auto name = c10::QualifiedName(qualname);
1636
        return script_compile_overloaded_function(
1637
            name,
1638
            overload_decl,
1639
            implementation_def,
1640
            rcb,
1641
            implementation_defaults,
1642
            signature);
1643
      });
1644
  m.def(
1645
      "_replace_overloaded_method_decl",
1646
      [](const Decl& overload_decl,
1647
         const Def& implementation_def,
1648
         const std::string& new_name) {
1649
        checkOverloadDecl(overload_decl, implementation_def.decl());
1650
        return implementation_def.withDecl(overload_decl).withName(new_name);
1651
      });
1652
  m.def(
1653
      "_create_function_from_trace",
1654
      [](const std::string& qualname,
1655
         const py::function& func,
1656
         const py::tuple& input_tuple,
1657
         const py::function& var_name_lookup_fn,
1658
         bool strict,
1659
         bool force_outplace,
1660
         const std::vector<std::string>& argument_names) {
1661
        auto typed_inputs = toTraceableStack(input_tuple);
1662
        std::shared_ptr<Graph> graph = std::get<0>(tracer::createGraphByTracing(
1663
            func,
1664
            typed_inputs,
1665
            var_name_lookup_fn,
1666
            strict,
1667
            force_outplace,
1668
            /*self=*/nullptr,
1669
            argument_names));
1670

1671
        auto cu = get_python_cu();
1672
        auto name = c10::QualifiedName(qualname);
1673
        auto result = cu->create_function(
1674
            std::move(name), std::move(graph), /*shouldMangle=*/true);
1675
        StrongFunctionPtr ret(std::move(cu), result);
1676
        didFinishEmitFunction(ret);
1677
        return ret;
1678
      },
1679
      py::arg("name"),
1680
      py::arg("func"),
1681
      py::arg("input_tuple"),
1682
      py::arg("var_name_lookup_fn"),
1683
      py::arg("strict"),
1684
      py::arg("force_outplace"),
1685
      py::arg("argument_names") = std::vector<std::string>());
1686

1687
  m.def(
1688
      "_create_function_from_trace_with_dict",
1689
      [](const std::string& qualname,
1690
         const py::function& func,
1691
         const py::dict& input_dict,
1692
         const py::function& var_name_lookup_fn,
1693
         bool strict,
1694
         bool force_outplace,
1695
         const std::vector<std::string>& argument_names) {
1696
        auto typed_inputs = toTraceableStack(input_dict);
1697
        std::shared_ptr<Graph> graph =
1698
            std::get<0>(tracer::createGraphByTracingWithDict(
1699
                func,
1700
                input_dict,
1701
                typed_inputs,
1702
                var_name_lookup_fn,
1703
                strict,
1704
                force_outplace,
1705
                /*self=*/nullptr,
1706
                argument_names));
1707

1708
        auto cu = get_python_cu();
1709
        auto name = c10::QualifiedName(qualname);
1710
        auto result = cu->create_function(
1711
            std::move(name), std::move(graph), /*shouldMangle=*/true);
1712
        StrongFunctionPtr ret(std::move(cu), result);
1713
        didFinishEmitFunction(ret);
1714
        return ret;
1715
      },
1716
      py::arg("name"),
1717
      py::arg("func"),
1718
      py::arg("input_dict"),
1719
      py::arg("var_name_lookup_fn"),
1720
      py::arg("strict"),
1721
      py::arg("force_outplace"),
1722
      py::arg("argument_names") = std::vector<std::string>());
1723

1724
  m.def(
1725
      "_jit_script_class_compile",
1726
      [](const std::string& qualifiedName,
1727
         const ClassDef& classDef,
1728
         const ClassMethodDefaults& defaults,
1729
         const ResolutionCallback& rcb) {
1730
        C10_LOG_API_USAGE_ONCE("torch.script.class");
1731
        if (classDef.superclass().present()) {
1732
          throw ErrorReport(classDef.range())
1733
              << "Torchscript does not support class inheritance.";
1734
        }
1735
        auto cu = get_python_cu();
1736
        auto classname = c10::QualifiedName(qualifiedName);
1737
        if (cu->get_type(classname) != nullptr) {
1738
          classname = cu->mangle(classname);
1739
        }
1740

1741
        auto classType = ClassType::create(
1742
            classname,
1743
            cu,
1744
            /* is_module = */ false,
1745
            /* doc_string = */ "",
1746
            getUnresolvedClassAttributes(classDef));
1747
        cu->register_type(classType);
1748
        std::vector<ResolverPtr> methodRcbs, propRcbs;
1749
        std::vector<Def> methodDefs;
1750
        std::vector<Property> props;
1751

1752
        for (const auto& def : classDef.body()) {
1753
          if (def.kind() != TK_DEF) {
1754
            throw ErrorReport(def.range())
1755
                << "Currently class bodies can only contain method "
1756
                   "definitions. File an issue on GitHub if you want "
1757
                   "something else!";
1758
          }
1759
          methodDefs.emplace_back(def);
1760
          methodRcbs.push_back(
1761
              pythonResolver(rcb, classDef.name().name(), classType));
1762
        }
1763

1764
        // Gather definitions for property getters and setters as well as
1765
        // corresponding resolution callbacks.
1766
        if (classDef.properties().present()) {
1767
          for (const auto& prop : classDef.properties().get()) {
1768
            props.emplace_back(prop);
1769
            propRcbs.push_back(
1770
                pythonResolver(rcb, classDef.name().name(), classType));
1771
          }
1772
        }
1773

1774
        const auto self = SimpleSelf(classType);
1775
        cu->define(classname, props, propRcbs, methodDefs, methodRcbs, &self);
1776

1777
        // Stitch in default arguments for methods. Properties don't need to be
1778
        // considered since there is no way to invoke setters without passing in
1779
        // a value.
1780
        auto defs_it = methodDefs.begin();
1781
        while (defs_it != methodDefs.end()) {
1782
          auto def_name = (*defs_it).name().name();
1783
          // If the method is not in the defaults map, assume there are
1784
          // no default arguments for it.
1785
          auto default_it = defaults.find(def_name);
1786
          if (default_it == defaults.end()) {
1787
            continue;
1788
          }
1789

1790
          const auto method_name =
1791
              QualifiedName(classname, (*defs_it).name().name());
1792
          auto& method = cu->get_function(method_name);
1793
          method.setSchema(getSchemaWithNameAndDefaults(
1794
              defs_it->range(),
1795
              method.getSchema(),
1796
              at::nullopt,
1797
              default_it->second));
1798
          ++defs_it;
1799
        }
1800
        return classType;
1801
      });
1802
  m.def(
1803
      "_jit_script_interface_compile",
1804
      [](const std::string& qualifiedName,
1805
         const ClassDef& classDef,
1806
         const ResolutionCallback& rcb,
1807
         bool is_module) {
1808
        auto cu = get_python_cu();
1809
        auto className = c10::QualifiedName(qualifiedName);
1810
        if (cu->get_type(className) != nullptr) {
1811
          className = cu->mangle(className);
1812
        }
1813

1814
        get_python_cu()->define_interface(
1815
            className, classDef, pythonResolver(rcb), is_module);
1816
        return className.qualifiedName();
1817
      });
1818

1819
  py::class_<torch::jit::ErrorReport::CallStack>(
1820
      m, "CallStack", py::dynamic_attr())
1821
      .def(py::init<const std::string&, const SourceRange&>());
1822

1823
  m.def("_parse_source_def", [](const std::string& src) {
1824
    Parser p(std::make_shared<Source>(src));
1825
    return Def(p.parseFunction(/*is_method=*/true));
1826
  });
1827
  m.def("parse_type_comment", [](const std::string& comment) {
1828
    Parser p(std::make_shared<Source>(comment));
1829
    return Decl(p.parseTypeComment());
1830
  });
1831

1832
  m.def("_get_upgraders_map_size", &get_upgraders_map_size);
1833
  m.def("_dump_upgraders_map", &dump_upgraders_map);
1834

1835
  m.def("_test_only_populate_upgraders", &test_only_populate_upgraders);
1836
  m.def("_test_only_remove_upgraders", &test_only_remove_upgraders);
1837

1838
  m.def("merge_type_from_type_comment", &mergeTypesFromTypeComment);
1839
  m.def("_get_max_operator_version", &getMaxOperatorVersion);
1840
  m.def("_get_operator_version_map", &get_operator_version_map);
1841
  m.def("_get_upgraders_entry_map", &get_upgraders_entry_map);
1842
  m.def("_get_upgrader_ranges", &getUpgradersRangeForOp);
1843
  m.def("_test_only_add_entry_to_op_version_map", &test_only_add_entry);
1844
  m.def("_test_only_remove_entry_to_op_version_map", &test_only_remove_entry);
1845
  m.def(
1846
      "import_ir_module",
1847
      [](std::shared_ptr<CompilationUnit> cu,
1848
         const std::string& filename,
1849
         py::object map_location,
1850
         const py::dict& extra_files,
1851
         bool restore_shapes = false) {
1852
        c10::optional<at::Device> optional_device;
1853
        if (!map_location.is_none()) {
1854
          AT_ASSERT(THPDevice_Check(map_location.ptr()));
1855
          optional_device =
1856
              reinterpret_cast<THPDevice*>(map_location.ptr())->device;
1857
        }
1858
        ExtraFilesMap extra_files_map = extra_files_from_python(extra_files);
1859
        auto ret = import_ir_module(
1860
            std::move(cu),
1861
            filename,
1862
            optional_device,
1863
            extra_files_map,
1864
            /*load_debug_files*/ true,
1865
            restore_shapes);
1866
        extra_files_to_python(extra_files_map, extra_files);
1867
        return ret;
1868
      });
1869
  m.def(
1870
      "_import_ir_module_from_package",
1871
      [](std::shared_ptr<CompilationUnit> cu,
1872
         std::shared_ptr<caffe2::serialize::PyTorchStreamReader> reader,
1873
         std::shared_ptr<torch::jit::DeserializationStorageContext>
1874
             storage_context,
1875
         py::object map_location,
1876
         std::string ts_id) {
1877
        c10::optional<at::Device> optional_device;
1878
        if (!map_location.is_none()) {
1879
          AT_ASSERT(THPDevice_Check(map_location.ptr()));
1880
          optional_device =
1881
              reinterpret_cast<THPDevice*>(map_location.ptr())->device;
1882
        }
1883
        return import_ir_module(
1884
            std::move(cu),
1885
            std::move(reader),
1886
            std::move(storage_context),
1887
            optional_device,
1888
            std::move(ts_id));
1889
      });
1890
  m.def(
1891
      "import_ir_module_from_buffer",
1892
      [](std::shared_ptr<CompilationUnit> cu,
1893
         const std::string& buffer,
1894
         py::object map_location,
1895
         const py::dict& extra_files,
1896
         bool restore_shapes = false) {
1897
        std::istringstream in(buffer);
1898
        c10::optional<at::Device> optional_device;
1899
        if (!map_location.is_none()) {
1900
          AT_ASSERT(THPDevice_Check(map_location.ptr()));
1901
          optional_device =
1902
              reinterpret_cast<THPDevice*>(map_location.ptr())->device;
1903
        }
1904
        ExtraFilesMap extra_files_map = extra_files_from_python(extra_files);
1905
        auto ret = import_ir_module(
1906
            std::move(cu),
1907
            in,
1908
            optional_device,
1909
            extra_files_map,
1910
            /*load_debug_files*/ true,
1911
            restore_shapes);
1912
        extra_files_to_python(extra_files_map, extra_files);
1913
        return ret;
1914
      });
1915
  m.def(
1916
      "_load_for_lite_interpreter",
1917
      [](const std::string& filename, py::object map_location) {
1918
        c10::optional<at::Device> optional_device;
1919
        if (!map_location.is_none()) {
1920
          AT_ASSERT(THPDevice_Check(map_location.ptr()));
1921
          optional_device =
1922
              reinterpret_cast<THPDevice*>(map_location.ptr())->device;
1923
        }
1924
        return _load_for_mobile(filename, optional_device);
1925
      });
1926
  m.def(
1927
      "_load_for_lite_interpreter_from_buffer",
1928
      [](const std::string& buffer, py::object map_location) {
1929
        std::istringstream in(buffer);
1930
        c10::optional<at::Device> optional_device;
1931
        if (!map_location.is_none()) {
1932
          AT_ASSERT(THPDevice_Check(map_location.ptr()));
1933
          optional_device =
1934
              reinterpret_cast<THPDevice*>(map_location.ptr())->device;
1935
        }
1936
        return _load_for_mobile(in, optional_device);
1937
      });
1938
  m.def(
1939
      "_backport_for_mobile",
1940
      [](const std::string& filename_input,
1941
         const std::string& filename_output,
1942
         const int64_t version) {
1943
        return _backport_for_mobile(filename_input, filename_output, version);
1944
      });
1945
  m.def(
1946
      "_backport_for_mobile_from_buffer",
1947
      [](const std::string& buffer_input,
1948
         const std::string& filename_output,
1949
         const int64_t version) {
1950
        std::istringstream in(buffer_input);
1951
        return _backport_for_mobile(in, filename_output, version);
1952
      });
1953
  m.def(
1954
      "_backport_for_mobile_to_buffer",
1955
      [](const std::string& filename_input, const int64_t version) {
1956
        std::ostringstream buffer_output;
1957
        bool success =
1958
            _backport_for_mobile(filename_input, buffer_output, version);
1959
        return success ? py::bytes(buffer_output.str()) : py::bytes("");
1960
      });
1961
  m.def(
1962
      "_backport_for_mobile_from_buffer_to_buffer",
1963
      [](const std::string& buffer_input, const int64_t version) {
1964
        std::istringstream in(buffer_input);
1965
        std::ostringstream buffer_output;
1966
        bool success = _backport_for_mobile(in, buffer_output, version);
1967
        return success ? py::bytes(buffer_output.str()) : py::bytes("");
1968
      });
1969
  m.def("_get_model_bytecode_version", [](const std::string& filename) {
1970
    return _get_model_bytecode_version(filename);
1971
  });
1972
  m.def(
1973
      "_get_model_extra_files",
1974
      [](const std::string& filename, const py::dict& py_extra_files) {
1975
        c10::optional<at::Device> optional_device;
1976
        ExtraFilesMap cpp_extra_files = ExtraFilesMap();
1977
        _load_for_mobile(filename, optional_device, cpp_extra_files);
1978
        extra_files_to_python(cpp_extra_files, py_extra_files);
1979

1980
        return py_extra_files;
1981
      });
1982
  m.def(
1983
      "_get_model_bytecode_version_from_buffer", [](const std::string& buffer) {
1984
        std::istringstream in(buffer);
1985
        return _get_model_bytecode_version(in);
1986
      });
1987
  m.def(
1988
      "_get_model_extra_files_from_buffer",
1989
      [](const std::string& buffer, const py::dict& py_extra_files) {
1990
        c10::optional<at::Device> optional_device;
1991
        ExtraFilesMap cpp_extra_files = ExtraFilesMap();
1992
        std::istringstream in(buffer);
1993
        _load_for_mobile(in, optional_device, cpp_extra_files);
1994
        extra_files_to_python(cpp_extra_files, py_extra_files);
1995

1996
        return py_extra_files;
1997
      });
1998
  m.def("_get_mobile_model_contained_types", [](const std::string& filename) {
1999
    return _get_mobile_model_contained_types(filename);
2000
  });
2001
  m.def(
2002
      "_get_mobile_model_contained_types_from_buffer",
2003
      [](const std::string& buffer) {
2004
        std::istringstream in(buffer);
2005
        return _get_mobile_model_contained_types(in);
2006
      });
2007
  m.def("_nn_module_to_mobile", [](const Module& module) {
2008
    CompilationOptions options;
2009
    return jitModuleToMobile(module, options);
2010
  });
2011
  py::class_<OperatorInfo>(m, "OperatorInfo")
2012
      .def_readonly("num_schema_args", &OperatorInfo::num_schema_args);
2013
  m.def("_get_model_ops_and_info", [](const std::string& filename) {
2014
    return _get_model_ops_and_info(filename);
2015
  });
2016
  m.def("_get_model_ops_and_info_from_buffer", [](const std::string& buffer) {
2017
    std::istringstream in(buffer);
2018
    return _get_model_ops_and_info(in);
2019
  });
2020
  m.def("_export_operator_list", [](torch::jit::mobile::Module& sm) {
2021
    return debugMakeSet(torch::jit::mobile::_export_operator_list(sm));
2022
  });
2023
  m.def(
2024
      "_quantize_ondevice_ptq_dynamic",
2025
      [](mobile::Module& m, const std::string& method_name) {
2026
        mobile::quantization::PTQQuanizationHelper ptq_helper;
2027
        ptq_helper.quantize_dynamic(m, method_name);
2028
      });
2029

2030
  m.def("_jit_set_emit_hooks", setEmitHooks);
2031
  m.def("_jit_get_emit_hooks", getEmitHooks);
2032
  m.def("_jit_clear_class_registry", []() {
2033
    get_python_cu()->_clear_python_cu();
2034
  });
2035
  m.def(
2036
      "_debug_set_autodiff_subgraph_inlining",
2037
      debugSetAutodiffSubgraphInlining);
2038
  m.def("_debug_set_fusion_group_inlining", debugSetFusionGroupInlining);
2039
  m.def("_debug_get_fusion_group_inlining", getFusionGroupInlining);
2040
  m.def("_propagate_shapes", _propagate_shapes);
2041
  m.def(
2042
      "_propagate_and_assign_input_shapes", _propagate_and_assign_input_shapes);
2043
  m.def(
2044
      "_last_executed_optimized_graph",
2045
      []() { return lastExecutedOptimizedGraph(); },
2046
      "Retrieve the optimized graph that was run the last time the graph executor ran on this thread");
2047
  m.def(
2048
      "_create_function_from_graph",
2049
      [](const std::string& qualname, std::shared_ptr<Graph> graph) {
2050
        // TODO this should go in the global Python CU
2051
        auto cu = std::make_shared<CompilationUnit>();
2052
        c10::QualifiedName name(qualname);
2053
        auto fn = cu->create_function(std::move(name), std::move(graph));
2054
        return StrongFunctionPtr(std::move(cu), fn);
2055
      });
2056
  m.def("_ivalue_tags_match", ivalue_tags_match);
2057
  m.def("_ivalue_debug_python_object", [](py::object py_obj) {
2058
    // convert to IValue first, IValue will incref via py::object
2059
    IValue pyobj_ivalue = toIValue(std::move(py_obj), PyObjectType::get());
2060
    // convert back to PyObject by borrowing the reference, which also
2061
    // incref, after the return of this function, IValue is out of scope
2062
    // which decref, so the return value is original refcount + 1
2063
    py::object ret = toPyObject(pyobj_ivalue);
2064
    return ret;
2065
  });
2066
  m.def("_jit_debug_module_iterators", _jit_debug_module_iterators);
2067

2068
  py::class_<testing::FileCheck>(m, "FileCheck")
2069
      .def(py::init<>())
2070
      .def("check", &testing::FileCheck::check)
2071
      .def("check_not", &testing::FileCheck::check_not)
2072
      .def("check_same", &testing::FileCheck::check_same)
2073
      .def("check_next", &testing::FileCheck::check_next)
2074
      .def("check_count", &testing::FileCheck::check_count)
2075
      .def("check_dag", &testing::FileCheck::check_dag)
2076
      .def(
2077
          "check_source_highlighted",
2078
          &testing::FileCheck::check_source_highlighted)
2079
      .def("check_regex", &testing::FileCheck::check_regex)
2080
      .def(
2081
          "check_count",
2082
          [](testing::FileCheck& f,
2083
             const std::string& str,
2084
             size_t count,
2085
             bool exactly) { return f.check_count(str, count, exactly); },
2086
          "Check Count",
2087
          py::arg("str"),
2088
          py::arg("count"),
2089
          py::arg("exactly") = false)
2090
      .def(
2091
          "run",
2092
          [](testing::FileCheck& f, const std::string& str) {
2093
            return f.run(str);
2094
          })
2095
      .def(
2096
          "run", [](testing::FileCheck& f, const Graph& g) { return f.run(g); })
2097
      .def(
2098
          "run",
2099
          [](testing::FileCheck& f,
2100
             const std::string& input,
2101
             const std::string& output) { return f.run(input, output); },
2102
          "Run",
2103
          py::arg("checks_file"),
2104
          py::arg("test_file"))
2105
      .def(
2106
          "run",
2107
          [](testing::FileCheck& f, const std::string& input, const Graph& g) {
2108
            return f.run(input, g);
2109
          },
2110
          "Run",
2111
          py::arg("checks_file"),
2112
          py::arg("graph"));
2113

2114
  m.def(
2115
      "_logging_set_logger",
2116
      [](logging::LoggerBase* logger) { return logging::setLogger(logger); },
2117
      py::return_value_policy::reference);
2118
  m.def("_set_graph_executor_optimize", [](bool optimize) {
2119
    setGraphExecutorOptimize(optimize);
2120
  });
2121

2122
  m.def(
2123
      "_get_graph_executor_optimize",
2124
      [](c10::optional<bool> new_setting = c10::nullopt) {
2125
        bool old_value = getGraphExecutorOptimize();
2126
        if (new_setting) {
2127
          setGraphExecutorOptimize(*new_setting);
2128
        }
2129
        return old_value;
2130
      },
2131
      py::arg("new_settings") = nullptr);
2132

2133
  m.def(
2134
      "_enable_mobile_interface_call_export",
2135
      &torch::jit::enableMobileInterfaceCallExport);
2136

2137
  m.def("_create_module_with_type", [](const ClassTypePtr& type) {
2138
     return Module(get_python_cu(), type);
2139
   }).def("_create_object_with_type", [](const ClassTypePtr& type) {
2140
    return Object(get_python_cu(), type);
2141
  });
2142

2143
  m.def("_export_opnames", [](Module& sm) {
2144
    return debugMakeList(torch::jit::export_opnames(sm));
2145
  });
2146

2147
  py::class_<
2148
      ConcreteModuleTypeBuilder,
2149
      std::shared_ptr<ConcreteModuleTypeBuilder>>(
2150
      m, "ConcreteModuleTypeBuilder")
2151
      .def(py::init<py::object>())
2152
      .def(
2153
          "add_constant",
2154
          [](ConcreteModuleTypeBuilder& self,
2155
             std::string name,
2156
             py::object value) {
2157
            self.addConstant(std::move(name), std::move(value));
2158
          })
2159
      .def("add_attribute", &ConcreteModuleTypeBuilder::addAttribute)
2160
      .def(
2161
          "add_function_attribute",
2162
          &ConcreteModuleTypeBuilder::addFunctionAttribute)
2163
      .def(
2164
          "add_builtin_function",
2165
          &ConcreteModuleTypeBuilder::addBuiltinFunction)
2166
      .def("add_forward_hook", &ConcreteModuleTypeBuilder::addForwardHook)
2167
      .def(
2168
          "add_forward_pre_hook", &ConcreteModuleTypeBuilder::addForwardPreHook)
2169
      .def("add_module", &ConcreteModuleTypeBuilder::addModule)
2170
      .def("add_overload", &ConcreteModuleTypeBuilder::addOverload)
2171
      .def("set_poisoned", &ConcreteModuleTypeBuilder::setPoisoned)
2172
      .def(
2173
          "add_failed_attribute",
2174
          &ConcreteModuleTypeBuilder::addFailedAttribute)
2175
      .def(
2176
          "add_ignored_attribute",
2177
          &ConcreteModuleTypeBuilder::addIgnoredAttribute)
2178
      .def(
2179
          "add_ignored_attributes",
2180
          [](ConcreteModuleTypeBuilder& self,
2181
             const std::vector<std::string>& names) {
2182
            for (auto& name : names) {
2183
              self.addIgnoredAttribute(name);
2184
            }
2185
          })
2186
      .def(
2187
          "set_module_dict",
2188
          [](ConcreteModuleTypeBuilder& self) {
2189
            self.setIterableModuleKind(IterableModuleKind::DICT);
2190
          })
2191
      .def("build", &ConcreteModuleTypeBuilder::build)
2192
      .def(
2193
          "equals",
2194
          [](const ConcreteModuleTypeBuilder& self,
2195
             const ConcreteModuleTypeBuilder& other) {
2196
            return self.equals(other);
2197
          })
2198
      .def(
2199
          "set_module_list",
2200
          [](ConcreteModuleTypeBuilder& self) {
2201
            self.setIterableModuleKind(IterableModuleKind::LIST);
2202
          })
2203
      .def(
2204
          "set_parameter_list",
2205
          [](ConcreteModuleTypeBuilder& self) {
2206
            self.setIterableModuleKind(IterableModuleKind::PARAMLIST);
2207
          })
2208
      .def("set_parameter_dict", [](ConcreteModuleTypeBuilder& self) {
2209
        self.setIterableModuleKind(IterableModuleKind::PARAMDICT);
2210
      });
2211

2212
  py::class_<ConcreteModuleType, std::shared_ptr<ConcreteModuleType>>(
2213
      m, "ConcreteModuleType")
2214
      .def_property_readonly("py_class", &ConcreteModuleType::getPyClass)
2215
      .def_property_readonly("jit_type", &ConcreteModuleType::getJitType)
2216
      .def_static("from_jit_type", &ConcreteModuleType::fromJitType)
2217
      .def("get_constants", &ConcreteModuleType::getConstantsPy)
2218
      .def("get_attributes", &ConcreteModuleType::getAttributesPy)
2219
      .def("get_modules", &ConcreteModuleType::getModulesPy)
2220
      .def("dump", &ConcreteModuleType::dump)
2221
      .def("is_ignored_attribute", &ConcreteModuleType::isIgnoredAttribute)
2222
      .def(
2223
          "equals",
2224
          [](const ConcreteModuleType& self, const ConcreteModuleType& other) {
2225
            return self.equals(other);
2226
          })
2227
      .def(
2228
          "equals",
2229
          [](const ConcreteModuleType& self,
2230
             const ConcreteModuleTypeBuilder& other) {
2231
            return self.equals(other);
2232
          })
2233
      .def(
2234
          "_create_methods_and_properties",
2235
          [](std::shared_ptr<ConcreteModuleType> concreteType,
2236
             const std::vector<Property>& properties,
2237
             const std::vector<ResolutionCallback>& propertyRcbs,
2238
             const std::vector<Def>& methodDefs,
2239
             const std::vector<ResolutionCallback>& methodRcbs,
2240
             const std::vector<FunctionDefaults>& defaults) {
2241
            TORCH_INTERNAL_ASSERT(methodDefs.size() == methodRcbs.size());
2242
            TORCH_INTERNAL_ASSERT(properties.size() == propertyRcbs.size());
2243

2244
            std::vector<ResolverPtr> methodResolvers, propertyResolvers;
2245
            methodResolvers.reserve(methodRcbs.size());
2246
            for (auto& callback : methodRcbs) {
2247
              methodResolvers.push_back(pythonResolver(callback));
2248
            }
2249

2250
            propertyResolvers.reserve(propertyRcbs.size());
2251
            for (auto& callback : propertyRcbs) {
2252
              propertyResolvers.push_back(pythonResolver(callback));
2253
            }
2254

2255
            const auto& selfType =
2256
                concreteType->getJitType()->expect<ClassType>();
2257
            const auto& prefix = selfType->name().value();
2258
            const auto self = ModuleSelf(std::move(concreteType));
2259
            auto cu = selfType->compilation_unit();
2260
            cu->define(
2261
                prefix,
2262
                properties,
2263
                propertyResolvers,
2264
                methodDefs,
2265
                methodResolvers,
2266
                &self);
2267
            // Stitch in default arguments for each Def if provided
2268
            auto defaults_it = defaults.begin();
2269
            auto defs_it = methodDefs.begin();
2270
            while (defs_it != methodDefs.end()) {
2271
              const auto method_name =
2272
                  QualifiedName(prefix, (*defs_it).name().name());
2273
              auto& method = cu->get_function(method_name);
2274
              method.setSchema(getSchemaWithNameAndDefaults(
2275
                  defs_it->range(),
2276
                  method.getSchema(),
2277
                  at::nullopt,
2278
                  *defaults_it));
2279
              ++defs_it;
2280
              ++defaults_it;
2281
            }
2282
          })
2283
      .def(
2284
          "_create_hooks",
2285
          [](std::shared_ptr<ConcreteModuleType> concreteType,
2286
             const std::vector<Def>& hookDefs,
2287
             const std::vector<ResolutionCallback>& hookRcbs,
2288
             const std::vector<Def>& preHookDefs,
2289
             const std::vector<ResolutionCallback>& preHookRcbs) {
2290
            TORCH_INTERNAL_ASSERT(hookDefs.size() == hookRcbs.size());
2291
            TORCH_INTERNAL_ASSERT(preHookDefs.size() == preHookRcbs.size());
2292

2293
            std::vector<ResolverPtr> hookResolvers, preHookResolvers;
2294

2295
            hookResolvers.reserve(hookRcbs.size());
2296
            for (auto& callback : hookRcbs) {
2297
              hookResolvers.push_back(pythonResolver(callback));
2298
            }
2299

2300
            preHookResolvers.reserve(preHookRcbs.size());
2301
            for (auto& callback : preHookRcbs) {
2302
              preHookResolvers.push_back(pythonResolver(callback));
2303
            }
2304

2305
            const auto& selfType =
2306
                concreteType->getJitType()->expect<ClassType>();
2307
            const auto& prefix = selfType->name().value();
2308
            const auto self = ModuleSelf(std::move(concreteType));
2309
            auto cu = selfType->compilation_unit();
2310
            cu->define_hooks(
2311
                prefix,
2312
                hookDefs,
2313
                hookResolvers,
2314
                preHookDefs,
2315
                preHookResolvers,
2316
                &self);
2317
          });
2318

2319
  m.def(
2320
      "_resolve_type",
2321
      [](const std::string& name,
2322
         const SourceRange& range,
2323
         const ResolutionCallback& rcb) {
2324
        return pythonResolver(rcb)->resolveType(name, range);
2325
      });
2326
  m.def(
2327
      "_resolve_type_from_object",
2328
      [](const py::object& obj,
2329
         const SourceRange& range,
2330
         const ResolutionCallback& rcb) {
2331
        return pythonResolver(rcb)->resolveTypeFromObject(obj, range);
2332
      });
2333

2334
  m.def(
2335
      "_run_emit_module_hook", [](const Module& m) { didFinishEmitModule(m); });
2336

2337
  m.def(
2338
      "_set_should_use_format_with_string_table",
2339
      setShouldUseFormatWithStringTable);
2340

2341
  // NOLINTNEXTLINE(bugprone-unused-raii)
2342
  py::class_<logging::LoggerBase, std::shared_ptr<logging::LoggerBase>>(
2343
      m, "LoggerBase");
2344
  py::enum_<logging::LockingLogger::AggregationType>(m, "AggregationType")
2345
      .value("SUM", logging::LockingLogger::AggregationType::SUM)
2346
      .value("AVG", logging::LockingLogger::AggregationType::AVG)
2347
      .export_values();
2348
  py::class_<
2349
      logging::LockingLogger,
2350
      logging::LoggerBase,
2351
      std::shared_ptr<logging::LockingLogger>>(m, "LockingLogger")
2352
      .def(py::init<>())
2353
      .def("set_aggregation_type", &logging::LockingLogger::setAggregationType)
2354
      .def("get_counter_val", &logging::LockingLogger::getCounterValue);
2355
  py::class_<
2356
      logging::NoopLogger,
2357
      logging::LoggerBase,
2358
      std::shared_ptr<logging::NoopLogger>>(m, "NoopLogger")
2359
      .def(py::init<>());
2360
  m.def("_jit_is_script_object", [](const py::object& obj) {
2361
    return py::isinstance<Object>(obj);
2362
  });
2363

2364
  m.def("_get_file_format", [](const std::string& path) {
2365
    switch (getFileFormat(path)) {
2366
      case FileFormat::FlatbufferFileFormat:
2367
        return "flatbuffer";
2368
      case FileFormat::ZipFileFormat:
2369
        return "zipfile";
2370
      default:
2371
        return "invalid";
2372
    }
2373
  });
2374

2375
  m.def(
2376
      "_save_parameters",
2377
      [](const std::map<std::string, at::Tensor>& map,
2378
         const std::string& filename,
2379
         bool use_flatbuffer = false) {
2380
        _save_parameters(map, filename, use_flatbuffer);
2381
      });
2382

2383
  m.def("_load_mobile_module_from_file", [](const std::string& filename) {
2384
    return torch::jit::load_mobile_module_from_file(filename);
2385
  });
2386
  m.def("_load_mobile_module_from_bytes", [](const std::string& bytes) {
2387
    auto bytes_copy = copyStr(bytes);
2388
    return torch::jit::parse_and_initialize_mobile_module(
2389
        bytes_copy, bytes.size());
2390
  });
2391
  m.def("_load_jit_module_from_file", [](const std::string& filename) {
2392
    ExtraFilesMap extra_files = ExtraFilesMap();
2393
    return torch::jit::load_jit_module_from_file(filename, extra_files);
2394
  });
2395
  m.def("_load_jit_module_from_bytes", [](const std::string& bytes) {
2396
    auto bytes_copy = copyStr(bytes);
2397
    ExtraFilesMap extra_files = ExtraFilesMap();
2398
    return torch::jit::parse_and_initialize_jit_module(
2399
        bytes_copy, bytes.size(), extra_files);
2400
  });
2401
  m.def(
2402
      "_save_mobile_module",
2403
      [](const torch::jit::mobile::Module& module,
2404
         const std::string& filename,
2405
         const ExtraFilesMap& _extra_files = ExtraFilesMap()) {
2406
        return torch::jit::save_mobile_module(module, filename, _extra_files);
2407
      });
2408
  m.def(
2409
      "_save_jit_module",
2410
      [](const torch::jit::Module& module,
2411
         const std::string& filename,
2412
         const ExtraFilesMap& _extra_files = ExtraFilesMap()) {
2413
        return torch::jit::save_jit_module(module, filename, _extra_files);
2414
      });
2415
  m.def(
2416
      "_save_mobile_module_to_bytes",
2417
      [](const torch::jit::mobile::Module& module,
2418
         const ExtraFilesMap& _extra_files = ExtraFilesMap()) {
2419
        auto detached_buffer =
2420
            torch::jit::save_mobile_module_to_bytes(module, _extra_files);
2421
        return py::bytes(
2422
            reinterpret_cast<char*>(detached_buffer->data()),
2423
            detached_buffer->size());
2424
      });
2425
  m.def(
2426
      "_save_jit_module_to_bytes",
2427
      [](const torch::jit::Module& module,
2428
         const ExtraFilesMap& _extra_files = ExtraFilesMap()) {
2429
        auto detached_buffer =
2430
            torch::jit::save_jit_module_to_bytes(module, _extra_files);
2431
        return py::bytes(
2432
            reinterpret_cast<char*>(detached_buffer->data()),
2433
            detached_buffer->size());
2434
      });
2435
  m.def("_get_module_info_from_flatbuffer", [](std::string flatbuffer_content) {
2436
    py::gil_scoped_acquire acquire;
2437
    py::dict result;
2438
    mobile::ModuleInfo minfo =
2439
        torch::jit::get_module_info_from_flatbuffer(&flatbuffer_content[0]);
2440
    result["bytecode_version"] = minfo.bytecode_version;
2441
    result["operator_version"] = minfo.operator_version;
2442
    result["function_names"] = minfo.function_names;
2443
    result["type_names"] = minfo.type_names;
2444
    result["opname_to_num_args"] = minfo.opname_to_num_args;
2445
    return result;
2446
  });
2447

2448
  m.def("_pickle_save", [](IValue v) {
2449
    auto bytes = torch::jit::pickle_save(std::move(v));
2450
    return py::bytes(bytes.data(), bytes.size());
2451
  });
2452

2453
  initScriptDictBindings(module);
2454
  initScriptListBindings(module);
2455
}
2456

2457
} // namespace torch::jit
2458

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.