pytorch

Форк
0
2213 строк · 85.1 Кб
1
#include <pybind11/pytypes.h>
2
#include <torch/csrc/utils/pybind.h>
3
#include <torch/csrc/utils/python_arg_parser.h>
4
#include <torch/csrc/utils/schema_info.h>
5

6
#include <ATen/core/operator_name.h>
7
#include <torch/csrc/jit/api/module.h>
8
#include <torch/csrc/jit/backends/backend_init.h>
9
#include <torch/csrc/jit/codegen/cuda/interface.h>
10
// #include <torch/csrc/jit/codegen/cuda/python_frontend/python_bindings.h>
11
#include <torch/csrc/jit/codegen/fuser/interface.h>
12
#include <torch/csrc/jit/codegen/fuser/kernel_cache.h>
13
#if (!defined(FBCODE_CAFFE2) && defined(BUILD_ONEDNN_GRAPH))
14
#include <torch/csrc/jit/codegen/onednn/interface.h>
15
#endif
16
#include <c10/core/SymNodeImpl.h>
17
#include <torch/csrc/jit/frontend/ir_emitter.h>
18
#include <torch/csrc/jit/frontend/tracer.h>
19
#include <torch/csrc/jit/ir/irparser.h>
20
#include <torch/csrc/jit/jit_log.h>
21
#include <torch/csrc/jit/passes/autocast.h>
22
#include <torch/csrc/jit/passes/batch_mm.h>
23
#include <torch/csrc/jit/passes/canonicalize.h>
24
#include <torch/csrc/jit/passes/canonicalize_graph_fuser_ops.h>
25
#include <torch/csrc/jit/passes/common_subexpression_elimination.h>
26
#include <torch/csrc/jit/passes/constant_pooling.h>
27
#include <torch/csrc/jit/passes/constant_propagation.h>
28
#include <torch/csrc/jit/passes/create_autodiff_subgraphs.h>
29
#include <torch/csrc/jit/passes/create_functional_graphs.h>
30
#include <torch/csrc/jit/passes/dbr_quantization/remove_redundant_aliases.h>
31
#include <torch/csrc/jit/passes/dead_code_elimination.h>
32
#include <torch/csrc/jit/passes/decompose_ops.h>
33
#include <torch/csrc/jit/passes/device_type_analysis.h>
34
#include <torch/csrc/jit/passes/dtype_analysis.h>
35
#include <torch/csrc/jit/passes/erase_number_types.h>
36
#include <torch/csrc/jit/passes/fold_conv_bn.h>
37
#include <torch/csrc/jit/passes/freeze_module.h>
38
#include <torch/csrc/jit/passes/frozen_concat_linear.h>
39
#include <torch/csrc/jit/passes/frozen_conv_add_relu_fusion.h>
40
#include <torch/csrc/jit/passes/frozen_conv_folding.h>
41
#include <torch/csrc/jit/passes/frozen_graph_optimizations.h>
42
#include <torch/csrc/jit/passes/frozen_linear_folding.h>
43
#include <torch/csrc/jit/passes/frozen_linear_transpose.h>
44
#include <torch/csrc/jit/passes/frozen_ops_to_mkldnn.h>
45
#include <torch/csrc/jit/passes/fuse_linear.h>
46
#include <torch/csrc/jit/passes/fuse_relu.h>
47
#include <torch/csrc/jit/passes/graph_fuser.h>
48
#include <torch/csrc/jit/passes/inline_fork_wait.h>
49
#include <torch/csrc/jit/passes/inliner.h>
50
#include <torch/csrc/jit/passes/integer_value_refinement.h>
51
#include <torch/csrc/jit/passes/loop_unrolling.h>
52
#include <torch/csrc/jit/passes/lower_graph.h>
53
#include <torch/csrc/jit/passes/lower_tuples.h>
54
#include <torch/csrc/jit/passes/metal_rewrite.h>
55
#include <torch/csrc/jit/passes/mobile_optimizer_type.h>
56
#include <torch/csrc/jit/passes/normalize_ops.h>
57
#include <torch/csrc/jit/passes/peephole.h>
58
#include <torch/csrc/jit/passes/peephole_list_idioms.h>
59
#include <torch/csrc/jit/passes/quantization/dedup_module_uses.h>
60
#include <torch/csrc/jit/passes/quantization/finalize.h>
61
#include <torch/csrc/jit/passes/quantization/fusion_passes.h>
62
#include <torch/csrc/jit/passes/quantization/insert_observers.h>
63
#include <torch/csrc/jit/passes/quantization/insert_quant_dequant.h>
64
#include <torch/csrc/jit/passes/quantization/quantization_type.h>
65
#include <torch/csrc/jit/passes/refine_tuple_types.h>
66
#include <torch/csrc/jit/passes/remove_dropout.h>
67
#include <torch/csrc/jit/passes/remove_expands.h>
68
#include <torch/csrc/jit/passes/remove_inplace_ops.h>
69
#include <torch/csrc/jit/passes/remove_mutation.h>
70
#include <torch/csrc/jit/passes/replacement_of_old_operators.h>
71
#include <torch/csrc/jit/passes/restore_mutation.h>
72
#include <torch/csrc/jit/passes/shape_analysis.h>
73
#include <torch/csrc/jit/passes/specialize_autogradzero.h>
74
#include <torch/csrc/jit/passes/subgraph_rewrite.h>
75
#include <torch/csrc/jit/passes/symbolic_shape_analysis.h>
76
#include <torch/csrc/jit/passes/tensorexpr_fuser.h>
77
#include <torch/csrc/jit/passes/utils/check_alias_annotation.h>
78
#include <torch/csrc/jit/passes/vulkan_rewrite.h>
79
#include <torch/csrc/jit/passes/xnnpack_rewrite.h>
80
#include <torch/csrc/jit/python/pybind_utils.h>
81
#include <torch/csrc/jit/python/python_arg_flatten.h>
82
#include <torch/csrc/jit/python/python_custom_class.h>
83
#include <torch/csrc/jit/python/python_ir.h>
84
#include <torch/csrc/jit/python/python_tracer.h>
85
#include <torch/csrc/jit/python/python_tree_views.h>
86
#include <torch/csrc/jit/python/script_init.h>
87
#include <torch/csrc/jit/python/utf8_decoding_ignore.h>
88
#include <torch/csrc/jit/runtime/argument_spec.h>
89
#include <torch/csrc/jit/runtime/autodiff.h>
90
#include <torch/csrc/jit/runtime/decomposition_registry.h>
91
#include <torch/csrc/jit/runtime/graph_executor.h>
92
#include <torch/csrc/jit/runtime/jit_exception.h>
93
#include <torch/csrc/jit/runtime/jit_trace.h>
94
#include <torch/csrc/jit/runtime/operator.h>
95
#include <torch/csrc/jit/runtime/print_handler.h>
96
#include <torch/csrc/jit/runtime/static/init.h>
97
#include <torch/csrc/jit/runtime/symbolic_shape_registry.h>
98
#include <torch/csrc/jit/serialization/export.h>
99
#include <torch/csrc/jit/serialization/import.h>
100
#include <torch/csrc/jit/tensorexpr/kernel.h>
101
#include <torch/csrc/jit/tensorexpr/tensorexpr_init.h>
102
#include <torch/csrc/utils/cpp_stacktraces.h>
103

104
#include <c10/macros/Export.h>
105
#include <c10/util/irange.h>
106
#include <c10/util/signal_handler.h>
107
#include <caffe2/serialize/inline_container.h>
108

109
#include <pybind11/cast.h>
110
#include <pybind11/functional.h>
111
#include <pybind11/iostream.h>
112
#include <pybind11/operators.h>
113

114
#include <torch/csrc/jit/runtime/profiling_graph_executor_impl.h>
115
#include <memory>
116
#include <sstream>
117
#include <stdexcept>
118
#include <string>
119
#include <tuple>
120
#include <utility>
121

