1
#include <pybind11/pytypes.h>
2
#include <torch/csrc/utils/pybind.h>
3
#include <torch/csrc/utils/python_arg_parser.h>
4
#include <torch/csrc/utils/schema_info.h>
6
#include <ATen/core/operator_name.h>
7
#include <torch/csrc/jit/api/module.h>
8
#include <torch/csrc/jit/backends/backend_init.h>
9
#include <torch/csrc/jit/codegen/cuda/interface.h>
10
// #include <torch/csrc/jit/codegen/cuda/python_frontend/python_bindings.h>
11
#include <torch/csrc/jit/codegen/fuser/interface.h>
12
#include <torch/csrc/jit/codegen/fuser/kernel_cache.h>
13
#if (!defined(FBCODE_CAFFE2) && defined(BUILD_ONEDNN_GRAPH))
14
#include <torch/csrc/jit/codegen/onednn/interface.h>
16
#include <c10/core/SymNodeImpl.h>
17
#include <torch/csrc/jit/frontend/ir_emitter.h>
18
#include <torch/csrc/jit/frontend/tracer.h>
19
#include <torch/csrc/jit/ir/irparser.h>
20
#include <torch/csrc/jit/jit_log.h>
21
#include <torch/csrc/jit/passes/autocast.h>
22
#include <torch/csrc/jit/passes/batch_mm.h>
23
#include <torch/csrc/jit/passes/canonicalize.h>
24
#include <torch/csrc/jit/passes/canonicalize_graph_fuser_ops.h>
25
#include <torch/csrc/jit/passes/common_subexpression_elimination.h>
26
#include <torch/csrc/jit/passes/constant_pooling.h>
27
#include <torch/csrc/jit/passes/constant_propagation.h>
28
#include <torch/csrc/jit/passes/create_autodiff_subgraphs.h>
29
#include <torch/csrc/jit/passes/create_functional_graphs.h>
30
#include <torch/csrc/jit/passes/dbr_quantization/remove_redundant_aliases.h>
31
#include <torch/csrc/jit/passes/dead_code_elimination.h>
32
#include <torch/csrc/jit/passes/decompose_ops.h>
33
#include <torch/csrc/jit/passes/device_type_analysis.h>
34
#include <torch/csrc/jit/passes/dtype_analysis.h>
35
#include <torch/csrc/jit/passes/erase_number_types.h>
36
#include <torch/csrc/jit/passes/fold_conv_bn.h>
37
#include <torch/csrc/jit/passes/freeze_module.h>
38
#include <torch/csrc/jit/passes/frozen_concat_linear.h>
39
#include <torch/csrc/jit/passes/frozen_conv_add_relu_fusion.h>
40
#include <torch/csrc/jit/passes/frozen_conv_folding.h>
41
#include <torch/csrc/jit/passes/frozen_graph_optimizations.h>
42
#include <torch/csrc/jit/passes/frozen_linear_folding.h>
43
#include <torch/csrc/jit/passes/frozen_linear_transpose.h>
44
#include <torch/csrc/jit/passes/frozen_ops_to_mkldnn.h>
45
#include <torch/csrc/jit/passes/fuse_linear.h>
46
#include <torch/csrc/jit/passes/fuse_relu.h>
47
#include <torch/csrc/jit/passes/graph_fuser.h>
48
#include <torch/csrc/jit/passes/inline_fork_wait.h>
49
#include <torch/csrc/jit/passes/inliner.h>
50
#include <torch/csrc/jit/passes/integer_value_refinement.h>
51
#include <torch/csrc/jit/passes/loop_unrolling.h>
52
#include <torch/csrc/jit/passes/lower_graph.h>
53
#include <torch/csrc/jit/passes/lower_tuples.h>
54
#include <torch/csrc/jit/passes/metal_rewrite.h>
55
#include <torch/csrc/jit/passes/mobile_optimizer_type.h>
56
#include <torch/csrc/jit/passes/normalize_ops.h>
57
#include <torch/csrc/jit/passes/peephole.h>
58
#include <torch/csrc/jit/passes/peephole_list_idioms.h>
59
#include <torch/csrc/jit/passes/quantization/dedup_module_uses.h>
60
#include <torch/csrc/jit/passes/quantization/finalize.h>
61
#include <torch/csrc/jit/passes/quantization/fusion_passes.h>
62
#include <torch/csrc/jit/passes/quantization/insert_observers.h>
63
#include <torch/csrc/jit/passes/quantization/insert_quant_dequant.h>
64
#include <torch/csrc/jit/passes/quantization/quantization_type.h>
65
#include <torch/csrc/jit/passes/refine_tuple_types.h>
66
#include <torch/csrc/jit/passes/remove_dropout.h>
67
#include <torch/csrc/jit/passes/remove_expands.h>
68
#include <torch/csrc/jit/passes/remove_inplace_ops.h>
69
#include <torch/csrc/jit/passes/remove_mutation.h>
70
#include <torch/csrc/jit/passes/replacement_of_old_operators.h>
71
#include <torch/csrc/jit/passes/restore_mutation.h>
72
#include <torch/csrc/jit/passes/shape_analysis.h>
73
#include <torch/csrc/jit/passes/specialize_autogradzero.h>
74
#include <torch/csrc/jit/passes/subgraph_rewrite.h>
75
#include <torch/csrc/jit/passes/symbolic_shape_analysis.h>
76
#include <torch/csrc/jit/passes/tensorexpr_fuser.h>
77
#include <torch/csrc/jit/passes/utils/check_alias_annotation.h>
78
#include <torch/csrc/jit/passes/vulkan_rewrite.h>
79
#include <torch/csrc/jit/passes/xnnpack_rewrite.h>
80
#include <torch/csrc/jit/python/pybind_utils.h>
81
#include <torch/csrc/jit/python/python_arg_flatten.h>
82
#include <torch/csrc/jit/python/python_custom_class.h>
83
#include <torch/csrc/jit/python/python_ir.h>
84
#include <torch/csrc/jit/python/python_tracer.h>
85
#include <torch/csrc/jit/python/python_tree_views.h>
86
#include <torch/csrc/jit/python/script_init.h>
87
#include <torch/csrc/jit/python/utf8_decoding_ignore.h>
88
#include <torch/csrc/jit/runtime/argument_spec.h>
89
#include <torch/csrc/jit/runtime/autodiff.h>
90
#include <torch/csrc/jit/runtime/decomposition_registry.h>
91
#include <torch/csrc/jit/runtime/graph_executor.h>
92
#include <torch/csrc/jit/runtime/jit_exception.h>
93
#include <torch/csrc/jit/runtime/jit_trace.h>
94
#include <torch/csrc/jit/runtime/operator.h>
95
#include <torch/csrc/jit/runtime/print_handler.h>
96
#include <torch/csrc/jit/runtime/static/init.h>
97
#include <torch/csrc/jit/runtime/symbolic_shape_registry.h>
98
#include <torch/csrc/jit/serialization/export.h>
99
#include <torch/csrc/jit/serialization/import.h>
100
#include <torch/csrc/jit/tensorexpr/kernel.h>
101
#include <torch/csrc/jit/tensorexpr/tensorexpr_init.h>
102
#include <torch/csrc/utils/cpp_stacktraces.h>
104
#include <c10/macros/Export.h>
105
#include <c10/util/irange.h>
106
#include <c10/util/signal_handler.h>
107
#include <caffe2/serialize/inline_container.h>
109
#include <pybind11/cast.h>
110
#include <pybind11/functional.h>
111
#include <pybind11/iostream.h>
112
#include <pybind11/operators.h>
114
#include <torch/csrc/jit/runtime/profiling_graph_executor_impl.h>
122
namespace torch::jit {
126
using c10::FunctionSchema;
127
using c10::SchemaArgType;
128
using c10::SchemaArgument;
130
using caffe2::serialize::PyTorchStreamReader;
131
using caffe2::serialize::PyTorchStreamWriter;
132
using torch::utils::SchemaInfo;
136
using autograd::variable_list;
138
bool loadPythonClasses() {
139
// Leaving this code here, because it will likely be useful at some point
140
// PyObject *jit_module = PyImport_ImportModule("torch.jit");
141
// THPUtils_assert(jit_module, "class loader couldn't access "
142
//"torch.jit module");
143
// PyObject *jit_dict = PyModule_GetDict(jit_module);
148
static bool opAllowsNumbersAsTensors(c10::Symbol symbol) {
149
return symbol.is_prims() || symbol.is_nvprims() ||
151
torch::should_allow_numbers_as_tensors(symbol.toUnqualString()));
154
c10::optional<IValue> toTypeInferredIValueOptional(py::handle input) {
155
// Errors need to be caught here because toTypeInferredIValue errors out
156
// on various object types, but we want it to work with all types.
158
return toTypeInferredIValue(input);
159
} catch (const c10::Error& e) {
163
} // anonymous namespace
165
#if !defined(USE_ROCM)
166
TORCH_API void runJITCPPTests();
169
void initJITBindings(PyObject* module) {
170
auto m = py::handle(module).cast<py::module>();
171
auto jit = m.def_submodule("_jit");
173
// This is a static object, so we must leak the Python object
174
// "release()" is used here to preserve 1 refcount on the
175
// object, preventing it from ever being de-allocated by CPython.
176
static py::handle exc =
177
py::exception<JITException>(m, "JITException").release();
179
py::register_exception_translator([](std::exception_ptr p) {
182
std::rethrow_exception(p);
184
} catch (const JITException& e) {
185
// special handling of JITException, to set its python class name and msg
186
py::gil_scoped_acquire acquire;
187
const auto& className = e.getPythonClassName();
188
const auto& originalMsg = e.getOriginalMsg();
189
JITException::setCaughtOriginalMsg(originalMsg.value_or(""));
190
JITException::setCaughtPythonClassName(className.value_or(""));
191
// If we still had the py::exception<JITException> object, we could
192
// just call it. But we must get a handle to leak it and there is no
193
// way I can find to re-create it from the handle. So setting the
194
// exception manually
195
PyErr_SetString(exc.ptr(), e.what());
200
"_get_caught_jit_exception_class_name",
201
JITException::getCaughtPythonClassName);
203
"_get_caught_jit_exception_original_msg",
204
JITException::getCaughtOriginalMsg);
206
py::class_<python::IODescriptor> iodescriptor(
208
"IODescriptor"); // NOLINT(bugprone-unused-raii)
210
m.def("_jit_init", loadPythonClasses)
212
"_jit_debug_fuser_num_cached_kernel_specs",
213
torch::jit::fuser::debugNumCachedKernelSpecs)
214
.def("_jit_pass_lower_all_tuples", LowerAllTuples)
216
"_new_symbolic_shape_symbol",
217
[]() { return c10::ShapeSymbol::newSymbol().value(); })
219
"_jit_shape_compute_graph_for_node",
220
[](Node* n) -> c10::optional<std::shared_ptr<Graph>> {
221
if (!n->maybeSchema()) {
224
return shapeComputeGraphForSchema(n->schema());
227
"_jit_decomposition_graph_for_node",
228
[](Node* n) -> c10::optional<std::shared_ptr<Graph>> {
229
if (!n->maybeSchema()) {
232
return GetDecomposition(n->schema());
234
.def("_jit_pass_run_decompositions", RunDecompositions)
235
// using Node* here instead of Schema because looking up the schema
236
// and passing it in from Python will have a different pointer than the
237
// schema that is globally used for caching
239
"_jit_register_shape_compute_graph_for_node",
240
[](Node* n, std::shared_ptr<Graph>& graph) {
241
if (n->maybeSchema()) {
242
const FunctionSchema& schema = n->schema();
243
RegisterShapeComputeGraphForSchema(schema, graph);
245
TORCH_INTERNAL_ASSERT(false, "Expected schema", n);
249
"_jit_register_decomposition_for_schema",
250
[](const FunctionSchema& s, std::shared_ptr<Graph>& graph) {
251
// because this is invoked by python, the function schema *
252
// becomes different, and we need to find and reuse the
253
// one that is used for caching
255
findOperatorFor(c10::OperatorName(s.name(), s.overload_name()));
256
RegisterDecomposition(op->schema(), graph);
258
.def("_jit_pass_propagate_shapes_on_graph", PropagateShapesOnGraph)
260
"_jit_pass_propagate_shapes_on_graph_and_build_compute",
261
[](std::shared_ptr<Graph>& graph) {
262
return PropagateShapesAndBuildLargeShapeComputeGraph(
263
graph, *graph->nodes().begin(), *graph->nodes().end());
266
"_jit_pass_propagate_shapes_on_graph_and_build_compute",
267
[](std::shared_ptr<Graph>& graph, Node* beg) {
268
return PropagateShapesAndBuildLargeShapeComputeGraph(
269
graph, beg, *graph->nodes().end());
272
"_jit_pass_propagate_shapes_on_graph_and_build_compute",
273
PropagateShapesAndBuildLargeShapeComputeGraph)
274
.def("_jit_pass_integer_value_refinement", RefineIntegerValues)
276
"_jit_set_symbolic_shapes_test_mode",
277
&setSymbolicShapeAnalysisTestMode)
279
"_jit_symbolic_shapes_test_mode_enabled",
280
&symbolicShapeAnalysisTestModeEnabled)
281
.def("_jit_pass_autocast", Autocast)
282
.def("_jit_set_autocast_mode", &setAutocastMode)
283
.def("_jit_pass_fuse", FuseGraph)
285
"_jit_pass_replace_old_ops_with_upgraders",
286
[](std::shared_ptr<Graph>& g) {
287
return ReplaceOldOperatorsWithUpgraders(g);
291
[](std::shared_ptr<Graph>& g) {
292
return EliminateDeadCode(g->block()); // overload resolution
295
"_jit_pass_dce_allow_deleting_nodes_with_side_effects",
296
[](std::shared_ptr<Graph>& g) {
297
return EliminateDeadCode(
300
DCESideEffectPolicy::
301
ALLOW_DELETING_NODES_WITH_SIDE_EFFECTS); // overload
306
[](std::shared_ptr<Graph>& g) {
307
return EliminateCommonSubexpression(g); // overload resolution
310
"_jit_pass_fuse_quantized_add_relu",
311
[](std::shared_ptr<Graph>& g) {
312
return FuseQuantizedAddRelu(g); // overload resolution
315
"_jit_pass_insert_observers",
317
const std::string& method_name,
318
const py::dict& qconfig_dict,
320
int quant_type_int) {
321
auto dict = py::cast<std::unordered_map<
323
c10::optional<std::tuple<Module, Module>>>>(qconfig_dict);
324
auto quant_type = static_cast<QuantType>(quant_type_int);
325
return InsertObservers(
326
module, method_name, dict, inplace, quant_type);
329
py::arg("method_name"),
330
py::arg("qconfig_dict"),
332
py::arg("quant_type_int") = 1)
334
"_jit_pass_insert_observer_method_for_ondevice_ptq",
336
const std::string& method_name,
337
const py::dict& qconfig_dict,
339
int quant_type_int) {
340
auto dict = py::cast<std::unordered_map<
342
c10::optional<std::tuple<Module, Module>>>>(qconfig_dict);
343
auto quant_type = static_cast<QuantType>(quant_type_int);
344
return InsertObserversForOnDevicePTQ(
345
module, method_name, dict, inplace, quant_type);
348
py::arg("method_name"),
349
py::arg("qconfig_dict"),
351
py::arg("quant_type_int") = 1)
353
"_jit_pass_insert_quant_dequant",
355
const std::string& method_name,
358
int quant_type_int) {
359
auto quant_type = static_cast<QuantType>(quant_type_int);
360
return InsertQuantDeQuant(
361
module, method_name, inplace, debug, quant_type);
364
py::arg("method_name"),
367
py::arg("quant_type_int") = 1)
369
"_jit_pass_insert_quant_dequant_for_ondevice_ptq",
371
const std::string& method_name,
374
int quant_type_int) {
375
auto quant_type = static_cast<QuantType>(quant_type_int);
376
return InsertQuantDeQuantOnDevicePTQ(
377
module, method_name, inplace, debug, quant_type);
380
py::arg("method_name"),
383
py::arg("quant_type_int") = 1)
385
"_jit_pass_insert_prepack_unpack",
386
[](std::shared_ptr<Graph>& g) { return InsertPrepackUnpack(g); })
388
"_jit_pass_insert_prepack_unpack",
389
[](Module& module) { return InsertPrepackUnpack(module); })
391
"_jit_pass_quant_fusion",
392
[](std::shared_ptr<Graph>& g) { return QuantFusion(g); })
394
"_jit_pass_fold_convbn",
395
[](Module& module) { return FoldConvBatchNorm(module); })
397
"_jit_pass_dbr_quant_remove_redundant_aliases",
398
[](Module& module) { return DBRQuantRemoveRedundantAliases(module); })
402
std::vector<std::string>& preservedAttrs,
403
bool freezeInterfaces,
404
bool preserveParameters) {
405
return freeze_module(
406
module, preservedAttrs, freezeInterfaces, preserveParameters);
409
py::arg("preservedAttrs") = std::vector<std::string>(),
410
py::arg("freezeInterfaces") = true,
411
py::arg("preserveParameters") = false)
412
.def("_jit_pass_concat_frozen_linear", &FrozenConcatLinear)
413
.def("_jit_pass_fold_frozen_conv_bn", &FoldFrozenConvBatchnorm)
414
.def("_jit_pass_fold_frozen_conv_add_or_sub", &FoldFrozenConvAddOrSub)
415
.def("_jit_pass_fold_frozen_conv_mul_or_div", &FoldFrozenConvMulOrDiv)
416
.def("_jit_pass_fold_frozen_linear_bn", &FoldFrozenLinearBatchnorm)
417
.def("_jit_pass_convert_frozen_ops_to_mkldnn", &ConvertFrozenOpsToMKLDNN)
418
.def("_jit_pass_fuse_frozen_conv_add_relu", &FuseFrozenConvAddRelu)
419
.def("_jit_pass_transpose_frozen_linear", &FrozenLinearTranspose)
420
.def("_jit_pass_optimize_frozen_graph", &OptimizeFrozenGraph)
422
"_jit_pass_optimize_for_inference",
423
[](Module& module, std::vector<std::string> other_methods) {
424
optimize_for_inference(module, other_methods);
427
py::arg("other_methods") = std::vector<std::string>())
428
.def("_jit_pass_fuse_linear", &FuseLinear)
430
"_jit_pass_fuse_add_relu",
431
[](std::shared_ptr<Graph>& graph) { FuseAddRelu(graph); })
432
.def("_jit_pass_dedup_module_uses", &DedupModuleUses)
433
.def("_jit_pass_replicate_dequantize", &ReplicateDeQuant)
435
"_jit_pass_swap_functional_linear",
436
[](std::shared_ptr<Graph>& graph) { SwapFunctionalLinear(graph); })
438
"_jit_pass_swap_functional_linear",
439
[](Module& module) { SwapFunctionalLinear(module); })
441
"_jit_pass_quant_finalize",
444
const std::vector<std::string>& preserved_attrs) {
445
auto quant_type = static_cast<QuantType>(quant_type_int);
446
return Finalize(module, quant_type, preserved_attrs);
449
py::arg("quant_type_int") = 1,
450
py::arg("preserved_attrs") = std::vector<std::string>())
452
"_jit_pass_quant_finalize_for_ondevice_ptq",
455
const std::string& method_name) {
456
auto quant_type = static_cast<QuantType>(quant_type_int);
457
return FinalizeOnDevicePTQ(module, quant_type, method_name);
460
py::arg("quant_type_int") = 1,
461
py::arg("preserved_attrs") = std::vector<std::string>())
463
"_jit_pass_pattern_based_rewrite",
464
[](const Module& m) { return PatternBasedRewrite(m); })
466
"_jit_pass_custom_pattern_based_rewrite",
467
[](const std::string& pattern,
468
const std::string& fused_node_name,
470
SubgraphRewriter subgraph_rewriter;
471
subgraph_rewriter.RegisterRewritePattern(pattern, fused_node_name);
472
subgraph_rewriter.runOnModule(m);
475
"_jit_pass_custom_pattern_based_rewrite_graph",
476
[](const std::string& pattern,
477
const std::string& fused_node_name,
478
std::shared_ptr<Graph> g,
479
const std::vector<std::pair<std::string, std::string>>&
481
SubgraphRewriter subgraph_rewriter;
482
subgraph_rewriter.RegisterRewritePattern(
483
pattern, fused_node_name, value_name_pairs);
484
subgraph_rewriter.runOnGraph(g);
487
py::arg("fused_node_name"),
489
py::arg("value_name_pairs") =
490
std::vector<std::pair<std::string, std::string>>())
491
.def("_jit_pass_constant_pooling", ConstantPooling)
492
// RemoveInplaceOps is used by CoreML so it must be removed with care.
493
.def("_jit_pass_propagate_dtype", DtypePropagation)
494
.def("_jit_pass_propagate_device", DeviceTypePropagation)
496
"_jit_pass_remove_inplace_ops",
497
[](const std::shared_ptr<Graph>& g) { return RemoveInplaceOps(g); })
499
"_jit_pass_create_functional_graphs",
500
[](std::shared_ptr<Graph>& g) { return CreateFunctionalGraphs(g); })
502
"_jit_pass_remove_mutation",
503
[](std::shared_ptr<Graph>& g) {
504
RemoveListMutation(g);
505
return RemoveTensorMutation(g);
508
"_jit_pass_functional_to_inplace_activation",
509
[](std::shared_ptr<Graph>& g) {
510
return FunctionalToInplaceActivation(g);
513
"_jit_pass_inplace_to_functional_activation",
514
[](std::shared_ptr<Graph>& g) {
515
return InplaceToFunctionalActivation(g);
518
"_jit_pass_inline_functional_graphs",
519
[](std::shared_ptr<Graph>& g) { return InlineFunctionalGraphs(g); })
521
"_jit_pass_peephole",
522
[](const std::shared_ptr<Graph>& g, bool disable_shape_peepholes) {
523
return PeepholeOptimize(g, disable_shape_peepholes);
526
py::arg("disable_shape_peepholes") = false)
528
"_jit_pass_peephole_list_idioms",
529
[](const std::shared_ptr<Graph>& g, bool refine_list_len) {
530
return PeepholeOptimizeListIdioms(g, refine_list_len);
533
py::arg("refine_list_len") = false)
535
"_jit_pass_refine_integer_values",
536
[](std::shared_ptr<Graph>& g) { return RefineIntegerValues(g); })
538
"_jit_pass_fuse_addmm",
539
[](std::shared_ptr<Graph>& g) { return FuseAddMM(g); })
541
"_jit_pass_canonicalize",
542
[](const std::shared_ptr<Graph>& g, bool keep_unique_names = true) {
543
return Canonicalize(g, keep_unique_names);
546
py::arg("keep_unique_names") = true)
547
.def("_jit_pass_lint", LintGraph)
549
"_jit_pass_complete_shape_analysis",
550
[](const std::shared_ptr<Graph>& graph,
551
const py::tuple& inputs,
553
ArgumentSpecCreator arg_spec_creator(*graph);
555
stack.reserve(inputs.size()); // captures?
556
for (auto& obj : inputs) {
557
stack.push_back(toTypeInferredIValue(obj));
559
ArgumentSpec spec = arg_spec_creator.create(with_grad, stack);
560
arg_spec_creator.specializeTypes(*graph, spec);
561
// We only get partial specialization from the arg_spec_creator, but
562
// we want full shape specialization. The alternative would be to
563
// have a "complete type inference" function in ArguemntSpecCreator.
564
auto g_inputs = graph->inputs();
565
for (const auto i : c10::irange(inputs.size())) {
566
if (stack[i].isTensor()) {
567
g_inputs[i]->setType(stack[i].type());
570
PropagateInputShapes(graph);
573
"_jit_interpret_graph",
574
[](std::shared_ptr<Graph>& graph, const py::tuple& inputs) {
576
stack.reserve(inputs.size()); // captures?
577
for (auto& obj : inputs) {
578
stack.push_back(toTypeInferredIValue(obj));
580
auto g_inputs = graph->inputs();
581
for (const auto i : c10::irange(inputs.size())) {
582
if (stack[i].isTensor()) {
583
g_inputs[i]->setType(stack[i].type());
586
Code code(graph, "<on-demand-func>");
587
InterpreterState(code).run(stack);
588
return createPyObjectForStack(std::move(stack));
591
"Interpret a JIT graph with given inputs without running any optimization passes on it"))
594
[](std::shared_ptr<Graph>& graph, const py::tuple& inputs) {
596
stack.reserve(inputs.size()); // captures?
597
for (auto& obj : inputs) {
598
stack.push_back(toTypeInferredIValue(obj));
600
auto g_inputs = graph->inputs();
601
for (const auto i : c10::irange(inputs.size())) {
602
if (stack[i].isTensor()) {
603
g_inputs[i]->setType(stack[i].type());
606
return TraceGraph(graph, stack);
610
[](Module& model, const py::tuple& inputs) {
611
auto graph = model.get_method("forward").graph();
613
stack.reserve(inputs.size() + 1); // captures?
614
push(stack, model._ivalue());
615
for (auto& obj : inputs) {
616
stack.push_back(toTypeInferredIValue(obj));
618
auto traced = TraceGraph(graph, stack);
619
GRAPH_DUMP("Traced Graph", traced);
621
// the easiest way to replace a graph in a module is
622
// to remove all the nodes in the original graph
623
// clone everything from the traced one
624
graph->block()->clear();
625
graph->block()->cloneFrom(traced->block(), nullptr);
626
GRAPH_DUMP("Copied Graph", graph);
628
.def("_jit_pass_remove_expands", RemoveExpands)
629
.def("_jit_pass_erase_number_types", EraseNumberTypes)
630
.def("_jit_pass_inline_fork_wait", InlineForkWait)
631
.def("_jit_pass_inline", Inline)
633
"_jit_pass_lower_graph",
634
[](std::shared_ptr<Graph>& graph, const Module& self) {
635
return LowerGraph(*graph, self._ivalue());
637
.def("_jit_pass_loop_unrolling", UnrollLoops)
638
.def("_jit_pass_constant_loop_unrolling", UnrollConstantLoops)
640
"_jit_pass_constant_propagation_immutable_types",
641
[](std::shared_ptr<Graph>& g) {
642
return ConstantPropagationImmutableTypes(g);
645
"_jit_pass_constant_propagation",
646
[](std::shared_ptr<Graph>& g) { return ConstantPropagation(g); },
648
.def("_jit_pass_erase_shape_information", EraseShapeInformation)
650
"_jit_object_is_non_holding",
652
return toIValue(n.output())->toObject()->is_weak_compilation_ref();
655
"_jit_erase_non_input_shape_information",
656
[](std::shared_ptr<Graph>& g) {
657
std::vector<TypePtr> input_types;
658
for (Value* v : g->inputs()) {
659
if (auto tt = v->type()->cast<TensorType>()) {
660
input_types.emplace_back(tt);
662
input_types.emplace_back(nullptr);
665
EraseShapeInformation(g);
666
for (size_t i = 0; i < input_types.size(); ++i) {
667
if (input_types[i]) {
668
g->inputs().at(i)->setType(input_types[i]);
673
"_jit_pass_create_autodiff_subgraphs",
674
[](const std::shared_ptr<Graph>& graph, py::object threshold) {
675
if (threshold.is_none()) {
676
CreateAutodiffSubgraphs(graph);
678
CreateAutodiffSubgraphs(graph, py::cast<int>(threshold));
682
py::arg("threshold") = py::none())
683
#if defined(BUILDING_TESTS) && !defined(USE_ROCM)
685
"_jit_run_cpp_tests",
687
// We have to release the GIL inside this method, because if we
688
// happen to initialize the autograd engine in these tests, the
689
// newly spawned worker threads will try to initialize their
690
// PyThreadState*, and they need the GIL for this.
691
pybind11::gil_scoped_release _no_gil;
692
return runJITCPPTests();
694
.def("_jit_has_cpp_tests", []() { return true; })
695
.def("_has_tensorexpr_cpp_tests", []() { return true; })
697
.def("_jit_run_cpp_tests", []() { throw std::exception(); })
698
.def("_jit_has_cpp_tests", []() { return false; })
699
.def("_run_tensorexpr_cpp_tests", []() { throw std::exception(); })
700
.def("_has_tensorexpr_cpp_tests", []() { return false; })
704
[](py::handle& obj) {
705
auto res = python::flatten(obj);
706
return std::make_pair(res.vars, res.desc);
710
[](const autograd::variable_list& vars, python::IODescriptor& desc) {
711
return py::reinterpret_steal<py::object>(
712
python::unflatten(vars, desc));
714
.def("_jit_pass_canonicalize_graph_fuser_ops", CanonicalizeOps)
715
.def("_jit_pass_decompose_ops", DecomposeOps)
716
.def("_jit_pass_specialize_autogradzero", specializeAutogradZero)
717
.def("_jit_override_can_fuse_on_cpu", &overrideCanFuseOnCPU)
718
.def("_jit_override_can_fuse_on_gpu", &overrideCanFuseOnGPU)
719
.def("_jit_can_fuse_on_cpu", &canFuseOnCPU)
720
.def("_jit_can_fuse_on_gpu", &canFuseOnGPU)
721
.def("_jit_can_fuse_on_cpu_legacy", &canFuseOnCPULegacy)
722
.def("_jit_override_can_fuse_on_cpu_legacy", &overrideCanFuseOnCPULegacy)
724
"_jit_differentiate",
726
// the python binding slightly differs in semantics
727
// it makes a copy of the input Graph, and works on that
728
// jit::differentiate mutates the input Graph
729
auto g_clone = g.copy();
730
return differentiate(g_clone);
733
"_jit_check_alias_annotation",
734
[](const std::shared_ptr<Graph>& g,
735
const py::tuple& args,
736
const std::string& unqualified_op_name) {
737
auto stack = toTraceableStack(args);
738
checkAliasAnnotation(g, std::move(stack), unqualified_op_name);
740
#if (!defined(FBCODE_CAFFE2) && defined(BUILD_ONEDNN_GRAPH))
741
.def("_jit_set_llga_enabled", &RegisterLlgaFuseGraph::setEnabled)
742
.def("_jit_llga_enabled", &RegisterLlgaFuseGraph::isEnabled)
744
.def("_jit_set_llga_enabled", [](bool flag) { return false; })
745
.def("_jit_llga_enabled", []() { return false; })
748
"_jit_set_tracer_state_warn",
750
jit::tracer::getTracerStateWarnMode() = new_warn;
753
"_jit_get_tracer_state_warn",
755
bool current_tracer_warn = jit::tracer::getTracerStateWarnMode();
756
return current_tracer_warn;
759
"_jit_set_nvfuser_skip_node_kind",
760
[](const std::string& op_name, bool flip = true) {
762
"nvfuser is no longer supported in torch script, use _jit_set_nvfuser_skip_node_kind is deprecated and a no-op");
765
"_jit_set_nvfuser_enabled",
768
"nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op");
771
"_jit_nvfuser_can_be_enabled",
774
"nvfuser is no longer supported in torch script, use _jit_nvfuser_can_be_enabled is deprecated and a no-op");
777
"_jit_set_nvfuser_single_node_mode",
780
"nvfuser is no longer supported in torch script, use _jit_set_nvfuser_single_node_mode is deprecated and a no-op");
783
"_jit_nvfuser_single_node_mode",
786
"nvfuser is no longer supported in torch script, use _jit_nvfuser_single_node_mode is deprecated and a no-op");
789
"_jit_set_nvfuser_horizontal_mode",
792
"nvfuser is no longer supported in torch script, use _jit_set_nvfuser_horizontal_mode is deprecated and a no-op");
795
"_jit_nvfuser_horizontal_mode",
798
"nvfuser is no longer supported in torch script, use _jit_nvfuser_horizontal_mode is deprecated and a no-op");
801
"_jit_set_nvfuser_guard_mode",
804
"nvfuser is no longer supported in torch script, use _jit_set_nvfuser_guard_mode is deprecated and a no-op");
807
"_jit_nvfuser_enabled",
810
"nvfuser is no longer supported in torch script, use _jit_nvfuser_enabled is deprecated and a no-op");
813
"_jit_nvfuser_set_comparison_callback",
814
[](bool, py::function) {
816
"nvfuser is no longer supported in torch script, use _jit_nvfuser_set_comparison_callback is deprecated and a no-op");
819
"_jit_nvfuser_clear_comparison_callback",
822
"nvfuser is no longer supported in torch script, use _jit_nvfuser_clear_comparison_callback is deprecated and a no-op");
825
"_jit_set_profiling_mode",
826
[](bool profiling_flag) {
827
bool oldState = getProfilingMode();
828
getProfilingMode() = profiling_flag;
832
"_jit_set_profiling_executor",
833
[](bool profiling_flag) {
834
bool oldState = getExecutorMode();
835
getExecutorMode() = profiling_flag;
839
"_jit_set_num_profiled_runs",
841
size_t old_num = getNumProfiledRuns();
842
getNumProfiledRuns() = num;
846
"_jit_get_num_profiled_runs",
848
// pybind can't automatically bind to atomic size_t
849
size_t num_runs = getNumProfiledRuns();
853
"_jit_set_bailout_depth",
856
"Use _jit_set_fusion_strategy, bailout depth is deprecated. Setting to (STATIC, ",
859
size_t old_depth = getBailoutDepth();
860
FusionStrategy strat = {{FusionBehavior::STATIC, depth}};
861
setFusionStrategy(strat);
865
"_jit_set_fusion_strategy",
866
[](std::vector<std::pair<std::string, size_t>> strategy) {
867
FusionStrategy vec_conv;
868
for (const auto& pair : strategy) {
869
if (pair.first == "STATIC") {
870
vec_conv.emplace_back(FusionBehavior::STATIC, pair.second);
871
} else if (pair.first == "DYNAMIC") {
872
vec_conv.emplace_back(FusionBehavior::DYNAMIC, pair.second);
874
TORCH_INTERNAL_ASSERT(
876
"FusionBehavior only supported 'STATIC' or 'DYNAMIC', got: ",
880
auto old_strategy = getFusionStrategy();
882
fmap(old_strategy, [](std::pair<FusionBehavior, size_t> behav) {
883
return std::pair<std::string, size_t>(
884
behav.first == FusionBehavior::STATIC ? "STATIC"
888
setFusionStrategy(vec_conv);
892
"_jit_set_inline_everything_mode",
893
[](bool enabled) { getInlineEverythingMode() = enabled; })
895
"_jit_get_inline_everything_mode",
896
[]() { return getInlineEverythingMode(); })
898
"_jit_get_logging_option",
899
[]() { return ::torch::jit::get_jit_logging_levels(); })
901
"_jit_set_logging_option",
902
[](std::string loggingOption) -> void {
903
::torch::jit::set_jit_logging_levels(loggingOption);
906
"_jit_set_logging_stream",
907
[](std::string stream_name) -> void {
908
if (stream_name == "stdout") {
909
::torch::jit::set_jit_logging_output_stream(std::cout);
910
} else if (stream_name == "stderr") {
911
::torch::jit::set_jit_logging_output_stream(std::cerr);
913
std::cerr << "ERROR: only `stdout` and `stderr`"
914
<< "are supported as output options" << std::endl;
919
[](const at::Tensor& ten) -> int64_t {
920
return reinterpret_cast<int64_t>(
921
ten.storage().unsafeGetStorageImpl());
924
"_jit_try_infer_type",
925
[](py::object obj) -> InferredType {
926
return tryToInferType(std::move(obj));
929
"_jit_get_te_cuda_pointwise_loop_levels",
931
using namespace torch::jit::tensorexpr;
932
return getTECudaPointwiseLoopLevels();
935
"_jit_set_te_cuda_pointwise_loop_levels",
937
using namespace torch::jit::tensorexpr;
938
return getTECudaPointwiseLoopLevels() = level;
941
"_jit_get_te_cuda_pointwise_block_count",
943
using namespace torch::jit::tensorexpr;
944
return getTECudaPointwiseBlockCount();
947
"_jit_set_te_cuda_pointwise_block_count",
948
[](int block_count) {
949
using namespace torch::jit::tensorexpr;
950
return getTECudaPointwiseBlockCount() = block_count;
953
"_jit_get_te_cuda_pointwise_block_size",
955
using namespace torch::jit::tensorexpr;
956
return getTECudaPointwiseBlockSize();
959
"_jit_set_te_cuda_pointwise_block_size",
961
using namespace torch::jit::tensorexpr;
962
return getTECudaPointwiseBlockSize() = block_size;
964
.def("_jit_set_texpr_fuser_enabled", &setTensorExprFuserEnabled)
965
.def("_jit_texpr_fuser_enabled", &tensorExprFuserEnabled)
966
.def("_jit_texpr_fallback_allowed", &tensorexpr::fallbackAllowed)
967
.def("_jit_texpr_set_fallback_allowed", &tensorexpr::setFallbackAllowed)
968
.def("_jit_set_texpr_reductions_enabled", &setTexprReductionsEnabled)
970
"_jit_set_texpr_dynamic_shape_enabled",
971
&setTensorExprDynamicShapeFusionEnabled)
973
"_jit_texpr_dynamic_shape_enabled",
974
&tensorExprDynamicShapeFusionEnabled)
975
.def("_jit_texpr_reductions_enabled", &texprReductionsEnabled)
977
"_jit_set_te_generate_block_code",
978
[](bool gen_block_code) {
979
using namespace torch::jit::tensorexpr;
980
return getTEGenerateBlockCode() = gen_block_code;
983
"_jit_get_te_generate_block_code",
985
using namespace torch::jit::tensorexpr;
986
return getTEGenerateBlockCode();
989
"_jit_get_te_must_use_llvm_cpu",
991
using namespace torch::jit::tensorexpr;
992
return getTEMustUseLLVMOnCPU();
995
"_jit_set_te_must_use_llvm_cpu",
997
using namespace torch::jit::tensorexpr;
998
getTEMustUseLLVMOnCPU() = use_llvm;
1001
"_jit_cat_wo_conditionals",
1002
[](bool optimize_cat) {
1003
using namespace torch::jit::tensorexpr;
1004
getCatWoConditionals() = optimize_cat;
1007
"_jit_opt_conditionals",
1008
[](bool opt_conds) {
1009
using namespace torch::jit::tensorexpr;
1010
getOptConditionals() = opt_conds;
1015
#ifdef TORCH_ENABLE_LLVM
1022
"_jit_pass_fuse_tensorexprs",
1023
[](std::shared_ptr<Graph>& g) {
1025
RemoveTensorTypeSpecializations(g);
1028
"_jit_fuser_get_fused_kernel_code",
1029
[](Graph& g, const std::vector<at::Tensor>& inps) {
1030
return debugGetFusedKernelCode(g, inps);
1033
"_jit_pass_remove_dropout",
1034
[](script::Module& module) { return removeDropout(module); })
1036
"_jit_pass_refine_tuple_types",
1037
[](std::shared_ptr<Graph>& graph) { return RefineTupleTypes(graph); })
1039
"_jit_pass_transform_conv1d_to_conv2d",
1040
[](std::shared_ptr<Graph>& graph) {
1041
return transformConv1dToConv2d(graph);
1044
"_jit_pass_transform_conv1d_to_conv2d",
1045
[](script::Module& module) {
1046
return transformConv1dToConv2d(module);
1049
"_jit_pass_insert_prepacked_ops",
1050
[](std::shared_ptr<Graph>& graph) {
1051
return insertPrePackedOps(graph);
1054
"_jit_pass_insert_prepacked_ops",
1055
[](script::Module& module) { return insertPrePackedOps(module); })
1057
"_jit_pass_fuse_clamp_w_prepacked_linear_conv",
1058
[](script::Module& module) {
1059
return fusePrePackedLinearConvWithClamp(module);
1062
"_jit_pass_fold_prepacking_ops",
1063
[](script::Module& module) { return FoldPrePackingOps(module); })
1065
"_jit_pass_optimize_for_mobile",
1066
[](script::Module& module,
1067
std::set<MobileOptimizerType>& optimization_blocklist,
1068
std::vector<std::string>& preserved_methods) {
1069
return optimizeForMobile(
1070
module, optimization_blocklist, preserved_methods);
1073
"_hack_do_not_use_clone_module_with_class",
1074
[](script::Module& module,
1075
std::vector<std::string>& ignored_methods,
1076
std::vector<std::string>& ignored_attributes) {
1077
const bool inplace = false;
1078
const std::unordered_set<std::string> ignored_methods_set(
1079
ignored_methods.begin(), ignored_methods.end());
1080
const std::unordered_set<std::string> ignored_attributes_set(
1081
ignored_attributes.begin(), ignored_attributes.end());
1082
return module.clone(
1083
inplace, ignored_methods_set, ignored_attributes_set);
1086
"_jit_pass_vulkan_insert_prepacked_ops",
1087
[](std::shared_ptr<Graph>& graph) {
1088
return vulkanInsertPrePackedOps(graph);
1091
"_jit_pass_vulkan_insert_prepacked_ops",
1092
[](script::Module& module) {
1093
return vulkanInsertPrePackedOps(module);
1096
"_jit_pass_vulkan_fuse_clamp_w_prepacked_conv",
1097
[](script::Module& module) {
1098
return vulkanFusePrePackedConvWithClamp(module);
1101
"_jit_pass_vulkan_fold_prepacking_ops",
1102
[](script::Module& module) {
1103
return vulkanFoldPrePackingOps(module);
1106
"_jit_pass_vulkan_optimize_for_mobile",
1107
[](script::Module& module,
1108
std::set<MobileOptimizerType>& optimization_blocklist,
1109
std::vector<std::string>& preserved_methods) {
1110
return vulkanOptimizeForMobile(
1111
module, optimization_blocklist, preserved_methods);
1114
"_jit_pass_metal_insert_prepacked_ops",
1115
[](std::shared_ptr<Graph>& graph) {
1116
return metalInsertPrePackedOps(graph);
1119
"_jit_pass_metal_insert_prepacked_ops",
1120
[](script::Module& module) {
1121
return metalInsertPrePackedOps(module);
1124
"_jit_pass_metal_fuse_clamp_w_prepacked_conv",
1125
[](script::Module& module) {
1126
return metalFusePrePackedConvWithClamp(module);
1129
"_jit_pass_metal_fold_prepacking_ops",
1130
[](script::Module& module) { return metalFoldPrePackingOps(module); })
1132
"_jit_pass_metal_optimize_for_mobile",
1133
[](script::Module& module,
1134
std::vector<std::string>& preserved_methods) {
1135
return metalOptimizeForMobile(module, preserved_methods);
1138
"_jit_pass_filter_non_tensor_arguments",
1139
[](std::map<std::string, IValue> params) {
1140
std::map<std::string, at::Tensor> retval;
1141
for (auto& kv : params) {
1142
if (kv.second.isTensor()) {
1143
retval[kv.first] = std::move(kv.second).toTensor();
1148
.def("_jit_pass_batch_mm", BatchMM)
1150
"_jit_decay_packed_param_input_types",
1152
for (Value* i : g.inputs()) {
1155
"__torch__.torch.classes.quantized.Conv2dPackedParamsBase") ||
1158
"__torch__.torch.classes.quantized.Conv3dPackedParamsBase") ||
1161
"__torch__.torch.classes.quantized.LinearPackedParamsBase")) {
1162
// Dummy CompleteTensorType to appease ONNX validator.
1163
i->setType(TensorType::create(
1166
std::vector<int64_t>{1},
1167
std::vector<int64_t>{1},
1172
.def("_jit_set_utf8_decoding_ignore", &setUTF8DecodingIgnore);
1174
// NB: This isn't actually used for regular PyTorch symbolic tracing;
1175
// XLA is what needs this
1176
#define SYMNODE_UNARY(n) .def(#n, [](c10::SymNode a) { return a->n(); })
1177
#define SYMNODE_BINARY(n) \
1178
.def(#n, [](c10::SymNode a, c10::SymNode b) { return a->n(b); })
1179
#define SYMNODE_SIZES_STRIDES(n) \
1182
[](c10::SymNode a, \
1183
c10::ArrayRef<c10::SymNode> sizes, \
1184
c10::ArrayRef<c10::SymNode> strides) { \
1185
return a->n(sizes, strides); \
1187
auto symnode_class =
1188
py::class_<c10::SymNodeImpl, c10::SymNode>(m, "_SymNode")
1190
// These DO NOT install magic methods; the SymInt/SymFloat wrapper in
1191
// Python is responsible for this
1192
SYMNODE_UNARY(clone)
1193
SYMNODE_UNARY(is_int)
1194
SYMNODE_UNARY(is_float)
1195
SYMNODE_UNARY(is_bool)
1196
SYMNODE_UNARY(bool_)
1198
SYMNODE_UNARY(sym_float)
1202
SYMNODE_BINARY(truediv)
1204
SYMNODE_BINARY(floordiv)
1212
SYMNODE_BINARY(sym_min)
1213
SYMNODE_BINARY(sym_max)
1214
SYMNODE_BINARY(sym_and)
1215
SYMNODE_BINARY(sym_or)
1216
SYMNODE_UNARY(sym_not)
1218
SYMNODE_UNARY(floor)
1220
SYMNODE_SIZES_STRIDES(is_contiguous)
1221
SYMNODE_SIZES_STRIDES(is_channels_last_contiguous_2d)
1222
SYMNODE_SIZES_STRIDES(is_channels_last_contiguous_3d)
1223
SYMNODE_SIZES_STRIDES(is_channels_last_strides_2d)
1224
SYMNODE_SIZES_STRIDES(is_channels_last_strides_3d)
1225
SYMNODE_SIZES_STRIDES(is_non_overlapping_and_dense)
1228
[](c10::SymNode a, const char* file, int64_t line) {
1229
return a->guard_int(file, line);
1233
[](c10::SymNode a, const char* file, int64_t line) {
1234
return a->guard_bool(file, line);
1238
[](c10::SymNode a, const char* file, int64_t line) {
1239
return a->guard_float(file, line);
1243
[](c10::SymNode a, const char* file, int64_t line) {
1244
return a->expect_true(file, line);
1248
[](c10::SymNode a, const char* file, int64_t line) {
1249
return a->expect_size(file, line);
1252
"guard_size_oblivious",
1253
[](c10::SymNode a, const char* file, int64_t line) {
1254
return a->guard_size_oblivious(file, line);
1258
[](c10::SymNode a) {
1259
return a->has_hint();
1263
[](c10::SymNode a, int64_t b) {
1264
return a->wrap_int(b);
1268
[](c10::SymNode a, double b) {
1269
return a->wrap_float(b);
1273
[](c10::SymNode a, bool b) {
1274
return a->wrap_bool(b);
1278
[](c10::SymNode a) { return a->str(); })
1281
[](c10::SymNode a) { return a->str(); })
1284
[](const c10::SymNode& node){
1285
return node->is_constant();
1289
[](const c10::SymNode& node) {
1290
return node->is_nested_int();
1294
[](const c10::SymNode& node) {
1295
return node->is_symbolic();
1299
[](const c10::SymNode& node) {
1300
return node->nested_int();
1304
[](const c10::SymNode& node) {
1305
return node->nested_int_coeff();
1310
// NOLINTNEXTLINE(bugprone-unused-raii)
1311
py::class_<CompleteArgumentSpec>(m, "CompleteArgumentSpec")
1312
.def("__repr__", [](CompleteArgumentSpec& self) {
1313
std::ostringstream s;
1317
// NOLINTNEXTLINE(bugprone-unused-raii)
1318
py::class_<ArgumentSpec>(m, "ArgumentSpec");
1319
py::class_<Code>(m, "Code")
1321
"grad_executor_states",
1323
std::vector<GraphExecutorState> states;
1324
for (auto& e : c.grad_executors()) {
1325
states.emplace_back(e->getDebugState());
1330
"differentiable_op_executor_states",
1332
std::vector<GraphExecutorState> states;
1333
for (auto& e : c.diff_graph_op_executors()) {
1334
if (e->isOptimized()) {
1335
states.emplace_back(e->getDebugState());
1337
// we leave an empty entry for node that doesn't have an
1339
states.emplace_back();
1344
.def("num_bailouts", [](Code& c) { return c.num_bailouts(); })
1345
.def("request_bailout", [](Code& c, size_t index) {
1346
c.request_bailout(index);
1349
py::class_<ExecutionPlan>(m, "ExecutionPlan")
1350
.def_property_readonly("graph", [](ExecutionPlan& s) { return s.graph; })
1351
.def_property_readonly("code", [](ExecutionPlan& s) { return s.code; });
1353
py::class_<Gradient>(m, "Gradient")
1354
.def_property_readonly("f", [](Gradient& m) { return m.f; })
1355
.def_property_readonly("df", [](Gradient& m) { return m.df; })
1356
.def_property_readonly(
1357
"f_real_outputs", [](Gradient& m) { return m.f_real_outputs; })
1358
.def_property_readonly(
1359
"df_input_vjps", [](Gradient& m) { return m.df_input_vjps; })
1360
.def_property_readonly(
1361
"df_input_captured_inputs",
1362
[](Gradient& m) { return m.df_input_captured_inputs; })
1363
.def_property_readonly(
1364
"df_input_captured_outputs",
1365
[](Gradient& m) { return m.df_input_captured_outputs; })
1366
.def_property_readonly(
1367
"df_output_vjps", [](Gradient& m) { return m.df_output_vjps; });
1369
py::class_<GraphExecutorState>(m, "GraphExecutorState")
1370
.def_property_readonly(
1371
"graph", [](GraphExecutorState& s) { return s.graph; })
1372
.def_property_readonly(
1374
[](GraphExecutorState& s) { return s.execution_plans; })
1375
.def_property_readonly(
1376
"fallback", [](GraphExecutorState& s) { return s.fallback; });
1378
py::class_<PyTorchStreamWriter>(m, "PyTorchFileWriter")
1379
.def(py::init<std::string>())
1380
.def(py::init([](const py::object& buffer) {
1381
auto writer_func = [=](const void* data, size_t size) {
1382
// Writing an empty file is a noop
1386
py::gil_scoped_acquire acquire;
1387
auto memory_view = py::memoryview::from_memory(
1388
reinterpret_cast<const char*>(data), size);
1389
buffer.attr("write")(std::move(memory_view));
1392
return std::make_unique<PyTorchStreamWriter>(std::move(writer_func));
1394
.def(py::init<const std::function<size_t(const void*, size_t)>&>())
1397
[](PyTorchStreamWriter& self,
1398
const std::string& name,
1401
// Since we don't know where the data come from, we cannot
1402
// release the GIL in this overload
1403
return self.writeRecord(name, data, size);
1407
[](PyTorchStreamWriter& self,
1408
const std::string& name,
1411
// It is not clear from the doc but according to CPython own code,
1412
// it is ok to use the result of PyBytes_AsString without the GIL
1414
// https://github.com/python/cpython/blob/e2a3e4b7488aff6fdc704a0f258bc315e96c1d6e/Objects/stringlib/join.h#L67
1415
const char* data_str = PyBytes_AsString(data.ptr());
1416
py::gil_scoped_release release;
1417
return self.writeRecord(name, data_str, size);
1421
[](PyTorchStreamWriter& self,
1422
const std::string& name,
1425
// Reading Tensor data is always ok without the GIL held
1426
py::gil_scoped_release release;
1427
return self.writeRecord(
1428
name, reinterpret_cast<const char*>(data.data()), size);
1432
[](PyTorchStreamWriter& self,
1433
const std::string& name,
1437
"write_record(): Passing Storage by data pointer is deprecated and will be an error in ",
1438
"the future, please pass the Storage object instead.");
1439
return self.writeRecord(
1440
name, reinterpret_cast<const char*>(data), size);
1442
.def("write_end_of_file", &PyTorchStreamWriter::writeEndOfFile)
1443
.def("set_min_version", &PyTorchStreamWriter::setMinVersion)
1444
.def("archive_name", &PyTorchStreamWriter::archiveName)
1445
.def("serialization_id", &PyTorchStreamWriter::serializationId)
1447
"get_all_written_records",
1448
&PyTorchStreamWriter::getAllWrittenRecords);
1450
py::enum_<MobileOptimizerType>(m, "_MobileOptimizerType")
1451
.value("CONV_BN_FUSION", MobileOptimizerType::CONV_BN_FUSION)
1453
"INSERT_FOLD_PREPACK_OPS",
1454
MobileOptimizerType::INSERT_FOLD_PREPACK_OPS)
1455
.value("REMOVE_DROPOUT", MobileOptimizerType::REMOVE_DROPOUT)
1456
.value("FUSE_ADD_RELU", MobileOptimizerType::FUSE_ADD_RELU)
1458
"HOIST_CONV_PACKED_PARAMS",
1459
MobileOptimizerType::HOIST_CONV_PACKED_PARAMS)
1461
"VULKAN_AUTOMATIC_GPU_TRANSFER",
1462
MobileOptimizerType::VULKAN_AUTOMATIC_GPU_TRANSFER);
1464
// This allows PyTorchStreamReader to read from a Python buffer. It requires
1465
// that the buffer implement `seek()`, `tell()`, and `read()`.
1466
class BufferAdapter : public caffe2::serialize::ReadAdapterInterface {
1468
BufferAdapter(const py::object& buffer) : buffer_(buffer) {
1469
// Jump to the end of the buffer to get its size
1470
auto current = buffer.attr("tell")();
1471
start_offset_ = py::cast<size_t>(current);
1472
buffer.attr("seek")(current, py::module::import("os").attr("SEEK_END"));
1473
size_ = py::cast<size_t>(buffer.attr("tell")()) - start_offset_;
1474
buffer.attr("seek")(current);
1476
// If we can read directly into a buffer, do that instead of an extra copy
1477
use_readinto_ = py::hasattr(buffer, "readinto");
1480
size_t size() const override {
1484
THPObjectPtr getMemview(void* buf, size_t n) const {
1485
THPObjectPtr memview(PyMemoryView_FromMemory(
1486
reinterpret_cast<char*>(buf), n, PyBUF_WRITE));
1488
throw python_error();
1493
size_t read(uint64_t pos, void* buf, size_t n, const char* what)
1495
// Seek to desired position (NB: this has to be a Py_ssize_t or Python
1496
// throws a weird error)
1497
Py_ssize_t absolute_pos = start_offset_ + pos;
1498
buffer_.attr("seek")(absolute_pos);
1500
if (use_readinto_) {
1501
auto memview = getMemview(buf, n);
1503
PyObject_CallMethod(buffer_.ptr(), "readinto", "O", memview.get());
1505
int64_t i = static_cast<int64_t>(PyLong_AsLongLong(res));
1513
// Read bytes into `buf` from the buffer
1514
std::string bytes = py::cast<std::string>(buffer_.attr("read")(n));
1517
bytes.data() + bytes.size(),
1518
reinterpret_cast<char*>(buf));
1519
return bytes.size();
1524
size_t start_offset_;
1528
py::class_<PyTorchStreamReader, std::shared_ptr<PyTorchStreamReader>>(
1529
m, "PyTorchFileReader")
1530
.def(py::init<std::string>())
1531
.def(py::init([](const py::object& buffer) {
1532
auto adapter = std::make_unique<BufferAdapter>(buffer);
1533
return std::make_shared<PyTorchStreamReader>(std::move(adapter));
1537
[](PyTorchStreamReader& self, const std::string& key) {
1538
auto [data, size] = self.getRecord(key);
1539
return py::bytes(reinterpret_cast<const char*>(data.get()), size);
1543
[](PyTorchStreamReader& self, const std::string& key) {
1544
return self.hasRecord(key);
1547
"get_storage_from_record",
1548
[](PyTorchStreamReader& self,
1549
const std::string& key,
1551
py::object data_type_obj) {
1552
at::DataPtr data(std::get<0>(self.getRecord(key)));
1554
reinterpret_cast<THPDtype*>(data_type_obj.ptr())->scalar_type;
1556
c10::Storage storage(
1557
c10::Storage::use_byte_size_t(),
1558
numel * elementSize(scalar_type),
1560
/*allocator=*/nullptr,
1561
/*resizable=*/false);
1563
c10::make_intrusive<at::TensorImpl, at::UndefinedTensorImpl>(
1565
at::DispatchKeySet(),
1566
at::CPU(scalar_type).typeMeta());
1567
return at::Tensor(std::move(ptr));
1569
.def("serialization_id", &PyTorchStreamReader::serializationId)
1572
[](PyTorchStreamReader& self) { return self.getAllRecords(); })
1574
"get_record_offset",
1575
[](PyTorchStreamReader& self, const std::string& key) {
1576
return self.getRecordOffset(key);
1579
// Used by torch.Package to coordinate deserialization of storages across
1580
// ScriptModules and eager modules
1582
DeserializationStorageContext,
1583
std::shared_ptr<DeserializationStorageContext>>(
1584
m, "DeserializationStorageContext")
1588
[](DeserializationStorageContext& self,
1589
const std::string& name,
1590
py::object data_type_obj) {
1591
c10::Storage storage = self.getStorage(name);
1593
reinterpret_cast<THPDtype*>(data_type_obj.ptr())->scalar_type;
1595
c10::make_intrusive<at::TensorImpl, at::UndefinedTensorImpl>(
1597
at::DispatchKeySet(),
1598
at::CPU(scalar_type).typeMeta());
1600
return at::Tensor(std::move(ptr));
1604
[](DeserializationStorageContext& self,
1605
const std::string& name,
1606
const at::Tensor& tensor) {
1607
return self.addStorage(name, tensor.storage());
1609
.def("has_storage", &DeserializationStorageContext::hasStorage);
1613
[](const std::string& op_name, const std::string& overload_name) {
1615
auto symbol = Symbol::fromQualString(op_name);
1616
auto operations = getAllOperatorsFor(symbol);
1617
for (const auto& op : operations) {
1618
if (op->schema().overload_name() == overload_name) {
1619
return op->schema();
1622
throw std::runtime_error("Found no matching schema");
1623
} catch (const c10::Error& e) {
1624
auto msg = torch::get_cpp_stacktraces_enabled()
1626
: e.what_without_backtrace();
1627
throw std::runtime_error(msg);
1632
"_get_operation_overload",
1633
[](const std::string& op_name, const std::string& overload_name) {
1635
auto symbol = Symbol::fromQualString(op_name);
1636
auto operations = getAllOperatorsFor(symbol);
1637
bool allow_numbers_as_tensors = opAllowsNumbersAsTensors(symbol);
1638
for (const auto& op : operations) {
1639
if (op->schema().overload_name() == overload_name) {
1641
py::cpp_function([op, symbol, allow_numbers_as_tensors](
1642
py::args args, py::kwargs kwargs) {
1643
ToIValueAllowNumbersAsTensors g(allow_numbers_as_tensors);
1644
return _get_operation_for_overload_or_packet(
1645
{op}, symbol, args, kwargs, /*is_overload*/ true);
1647
auto func_dk = py::cpp_function(
1648
[op, symbol, allow_numbers_as_tensors](
1649
c10::DispatchKey dk_, py::args args, py::kwargs kwargs) {
1650
c10::optional<c10::DispatchKey> dk =
1651
c10::make_optional(dk_);
1652
ToIValueAllowNumbersAsTensors g(allow_numbers_as_tensors);
1653
return _get_operation_for_overload_or_packet(
1654
{op}, symbol, args, kwargs, /*is_overload*/ true, dk);
1656
return py::make_tuple(
1657
func, func_dk, py::cast(op->getTags().vec()));
1660
throw std::runtime_error("Found no matching operator overload");
1661
} catch (const c10::Error& e) {
1662
auto msg = torch::get_cpp_stacktraces_enabled()
1664
: e.what_without_backtrace();
1665
throw std::runtime_error(msg);
1670
"_jit_resolve_packet",
1671
[](const char* op_name, py::args args, py::kwargs kwargs) {
1673
auto symbol = Symbol::fromQualString(op_name);
1674
bool allow_numbers_as_tensors = opAllowsNumbersAsTensors(symbol);
1675
ToIValueAllowNumbersAsTensors g(allow_numbers_as_tensors);
1676
const auto overloads = getAllSortedOperatorsFor(symbol);
1677
auto opWithStack = getOpWithStack(overloads, args, kwargs);
1678
std::shared_ptr<Operator> overload = std::get<0>(opWithStack);
1679
auto result = overload->schema().overload_name();
1684
} catch (const c10::Error& e) {
1685
auto msg = torch::get_cpp_stacktraces_enabled()
1687
: e.what_without_backtrace();
1688
throw std::runtime_error(msg);
1693
"_jit_get_operation",
1694
[](const std::string& op_name) {
1696
auto symbol = Symbol::fromQualString(op_name);
1697
const auto sortedOps = getAllSortedOperatorsFor(symbol);
1698
if (sortedOps.empty()) {
1700
return py::make_tuple(py::none(), py::none());
1703
std::ostringstream docstring;
1704
docstring << "Automatically bound operator '" << op_name
1705
<< "' with schema(s):\n";
1707
for (const auto& op : sortedOps) {
1708
docstring << " " << op->schema() << "\n";
1711
py::list overload_names;
1712
for (const auto& op : sortedOps) {
1713
overload_names.append(py::str(op->schema().overload_name()));
1716
bool allow_numbers_as_tensors = opAllowsNumbersAsTensors(symbol);
1718
auto func = py::cpp_function(
1719
[sortedOps, symbol, allow_numbers_as_tensors](
1720
py::args args, py::kwargs kwargs) {
1721
ToIValueAllowNumbersAsTensors g(allow_numbers_as_tensors);
1722
return _get_operation_for_overload_or_packet(
1723
sortedOps, symbol, args, kwargs, false);
1725
py::name(symbol.toUnqualString()),
1726
py::doc(docstring.str().c_str()));
1727
return py::make_tuple(func, overload_names);
1728
} catch (const c10::Error& e) {
1729
auto msg = torch::get_cpp_stacktraces_enabled()
1731
: e.what_without_backtrace();
1732
throw std::runtime_error(msg);
1735
py::arg("qualified_name"));
1739
[](const std::string& input, bool parse_tensor_constants) {
1740
auto graph = std::make_shared<Graph>();
1741
parseIR(input, &*graph, parse_tensor_constants);
1745
py::arg("parse_tensor_constants") = false);
1746
m.def("parse_schema", parseSchema);
1747
m.def("unify_type_list", [](const std::vector<TypePtr>& types) {
1748
std::ostringstream s;
1749
auto type = unifyTypeList(types, s);
1751
throw std::runtime_error(s.str());
1753
return type.value();
1755
py::enum_<SchemaArgType>(m, "_SchemaArgType")
1756
.value("input", SchemaArgType::input)
1757
.value("output", SchemaArgType::output);
1758
py::class_<SchemaArgument>(m, "_SchemaArgument")
1759
.def(py::init<SchemaArgType, size_t>())
1760
.def_readwrite("type", &SchemaArgument::type)
1761
.def_readwrite("index", &SchemaArgument::index);
1762
py::class_<SchemaInfo>(m, "_SchemaInfo")
1763
.def(py::init<FunctionSchema>())
1764
.def("is_mutable", [](SchemaInfo& self) { return self.is_mutable(); })
1767
[](SchemaInfo& self, const SchemaArgument& argument) {
1768
return self.is_mutable(argument);
1772
[](SchemaInfo& self, const std::string& name) {
1773
return self.has_argument(name);
1777
[](SchemaInfo& self, const std::string& name) {
1778
return self.is_mutable(name);
1782
[](SchemaInfo& self,
1783
const SchemaArgument& lhs,
1784
const SchemaArgument& rhs) { return self.may_alias(lhs, rhs); })
1786
"may_contain_alias",
1787
[](SchemaInfo& self,
1788
const SchemaArgument& lhs,
1789
const SchemaArgument& rhs) {
1790
return self.may_contain_alias(lhs, rhs);
1793
"add_argument_value",
1794
[](SchemaInfo& self,
1795
const std::string& name,
1796
const py::object& value) {
1797
c10::optional<IValue> i_value = toTypeInferredIValueOptional(value);
1799
// For normalization purposes there is an inconsistency within
1800
// torch.fx that turns all arguments named "self" into "input".
1801
// Thus this check ensures that those arguments are checked
1803
if (name == "input" && !self.hasInputArgumentNamed("input")) {
1804
self.addArgumentValue("self", *i_value);
1806
self.addArgumentValue(name, *i_value);
1810
.def("add_argument_values", [](SchemaInfo& self, const py::dict& values) {
1811
std::unordered_map<std::string, IValue> value_map;
1812
for (const auto& key_pair : values) {
1813
IValue key = toTypeInferredIValue(key_pair.first);
1814
TORCH_INTERNAL_ASSERT(
1816
"Add argument value keys types should be strings.");
1817
c10::optional<IValue> value =
1818
toTypeInferredIValueOptional(key_pair.second);
1820
// For normalization purposes there is an inconsistency within
1822
// turns all arguments named "self" into "input". Thus this check
1823
// ensures that those arguments are checked correctly.
1824
if (key.toStringRef() == "input" &&
1825
!self.hasInputArgumentNamed("input")) {
1826
self.addArgumentValue("self", *value);
1828
value_map[key.toStringRef()] = *value;
1832
self.addArgumentValues(value_map);
1834
py::class_<FunctionSchema>(m, "FunctionSchema")
1835
.def_property_readonly(
1836
"name", [](FunctionSchema& self) { return self.name(); })
1837
.def_property_readonly(
1839
[](FunctionSchema& self) { return self.overload_name(); })
1840
.def_property_readonly(
1841
"arguments", [](FunctionSchema& self) { return self.arguments(); })
1842
.def_property_readonly(
1843
"returns", [](FunctionSchema& self) { return self.returns(); })
1845
"is_backward_compatible_with",
1846
[](const FunctionSchema& self, const FunctionSchema& old_schema) {
1847
return self.isBackwardCompatibleWith(old_schema);
1850
"check_forward_compatible_with",
1851
[](const FunctionSchema& self, const FunctionSchema& old_schema) {
1852
std::ostringstream out;
1853
auto result = self.isForwardCompatibleWith(old_schema, out);
1854
return std::make_pair(result, out.str());
1858
[](const FunctionSchema& self, const FunctionSchema& other) {
1859
return self == other;
1863
[](const FunctionSchema& self) {
1864
return std::hash<FunctionSchema>{}(self);
1868
[](FunctionSchema& self) {
1869
std::stringstream ss;
1873
.def_property_readonly(
1874
"is_mutable", [](FunctionSchema& self) { return self.is_mutable(); });
1875
py::class_<Argument>(m, "Argument")
1876
.def_property_readonly("name", [](Argument& self) { return self.name(); })
1877
.def_property_readonly("type", [](Argument& self) { return self.type(); })
1878
.def_property_readonly(
1879
"real_type", [](Argument& self) { return self.real_type(); })
1880
.def_property_readonly(
1882
[](Argument& self) -> py::object {
1883
return (self.N()) ? py::cast(*self.N()) : py::none();
1885
.def_property_readonly(
1887
[](Argument& self) -> py::object {
1888
if (!self.default_value()) {
1891
IValue v = *self.default_value();
1892
return toPyObject(std::move(v));
1895
"has_default_value",
1896
[](Argument& self) -> py::bool_ {
1897
return self.default_value().has_value();
1899
.def_property_readonly(
1900
"alias_info", [](Argument& self) { return self.alias_info(); })
1901
.def_property_readonly(
1902
"is_out", [](Argument& self) { return self.is_out(); })
1903
.def_property_readonly("kwarg_only", [](Argument& self) -> bool {
1904
return self.kwarg_only();
1906
py::class_<AliasInfo>(m, "_AliasInfo")
1907
.def_property_readonly(
1908
"is_write", [](AliasInfo& self) { return self.isWrite(); })
1909
.def_property_readonly(
1911
[](AliasInfo& self) {
1912
std::set<py::str> before_set_python;
1913
for (const auto& set : self.beforeSets()) {
1914
before_set_python.insert(py::str(set.toUnqualString()));
1916
return before_set_python;
1918
.def_property_readonly("after_set", [](AliasInfo& self) {
1919
std::set<py::str> after_set_python;
1920
for (const auto& set : self.afterSets()) {
1921
after_set_python.insert(py::str(set.toUnqualString()));
1923
return after_set_python;
1925
m.def("_jit_get_all_schemas", []() {
1926
const std::vector<std::shared_ptr<Operator>>& operations =
1928
return fmap(operations, [](const std::shared_ptr<Operator>& op) {
1929
return op->schema();
1932
m.def("_jit_get_custom_class_schemas", customClassSchemasForBCCheck);
1933
m.def("_jit_get_schemas_for_operator", [](const std::string& qualified_name) {
1934
auto symbol = Symbol::fromQualString(qualified_name);
1935
const auto& operations = getAllOperatorsFor(symbol);
1936
return fmap(operations, [](const std::shared_ptr<Operator>& op) {
1937
return op->schema();
1940
m.def("_is_tracing", []() { return jit::tracer::isTracing(); });
1942
py::class_<PythonFutureWrapper, std::shared_ptr<PythonFutureWrapper>>(
1944
.def(py::init([](std::vector<c10::Device> devices = {}) {
1945
return std::make_shared<PythonFutureWrapper>(
1946
c10::make_intrusive<c10::ivalue::Future>(
1947
PyObjectType::get(), std::move(devices)));
1951
// Intentionally not releasing GIL
1952
&PythonFutureWrapper::done)
1955
&PythonFutureWrapper::value,
1956
py::call_guard<py::gil_scoped_release>())
1959
&PythonFutureWrapper::wait,
1960
py::call_guard<py::gil_scoped_release>())
1963
&PythonFutureWrapper::then,
1964
py::call_guard<py::gil_scoped_release>())
1966
"add_done_callback",
1967
&PythonFutureWrapper::add_done_callback,
1968
py::call_guard<py::gil_scoped_release>())
1971
// Intentionally not releasing GIL
1972
&PythonFutureWrapper::markCompleted)
1975
// Intentionally not releasing GIL as this just does an assign
1976
[](PythonFutureWrapper& self, py::function unwrapFunc) {
1977
auto functionGuard =
1978
std::make_shared<torch::jit::PythonFunctionGuard>(
1979
std::move(unwrapFunc));
1981
std::function<void(py::object)> pf =
1982
[functionGuard(std::move(functionGuard))](
1983
const py::object& inp) {
1984
return functionGuard->func_(inp);
1986
self.unwrap_func = std::move(pf);
1991
[](const PythonFutureWrapper& /* unused */) {
1992
TORCH_CHECK(false, "Can not pickle torch.futures.Future");
1993
// Note that this return has no meaning since we always
1994
// throw, it's only here to satisfy Pybind API's
1996
return py::make_tuple();
1999
[](const py::tuple& /* unused */) { // NOLINT
2000
TORCH_CHECK(false, "Can not unpickle torch.futures.Future");
2001
// Note that this return has no meaning since we always
2002
// throw, it's only here to satisfy PyBind's API
2006
py::call_guard<py::gil_scoped_release>());
2008
py::class_<PythonAwaitWrapper, std::shared_ptr<PythonAwaitWrapper>>(
2012
&PythonAwaitWrapper::wait,
2013
py::call_guard<py::gil_scoped_release>())
2014
.def("fn", &PythonAwaitWrapper::fn)
2015
.def("args", &PythonAwaitWrapper::args)
2016
.def("type", &PythonAwaitWrapper::type)
2017
.def("is_nowait", &PythonAwaitWrapper::is_nowait)
2020
[](PythonAwaitWrapper& self, const std::string& name) -> py::object {
2021
// In eager mode allow Await[W] to be used as W, redirecting getattr
2022
// to the result of delayed function.
2023
return py::getattr(self.wait(), name.c_str(), py::none());
2028
[](const PythonAwaitWrapper& /* unused */) {
2029
TORCH_CHECK(false, "Can not pickle torch.jit._Await");
2030
// Note that this return has no meaning since we always
2031
// throw, it's only here to satisfy Pybind API's
2033
return py::make_tuple();
2036
[](const py::tuple& /* unused */) { // NOLINT
2037
TORCH_CHECK(false, "Can not unpickle torch.jit._Await");
2038
// Note that this return has no meaning since we always
2039
// throw, it's only here to satisfy PyBind's API
2043
py::call_guard<py::gil_scoped_release>());
2045
m.def("_is_alias_of", [](const py::object& self, const py::object& other) {
2046
c10::optional<IValue> self_value = toTypeInferredIValueOptional(self);
2047
c10::optional<IValue> other_value = toTypeInferredIValueOptional(other);
2049
// Only return true if we are certain that self and other are aliasing.
2050
if (!self_value || !other_value) {
2053
return self_value->isAliasOf(*other_value);
2055
m.def("_overlaps", [](const py::object& self, const py::object& other) {
2056
c10::optional<IValue> self_value = toTypeInferredIValueOptional(self);
2057
c10::optional<IValue> other_value = toTypeInferredIValueOptional(other);
2059
// Only return true if we are certain that self and other are overlapping.
2060
if (!self_value || !other_value) {
2063
return self_value->overlaps(*other_value);
2065
m.def("_awaitable", [](const py::args& args, const py::kwargs& kwargs) {
2066
AT_ASSERT(args.size() >= 1);
2067
py::tuple args_tup(args.size() - 1);
2068
for (const auto i : c10::irange(1, args.size())) {
2069
args_tup[i - 1] = args[i];
2071
return std::make_shared<PythonAwaitWrapper>(
2072
py::cast<py::function>(args[0]), std::move(args_tup));
2074
m.def("_awaitable_nowait", [](py::handle input) {
2075
return std::make_shared<PythonAwaitWrapper>(std::move(input));
2078
"_awaitable_wait", [](const std::shared_ptr<PythonAwaitWrapper>& py_aw) {
2079
TORCH_CHECK(py_aw, "Await can't be None");
2080
return py_aw->wait();
2082
m.def("fork", [](const py::args& args, const py::kwargs& kwargs) {
2083
AT_ASSERT(!args.empty());
2085
py::function f = py::cast<py::function>(args[0]);
2086
py::tuple args_tup(args.size() - 1);
2088
for (const auto i : c10::irange(1, args.size())) {
2089
args_tup[i - 1] = args[i];
2092
if (jit::tracer::isTracing()) {
2093
auto graph = jit::tracer::getTracingState()->graph;
2094
auto fork_node = graph->insertNode(graph->create(prim::TracedFork, 1));
2095
auto body_block = fork_node->addBlock();
2097
Value* node_output = nullptr;
2098
py::object py_func_output;
2099
// Insert new trace ops into the fork op's sub-block
2100
WithInsertPoint guard(body_block);
2101
IValue output_ivalue;
2103
tracer::WithNestedTracingFrame env_guard;
2105
// Run the user-supplied function
2106
py_func_output = f(*args_tup, **kwargs);
2108
// Convert the output of the user-supplied function to IValue. The type
2109
// information of this IValue is used both to record the correct type in
2111
output_ivalue = toTypeInferredIValue(py_func_output);
2112
Value* out_val = jit::tracer::getValueTrace(output_ivalue);
2113
body_block->registerOutput(out_val);
2115
fork_node->output()->setType(FutureType::create(out_val->type()));
2119
c10::make_intrusive<c10::ivalue::Future>(output_ivalue.type());
2121
// Record the ivalue in the tracer
2122
jit::tracer::setValueTrace(retval, node_output);
2124
// stuff the ivalue output in the Future
2125
retval->markCompleted(output_ivalue);
2127
return std::make_shared<PythonFutureWrapper>(retval);
2129
auto result = toTypeInferredIValue(f(*args_tup, **kwargs));
2130
auto retval = c10::make_intrusive<c10::ivalue::Future>(result.type());
2131
retval->markCompleted(std::move(result));
2132
return std::make_shared<PythonFutureWrapper>(retval);
2136
m.def("wait", [](const std::shared_ptr<PythonFutureWrapper>& fut) {
2137
TORCH_CHECK(fut, "Future can't be None");
2143
[](const std::vector<std::shared_ptr<jit::PythonFutureWrapper>>& futures)
2144
-> std::shared_ptr<jit::PythonFutureWrapper> {
2145
auto typePtr = futures.empty() || futures[0] == nullptr
2147
: futures[0]->fut->elementType();
2148
c10::List<c10::intrusive_ptr<c10::ivalue::Future>> asList(
2149
c10::FutureType::create(typePtr));
2150
asList.reserve(futures.size());
2151
for (const auto& f : futures) {
2152
TORCH_CHECK(f, "Future can't be None");
2153
asList.push_back(f->fut);
2155
return std::make_shared<jit::PythonFutureWrapper>(
2156
c10::collectAll(asList),
2157
/* unwrap_func */ [futures](const py::object& /*unused*/) {
2158
// Throw errors when calling wait() on the returned Future if
2159
// any of the original futures would throw.
2160
// NB: PythonFutureWrapper takes an unwrap_func which serves as a
2161
// callback to evalute the value in the Future. RPC uses this
2162
// unwrap_func to check whether the returned py::object is a
2163
// RemoteException object, and re-throw the exception if it is.
2164
// By extracting the c10::ivalue::Future from PythonFutureWrapper
2165
// the unwrap_func on the original PythonFutureWrapper objects are
2166
// discarded, and hence it will return the RemoteException as an
2167
// object instead of re-throwing it.
2168
for (auto& fut : futures) {
2173
py::call_guard<py::gil_scoped_release>());
2175
m.def("_jit_assert_is_instance", [](py::object obj, const TypePtr& type) {
2176
toIValue(std::move(obj), type);
2179
#if defined(C10_SUPPORTS_FATAL_SIGNAL_HANDLERS)
2180
m.def("_set_print_stack_traces_on_fatal_signal", [](bool print) {
2181
c10::FatalSignalHandler::getInstance().setPrintStackTracesOnFatalSignal(
2184
#endif // defined(C10_SUPPORTS_SIGNAL_HANDLER)
2186
initPythonCustomClassBindings(module);
2187
initPythonIRBindings(module);
2188
tracer::initPythonTracerBindings(module);
2189
initTreeViewBindings(module);
2190
initJitScriptBindings(module);
2191
initJitBackendBindings(module);
2192
initStaticModuleBindings(module);
2193
initTensorExprBindings(module);
2194
// initNvFuserPythonBindings(module);
2196
setPrintHandler([](const std::string& str) {
2197
py::gil_scoped_acquire acquire;
2199
auto _stdout = py::module::import("sys").attr("stdout");
2200
_stdout.attr("write")(str);
2201
} catch (py::error_already_set& e) {
2202
throw std::runtime_error(e.what());
2206
// On exit we need to reset the print handler to default one,
2207
// because otherwise prim::Print() instruction won't work for JIT modules.
2208
auto atexit = py::module_::import("atexit");
2209
atexit.attr("register")(
2210
py::cpp_function([]() { setPrintHandler(getDefaultPrintHandler()); }));
2213
} // namespace torch::jit