1
#include <onnx/onnx_pb.h>
2
#include <torch/csrc/onnx/back_compat.h>
3
#include <torch/csrc/onnx/init.h>
4
#include <torch/csrc/onnx/onnx.h>
5
#include <torch/version.h>
7
#include <torch/csrc/Exceptions.h>
8
#include <torch/csrc/jit/passes/onnx.h>
9
#include <torch/csrc/jit/passes/onnx/cast_all_constant_to_floating.h>
10
#include <torch/csrc/jit/passes/onnx/constant_fold.h>
11
#include <torch/csrc/jit/passes/onnx/deduplicate_initializers.h>
12
#include <torch/csrc/jit/passes/onnx/eliminate_unused_items.h>
13
#include <torch/csrc/jit/passes/onnx/eval_peephole.h>
14
#include <torch/csrc/jit/passes/onnx/fixup_onnx_controlflow.h>
15
#include <torch/csrc/jit/passes/onnx/function_extraction.h>
16
#include <torch/csrc/jit/passes/onnx/function_substitution.h>
17
#include <torch/csrc/jit/passes/onnx/list_model_parameters.h>
18
#include <torch/csrc/jit/passes/onnx/naming.h>
19
#include <torch/csrc/jit/passes/onnx/onnx_log.h>
20
#include <torch/csrc/jit/passes/onnx/pattern_conversion/autograd_function_process.h>
21
#include <torch/csrc/jit/passes/onnx/pattern_conversion/pattern_conversion.h>
22
#include <torch/csrc/jit/passes/onnx/pattern_conversion/pattern_encapsulation.h>
23
#include <torch/csrc/jit/passes/onnx/peephole.h>
24
#include <torch/csrc/jit/passes/onnx/prepare_division_for_onnx.h>
25
#include <torch/csrc/jit/passes/onnx/preprocess_for_onnx.h>
26
#include <torch/csrc/jit/passes/onnx/remove_inplace_ops_for_onnx.h>
27
#include <torch/csrc/jit/passes/onnx/scalar_type_analysis.h>
28
#include <torch/csrc/jit/passes/onnx/shape_type_inference.h>
29
#include <torch/csrc/jit/passes/onnx/unpack_quantized_weights.h>
30
#include <torch/csrc/jit/serialization/export.h>
32
namespace torch::onnx {
34
using namespace torch::jit;
36
void initONNXBindings(PyObject* module) {
37
auto m = py::handle(module).cast<py::module>();
39
// ONNX specific passes
40
m.def("_jit_pass_onnx_remove_print", RemovePrintOps)
41
.def("_jit_pass_onnx_preprocess_caffe2", PreprocessCaffe2Ops)
42
.def("_jit_pass_onnx", ToONNX)
44
"_jit_pass_onnx_assign_output_shape",
45
::torch::wrap_pybind_function(
46
[](std::shared_ptr<Graph>& graph,
47
const std::vector<at::Tensor>& tensors,
48
const python::IODescriptor& desc,
49
bool onnx_shape_inference,
52
ONNXAssignOutputShape(
61
"_jit_pass_onnx_function_substitution",
62
wrap_pybind_function(ONNXFunctionCallSubstitution))
64
"_jit_pass_onnx_autograd_function_process",
65
wrap_pybind_function(ONNXAutogradFunctionProcess))
67
"_jit_pass_onnx_peephole",
68
::torch::wrap_pybind_function([](std::shared_ptr<Graph>& graph,
70
bool fixed_batch_size) {
71
return PeepholeOptimizeONNX(graph, opset_version, fixed_batch_size);
74
"_jit_pass_onnx_preprocess",
75
::torch::wrap_pybind_function(PreprocessForONNX))
77
"_jit_pass_onnx_eval_peephole",
78
::torch::wrap_pybind_function(
79
[](std::shared_ptr<Graph>& graph,
80
std::map<std::string, IValue>& paramsDict) {
81
EvalPeepholeONNX(graph, paramsDict);
84
pybind11::return_value_policy::move)
86
"_jit_pass_onnx_cast_all_constant_to_floating",
87
::torch::wrap_pybind_function(CastAllConstantToFloating))
89
"_jit_pass_onnx_constant_fold",
90
::torch::wrap_pybind_function(
91
[](std::shared_ptr<Graph>& graph,
92
std::map<std::string, IValue>& paramsDict,
97
opset_version); // overload resolution
100
pybind11::return_value_policy::move)
102
"_jit_pass_onnx_eliminate_unused_items",
103
::torch::wrap_pybind_function(
104
[](std::shared_ptr<Graph>& graph,
105
std::map<std::string, IValue>& paramsDict) {
106
EliminateUnusedItemsONNX(
108
paramsDict); // overload resolution
111
pybind11::return_value_policy::move)
113
"_jit_pass_onnx_scalar_type_analysis",
114
::torch::wrap_pybind_function([](std::shared_ptr<Graph>& graph,
115
bool lowprecision_cast,
117
return ScalarTypeAnalysisForONNX(
118
graph, lowprecision_cast, opset_version);
121
py::arg("lowprecision_cast") = true,
122
py::arg("opset_version"))
124
"_jit_pass_onnx_remove_inplace_ops_for_onnx",
125
::torch::wrap_pybind_function(RemoveInplaceOpsForONNX))
127
"_jit_pass_onnx_node_shape_type_inference",
128
::torch::wrap_pybind_function(
130
std::map<std::string, IValue>& params_dict,
132
ONNXShapeTypeInference(n, params_dict, opset_version);
135
"_jit_pass_onnx_graph_shape_type_inference",
136
::torch::wrap_pybind_function(
137
[](std::shared_ptr<Graph>& graph,
138
std::map<std::string, IValue>& params_dict,
140
ONNXShapeTypeInference(graph, params_dict, opset_version);
143
py::arg("params_dict"),
144
py::arg("opset_version"))
146
"_jit_pass_onnx_set_dynamic_input_shape",
147
::torch::wrap_pybind_function(ONNXSetDynamicInputShape))
148
.def("_jit_pass_onnx_lint", torch::wrap_pybind_function(ONNXLintGraph))
150
"_jit_pass_onnx_function_extraction",
151
::torch::wrap_pybind_function(
152
torch::jit::onnx::ONNXFunctionExtraction))
153
.def("_jit_pass_onnx_block", torch::wrap_pybind_function(BlockToONNX))
155
"_jit_pass_onnx_unpack_quantized_weights",
156
::torch::wrap_pybind_function(
157
[](std::shared_ptr<Graph>& graph,
158
std::map<std::string, IValue>& paramsDict,
160
UnpackQuantizedWeights(graph, paramsDict, caffe2);
163
pybind11::return_value_policy::move)
165
"_jit_pass_onnx_quantization_insert_permutes",
166
::torch::wrap_pybind_function(
167
[](std::shared_ptr<Graph>& graph,
168
std::map<std::string, IValue>& paramsDict) {
169
insertPermutes(graph, paramsDict);
172
pybind11::return_value_policy::move)
174
"_jit_onnx_list_model_parameters",
175
::torch::wrap_pybind_function(
176
[](Module& module) { return list_module_parameters(module); }))
178
"_jit_pass_prepare_division_for_onnx",
179
::torch::wrap_pybind_function(PrepareDivisionForONNX))
181
"_jit_onnx_convert_pattern_from_subblock",
182
::torch::wrap_pybind_function(ConvertPatternFromSubblock))
184
"_jit_pass_fixup_onnx_controlflow_node",
185
::torch::wrap_pybind_function(FixupONNXControlflowNode))
187
"_jit_pass_onnx_deduplicate_initializers",
188
::torch::wrap_pybind_function(
189
[](std::shared_ptr<Graph>& graph,
190
std::map<std::string, IValue> params_dict,
192
DeduplicateInitializers(graph, params_dict, is_train);
195
pybind11::return_value_policy::move)
197
"_jit_pass_onnx_clear_scope_records",
198
&torch::jit::onnx::ONNXClearScopeRecords)
200
"_jit_pass_onnx_track_scope_attributes",
201
&torch::jit::onnx::ONNXTrackScopeAttributes)
203
"_jit_is_onnx_log_enabled",
204
::torch::jit::onnx::is_log_enabled,
205
"Returns whether ONNX logging is enabled or disabled.")
207
"_jit_set_onnx_log_enabled",
208
::torch::jit::onnx::set_log_enabled,
209
"Enables or disables ONNX logging.")
211
"_jit_set_onnx_log_output_stream",
212
[](const std::string& stream_name = "stdout") -> void {
213
std::shared_ptr<std::ostream> out;
214
if (stream_name == "stdout") {
215
out = std::shared_ptr<std::ostream>(
216
&std::cout, [](std::ostream*) {});
217
} else if (stream_name == "stderr") {
218
out = std::shared_ptr<std::ostream>(
219
&std::cerr, [](std::ostream*) {});
221
std::cerr << "ERROR: only `stdout` and `stderr`"
222
<< "are supported as `stream_name`" << std::endl;
224
::torch::jit::onnx::set_log_output_stream(out);
226
"Set specific file stream for ONNX logging.")
229
[](const py::args& args) -> void {
230
if (::torch::jit::onnx::is_log_enabled()) {
231
auto& out = ::torch::jit::onnx::_get_log_output_stream();
232
for (auto arg : args) {
233
out << ::c10::str(arg);
238
"Write `args` to the previously specified ONNX log stream.")
240
"_jit_pass_onnx_assign_scoped_names_for_node_and_value",
241
::torch::wrap_pybind_function(
242
::torch::jit::onnx::AssignScopedNamesForNodeAndValue),
243
"Assign informative scoped names for nodes and values.")
245
"_jit_onnx_create_full_scope_name",
246
::torch::wrap_pybind_function(
247
::torch::jit::onnx::ONNXScopeName::createFullScopeName),
248
"Create a full scope name from class name and variable name.");
252
::torch::wrap_pybind_function([](const std::string& proto_string) {
253
check_onnx_proto(proto_string);
255
py::arg("proto_string"));
257
auto onnx = m.def_submodule("_onnx");
258
py::enum_<::ONNX_NAMESPACE::TensorProto_DataType>(onnx, "TensorProtoDataType")
259
.value("UNDEFINED", ::ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED)
260
.value("FLOAT", ::ONNX_NAMESPACE::TensorProto_DataType_FLOAT)
261
.value("UINT8", ::ONNX_NAMESPACE::TensorProto_DataType_UINT8)
262
.value("INT8", ::ONNX_NAMESPACE::TensorProto_DataType_INT8)
263
.value("UINT16", ::ONNX_NAMESPACE::TensorProto_DataType_UINT16)
264
.value("INT16", ::ONNX_NAMESPACE::TensorProto_DataType_INT16)
265
.value("INT32", ::ONNX_NAMESPACE::TensorProto_DataType_INT32)
266
.value("INT64", ::ONNX_NAMESPACE::TensorProto_DataType_INT64)
267
.value("STRING", ::ONNX_NAMESPACE::TensorProto_DataType_STRING)
268
.value("BOOL", ::ONNX_NAMESPACE::TensorProto_DataType_BOOL)
269
.value("FLOAT16", ::ONNX_NAMESPACE::TensorProto_DataType_FLOAT16)
270
.value("DOUBLE", ::ONNX_NAMESPACE::TensorProto_DataType_DOUBLE)
271
.value("UINT32", ::ONNX_NAMESPACE::TensorProto_DataType_UINT32)
272
.value("UINT64", ::ONNX_NAMESPACE::TensorProto_DataType_UINT64)
273
.value("COMPLEX64", ::ONNX_NAMESPACE::TensorProto_DataType_COMPLEX64)
274
.value("COMPLEX128", ::ONNX_NAMESPACE::TensorProto_DataType_COMPLEX128)
275
.value("BFLOAT16", ::ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16)
276
.value("FLOAT8E4M3FN", ::torch::onnx::TensorProto_DataType_FLOAT8E4M3FN)
278
"FLOAT8E4M3FNUZ", ::torch::onnx::TensorProto_DataType_FLOAT8E4M3FNUZ)
279
.value("FLOAT8E5M2", ::torch::onnx::TensorProto_DataType_FLOAT8E5M2)
281
"FLOAT8E5M2FNUZ", ::torch::onnx::TensorProto_DataType_FLOAT8E5M2FNUZ);
283
py::enum_<OperatorExportTypes>(onnx, "OperatorExportTypes")
284
.value("ONNX", OperatorExportTypes::ONNX)
285
.value("ONNX_ATEN", OperatorExportTypes::ONNX_ATEN)
286
.value("ONNX_ATEN_FALLBACK", OperatorExportTypes::ONNX_ATEN_FALLBACK)
287
.value("ONNX_FALLTHROUGH", OperatorExportTypes::ONNX_FALLTHROUGH);
289
py::enum_<TrainingMode>(onnx, "TrainingMode")
290
.value("EVAL", TrainingMode::EVAL)
291
.value("PRESERVE", TrainingMode::PRESERVE)
292
.value("TRAINING", TrainingMode::TRAINING);
294
onnx.attr("PRODUCER_VERSION") = py::str(TORCH_VERSION);
297
onnx.attr("_CAFFE2_ATEN_FALLBACK") = true;
299
onnx.attr("_CAFFE2_ATEN_FALLBACK") = false;
302
} // namespace torch::onnx