122
namespace torch::jit {
123

124
using c10::AliasInfo;
125
using c10::Argument;
126
using c10::FunctionSchema;
127
using c10::SchemaArgType;
128
using c10::SchemaArgument;
129
using c10::SymNode;
130
using caffe2::serialize::PyTorchStreamReader;
131
using caffe2::serialize::PyTorchStreamWriter;
132
using torch::utils::SchemaInfo;
133

134
namespace {
135

136
using autograd::variable_list;
137

138
bool loadPythonClasses() {
139
  // Leaving this code here, because it will likely be useful at some point
140
  // PyObject *jit_module = PyImport_ImportModule("torch.jit");
141
  // THPUtils_assert(jit_module, "class loader couldn't access "
142
  //"torch.jit module");
143
  // PyObject *jit_dict = PyModule_GetDict(jit_module);
144

145
  return true;
146
}
147

148
static bool opAllowsNumbersAsTensors(c10::Symbol symbol) {
149
  return symbol.is_prims() || symbol.is_nvprims() ||
150
      (symbol.is_aten() &&
151
       torch::should_allow_numbers_as_tensors(symbol.toUnqualString()));
152
}
153

154
c10::optional<IValue> toTypeInferredIValueOptional(py::handle input) {
155
  // Errors need to be caught here because toTypeInferredIValue errors out
156
  // on various object types, but we want it to work with all types.
157
  try {
158
    return toTypeInferredIValue(input);
159
  } catch (const c10::Error& e) {
160
    return c10::nullopt;
161
  }
162
}
163
} // anonymous namespace
164

165
#if !defined(USE_ROCM)
166
TORCH_API void runJITCPPTests();
167
#endif
168

169
void initJITBindings(PyObject* module) {
170
  auto m = py::handle(module).cast<py::module>();
171
  auto jit = m.def_submodule("_jit");
172

173
  // This is a static object, so we must leak the Python object
174
  // "release()" is used here to preserve 1 refcount on the
175
  // object, preventing it from ever being de-allocated by CPython.
176
  static py::handle exc =
177
      py::exception<JITException>(m, "JITException").release();
178

179
  py::register_exception_translator([](std::exception_ptr p) {
180
    try {
181
      if (p) {
182
        std::rethrow_exception(p);
183
      }
184
    } catch (const JITException& e) {
185
      // special handling of JITException, to set its python class name and msg
186
      py::gil_scoped_acquire acquire;
187
      const auto& className = e.getPythonClassName();
188
      const auto& originalMsg = e.getOriginalMsg();
189
      JITException::setCaughtOriginalMsg(originalMsg.value_or(""));
190
      JITException::setCaughtPythonClassName(className.value_or(""));
191
      // If we still had the py::exception<JITException> object, we could
192
      // just call it. But we must get a handle to leak it and there is no
193
      // way I can find to re-create it from the handle. So setting the
194
      // exception manually
195
      PyErr_SetString(exc.ptr(), e.what());
196
    }
197
  });
198

199
  m.def(
200
      "_get_caught_jit_exception_class_name",
201
      JITException::getCaughtPythonClassName);
202
  m.def(
203
      "_get_caught_jit_exception_original_msg",
204
      JITException::getCaughtOriginalMsg);
205

206
  py::class_<python::IODescriptor> iodescriptor(
207
      m,
208
      "IODescriptor"); // NOLINT(bugprone-unused-raii)
209

210
  m.def("_jit_init", loadPythonClasses)
211
      .def(
212
          "_jit_debug_fuser_num_cached_kernel_specs",
213
          torch::jit::fuser::debugNumCachedKernelSpecs)
214
      .def("_jit_pass_lower_all_tuples", LowerAllTuples)
215
      .def(
216
          "_new_symbolic_shape_symbol",
217
          []() { return c10::ShapeSymbol::newSymbol().value(); })
218
      .def(
219
          "_jit_shape_compute_graph_for_node",
220
          [](Node* n) -> c10::optional<std::shared_ptr<Graph>> {
221
            if (!n->maybeSchema()) {
222
              return c10::nullopt;
223
            }
224
            return shapeComputeGraphForSchema(n->schema());
225
          })
226
      .def(
227
          "_jit_decomposition_graph_for_node",
228
          [](Node* n) -> c10::optional<std::shared_ptr<Graph>> {
229
            if (!n->maybeSchema()) {
230
              return c10::nullopt;
231
            }
232
            return GetDecomposition(n->schema());
233
          })
234
      .def("_jit_pass_run_decompositions", RunDecompositions)
235
      // using Node* here instead of Schema because looking up the schema
236
      // and passing it in from Python will have a different pointer than the
237
      // schema that is globally used for caching
238
      .def(
239
          "_jit_register_shape_compute_graph_for_node",
240
          [](Node* n, std::shared_ptr<Graph>& graph) {
241
            if (n->maybeSchema()) {
242
              const FunctionSchema& schema = n->schema();
243
              RegisterShapeComputeGraphForSchema(schema, graph);
244
            } else {
245
              TORCH_INTERNAL_ASSERT(false, "Expected schema", n);
246
            }
247
          })
248
      .def(
249
          "_jit_register_decomposition_for_schema",
250
          [](const FunctionSchema& s, std::shared_ptr<Graph>& graph) {
251
            // because this is invoked by python, the function schema *
252
            // becomes different, and we need to find and reuse the
253
            // one that is used for caching
254
            auto op =
255
                findOperatorFor(c10::OperatorName(s.name(), s.overload_name()));
256
            RegisterDecomposition(op->schema(), graph);
257
          })
258
      .def("_jit_pass_propagate_shapes_on_graph", PropagateShapesOnGraph)
259
      .def(
260
          "_jit_pass_propagate_shapes_on_graph_and_build_compute",
261
          [](std::shared_ptr<Graph>& graph) {
262
            return PropagateShapesAndBuildLargeShapeComputeGraph(
263
                graph, *graph->nodes().begin(), *graph->nodes().end());
264
          })
265
      .def(
266
          "_jit_pass_propagate_shapes_on_graph_and_build_compute",
267
          [](std::shared_ptr<Graph>& graph, Node* beg) {
268
            return PropagateShapesAndBuildLargeShapeComputeGraph(
269
                graph, beg, *graph->nodes().end());
270
          })
271
      .def(
272
          "_jit_pass_propagate_shapes_on_graph_and_build_compute",
273
          PropagateShapesAndBuildLargeShapeComputeGraph)
274
      .def("_jit_pass_integer_value_refinement", RefineIntegerValues)
275
      .def(
276
          "_jit_set_symbolic_shapes_test_mode",
277
          &setSymbolicShapeAnalysisTestMode)
278
      .def(
279
          "_jit_symbolic_shapes_test_mode_enabled",
280
          &symbolicShapeAnalysisTestModeEnabled)
281
      .def("_jit_pass_autocast", Autocast)
282
      .def("_jit_set_autocast_mode", &setAutocastMode)
283
      .def("_jit_pass_fuse", FuseGraph)
284
      .def(
285
          "_jit_pass_replace_old_ops_with_upgraders",
286
          [](std::shared_ptr<Graph>& g) {
287
            return ReplaceOldOperatorsWithUpgraders(g);
288
          })
289
      .def(
290
          "_jit_pass_dce",
291
          [](std::shared_ptr<Graph>& g) {
292
            return EliminateDeadCode(g->block()); // overload resolution
293
          })
294
      .def(
295
          "_jit_pass_dce_allow_deleting_nodes_with_side_effects",
296
          [](std::shared_ptr<Graph>& g) {
297
            return EliminateDeadCode(
298
                g->block(),
299
                true,
300
                DCESideEffectPolicy::
301
                    ALLOW_DELETING_NODES_WITH_SIDE_EFFECTS); // overload
302
                                                             // resolution
303
          })
304
      .def(
305
          "_jit_pass_cse",
306
          [](std::shared_ptr<Graph>& g) {
307
            return EliminateCommonSubexpression(g); // overload resolution
308
          })
309
      .def(
310
          "_jit_pass_fuse_quantized_add_relu",
311
          [](std::shared_ptr<Graph>& g) {
312
            return FuseQuantizedAddRelu(g); // overload resolution
313
          })
314
      .def(
315
          "_jit_pass_insert_observers",
316
          [](Module& module,
317
             const std::string& method_name,
318
             const py::dict& qconfig_dict,
319
             bool inplace,
320
             int quant_type_int) {
321
            auto dict = py::cast<std::unordered_map<
322
                std::string,
323
                c10::optional<std::tuple<Module, Module>>>>(qconfig_dict);
324
            auto quant_type = static_cast<QuantType>(quant_type_int);
325
            return InsertObservers(
326
                module, method_name, dict, inplace, quant_type);
327
          },
328
          py::arg("module"),
329
          py::arg("method_name"),
330
          py::arg("qconfig_dict"),
331
          py::arg("inplace"),
332
          py::arg("quant_type_int") = 1)
333
      .def(
334
          "_jit_pass_insert_observer_method_for_ondevice_ptq",
335
          [](Module& module,
336
             const std::string& method_name,
337
             const py::dict& qconfig_dict,
338
             bool inplace,
339
             int quant_type_int) {
340
            auto dict = py::cast<std::unordered_map<
341
                std::string,
342
                c10::optional<std::tuple<Module, Module>>>>(qconfig_dict);
343
            auto quant_type = static_cast<QuantType>(quant_type_int);
344
            return InsertObserversForOnDevicePTQ(
345
                module, method_name, dict, inplace, quant_type);
346
          },
347
          py::arg("module"),
348
          py::arg("method_name"),
349
          py::arg("qconfig_dict"),
350
          py::arg("inplace"),
351
          py::arg("quant_type_int") = 1)
352
      .def(
353
          "_jit_pass_insert_quant_dequant",
354
          [](Module& module,
355
             const std::string& method_name,
356
             bool inplace,
357
             bool debug,
358
             int quant_type_int) {
359
            auto quant_type = static_cast<QuantType>(quant_type_int);
360
            return InsertQuantDeQuant(
361
                module, method_name, inplace, debug, quant_type);
362
          },
363
          py::arg("module"),
364
          py::arg("method_name"),
365
          py::arg("inplace"),
366
          py::arg("debug"),
367
          py::arg("quant_type_int") = 1)
368
      .def(
369
          "_jit_pass_insert_quant_dequant_for_ondevice_ptq",
370
          [](Module& module,
371
             const std::string& method_name,
372
             bool inplace,
373
             bool debug,
374
             int quant_type_int) {
375
            auto quant_type = static_cast<QuantType>(quant_type_int);
376
            return InsertQuantDeQuantOnDevicePTQ(
377
                module, method_name, inplace, debug, quant_type);
378
          },
379
          py::arg("module"),
380
          py::arg("method_name"),
381
          py::arg("inplace"),
382
          py::arg("debug"),
383
          py::arg("quant_type_int") = 1)
384
      .def(
385
          "_jit_pass_insert_prepack_unpack",
386
          [](std::shared_ptr<Graph>& g) { return InsertPrepackUnpack(g); })
387
      .def(
388
          "_jit_pass_insert_prepack_unpack",
389
          [](Module& module) { return InsertPrepackUnpack(module); })
390
      .def(
391
          "_jit_pass_quant_fusion",
392
          [](std::shared_ptr<Graph>& g) { return QuantFusion(g); })
393
      .def(
394
          "_jit_pass_fold_convbn",
395
          [](Module& module) { return FoldConvBatchNorm(module); })
396
      .def(
397
          "_jit_pass_dbr_quant_remove_redundant_aliases",
398
          [](Module& module) { return DBRQuantRemoveRedundantAliases(module); })
399
      .def(
400
          "_freeze_module",
401
          [](Module& module,
402
             std::vector<std::string>& preservedAttrs,
403
             bool freezeInterfaces,
404
             bool preserveParameters) {
405
            return freeze_module(
406
                module, preservedAttrs, freezeInterfaces, preserveParameters);
407
          },
408
          py::arg("module"),
409
          py::arg("preservedAttrs") = std::vector<std::string>(),
410
          py::arg("freezeInterfaces") = true,
411
          py::arg("preserveParameters") = false)
412
      .def("_jit_pass_concat_frozen_linear", &FrozenConcatLinear)
413
      .def("_jit_pass_fold_frozen_conv_bn", &FoldFrozenConvBatchnorm)
414
      .def("_jit_pass_fold_frozen_conv_add_or_sub", &FoldFrozenConvAddOrSub)
415
      .def("_jit_pass_fold_frozen_conv_mul_or_div", &FoldFrozenConvMulOrDiv)
416
      .def("_jit_pass_fold_frozen_linear_bn", &FoldFrozenLinearBatchnorm)
417
      .def("_jit_pass_convert_frozen_ops_to_mkldnn", &ConvertFrozenOpsToMKLDNN)
418
      .def("_jit_pass_fuse_frozen_conv_add_relu", &FuseFrozenConvAddRelu)
419
      .def("_jit_pass_transpose_frozen_linear", &FrozenLinearTranspose)
420
      .def("_jit_pass_optimize_frozen_graph", &OptimizeFrozenGraph)
421
      .def(
422
          "_jit_pass_optimize_for_inference",
423
          [](Module& module, std::vector<std::string> other_methods) {
424
            optimize_for_inference(module, other_methods);
425
          },
426
          py::arg("module"),
427
          py::arg("other_methods") = std::vector<std::string>())
428
      .def("_jit_pass_fuse_linear", &FuseLinear)
429
      .def(
430
          "_jit_pass_fuse_add_relu",
431
          [](std::shared_ptr<Graph>& graph) { FuseAddRelu(graph); })
432
      .def("_jit_pass_dedup_module_uses", &DedupModuleUses)
433
      .def("_jit_pass_replicate_dequantize", &ReplicateDeQuant)
434
      .def(
435
          "_jit_pass_swap_functional_linear",
436
          [](std::shared_ptr<Graph>& graph) { SwapFunctionalLinear(graph); })
437
      .def(
438
          "_jit_pass_swap_functional_linear",
439
          [](Module& module) { SwapFunctionalLinear(module); })
440
      .def(
441
          "_jit_pass_quant_finalize",
442
          [](Module& module,
443
             int quant_type_int,
444
             const std::vector<std::string>& preserved_attrs) {
445
            auto quant_type = static_cast<QuantType>(quant_type_int);
446
            return Finalize(module, quant_type, preserved_attrs);
447
          },
448
          py::arg("module"),
449
          py::arg("quant_type_int") = 1,
450
          py::arg("preserved_attrs") = std::vector<std::string>())
451
      .def(
452
          "_jit_pass_quant_finalize_for_ondevice_ptq",
453
          [](Module& module,
454
             int quant_type_int,
455
             const std::string& method_name) {
456
            auto quant_type = static_cast<QuantType>(quant_type_int);
457
            return FinalizeOnDevicePTQ(module, quant_type, method_name);
458
          },
459
          py::arg("module"),
460
          py::arg("quant_type_int") = 1,
461
          py::arg("preserved_attrs") = std::vector<std::string>())
462
      .def(
463
          "_jit_pass_pattern_based_rewrite",
464
          [](const Module& m) { return PatternBasedRewrite(m); })
465
      .def(
466
          "_jit_pass_custom_pattern_based_rewrite",
467
          [](const std::string& pattern,
468
             const std::string& fused_node_name,
469
             const Module& m) {
470
            SubgraphRewriter subgraph_rewriter;
471
            subgraph_rewriter.RegisterRewritePattern(pattern, fused_node_name);
472
            subgraph_rewriter.runOnModule(m);
473
          })
474
      .def(
475
          "_jit_pass_custom_pattern_based_rewrite_graph",
476
          [](const std::string& pattern,
477
             const std::string& fused_node_name,
478
             std::shared_ptr<Graph> g,
479
             const std::vector<std::pair<std::string, std::string>>&
480
                 value_name_pairs) {
481
            SubgraphRewriter subgraph_rewriter;
482
            subgraph_rewriter.RegisterRewritePattern(
483
                pattern, fused_node_name, value_name_pairs);
484
            subgraph_rewriter.runOnGraph(g);
485
          },
486
          py::arg("pattern"),
487
          py::arg("fused_node_name"),
488
          py::arg("g"),
489
          py::arg("value_name_pairs") =
490
              std::vector<std::pair<std::string, std::string>>())
491
      .def("_jit_pass_constant_pooling", ConstantPooling)
492
      // RemoveInplaceOps is used by CoreML so it must be removed with care.
493
      .def("_jit_pass_propagate_dtype", DtypePropagation)
494
      .def("_jit_pass_propagate_device", DeviceTypePropagation)
495
      .def(
496
          "_jit_pass_remove_inplace_ops",
497
          [](const std::shared_ptr<Graph>& g) { return RemoveInplaceOps(g); })
498
      .def(
499
          "_jit_pass_create_functional_graphs",
500
          [](std::shared_ptr<Graph>& g) { return CreateFunctionalGraphs(g); })
501
      .def(
502
          "_jit_pass_remove_mutation",
503
          [](std::shared_ptr<Graph>& g) {
504
            RemoveListMutation(g);
505
            return RemoveTensorMutation(g);
506
          })
507
      .def(
508
          "_jit_pass_functional_to_inplace_activation",
509
          [](std::shared_ptr<Graph>& g) {
510
            return FunctionalToInplaceActivation(g);
511
          })
512
      .def(
513
          "_jit_pass_inplace_to_functional_activation",
514
          [](std::shared_ptr<Graph>& g) {
515
            return InplaceToFunctionalActivation(g);
516
          })
517
      .def(
518
          "_jit_pass_inline_functional_graphs",
519
          [](std::shared_ptr<Graph>& g) { return InlineFunctionalGraphs(g); })
520
      .def(
521
          "_jit_pass_peephole",
522
          [](const std::shared_ptr<Graph>& g, bool disable_shape_peepholes) {
523
            return PeepholeOptimize(g, disable_shape_peepholes);
524
          },
525
          py::arg("graph"),
526
          py::arg("disable_shape_peepholes") = false)
527
      .def(
528
          "_jit_pass_peephole_list_idioms",
529
          [](const std::shared_ptr<Graph>& g, bool refine_list_len) {
530
            return PeepholeOptimizeListIdioms(g, refine_list_len);
531
          },
532
          py::arg("graph"),
533
          py::arg("refine_list_len") = false)
534
      .def(
535
          "_jit_pass_refine_integer_values",
536
          [](std::shared_ptr<Graph>& g) { return RefineIntegerValues(g); })
537
      .def(
538
          "_jit_pass_fuse_addmm",
539
          [](std::shared_ptr<Graph>& g) { return FuseAddMM(g); })
540
      .def(
541
          "_jit_pass_canonicalize",
542
          [](const std::shared_ptr<Graph>& g, bool keep_unique_names = true) {
543
            return Canonicalize(g, keep_unique_names);
544
          },
545
          py::arg("graph"),
546
          py::arg("keep_unique_names") = true)
547
      .def("_jit_pass_lint", LintGraph)
548
      .def(
549
          "_jit_pass_complete_shape_analysis",
550
          [](const std::shared_ptr<Graph>& graph,
551
             const py::tuple& inputs,
552
             bool with_grad) {
553
            ArgumentSpecCreator arg_spec_creator(*graph);
554
            Stack stack;
555
            stack.reserve(inputs.size()); // captures?
556
            for (auto& obj : inputs) {
557
              stack.push_back(toTypeInferredIValue(obj));
558
            }
559
            ArgumentSpec spec = arg_spec_creator.create(with_grad, stack);
560
            arg_spec_creator.specializeTypes(*graph, spec);
561
            // We only get partial specialization from the arg_spec_creator, but
562
            // we want full shape specialization. The alternative would be to
563
            // have a "complete type inference" function in ArguemntSpecCreator.
564
            auto g_inputs = graph->inputs();
565
            for (const auto i : c10::irange(inputs.size())) {
566
              if (stack[i].isTensor()) {
567
                g_inputs[i]->setType(stack[i].type());
568
              }
569
            }
570
            PropagateInputShapes(graph);
571
          })
572
      .def(
573
          "_jit_interpret_graph",
574
          [](std::shared_ptr<Graph>& graph, const py::tuple& inputs) {
575
            Stack stack;
576
            stack.reserve(inputs.size()); // captures?
577
            for (auto& obj : inputs) {
578
              stack.push_back(toTypeInferredIValue(obj));
579
            }
580
            auto g_inputs = graph->inputs();
581
            for (const auto i : c10::irange(inputs.size())) {
582
              if (stack[i].isTensor()) {
583
                g_inputs[i]->setType(stack[i].type());
584
              }
585
            }
586
            Code code(graph, "<on-demand-func>");
587
            InterpreterState(code).run(stack);
588
            return createPyObjectForStack(std::move(stack));
589
          },
590
          py::doc(
591
              "Interpret a JIT graph with given inputs without running any optimization passes on it"))
592
      .def(
593
          "_jit_trace_graph",
594
          [](std::shared_ptr<Graph>& graph, const py::tuple& inputs) {
595
            Stack stack;
596
            stack.reserve(inputs.size()); // captures?
597
            for (auto& obj : inputs) {
598
              stack.push_back(toTypeInferredIValue(obj));
599
            }
600
            auto g_inputs = graph->inputs();
601
            for (const auto i : c10::irange(inputs.size())) {
602
              if (stack[i].isTensor()) {
603
                g_inputs[i]->setType(stack[i].type());
604
              }
605
            }
606
            return TraceGraph(graph, stack);
607
          })
608
      .def(
609
          "_jit_trace_module",
610
          [](Module& model, const py::tuple& inputs) {
611
            auto graph = model.get_method("forward").graph();
612
            Stack stack;
613
            stack.reserve(inputs.size() + 1); // captures?
614
            push(stack, model._ivalue());
615
            for (auto& obj : inputs) {
616
              stack.push_back(toTypeInferredIValue(obj));
617
            }
618
            auto traced = TraceGraph(graph, stack);
619
            GRAPH_DUMP("Traced Graph", traced);
620

621
            // the easiest way to replace a graph in a module is
622
            // to remove all the nodes in the original graph
623
            // clone everything from the traced one
624
            graph->block()->clear();
625
            graph->block()->cloneFrom(traced->block(), nullptr);
626
            GRAPH_DUMP("Copied Graph", graph);
627
          })
628
      .def("_jit_pass_remove_expands", RemoveExpands)
629
      .def("_jit_pass_erase_number_types", EraseNumberTypes)
630
      .def("_jit_pass_inline_fork_wait", InlineForkWait)
631
      .def("_jit_pass_inline", Inline)
632
      .def(
633
          "_jit_pass_lower_graph",
634
          [](std::shared_ptr<Graph>& graph, const Module& self) {
635
            return LowerGraph(*graph, self._ivalue());
636
          })
637
      .def("_jit_pass_loop_unrolling", UnrollLoops)
638
      .def("_jit_pass_constant_loop_unrolling", UnrollConstantLoops)
639
      .def(
640
          "_jit_pass_constant_propagation_immutable_types",
641
          [](std::shared_ptr<Graph>& g) {
642
            return ConstantPropagationImmutableTypes(g);
643
          })
644
      .def(
645
          "_jit_pass_constant_propagation",
646
          [](std::shared_ptr<Graph>& g) { return ConstantPropagation(g); },
647
          py::arg("graph"))
648
      .def("_jit_pass_erase_shape_information", EraseShapeInformation)
649
      .def(
650
          "_jit_object_is_non_holding",
651
          [](Node& n) {
652
            return toIValue(n.output())->toObject()->is_weak_compilation_ref();
653
          })
654
      .def(
655
          "_jit_erase_non_input_shape_information",
656
          [](std::shared_ptr<Graph>& g) {
657
            std::vector<TypePtr> input_types;
658
            for (Value* v : g->inputs()) {
659
              if (auto tt = v->type()->cast<TensorType>()) {
660
                input_types.emplace_back(tt);
661
              } else {
662
                input_types.emplace_back(nullptr);
663
              }
664
            }
665
            EraseShapeInformation(g);
666
            for (size_t i = 0; i < input_types.size(); ++i) {
667
              if (input_types[i]) {
668
                g->inputs().at(i)->setType(input_types[i]);
669
              }
670
            }
671
          })
672
      .def(
673
          "_jit_pass_create_autodiff_subgraphs",
674
          [](const std::shared_ptr<Graph>& graph, py::object threshold) {
675
            if (threshold.is_none()) {
676
              CreateAutodiffSubgraphs(graph);
677
            } else {
678
              CreateAutodiffSubgraphs(graph, py::cast<int>(threshold));
679
            }
680
          },
681
          py::arg("graph"),
682
          py::arg("threshold") = py::none())
683
#if defined(BUILDING_TESTS) && !defined(USE_ROCM)
684
      .def(
685
          "_jit_run_cpp_tests",
686
          []() {
687
            // We have to release the GIL inside this method, because if we
688
            // happen to initialize the autograd engine in these tests, the
689
            // newly spawned worker threads will try to initialize their
690
            // PyThreadState*, and they need the GIL for this.
691
            pybind11::gil_scoped_release _no_gil;
692
            return runJITCPPTests();
693
          })
694
      .def("_jit_has_cpp_tests", []() { return true; })
695
      .def("_has_tensorexpr_cpp_tests", []() { return true; })
696
#else
697
      .def("_jit_run_cpp_tests", []() { throw std::exception(); })
698
      .def("_jit_has_cpp_tests", []() { return false; })
699
      .def("_run_tensorexpr_cpp_tests", []() { throw std::exception(); })
700
      .def("_has_tensorexpr_cpp_tests", []() { return false; })
701
#endif
702
      .def(
703
          "_jit_flatten",
704
          [](py::handle& obj) {
705
            auto res = python::flatten(obj);
706
            return std::make_pair(res.vars, res.desc);
707
          })
708
      .def(
709
          "_jit_unflatten",
710
          [](const autograd::variable_list& vars, python::IODescriptor& desc) {
711
            return py::reinterpret_steal<py::object>(
712
                python::unflatten(vars, desc));
713
          })
714
      .def("_jit_pass_canonicalize_graph_fuser_ops", CanonicalizeOps)
715
      .def("_jit_pass_decompose_ops", DecomposeOps)
716
      .def("_jit_pass_specialize_autogradzero", specializeAutogradZero)
717
      .def("_jit_override_can_fuse_on_cpu", &overrideCanFuseOnCPU)
718
      .def("_jit_override_can_fuse_on_gpu", &overrideCanFuseOnGPU)
719
      .def("_jit_can_fuse_on_cpu", &canFuseOnCPU)
720
      .def("_jit_can_fuse_on_gpu", &canFuseOnGPU)
721
      .def("_jit_can_fuse_on_cpu_legacy", &canFuseOnCPULegacy)
722
      .def("_jit_override_can_fuse_on_cpu_legacy", &overrideCanFuseOnCPULegacy)
723
      .def(
724
          "_jit_differentiate",
725
          [](Graph& g) {
726
            // the python binding slightly differs in semantics
727
            // it makes a copy of the input Graph, and works on that
728
            // jit::differentiate mutates the input Graph
729
            auto g_clone = g.copy();
730
            return differentiate(g_clone);
731
          })
732
      .def(
733
          "_jit_check_alias_annotation",
734
          [](const std::shared_ptr<Graph>& g,
735
             const py::tuple& args,
736
             const std::string& unqualified_op_name) {
737
            auto stack = toTraceableStack(args);
738
            checkAliasAnnotation(g, std::move(stack), unqualified_op_name);
739
          })
740
#if (!defined(FBCODE_CAFFE2) && defined(BUILD_ONEDNN_GRAPH))
741
      .def("_jit_set_llga_enabled", &RegisterLlgaFuseGraph::setEnabled)
742
      .def("_jit_llga_enabled", &RegisterLlgaFuseGraph::isEnabled)
743
#else
744
      .def("_jit_set_llga_enabled", [](bool flag) { return false; })
745
      .def("_jit_llga_enabled", []() { return false; })
746
#endif
747
      .def(
748
          "_jit_set_tracer_state_warn",
749
          [](bool new_warn) {
750
            jit::tracer::getTracerStateWarnMode() = new_warn;
751
          })
752
      .def(
753
          "_jit_get_tracer_state_warn",
754
          []() {
755
            bool current_tracer_warn = jit::tracer::getTracerStateWarnMode();
756
            return current_tracer_warn;
757
          })
758
      .def(
759
          "_jit_set_nvfuser_skip_node_kind",
760
          [](const std::string& op_name, bool flip = true) {
761
            TORCH_WARN(
762
                "nvfuser is no longer supported in torch script, use _jit_set_nvfuser_skip_node_kind is deprecated and a no-op");
763
          })
764
      .def(
765
          "_jit_set_nvfuser_enabled",
766
          [](bool) {
767
            TORCH_WARN(
768
                "nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op");
769
          })
770
      .def(
771
          "_jit_nvfuser_can_be_enabled",
772
          []() {
773
            TORCH_WARN(
774
                "nvfuser is no longer supported in torch script, use _jit_nvfuser_can_be_enabled is deprecated and a no-op");
775
          })
776
      .def(
777
          "_jit_set_nvfuser_single_node_mode",
778
          [](bool) {
779
            TORCH_WARN(
780
                "nvfuser is no longer supported in torch script, use _jit_set_nvfuser_single_node_mode is deprecated and a no-op");
781
          })
782
      .def(
783
          "_jit_nvfuser_single_node_mode",
784
          []() {
785
            TORCH_WARN(
786
                "nvfuser is no longer supported in torch script, use _jit_nvfuser_single_node_mode is deprecated and a no-op");
787
          })
788
      .def(
789
          "_jit_set_nvfuser_horizontal_mode",
790
          [](bool) {
791
            TORCH_WARN(
792
                "nvfuser is no longer supported in torch script, use _jit_set_nvfuser_horizontal_mode is deprecated and a no-op");
793
          })
794
      .def(
795
          "_jit_nvfuser_horizontal_mode",
796
          []() {
797
            TORCH_WARN(
798
                "nvfuser is no longer supported in torch script, use _jit_nvfuser_horizontal_mode is deprecated and a no-op");
799
          })
800
      .def(
801
          "_jit_set_nvfuser_guard_mode",
802
          [](bool) {
803
            TORCH_WARN(
804
                "nvfuser is no longer supported in torch script, use _jit_set_nvfuser_guard_mode is deprecated and a no-op");
805
          })
806
      .def(
807
          "_jit_nvfuser_enabled",
808
          []() {
809
            TORCH_WARN(
810
                "nvfuser is no longer supported in torch script, use _jit_nvfuser_enabled is deprecated and a no-op");
811
          })
812
      .def(
813
          "_jit_nvfuser_set_comparison_callback",
814
          [](bool, py::function) {
815
            TORCH_WARN(
816
                "nvfuser is no longer supported in torch script, use _jit_nvfuser_set_comparison_callback is deprecated and a no-op");
817
          })
818
      .def(
819
          "_jit_nvfuser_clear_comparison_callback",
820
          []() {
821
            TORCH_WARN(
822
                "nvfuser is no longer supported in torch script, use _jit_nvfuser_clear_comparison_callback is deprecated and a no-op");
823
          })
824
      .def(
825
          "_jit_set_profiling_mode",
826
          [](bool profiling_flag) {
827
            bool oldState = getProfilingMode();
828
            getProfilingMode() = profiling_flag;
829
            return oldState;
830
          })
831
      .def(
832
          "_jit_set_profiling_executor",
833
          [](bool profiling_flag) {
834
            bool oldState = getExecutorMode();
835
            getExecutorMode() = profiling_flag;
836
            return oldState;
837
          })
838
      .def(
839
          "_jit_set_num_profiled_runs",
840
          [](size_t num) {
841
            size_t old_num = getNumProfiledRuns();
842
            getNumProfiledRuns() = num;
843
            return old_num;
844
          })
845
      .def(
846
          "_jit_get_num_profiled_runs",
847
          [] {
848
            // pybind can't automatically bind to atomic size_t
849
            size_t num_runs = getNumProfiledRuns();
850
            return num_runs;
851
          })
852
      .def(
853
          "_jit_set_bailout_depth",
854
          [](size_t depth) {
855
            TORCH_WARN(
856
                "Use _jit_set_fusion_strategy, bailout depth is deprecated. Setting to (STATIC, ",
857
                depth,
858
                ")");
859
            size_t old_depth = getBailoutDepth();
860
            FusionStrategy strat = {{FusionBehavior::STATIC, depth}};
861
            setFusionStrategy(strat);
862
            return old_depth;
863
          })
864
      .def(
865
          "_jit_set_fusion_strategy",
866
          [](std::vector<std::pair<std::string, size_t>> strategy) {
867
            FusionStrategy vec_conv;
868
            for (const auto& pair : strategy) {
869
              if (pair.first == "STATIC") {
870
                vec_conv.emplace_back(FusionBehavior::STATIC, pair.second);
871
              } else if (pair.first == "DYNAMIC") {
872
                vec_conv.emplace_back(FusionBehavior::DYNAMIC, pair.second);
873
              } else {
874
                TORCH_INTERNAL_ASSERT(
875
                    false,
876
                    "FusionBehavior only supported 'STATIC' or 'DYNAMIC', got: ",
877
                    pair.first);
878
              }
879
            }
880
            auto old_strategy = getFusionStrategy();
881
            auto strat =
882
                fmap(old_strategy, [](std::pair<FusionBehavior, size_t> behav) {
883
                  return std::pair<std::string, size_t>(
884
                      behav.first == FusionBehavior::STATIC ? "STATIC"
885
                                                            : "DYNAMIC",
886
                      behav.second);
887
                });
888
            setFusionStrategy(vec_conv);
889
            return strat;
890
          })
891
      .def(
892
          "_jit_set_inline_everything_mode",
893
          [](bool enabled) { getInlineEverythingMode() = enabled; })
894
      .def(
895
          "_jit_get_inline_everything_mode",
896
          []() { return getInlineEverythingMode(); })
897
      .def(
898
          "_jit_get_logging_option",
899
          []() { return ::torch::jit::get_jit_logging_levels(); })
900
      .def(
901
          "_jit_set_logging_option",
902
          [](std::string loggingOption) -> void {
903
            ::torch::jit::set_jit_logging_levels(loggingOption);
904
          })
905
      .def(
906
          "_jit_set_logging_stream",
907
          [](std::string stream_name) -> void {
908
            if (stream_name == "stdout") {
909
              ::torch::jit::set_jit_logging_output_stream(std::cout);
910
            } else if (stream_name == "stderr") {
911
              ::torch::jit::set_jit_logging_output_stream(std::cerr);
912
            } else {
913
              std::cerr << "ERROR: only `stdout` and `stderr`"
914
                        << "are supported as output options" << std::endl;
915
            }
916
          })
917
      .def(
918
          "_storage_id",
919
          [](const at::Tensor& ten) -> int64_t {
920
            return reinterpret_cast<int64_t>(
921
                ten.storage().unsafeGetStorageImpl());
922
          })
923
      .def(
924
          "_jit_try_infer_type",
925
          [](py::object obj) -> InferredType {
926
            return tryToInferType(std::move(obj));
927
          })
928
      .def(
929
          "_jit_get_te_cuda_pointwise_loop_levels",
930
          []() -> int {
931
            using namespace torch::jit::tensorexpr;
932
            return getTECudaPointwiseLoopLevels();
933
          })
934
      .def(
935
          "_jit_set_te_cuda_pointwise_loop_levels",
936
          [](int level) {
937
            using namespace torch::jit::tensorexpr;
938
            return getTECudaPointwiseLoopLevels() = level;
939
          })
940
      .def(
941
          "_jit_get_te_cuda_pointwise_block_count",
942
          []() -> int {
943
            using namespace torch::jit::tensorexpr;
944
            return getTECudaPointwiseBlockCount();
945
          })
946
      .def(
947
          "_jit_set_te_cuda_pointwise_block_count",
948
          [](int block_count) {
949
            using namespace torch::jit::tensorexpr;
950
            return getTECudaPointwiseBlockCount() = block_count;
951
          })
952
      .def(
953
          "_jit_get_te_cuda_pointwise_block_size",
954
          []() -> int {
955
            using namespace torch::jit::tensorexpr;
956
            return getTECudaPointwiseBlockSize();
957
          })
958
      .def(
959
          "_jit_set_te_cuda_pointwise_block_size",
960
          [](int block_size) {
961
            using namespace torch::jit::tensorexpr;
962
            return getTECudaPointwiseBlockSize() = block_size;
963
          })
964
      .def("_jit_set_texpr_fuser_enabled", &setTensorExprFuserEnabled)
965
      .def("_jit_texpr_fuser_enabled", &tensorExprFuserEnabled)
966
      .def("_jit_texpr_fallback_allowed", &tensorexpr::fallbackAllowed)
967
      .def("_jit_texpr_set_fallback_allowed", &tensorexpr::setFallbackAllowed)
968
      .def("_jit_set_texpr_reductions_enabled", &setTexprReductionsEnabled)
969
      .def(
970
          "_jit_set_texpr_dynamic_shape_enabled",
971
          &setTensorExprDynamicShapeFusionEnabled)
972
      .def(
973
          "_jit_texpr_dynamic_shape_enabled",
974
          &tensorExprDynamicShapeFusionEnabled)
975
      .def("_jit_texpr_reductions_enabled", &texprReductionsEnabled)
976
      .def(
977
          "_jit_set_te_generate_block_code",
978
          [](bool gen_block_code) {
979
            using namespace torch::jit::tensorexpr;
980
            return getTEGenerateBlockCode() = gen_block_code;
981
          })
982
      .def(
983
          "_jit_get_te_generate_block_code",
984
          []() -> bool {
985
            using namespace torch::jit::tensorexpr;
986
            return getTEGenerateBlockCode();
987
          })
988
      .def(
989
          "_jit_get_te_must_use_llvm_cpu",
990
          []() -> bool {
991
            using namespace torch::jit::tensorexpr;
992
            return getTEMustUseLLVMOnCPU();
993
          })
994
      .def(
995
          "_jit_set_te_must_use_llvm_cpu",
996
          [](bool use_llvm) {
997
            using namespace torch::jit::tensorexpr;
998
            getTEMustUseLLVMOnCPU() = use_llvm;
999
          })
1000
      .def(
1001
          "_jit_cat_wo_conditionals",
1002
          [](bool optimize_cat) {
1003
            using namespace torch::jit::tensorexpr;
1004
            getCatWoConditionals() = optimize_cat;
1005
          })
1006
      .def(
1007
          "_jit_opt_conditionals",
1008
          [](bool opt_conds) {
1009
            using namespace torch::jit::tensorexpr;
1010
            getOptConditionals() = opt_conds;
1011
          })
1012
      .def(
1013
          "_llvm_enabled",
1014
          []() {
1015
#ifdef TORCH_ENABLE_LLVM
1016
            return true;
1017
#else
1018
            return false;
1019
#endif
1020
          })
1021
      .def(
1022
          "_jit_pass_fuse_tensorexprs",
1023
          [](std::shared_ptr<Graph>& g) {
1024
            FuseTensorExprs(g);
1025
            RemoveTensorTypeSpecializations(g);
1026
          })
1027
      .def(
1028
          "_jit_fuser_get_fused_kernel_code",
1029
          [](Graph& g, const std::vector<at::Tensor>& inps) {
1030
            return debugGetFusedKernelCode(g, inps);
1031
          })
1032
      .def(
1033
          "_jit_pass_remove_dropout",
1034
          [](script::Module& module) { return removeDropout(module); })
1035
      .def(
1036
          "_jit_pass_refine_tuple_types",
1037
          [](std::shared_ptr<Graph>& graph) { return RefineTupleTypes(graph); })
1038
      .def(
1039
          "_jit_pass_transform_conv1d_to_conv2d",
1040
          [](std::shared_ptr<Graph>& graph) {
1041
            return transformConv1dToConv2d(graph);
1042
          })
1043
      .def(
1044
          "_jit_pass_transform_conv1d_to_conv2d",
1045
          [](script::Module& module) {
1046
            return transformConv1dToConv2d(module);
1047
          })
1048
      .def(
1049
          "_jit_pass_insert_prepacked_ops",
1050
          [](std::shared_ptr<Graph>& graph) {
1051
            return insertPrePackedOps(graph);
1052
          })
1053
      .def(
1054
          "_jit_pass_insert_prepacked_ops",
1055
          [](script::Module& module) { return insertPrePackedOps(module); })
1056
      .def(
1057
          "_jit_pass_fuse_clamp_w_prepacked_linear_conv",
1058
          [](script::Module& module) {
1059
            return fusePrePackedLinearConvWithClamp(module);
1060
          })
1061
      .def(
1062
          "_jit_pass_fold_prepacking_ops",
1063
          [](script::Module& module) { return FoldPrePackingOps(module); })
1064
      .def(
1065
          "_jit_pass_optimize_for_mobile",
1066
          [](script::Module& module,
1067
             std::set<MobileOptimizerType>& optimization_blocklist,
1068
             std::vector<std::string>& preserved_methods) {
1069
            return optimizeForMobile(
1070
                module, optimization_blocklist, preserved_methods);
1071
          })
1072
      .def(
1073
          "_hack_do_not_use_clone_module_with_class",
1074
          [](script::Module& module,
1075
             std::vector<std::string>& ignored_methods,
1076
             std::vector<std::string>& ignored_attributes) {
1077
            const bool inplace = false;
1078
            const std::unordered_set<std::string> ignored_methods_set(
1079
                ignored_methods.begin(), ignored_methods.end());
1080
            const std::unordered_set<std::string> ignored_attributes_set(
1081
                ignored_attributes.begin(), ignored_attributes.end());
1082
            return module.clone(
1083
                inplace, ignored_methods_set, ignored_attributes_set);
1084
          })
1085
      .def(
1086
          "_jit_pass_vulkan_insert_prepacked_ops",
1087
          [](std::shared_ptr<Graph>& graph) {
1088
            return vulkanInsertPrePackedOps(graph);
1089
          })
1090
      .def(
1091
          "_jit_pass_vulkan_insert_prepacked_ops",
1092
          [](script::Module& module) {
1093
            return vulkanInsertPrePackedOps(module);
1094
          })
1095
      .def(
1096
          "_jit_pass_vulkan_fuse_clamp_w_prepacked_conv",
1097
          [](script::Module& module) {
1098
            return vulkanFusePrePackedConvWithClamp(module);
1099
          })
1100
      .def(
1101
          "_jit_pass_vulkan_fold_prepacking_ops",
1102
          [](script::Module& module) {
1103
            return vulkanFoldPrePackingOps(module);
1104
          })
1105
      .def(
1106
          "_jit_pass_vulkan_optimize_for_mobile",
1107
          [](script::Module& module,
1108
             std::set<MobileOptimizerType>& optimization_blocklist,
1109
             std::vector<std::string>& preserved_methods) {
1110
            return vulkanOptimizeForMobile(
1111
                module, optimization_blocklist, preserved_methods);
1112
          })
1113
      .def(
1114
          "_jit_pass_metal_insert_prepacked_ops",
1115
          [](std::shared_ptr<Graph>& graph) {
1116
            return metalInsertPrePackedOps(graph);
1117
          })
1118
      .def(
1119
          "_jit_pass_metal_insert_prepacked_ops",
1120
          [](script::Module& module) {
1121
            return metalInsertPrePackedOps(module);
1122
          })
1123
      .def(
1124
          "_jit_pass_metal_fuse_clamp_w_prepacked_conv",
1125
          [](script::Module& module) {
1126
            return metalFusePrePackedConvWithClamp(module);
1127
          })
1128
      .def(
1129
          "_jit_pass_metal_fold_prepacking_ops",
1130
          [](script::Module& module) { return metalFoldPrePackingOps(module); })
1131
      .def(
1132
          "_jit_pass_metal_optimize_for_mobile",
1133
          [](script::Module& module,
1134
             std::vector<std::string>& preserved_methods) {
1135
            return metalOptimizeForMobile(module, preserved_methods);
1136
          })
1137
      .def(
1138
          "_jit_pass_filter_non_tensor_arguments",
1139
          [](std::map<std::string, IValue> params) {
1140
            std::map<std::string, at::Tensor> retval;
1141
            for (auto& kv : params) {
1142
              if (kv.second.isTensor()) {
1143
                retval[kv.first] = std::move(kv.second).toTensor();
1144
              }
1145
            }
1146
            return retval;
1147
          })
1148
      .def("_jit_pass_batch_mm", BatchMM)
1149
      .def(
1150
          "_jit_decay_packed_param_input_types",
1151
          [](Graph& g) {
1152
            for (Value* i : g.inputs()) {
1153
              if (i->type() ==
1154
                      getCustomClass(
1155
                          "__torch__.torch.classes.quantized.Conv2dPackedParamsBase") ||
1156
                  i->type() ==
1157
                      getCustomClass(
1158
                          "__torch__.torch.classes.quantized.Conv3dPackedParamsBase") ||
1159
                  i->type() ==
1160
                      getCustomClass(
1161
                          "__torch__.torch.classes.quantized.LinearPackedParamsBase")) {
1162
                // Dummy CompleteTensorType to appease ONNX validator.
1163
                i->setType(TensorType::create(
1164
                    at::kQInt8,
1165
                    c10::kCPU,
1166
                    std::vector<int64_t>{1},
1167
                    std::vector<int64_t>{1},
1168
                    c10::nullopt));
1169
              }
1170
            }
1171
          })
1172
      .def("_jit_set_utf8_decoding_ignore", &setUTF8DecodingIgnore);
1173

1174
  // NB: This isn't actually used for regular PyTorch symbolic tracing;
1175
  // XLA is what needs this
1176
#define SYMNODE_UNARY(n) .def(#n, [](c10::SymNode a) { return a->n(); })
1177
#define SYMNODE_BINARY(n) \
1178
  .def(#n, [](c10::SymNode a, c10::SymNode b) { return a->n(b); })
1179
#define SYMNODE_SIZES_STRIDES(n)                \
1180
  .def(                                         \
1181
      #n,                                       \
1182
      [](c10::SymNode a,                        \
1183
         c10::ArrayRef<c10::SymNode> sizes,     \
1184
         c10::ArrayRef<c10::SymNode> strides) { \
1185
        return a->n(sizes, strides);            \
1186
      })
1187
  auto symnode_class =
1188
      py::class_<c10::SymNodeImpl, c10::SymNode>(m, "_SymNode")
1189
      // clang-format off
1190
      // These DO NOT install magic methods; the SymInt/SymFloat wrapper in
1191
      // Python is responsible for this
1192
      SYMNODE_UNARY(clone)
1193
      SYMNODE_UNARY(is_int)
1194
      SYMNODE_UNARY(is_float)
1195
      SYMNODE_UNARY(is_bool)
1196
      SYMNODE_UNARY(bool_)
1197
      SYMNODE_UNARY(int_)
1198
      SYMNODE_UNARY(sym_float)
1199
      SYMNODE_BINARY(add)
1200
      SYMNODE_BINARY(sub)
1201
      SYMNODE_BINARY(mul)
1202
      SYMNODE_BINARY(truediv)
1203
      SYMNODE_BINARY(pow)
1204
      SYMNODE_BINARY(floordiv)
1205
      SYMNODE_BINARY(mod)
1206
      SYMNODE_BINARY(eq)
1207
      SYMNODE_BINARY(ne)
1208
      SYMNODE_BINARY(gt)
1209
      SYMNODE_BINARY(lt)
1210
      SYMNODE_BINARY(le)
1211
      SYMNODE_BINARY(ge)
1212
      SYMNODE_BINARY(sym_min)
1213
      SYMNODE_BINARY(sym_max)
1214
      SYMNODE_BINARY(sym_and)
1215
      SYMNODE_BINARY(sym_or)
1216
      SYMNODE_UNARY(sym_not)
1217
      SYMNODE_UNARY(ceil)
1218
      SYMNODE_UNARY(floor)
1219
      SYMNODE_UNARY(neg)
1220
      SYMNODE_SIZES_STRIDES(is_contiguous)
1221
      SYMNODE_SIZES_STRIDES(is_channels_last_contiguous_2d)
1222
      SYMNODE_SIZES_STRIDES(is_channels_last_contiguous_3d)
1223
      SYMNODE_SIZES_STRIDES(is_channels_last_strides_2d)
1224
      SYMNODE_SIZES_STRIDES(is_channels_last_strides_3d)
1225
      SYMNODE_SIZES_STRIDES(is_non_overlapping_and_dense)
1226
      .def(
1227
          "guard_int",
1228
          [](c10::SymNode a, const char* file, int64_t line) {
1229
            return a->guard_int(file, line);
1230
          })
1231
      .def(
1232
          "guard_bool",
1233
          [](c10::SymNode a, const char* file, int64_t line) {
1234
            return a->guard_bool(file, line);
1235
          })
1236
      .def(
1237
          "guard_float",
1238
          [](c10::SymNode a, const char* file, int64_t line) {
1239
            return a->guard_float(file, line);
1240
          })
1241
      .def(
1242
          "expect_true",
1243
          [](c10::SymNode a, const char* file, int64_t line) {
1244
            return a->expect_true(file, line);
1245
          })
1246
      .def(
1247
          "expect_size",
1248
          [](c10::SymNode a, const char* file, int64_t line) {
1249
            return a->expect_size(file, line);
1250
          })
1251
      .def(
1252
          "guard_size_oblivious",
1253
          [](c10::SymNode a, const char* file, int64_t line) {
1254
            return a->guard_size_oblivious(file, line);
1255
          })
1256
      .def(
1257
          "has_hint",
1258
          [](c10::SymNode a) {
1259
            return a->has_hint();
1260
          })
1261
      .def(
1262
          "wrap_int",
1263
          [](c10::SymNode a, int64_t b) {
1264
            return a->wrap_int(b);
1265
          })
1266
      .def(
1267
          "wrap_float",
1268
          [](c10::SymNode a, double b) {
1269
            return a->wrap_float(b);
1270
          })
1271
      .def(
1272
          "wrap_bool",
1273
          [](c10::SymNode a, bool b) {
1274
            return a->wrap_bool(b);
1275
          })
1276
      .def(
1277
          "__str__",
1278
          [](c10::SymNode a) { return a->str(); })
1279
      .def(
1280
          "__repr__",
1281
          [](c10::SymNode a) { return a->str(); })
1282
      .def(
1283
          "is_constant",
1284
          [](const c10::SymNode& node){
1285
            return node->is_constant();
1286
          })
1287
      .def(
1288
          "is_nested_int",
1289
          [](const c10::SymNode& node) {
1290
            return node->is_nested_int();
1291
          })
1292
      .def(
1293
          "is_symbolic",
1294
          [](const c10::SymNode& node) {
1295
            return node->is_symbolic();
1296
          })
1297
      .def(
1298
          "nested_int",
1299
          [](const c10::SymNode& node) {
1300
            return node->nested_int();
1301
          })
1302
      .def(
1303
          "nested_int_coeff",
1304
          [](const c10::SymNode& node) {
1305
            return node->nested_int_coeff();
1306
          });
1307

1308
  // clang-format on
1309

1310
  // NOLINTNEXTLINE(bugprone-unused-raii)
1311
  py::class_<CompleteArgumentSpec>(m, "CompleteArgumentSpec")
1312
      .def("__repr__", [](CompleteArgumentSpec& self) {
1313
        std::ostringstream s;
1314
        s << self;
1315
        return s.str();
1316
      });
1317
  // NOLINTNEXTLINE(bugprone-unused-raii)
1318
  py::class_<ArgumentSpec>(m, "ArgumentSpec");
1319
  py::class_<Code>(m, "Code")
1320
      .def(
1321
          "grad_executor_states",
1322
          [](Code& c) {
1323
            std::vector<GraphExecutorState> states;
1324
            for (auto& e : c.grad_executors()) {
1325
              states.emplace_back(e->getDebugState());
1326
            }
1327
            return states;
1328
          })
1329
      .def(
1330
          "differentiable_op_executor_states",
1331
          [](Code& c) {
1332
            std::vector<GraphExecutorState> states;
1333
            for (auto& e : c.diff_graph_op_executors()) {
1334
              if (e->isOptimized()) {
1335
                states.emplace_back(e->getDebugState());
1336
              } else {
1337
                // we leave an empty entry for node that doesn't have an
1338
                // optimized plan
1339
                states.emplace_back();
1340
              }
1341
            }
1342
            return states;
1343
          })
1344
      .def("num_bailouts", [](Code& c) { return c.num_bailouts(); })
1345
      .def("request_bailout", [](Code& c, size_t index) {
1346
        c.request_bailout(index);
1347
      });
1348

1349
  py::class_<ExecutionPlan>(m, "ExecutionPlan")
1350
      .def_property_readonly("graph", [](ExecutionPlan& s) { return s.graph; })
1351
      .def_property_readonly("code", [](ExecutionPlan& s) { return s.code; });
1352

1353
  py::class_<Gradient>(m, "Gradient")
1354
      .def_property_readonly("f", [](Gradient& m) { return m.f; })
1355
      .def_property_readonly("df", [](Gradient& m) { return m.df; })
1356
      .def_property_readonly(
1357
          "f_real_outputs", [](Gradient& m) { return m.f_real_outputs; })
1358
      .def_property_readonly(
1359
          "df_input_vjps", [](Gradient& m) { return m.df_input_vjps; })
1360
      .def_property_readonly(
1361
          "df_input_captured_inputs",
1362
          [](Gradient& m) { return m.df_input_captured_inputs; })
1363
      .def_property_readonly(
1364
          "df_input_captured_outputs",
1365
          [](Gradient& m) { return m.df_input_captured_outputs; })
1366
      .def_property_readonly(
1367
          "df_output_vjps", [](Gradient& m) { return m.df_output_vjps; });
1368

1369
  py::class_<GraphExecutorState>(m, "GraphExecutorState")
1370
      .def_property_readonly(
1371
          "graph", [](GraphExecutorState& s) { return s.graph; })
1372
      .def_property_readonly(
1373
          "execution_plans",
1374
          [](GraphExecutorState& s) { return s.execution_plans; })
1375
      .def_property_readonly(
1376
          "fallback", [](GraphExecutorState& s) { return s.fallback; });
1377

1378
  py::class_<PyTorchStreamWriter>(m, "PyTorchFileWriter")
1379
      .def(py::init<std::string>())
1380
      .def(py::init([](const py::object& buffer) {
1381
        auto writer_func = [=](const void* data, size_t size) {
1382
          // Writing an empty file is a noop
1383
          if (size == 0) {
1384
            return size;
1385
          }
1386
          py::gil_scoped_acquire acquire;
1387
          auto memory_view = py::memoryview::from_memory(
1388
              reinterpret_cast<const char*>(data), size);
1389
          buffer.attr("write")(std::move(memory_view));
1390
          return size;
1391
        };
1392
        return std::make_unique<PyTorchStreamWriter>(std::move(writer_func));
1393
      }))
1394
      .def(py::init<const std::function<size_t(const void*, size_t)>&>())
1395
      .def(
1396
          "write_record",
1397
          [](PyTorchStreamWriter& self,
1398
             const std::string& name,
1399
             const char* data,
1400
             size_t size) {
1401
            // Since we don't know where the data come from, we cannot
1402
            // release the GIL in this overload
1403
            return self.writeRecord(name, data, size);
1404
          })
1405
      .def(
1406
          "write_record",
1407
          [](PyTorchStreamWriter& self,
1408
             const std::string& name,
1409
             py::bytes data,
1410
             size_t size) {
1411
            // It is not clear from the doc but according to CPython own code,
1412
            // it is ok to use the result of PyBytes_AsString without the GIL
1413
            // being held
1414
            // https://github.com/python/cpython/blob/e2a3e4b7488aff6fdc704a0f258bc315e96c1d6e/Objects/stringlib/join.h#L67
1415
            const char* data_str = PyBytes_AsString(data.ptr());
1416
            py::gil_scoped_release release;
1417
            return self.writeRecord(name, data_str, size);
1418
          })
1419
      .def(
1420
          "write_record",
1421
          [](PyTorchStreamWriter& self,
1422
             const std::string& name,
1423
             c10::Storage data,
1424
             size_t size) {
1425
            // Reading Tensor data is always ok without the GIL held
1426
            py::gil_scoped_release release;
1427
            return self.writeRecord(
1428
                name, reinterpret_cast<const char*>(data.data()), size);
1429
          })
1430
      .def(
1431
          "write_record",
1432
          [](PyTorchStreamWriter& self,
1433
             const std::string& name,
1434
             uintptr_t data,
1435
             size_t size) {
1436
            TORCH_WARN_ONCE(
1437
                "write_record(): Passing Storage by data pointer is deprecated and will be an error in ",
1438
                "the future, please pass the Storage object instead.");
1439
            return self.writeRecord(
1440
                name, reinterpret_cast<const char*>(data), size);
1441
          })
1442
      .def("write_end_of_file", &PyTorchStreamWriter::writeEndOfFile)
1443
      .def("set_min_version", &PyTorchStreamWriter::setMinVersion)
1444
      .def("archive_name", &PyTorchStreamWriter::archiveName)
1445
      .def("serialization_id", &PyTorchStreamWriter::serializationId)
1446
      .def(
1447
          "get_all_written_records",
1448
          &PyTorchStreamWriter::getAllWrittenRecords);
1449

1450
  py::enum_<MobileOptimizerType>(m, "_MobileOptimizerType")
1451
      .value("CONV_BN_FUSION", MobileOptimizerType::CONV_BN_FUSION)
1452
      .value(
1453
          "INSERT_FOLD_PREPACK_OPS",
1454
          MobileOptimizerType::INSERT_FOLD_PREPACK_OPS)
1455
      .value("REMOVE_DROPOUT", MobileOptimizerType::REMOVE_DROPOUT)
1456
      .value("FUSE_ADD_RELU", MobileOptimizerType::FUSE_ADD_RELU)
1457
      .value(
1458
          "HOIST_CONV_PACKED_PARAMS",
1459
          MobileOptimizerType::HOIST_CONV_PACKED_PARAMS)
1460
      .value(
1461
          "VULKAN_AUTOMATIC_GPU_TRANSFER",
1462
          MobileOptimizerType::VULKAN_AUTOMATIC_GPU_TRANSFER);
1463

1464
  // This allows PyTorchStreamReader to read from a Python buffer. It requires
1465
  // that the buffer implement `seek()`, `tell()`, and `read()`.
1466
  class BufferAdapter : public caffe2::serialize::ReadAdapterInterface {
1467
   public:
1468
    BufferAdapter(const py::object& buffer) : buffer_(buffer) {
1469
      // Jump to the end of the buffer to get its size
1470
      auto current = buffer.attr("tell")();
1471
      start_offset_ = py::cast<size_t>(current);
1472
      buffer.attr("seek")(current, py::module::import("os").attr("SEEK_END"));
1473
      size_ = py::cast<size_t>(buffer.attr("tell")()) - start_offset_;
1474
      buffer.attr("seek")(current);
1475

1476
      // If we can read directly into a buffer, do that instead of an extra copy
1477
      use_readinto_ = py::hasattr(buffer, "readinto");
1478
    }
1479

1480
    size_t size() const override {
1481
      return size_;
1482
    }
1483

1484
    THPObjectPtr getMemview(void* buf, size_t n) const {
1485
      THPObjectPtr memview(PyMemoryView_FromMemory(
1486
          reinterpret_cast<char*>(buf), n, PyBUF_WRITE));
1487
      if (!memview) {
1488
        throw python_error();
1489
      }
1490
      return memview;
1491
    }
1492

1493
    size_t read(uint64_t pos, void* buf, size_t n, const char* what)
1494
        const override {
1495
      // Seek to desired position (NB: this has to be a Py_ssize_t or Python
1496
      // throws a weird error)
1497
      Py_ssize_t absolute_pos = start_offset_ + pos;
1498
      buffer_.attr("seek")(absolute_pos);
1499

1500
      if (use_readinto_) {
1501
        auto memview = getMemview(buf, n);
1502
        auto res =
1503
            PyObject_CallMethod(buffer_.ptr(), "readinto", "O", memview.get());
1504
        if (res) {
1505
          int64_t i = static_cast<int64_t>(PyLong_AsLongLong(res));
1506
          Py_DECREF(res);
1507
          if (i > 0) {
1508
            return i;
1509
          }
1510
        }
1511
      }
1512

1513
      // Read bytes into `buf` from the buffer
1514
      std::string bytes = py::cast<std::string>(buffer_.attr("read")(n));
1515
      std::copy(
1516
          bytes.data(),
1517
          bytes.data() + bytes.size(),
1518
          reinterpret_cast<char*>(buf));
1519
      return bytes.size();
1520
    }
1521

1522
    py::object buffer_;
1523
    size_t size_;
1524
    size_t start_offset_;
1525
    bool use_readinto_;
1526
  };
1527

1528
  py::class_<PyTorchStreamReader, std::shared_ptr<PyTorchStreamReader>>(
1529
      m, "PyTorchFileReader")
1530
      .def(py::init<std::string>())
1531
      .def(py::init([](const py::object& buffer) {
1532
        auto adapter = std::make_unique<BufferAdapter>(buffer);
1533
        return std::make_shared<PyTorchStreamReader>(std::move(adapter));
1534
      }))
1535
      .def(
1536
          "get_record",
1537
          [](PyTorchStreamReader& self, const std::string& key) {
1538
            auto [data, size] = self.getRecord(key);
1539
            return py::bytes(reinterpret_cast<const char*>(data.get()), size);
1540
          })
1541
      .def(
1542
          "has_record",
1543
          [](PyTorchStreamReader& self, const std::string& key) {
1544
            return self.hasRecord(key);
1545
          })
1546
      .def(
1547
          "get_storage_from_record",
1548
          [](PyTorchStreamReader& self,
1549
             const std::string& key,
1550
             size_t numel,
1551
             py::object data_type_obj) {
1552
            at::DataPtr data(std::get<0>(self.getRecord(key)));
1553
            auto scalar_type =
1554
                reinterpret_cast<THPDtype*>(data_type_obj.ptr())->scalar_type;
1555

1556
            c10::Storage storage(
1557
                c10::Storage::use_byte_size_t(),
1558
                numel * elementSize(scalar_type),
1559
                std::move(data),
1560
                /*allocator=*/nullptr,
1561
                /*resizable=*/false);
1562
            auto ptr =
1563
                c10::make_intrusive<at::TensorImpl, at::UndefinedTensorImpl>(
1564
                    std::move(storage),
1565
                    at::DispatchKeySet(),
1566
                    at::CPU(scalar_type).typeMeta());
1567
            return at::Tensor(std::move(ptr));
1568
          })
1569
      .def("serialization_id", &PyTorchStreamReader::serializationId)
1570
      .def(
1571
          "get_all_records",
1572
          [](PyTorchStreamReader& self) { return self.getAllRecords(); })
1573
      .def(
1574
          "get_record_offset",
1575
          [](PyTorchStreamReader& self, const std::string& key) {
1576
            return self.getRecordOffset(key);
1577
          });
1578

1579
  // Used by torch.Package to coordinate deserialization of storages across
1580
  // ScriptModules and eager modules
1581
  py::class_<
1582
      DeserializationStorageContext,
1583
      std::shared_ptr<DeserializationStorageContext>>(
1584
      m, "DeserializationStorageContext")
1585
      .def(py::init<>())
1586
      .def(
1587
          "get_storage",
1588
          [](DeserializationStorageContext& self,
1589
             const std::string& name,
1590
             py::object data_type_obj) {
1591
            c10::Storage storage = self.getStorage(name);
1592
            auto scalar_type =
1593
                reinterpret_cast<THPDtype*>(data_type_obj.ptr())->scalar_type;
1594
            auto ptr =
1595
                c10::make_intrusive<at::TensorImpl, at::UndefinedTensorImpl>(
1596
                    std::move(storage),
1597
                    at::DispatchKeySet(),
1598
                    at::CPU(scalar_type).typeMeta());
1599

1600
            return at::Tensor(std::move(ptr));
1601
          })
1602
      .def(
1603
          "add_storage",
1604
          [](DeserializationStorageContext& self,
1605
             const std::string& name,
1606
             const at::Tensor& tensor) {
1607
            return self.addStorage(name, tensor.storage());
1608
          })
1609
      .def("has_storage", &DeserializationStorageContext::hasStorage);
1610

1611
  m.def(
1612
      "_get_schema",
1613
      [](const std::string& op_name, const std::string& overload_name) {
1614
        try {
1615
          auto symbol = Symbol::fromQualString(op_name);
1616
          auto operations = getAllOperatorsFor(symbol);
1617
          for (const auto& op : operations) {
1618
            if (op->schema().overload_name() == overload_name) {
1619
              return op->schema();
1620
            }
1621
          }
1622
          throw std::runtime_error("Found no matching schema");
1623
        } catch (const c10::Error& e) {
1624
          auto msg = torch::get_cpp_stacktraces_enabled()
1625
              ? e.what()
1626
              : e.what_without_backtrace();
1627
          throw std::runtime_error(msg);
1628
        }
1629
      });
1630

1631
  m.def(
1632
      "_get_operation_overload",
1633
      [](const std::string& op_name, const std::string& overload_name) {
1634
        try {
1635
          auto symbol = Symbol::fromQualString(op_name);
1636
          auto operations = getAllOperatorsFor(symbol);
1637
          bool allow_numbers_as_tensors = opAllowsNumbersAsTensors(symbol);
1638
          for (const auto& op : operations) {
1639
            if (op->schema().overload_name() == overload_name) {
1640
              auto func =
1641
                  py::cpp_function([op, symbol, allow_numbers_as_tensors](
1642
                                       py::args args, py::kwargs kwargs) {
1643
                    ToIValueAllowNumbersAsTensors g(allow_numbers_as_tensors);
1644
                    return _get_operation_for_overload_or_packet(
1645
                        {op}, symbol, args, kwargs, /*is_overload*/ true);
1646
                  });
1647
              auto func_dk = py::cpp_function(
1648
                  [op, symbol, allow_numbers_as_tensors](
1649
                      c10::DispatchKey dk_, py::args args, py::kwargs kwargs) {
1650
                    c10::optional<c10::DispatchKey> dk =
1651
                        c10::make_optional(dk_);
1652
                    ToIValueAllowNumbersAsTensors g(allow_numbers_as_tensors);
1653
                    return _get_operation_for_overload_or_packet(
1654
                        {op}, symbol, args, kwargs, /*is_overload*/ true, dk);
1655
                  });
1656
              return py::make_tuple(
1657
                  func, func_dk, py::cast(op->getTags().vec()));
1658
            }
1659
          }
1660
          throw std::runtime_error("Found no matching operator overload");
1661
        } catch (const c10::Error& e) {
1662
          auto msg = torch::get_cpp_stacktraces_enabled()
1663
              ? e.what()
1664
              : e.what_without_backtrace();
1665
          throw std::runtime_error(msg);
1666
        }
1667
      });
1668

1669
  m.def(
1670
      "_jit_resolve_packet",
1671
      [](const char* op_name, py::args args, py::kwargs kwargs) {
1672
        try {
1673
          auto symbol = Symbol::fromQualString(op_name);
1674
          bool allow_numbers_as_tensors = opAllowsNumbersAsTensors(symbol);
1675
          ToIValueAllowNumbersAsTensors g(allow_numbers_as_tensors);
1676
          const auto overloads = getAllSortedOperatorsFor(symbol);
1677
          auto opWithStack = getOpWithStack(overloads, args, kwargs);
1678
          std::shared_ptr<Operator> overload = std::get<0>(opWithStack);
1679
          auto result = overload->schema().overload_name();
1680
          if (result == "") {
1681
            result = "default";
1682
          }
1683
          return result;
1684
        } catch (const c10::Error& e) {
1685
          auto msg = torch::get_cpp_stacktraces_enabled()
1686
              ? e.what()
1687
              : e.what_without_backtrace();
1688
          throw std::runtime_error(msg);
1689
        }
1690
      });
1691

1692
  m.def(
1693
      "_jit_get_operation",
1694
      [](const std::string& op_name) {
1695
        try {
1696
          auto symbol = Symbol::fromQualString(op_name);
1697
          const auto sortedOps = getAllSortedOperatorsFor(symbol);
1698
          if (sortedOps.empty()) {
1699
            // No such operator
1700
            return py::make_tuple(py::none(), py::none());
1701
          }
1702

1703
          std::ostringstream docstring;
1704
          docstring << "Automatically bound operator '" << op_name
1705
                    << "' with schema(s):\n";
1706

1707
          for (const auto& op : sortedOps) {
1708
            docstring << "  " << op->schema() << "\n";
1709
          }
1710

1711
          py::list overload_names;
1712
          for (const auto& op : sortedOps) {
1713
            overload_names.append(py::str(op->schema().overload_name()));
1714
          }
1715

1716
          bool allow_numbers_as_tensors = opAllowsNumbersAsTensors(symbol);
1717

1718
          auto func = py::cpp_function(
1719
              [sortedOps, symbol, allow_numbers_as_tensors](
1720
                  py::args args, py::kwargs kwargs) {
1721
                ToIValueAllowNumbersAsTensors g(allow_numbers_as_tensors);
1722
                return _get_operation_for_overload_or_packet(
1723
                    sortedOps, symbol, args, kwargs, false);
1724
              },
1725
              py::name(symbol.toUnqualString()),
1726
              py::doc(docstring.str().c_str()));
1727
          return py::make_tuple(func, overload_names);
1728
        } catch (const c10::Error& e) {
1729
          auto msg = torch::get_cpp_stacktraces_enabled()
1730
              ? e.what()
1731
              : e.what_without_backtrace();
1732
          throw std::runtime_error(msg);
1733
        }
1734
      },
1735
      py::arg("qualified_name"));
1736

1737
  m.def(
1738
      "parse_ir",
1739
      [](const std::string& input, bool parse_tensor_constants) {
1740
        auto graph = std::make_shared<Graph>();
1741
        parseIR(input, &*graph, parse_tensor_constants);
1742
        return graph;
1743
      },
1744
      py::arg("input"),
1745
      py::arg("parse_tensor_constants") = false);
1746
  m.def("parse_schema", parseSchema);
1747
  m.def("unify_type_list", [](const std::vector<TypePtr>& types) {
1748
    std::ostringstream s;
1749
    auto type = unifyTypeList(types, s);
1750
    if (!type) {
1751
      throw std::runtime_error(s.str());
1752
    }
1753
    return type.value();
1754
  });
1755
  py::enum_<SchemaArgType>(m, "_SchemaArgType")
1756
      .value("input", SchemaArgType::input)
1757
      .value("output", SchemaArgType::output);
1758
  py::class_<SchemaArgument>(m, "_SchemaArgument")
1759
      .def(py::init<SchemaArgType, size_t>())
1760
      .def_readwrite("type", &SchemaArgument::type)
1761
      .def_readwrite("index", &SchemaArgument::index);
1762
  py::class_<SchemaInfo>(m, "_SchemaInfo")
1763
      .def(py::init<FunctionSchema>())
1764
      .def("is_mutable", [](SchemaInfo& self) { return self.is_mutable(); })
1765
      .def(
1766
          "is_mutable",
1767
          [](SchemaInfo& self, const SchemaArgument& argument) {
1768
            return self.is_mutable(argument);
1769
          })
1770
      .def(
1771
          "has_argument",
1772
          [](SchemaInfo& self, const std::string& name) {
1773
            return self.has_argument(name);
1774
          })
1775
      .def(
1776
          "is_mutable",
1777
          [](SchemaInfo& self, const std::string& name) {
1778
            return self.is_mutable(name);
1779
          })
1780
      .def(
1781
          "may_alias",
1782
          [](SchemaInfo& self,
1783
             const SchemaArgument& lhs,
1784
             const SchemaArgument& rhs) { return self.may_alias(lhs, rhs); })
1785
      .def(
1786
          "may_contain_alias",
1787
          [](SchemaInfo& self,
1788
             const SchemaArgument& lhs,
1789
             const SchemaArgument& rhs) {
1790
            return self.may_contain_alias(lhs, rhs);
1791
          })
1792
      .def(
1793
          "add_argument_value",
1794
          [](SchemaInfo& self,
1795
             const std::string& name,
1796
             const py::object& value) {
1797
            c10::optional<IValue> i_value = toTypeInferredIValueOptional(value);
1798
            if (i_value) {
1799
              // For normalization purposes there is an inconsistency within
1800
              // torch.fx that turns all arguments named "self" into "input".
1801
              // Thus this check ensures that those arguments are checked
1802
              // correctly.
1803
              if (name == "input" && !self.hasInputArgumentNamed("input")) {
1804
                self.addArgumentValue("self", *i_value);
1805
              } else {
1806
                self.addArgumentValue(name, *i_value);
1807
              }
1808
            }
1809
          })
1810
      .def("add_argument_values", [](SchemaInfo& self, const py::dict& values) {
1811
        std::unordered_map<std::string, IValue> value_map;
1812
        for (const auto& key_pair : values) {
1813
          IValue key = toTypeInferredIValue(key_pair.first);
1814
          TORCH_INTERNAL_ASSERT(
1815
              key.isString(),
1816
              "Add argument value keys types should be strings.");
1817
          c10::optional<IValue> value =
1818
              toTypeInferredIValueOptional(key_pair.second);
1819
          if (value) {
1820
            // For normalization purposes there is an inconsistency within
1821
            // torch.fx that
1822
            // turns all arguments named "self" into "input". Thus this check
1823
            // ensures that those arguments are checked correctly.
1824
            if (key.toStringRef() == "input" &&
1825
                !self.hasInputArgumentNamed("input")) {
1826
              self.addArgumentValue("self", *value);
1827
            } else {
1828
              value_map[key.toStringRef()] = *value;
1829
            }
1830
          }
1831
        }
1832
        self.addArgumentValues(value_map);
1833
      });
1834
  py::class_<FunctionSchema>(m, "FunctionSchema")
1835
      .def_property_readonly(
1836
          "name", [](FunctionSchema& self) { return self.name(); })
1837
      .def_property_readonly(
1838
          "overload_name",
1839
          [](FunctionSchema& self) { return self.overload_name(); })
1840
      .def_property_readonly(
1841
          "arguments", [](FunctionSchema& self) { return self.arguments(); })
1842
      .def_property_readonly(
1843
          "returns", [](FunctionSchema& self) { return self.returns(); })
1844
      .def(
1845
          "is_backward_compatible_with",
1846
          [](const FunctionSchema& self, const FunctionSchema& old_schema) {
1847
            return self.isBackwardCompatibleWith(old_schema);
1848
          })
1849
      .def(
1850
          "check_forward_compatible_with",
1851
          [](const FunctionSchema& self, const FunctionSchema& old_schema) {
1852
            std::ostringstream out;
1853
            auto result = self.isForwardCompatibleWith(old_schema, out);
1854
            return std::make_pair(result, out.str());
1855
          })
1856
      .def(
1857
          "__eq__",
1858
          [](const FunctionSchema& self, const FunctionSchema& other) {
1859
            return self == other;
1860
          })
1861
      .def(
1862
          "__hash__",
1863
          [](const FunctionSchema& self) {
1864
            return std::hash<FunctionSchema>{}(self);
1865
          })
1866
      .def(
1867
          "__str__",
1868
          [](FunctionSchema& self) {
1869
            std::stringstream ss;
1870
            ss << self;
1871
            return ss.str();
1872
          })
1873
      .def_property_readonly(
1874
          "is_mutable", [](FunctionSchema& self) { return self.is_mutable(); });
1875
  py::class_<Argument>(m, "Argument")
1876
      .def_property_readonly("name", [](Argument& self) { return self.name(); })
1877
      .def_property_readonly("type", [](Argument& self) { return self.type(); })
1878
      .def_property_readonly(
1879
          "real_type", [](Argument& self) { return self.real_type(); })
1880
      .def_property_readonly(
1881
          "N",
1882
          [](Argument& self) -> py::object {
1883
            return (self.N()) ? py::cast(*self.N()) : py::none();
1884
          })
1885
      .def_property_readonly(
1886
          "default_value",
1887
          [](Argument& self) -> py::object {
1888
            if (!self.default_value()) {
1889
              return py::none();
1890
            }
1891
            IValue v = *self.default_value();
1892
            return toPyObject(std::move(v));
1893
          })
1894
      .def(
1895
          "has_default_value",
1896
          [](Argument& self) -> py::bool_ {
1897
            return self.default_value().has_value();
1898
          })
1899
      .def_property_readonly(
1900
          "alias_info", [](Argument& self) { return self.alias_info(); })
1901
      .def_property_readonly(
1902
          "is_out", [](Argument& self) { return self.is_out(); })
1903
      .def_property_readonly("kwarg_only", [](Argument& self) -> bool {
1904
        return self.kwarg_only();
1905
      });
1906
  py::class_<AliasInfo>(m, "_AliasInfo")
1907
      .def_property_readonly(
1908
          "is_write", [](AliasInfo& self) { return self.isWrite(); })
1909
      .def_property_readonly(
1910
          "before_set",
1911
          [](AliasInfo& self) {
1912
            std::set<py::str> before_set_python;
1913
            for (const auto& set : self.beforeSets()) {
1914
              before_set_python.insert(py::str(set.toUnqualString()));
1915
            }
1916
            return before_set_python;
1917
          })
1918
      .def_property_readonly("after_set", [](AliasInfo& self) {
1919
        std::set<py::str> after_set_python;
1920
        for (const auto& set : self.afterSets()) {
1921
          after_set_python.insert(py::str(set.toUnqualString()));
1922
        }
1923
        return after_set_python;
1924
      });
1925
  m.def("_jit_get_all_schemas", []() {
1926
    const std::vector<std::shared_ptr<Operator>>& operations =
1927
        getAllOperators();
1928
    return fmap(operations, [](const std::shared_ptr<Operator>& op) {
1929
      return op->schema();
1930
    });
1931
  });
1932
  m.def("_jit_get_custom_class_schemas", customClassSchemasForBCCheck);
1933
  m.def("_jit_get_schemas_for_operator", [](const std::string& qualified_name) {
1934
    auto symbol = Symbol::fromQualString(qualified_name);
1935
    const auto& operations = getAllOperatorsFor(symbol);
1936
    return fmap(operations, [](const std::shared_ptr<Operator>& op) {
1937
      return op->schema();
1938
    });
1939
  });
1940
  m.def("_is_tracing", []() { return jit::tracer::isTracing(); });
1941

1942
  py::class_<PythonFutureWrapper, std::shared_ptr<PythonFutureWrapper>>(
1943
      m, "Future")
1944
      .def(py::init([](std::vector<c10::Device> devices = {}) {
1945
        return std::make_shared<PythonFutureWrapper>(
1946
            c10::make_intrusive<c10::ivalue::Future>(
1947
                PyObjectType::get(), std::move(devices)));
1948
      }))
1949
      .def(
1950
          "done",
1951
          // Intentionally not releasing GIL
1952
          &PythonFutureWrapper::done)
1953
      .def(
1954
          "value",
1955
          &PythonFutureWrapper::value,
1956
          py::call_guard<py::gil_scoped_release>())
1957
      .def(
1958
          "wait",
1959
          &PythonFutureWrapper::wait,
1960
          py::call_guard<py::gil_scoped_release>())
1961
      .def(
1962
          "then",
1963
          &PythonFutureWrapper::then,
1964
          py::call_guard<py::gil_scoped_release>())
1965
      .def(
1966
          "add_done_callback",
1967
          &PythonFutureWrapper::add_done_callback,
1968
          py::call_guard<py::gil_scoped_release>())
1969
      .def(
1970
          "set_result",
1971
          // Intentionally not releasing GIL
1972
          &PythonFutureWrapper::markCompleted)
1973
      .def(
1974
          "_set_unwrap_func",
1975
          // Intentionally not releasing GIL as this just does an assign
1976
          [](PythonFutureWrapper& self, py::function unwrapFunc) {
1977
            auto functionGuard =
1978
                std::make_shared<torch::jit::PythonFunctionGuard>(
1979
                    std::move(unwrapFunc));
1980

1981
            std::function<void(py::object)> pf =
1982
                [functionGuard(std::move(functionGuard))](
1983
                    const py::object& inp) {
1984
                  return functionGuard->func_(inp);
1985
                };
1986
            self.unwrap_func = std::move(pf);
1987
          })
1988
      .def(
1989
          py::pickle(
1990
              /* __getstate__ */
1991
              [](const PythonFutureWrapper& /* unused */) {
1992
                TORCH_CHECK(false, "Can not pickle torch.futures.Future");
1993
                // Note that this return has no meaning since we always
1994
                // throw, it's only here to satisfy Pybind API's
1995
                // requirement.
1996
                return py::make_tuple();
1997
              },
1998
              /* __setstate__ */
1999
              [](const py::tuple& /* unused */) { // NOLINT
2000
                TORCH_CHECK(false, "Can not unpickle torch.futures.Future");
2001
                // Note that this return has no meaning since we always
2002
                // throw, it's only here to satisfy PyBind's API
2003
                // requirement.
2004
                return nullptr;
2005
              }),
2006
          py::call_guard<py::gil_scoped_release>());
2007

2008
  py::class_<PythonAwaitWrapper, std::shared_ptr<PythonAwaitWrapper>>(
2009
      m, "_Await")
2010
      .def(
2011
          "wait",
2012
          &PythonAwaitWrapper::wait,
2013
          py::call_guard<py::gil_scoped_release>())
2014
      .def("fn", &PythonAwaitWrapper::fn)
2015
      .def("args", &PythonAwaitWrapper::args)
2016
      .def("type", &PythonAwaitWrapper::type)
2017
      .def("is_nowait", &PythonAwaitWrapper::is_nowait)
2018
      .def(
2019
          "__getattr__",
2020
          [](PythonAwaitWrapper& self, const std::string& name) -> py::object {
2021
            // In eager mode allow Await[W] to be used as W, redirecting getattr
2022
            // to the result of delayed function.
2023
            return py::getattr(self.wait(), name.c_str(), py::none());
2024
          })
2025
      .def(
2026
          py::pickle(
2027
              /* __getstate__ */
2028
              [](const PythonAwaitWrapper& /* unused */) {
2029
                TORCH_CHECK(false, "Can not pickle torch.jit._Await");
2030
                // Note that this return has no meaning since we always
2031
                // throw, it's only here to satisfy Pybind API's
2032
                // requirement.
2033
                return py::make_tuple();
2034
              },
2035
              /* __setstate__ */
2036
              [](const py::tuple& /* unused */) { // NOLINT
2037
                TORCH_CHECK(false, "Can not unpickle torch.jit._Await");
2038
                // Note that this return has no meaning since we always
2039
                // throw, it's only here to satisfy PyBind's API
2040
                // requirement.
2041
                return nullptr;
2042
              }),
2043
          py::call_guard<py::gil_scoped_release>());
2044

2045
  m.def("_is_alias_of", [](const py::object& self, const py::object& other) {
2046
    c10::optional<IValue> self_value = toTypeInferredIValueOptional(self);
2047
    c10::optional<IValue> other_value = toTypeInferredIValueOptional(other);
2048

2049
    // Only return true if we are certain that self and other are aliasing.
2050
    if (!self_value || !other_value) {
2051
      return false;
2052
    }
2053
    return self_value->isAliasOf(*other_value);
2054
  });
2055
  m.def("_overlaps", [](const py::object& self, const py::object& other) {
2056
    c10::optional<IValue> self_value = toTypeInferredIValueOptional(self);
2057
    c10::optional<IValue> other_value = toTypeInferredIValueOptional(other);
2058

2059
    // Only return true if we are certain that self and other are overlapping.
2060
    if (!self_value || !other_value) {
2061
      return false;
2062
    }
2063
    return self_value->overlaps(*other_value);
2064
  });
2065
  m.def("_awaitable", [](const py::args& args, const py::kwargs& kwargs) {
2066
    AT_ASSERT(args.size() >= 1);
2067
    py::tuple args_tup(args.size() - 1);
2068
    for (const auto i : c10::irange(1, args.size())) {
2069
      args_tup[i - 1] = args[i];
2070
    }
2071
    return std::make_shared<PythonAwaitWrapper>(
2072
        py::cast<py::function>(args[0]), std::move(args_tup));
2073
  });
2074
  m.def("_awaitable_nowait", [](py::handle input) {
2075
    return std::make_shared<PythonAwaitWrapper>(std::move(input));
2076
  });
2077
  m.def(
2078
      "_awaitable_wait", [](const std::shared_ptr<PythonAwaitWrapper>& py_aw) {
2079
        TORCH_CHECK(py_aw, "Await can't be None");
2080
        return py_aw->wait();
2081
      });
2082
  m.def("fork", [](const py::args& args, const py::kwargs& kwargs) {
2083
    AT_ASSERT(!args.empty());
2084

2085
    py::function f = py::cast<py::function>(args[0]);
2086
    py::tuple args_tup(args.size() - 1);
2087

2088
    for (const auto i : c10::irange(1, args.size())) {
2089
      args_tup[i - 1] = args[i];
2090
    }
2091

2092
    if (jit::tracer::isTracing()) {
2093
      auto graph = jit::tracer::getTracingState()->graph;
2094
      auto fork_node = graph->insertNode(graph->create(prim::TracedFork, 1));
2095
      auto body_block = fork_node->addBlock();
2096

2097
      Value* node_output = nullptr;
2098
      py::object py_func_output;
2099
      // Insert new trace ops into the fork op's sub-block
2100
      WithInsertPoint guard(body_block);
2101
      IValue output_ivalue;
2102
      {
2103
        tracer::WithNestedTracingFrame env_guard;
2104

2105
        // Run the user-supplied function
2106
        py_func_output = f(*args_tup, **kwargs);
2107

2108
        // Convert the output of the user-supplied function to IValue. The type
2109
        // information of this IValue is used both to record the correct type in
2110
        // the trace.
2111
        output_ivalue = toTypeInferredIValue(py_func_output);
2112
        Value* out_val = jit::tracer::getValueTrace(output_ivalue);
2113
        body_block->registerOutput(out_val);
2114
        node_output =
2115
            fork_node->output()->setType(FutureType::create(out_val->type()));
2116
      }
2117

2118
      auto retval =
2119
          c10::make_intrusive<c10::ivalue::Future>(output_ivalue.type());
2120

2121
      // Record the ivalue in the tracer
2122
      jit::tracer::setValueTrace(retval, node_output);
2123

2124
      // stuff the ivalue output in the Future
2125
      retval->markCompleted(output_ivalue);
2126

2127
      return std::make_shared<PythonFutureWrapper>(retval);
2128
    } else {
2129
      auto result = toTypeInferredIValue(f(*args_tup, **kwargs));
2130
      auto retval = c10::make_intrusive<c10::ivalue::Future>(result.type());
2131
      retval->markCompleted(std::move(result));
2132
      return std::make_shared<PythonFutureWrapper>(retval);
2133
    }
2134
  });
2135

2136
  m.def("wait", [](const std::shared_ptr<PythonFutureWrapper>& fut) {
2137
    TORCH_CHECK(fut, "Future can't be None");
2138
    return fut->wait();
2139
  });
2140

2141
  m.def(
2142
      "_collect_all",
2143
      [](const std::vector<std::shared_ptr<jit::PythonFutureWrapper>>& futures)
2144
          -> std::shared_ptr<jit::PythonFutureWrapper> {
2145
        auto typePtr = futures.empty() || futures[0] == nullptr
2146
            ? AnyType::get()
2147
            : futures[0]->fut->elementType();
2148
        c10::List<c10::intrusive_ptr<c10::ivalue::Future>> asList(
2149
            c10::FutureType::create(typePtr));
2150
        asList.reserve(futures.size());
2151
        for (const auto& f : futures) {
2152
          TORCH_CHECK(f, "Future can't be None");
2153
          asList.push_back(f->fut);
2154
        }
2155
        return std::make_shared<jit::PythonFutureWrapper>(
2156
            c10::collectAll(asList),
2157
            /* unwrap_func */ [futures](const py::object& /*unused*/) {
2158
              // Throw errors when calling wait() on the returned Future if
2159
              // any of the original futures would throw.
2160
              // NB: PythonFutureWrapper takes an unwrap_func which serves as a
2161
              // callback to evalute the value in the Future. RPC uses this
2162
              // unwrap_func to check whether the returned py::object is a
2163
              // RemoteException object, and re-throw the exception if it is.
2164
              // By extracting the c10::ivalue::Future from PythonFutureWrapper
2165
              // the unwrap_func on the original PythonFutureWrapper objects are
2166
              // discarded, and hence it will return the RemoteException as an
2167
              // object instead of re-throwing it.
2168
              for (auto& fut : futures) {
2169
                fut->wait();
2170
              }
2171
            });
2172
      },
2173
      py::call_guard<py::gil_scoped_release>());
2174

2175
  m.def("_jit_assert_is_instance", [](py::object obj, const TypePtr& type) {
2176
    toIValue(std::move(obj), type);
2177
  });
2178

2179
#if defined(C10_SUPPORTS_FATAL_SIGNAL_HANDLERS)
2180
  m.def("_set_print_stack_traces_on_fatal_signal", [](bool print) {
2181
    c10::FatalSignalHandler::getInstance().setPrintStackTracesOnFatalSignal(
2182
        print);
2183
  });
2184
#endif // defined(C10_SUPPORTS_SIGNAL_HANDLER)
2185

2186
  initPythonCustomClassBindings(module);
2187
  initPythonIRBindings(module);
2188
  tracer::initPythonTracerBindings(module);
2189
  initTreeViewBindings(module);
2190
  initJitScriptBindings(module);
2191
  initJitBackendBindings(module);
2192
  initStaticModuleBindings(module);
2193
  initTensorExprBindings(module);
2194
  // initNvFuserPythonBindings(module);
2195

2196
  setPrintHandler([](const std::string& str) {
2197
    py::gil_scoped_acquire acquire;
2198
    try {
2199
      auto _stdout = py::module::import("sys").attr("stdout");
2200
      _stdout.attr("write")(str);
2201
    } catch (py::error_already_set& e) {
2202
      throw std::runtime_error(e.what());
2203
    }
2204
  });
2205

2206
  // On exit we need to reset the print handler to default one,
2207
  // because otherwise prim::Print() instruction won't work for JIT modules.
2208
  auto atexit = py::module_::import("atexit");
2209
  atexit.attr("register")(
2210
      py::cpp_function([]() { setPrintHandler(getDefaultPrintHandler()); }));
2211
}
2212

2213
} // namespace torch::jit
2214

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.