pytorch

Форк
0
/
init.cpp 
302 строки · 12.8 Кб
1
#include <onnx/onnx_pb.h>
2
#include <torch/csrc/onnx/back_compat.h>
3
#include <torch/csrc/onnx/init.h>
4
#include <torch/csrc/onnx/onnx.h>
5
#include <torch/version.h>
6

7
#include <torch/csrc/Exceptions.h>
8
#include <torch/csrc/jit/passes/onnx.h>
9
#include <torch/csrc/jit/passes/onnx/cast_all_constant_to_floating.h>
10
#include <torch/csrc/jit/passes/onnx/constant_fold.h>
11
#include <torch/csrc/jit/passes/onnx/deduplicate_initializers.h>
12
#include <torch/csrc/jit/passes/onnx/eliminate_unused_items.h>
13
#include <torch/csrc/jit/passes/onnx/eval_peephole.h>
14
#include <torch/csrc/jit/passes/onnx/fixup_onnx_controlflow.h>
15
#include <torch/csrc/jit/passes/onnx/function_extraction.h>
16
#include <torch/csrc/jit/passes/onnx/function_substitution.h>
17
#include <torch/csrc/jit/passes/onnx/list_model_parameters.h>
18
#include <torch/csrc/jit/passes/onnx/naming.h>
19
#include <torch/csrc/jit/passes/onnx/onnx_log.h>
20
#include <torch/csrc/jit/passes/onnx/pattern_conversion/autograd_function_process.h>
21
#include <torch/csrc/jit/passes/onnx/pattern_conversion/pattern_conversion.h>
22
#include <torch/csrc/jit/passes/onnx/pattern_conversion/pattern_encapsulation.h>
23
#include <torch/csrc/jit/passes/onnx/peephole.h>
24
#include <torch/csrc/jit/passes/onnx/prepare_division_for_onnx.h>
25
#include <torch/csrc/jit/passes/onnx/preprocess_for_onnx.h>
26
#include <torch/csrc/jit/passes/onnx/remove_inplace_ops_for_onnx.h>
27
#include <torch/csrc/jit/passes/onnx/scalar_type_analysis.h>
28
#include <torch/csrc/jit/passes/onnx/shape_type_inference.h>
29
#include <torch/csrc/jit/passes/onnx/unpack_quantized_weights.h>
30
#include <torch/csrc/jit/serialization/export.h>
31

32
namespace torch::onnx {
33

34
using namespace torch::jit;
35

36
void initONNXBindings(PyObject* module) {
37
  auto m = py::handle(module).cast<py::module>();
38

39
  // ONNX specific passes
40
  m.def("_jit_pass_onnx_remove_print", RemovePrintOps)
41
      .def("_jit_pass_onnx_preprocess_caffe2", PreprocessCaffe2Ops)
42
      .def("_jit_pass_onnx", ToONNX)
43
      .def(
44
          "_jit_pass_onnx_assign_output_shape",
45
          ::torch::wrap_pybind_function(
46
              [](std::shared_ptr<Graph>& graph,
47
                 const std::vector<at::Tensor>& tensors,
48
                 const python::IODescriptor& desc,
49
                 bool onnx_shape_inference,
50
                 bool is_script,
51
                 int opset_version) {
52
                ONNXAssignOutputShape(
53
                    graph,
54
                    tensors,
55
                    desc,
56
                    onnx_shape_inference,
57
                    is_script,
58
                    opset_version);
59
              }))
60
      .def(
61
          "_jit_pass_onnx_function_substitution",
62
          wrap_pybind_function(ONNXFunctionCallSubstitution))
63
      .def(
64
          "_jit_pass_onnx_autograd_function_process",
65
          wrap_pybind_function(ONNXAutogradFunctionProcess))
66
      .def(
67
          "_jit_pass_onnx_peephole",
68
          ::torch::wrap_pybind_function([](std::shared_ptr<Graph>& graph,
69
                                           int opset_version,
70
                                           bool fixed_batch_size) {
71
            return PeepholeOptimizeONNX(graph, opset_version, fixed_batch_size);
72
          }))
73
      .def(
74
          "_jit_pass_onnx_preprocess",
75
          ::torch::wrap_pybind_function(PreprocessForONNX))
76
      .def(
77
          "_jit_pass_onnx_eval_peephole",
78
          ::torch::wrap_pybind_function(
79
              [](std::shared_ptr<Graph>& graph,
80
                 std::map<std::string, IValue>& paramsDict) {
81
                EvalPeepholeONNX(graph, paramsDict);
82
                return paramsDict;
83
              }),
84
          pybind11::return_value_policy::move)
85
      .def(
86
          "_jit_pass_onnx_cast_all_constant_to_floating",
87
          ::torch::wrap_pybind_function(CastAllConstantToFloating))
88
      .def(
89
          "_jit_pass_onnx_constant_fold",
90
          ::torch::wrap_pybind_function(
91
              [](std::shared_ptr<Graph>& graph,
92
                 std::map<std::string, IValue>& paramsDict,
93
                 int opset_version) {
94
                ConstantFoldONNX(
95
                    graph,
96
                    paramsDict,
97
                    opset_version); // overload resolution
98
                return paramsDict;
99
              }),
100
          pybind11::return_value_policy::move)
101
      .def(
102
          "_jit_pass_onnx_eliminate_unused_items",
103
          ::torch::wrap_pybind_function(
104
              [](std::shared_ptr<Graph>& graph,
105
                 std::map<std::string, IValue>& paramsDict) {
106
                EliminateUnusedItemsONNX(
107
                    graph->block(),
108
                    paramsDict); // overload resolution
109
                return paramsDict;
110
              }),
111
          pybind11::return_value_policy::move)
112
      .def(
113
          "_jit_pass_onnx_scalar_type_analysis",
114
          ::torch::wrap_pybind_function([](std::shared_ptr<Graph>& graph,
115
                                           bool lowprecision_cast,
116
                                           int opset_version) {
117
            return ScalarTypeAnalysisForONNX(
118
                graph, lowprecision_cast, opset_version);
119
          }),
120
          py::arg("graph"),
121
          py::arg("lowprecision_cast") = true,
122
          py::arg("opset_version"))
123
      .def(
124
          "_jit_pass_onnx_remove_inplace_ops_for_onnx",
125
          ::torch::wrap_pybind_function(RemoveInplaceOpsForONNX))
126
      .def(
127
          "_jit_pass_onnx_node_shape_type_inference",
128
          ::torch::wrap_pybind_function(
129
              [](Node* n,
130
                 std::map<std::string, IValue>& params_dict,
131
                 int opset_version) {
132
                ONNXShapeTypeInference(n, params_dict, opset_version);
133
              }))
134
      .def(
135
          "_jit_pass_onnx_graph_shape_type_inference",
136
          ::torch::wrap_pybind_function(
137
              [](std::shared_ptr<Graph>& graph,
138
                 std::map<std::string, IValue>& params_dict,
139
                 int opset_version) {
140
                ONNXShapeTypeInference(graph, params_dict, opset_version);
141
              }),
142
          py::arg("graph"),
143
          py::arg("params_dict"),
144
          py::arg("opset_version"))
145
      .def(
146
          "_jit_pass_onnx_set_dynamic_input_shape",
147
          ::torch::wrap_pybind_function(ONNXSetDynamicInputShape))
148
      .def("_jit_pass_onnx_lint", torch::wrap_pybind_function(ONNXLintGraph))
149
      .def(
150
          "_jit_pass_onnx_function_extraction",
151
          ::torch::wrap_pybind_function(
152
              torch::jit::onnx::ONNXFunctionExtraction))
153
      .def("_jit_pass_onnx_block", torch::wrap_pybind_function(BlockToONNX))
154
      .def(
155
          "_jit_pass_onnx_unpack_quantized_weights",
156
          ::torch::wrap_pybind_function(
157
              [](std::shared_ptr<Graph>& graph,
158
                 std::map<std::string, IValue>& paramsDict,
159
                 bool caffe2) {
160
                UnpackQuantizedWeights(graph, paramsDict, caffe2);
161
                return paramsDict;
162
              }),
163
          pybind11::return_value_policy::move)
164
      .def(
165
          "_jit_pass_onnx_quantization_insert_permutes",
166
          ::torch::wrap_pybind_function(
167
              [](std::shared_ptr<Graph>& graph,
168
                 std::map<std::string, IValue>& paramsDict) {
169
                insertPermutes(graph, paramsDict);
170
                return paramsDict;
171
              }),
172
          pybind11::return_value_policy::move)
173
      .def(
174
          "_jit_onnx_list_model_parameters",
175
          ::torch::wrap_pybind_function(
176
              [](Module& module) { return list_module_parameters(module); }))
177
      .def(
178
          "_jit_pass_prepare_division_for_onnx",
179
          ::torch::wrap_pybind_function(PrepareDivisionForONNX))
180
      .def(
181
          "_jit_onnx_convert_pattern_from_subblock",
182
          ::torch::wrap_pybind_function(ConvertPatternFromSubblock))
183
      .def(
184
          "_jit_pass_fixup_onnx_controlflow_node",
185
          ::torch::wrap_pybind_function(FixupONNXControlflowNode))
186
      .def(
187
          "_jit_pass_onnx_deduplicate_initializers",
188
          ::torch::wrap_pybind_function(
189
              [](std::shared_ptr<Graph>& graph,
190
                 std::map<std::string, IValue> params_dict,
191
                 bool is_train) {
192
                DeduplicateInitializers(graph, params_dict, is_train);
193
                return params_dict;
194
              }),
195
          pybind11::return_value_policy::move)
196
      .def(
197
          "_jit_pass_onnx_clear_scope_records",
198
          &torch::jit::onnx::ONNXClearScopeRecords)
199
      .def(
200
          "_jit_pass_onnx_track_scope_attributes",
201
          &torch::jit::onnx::ONNXTrackScopeAttributes)
202
      .def(
203
          "_jit_is_onnx_log_enabled",
204
          ::torch::jit::onnx::is_log_enabled,
205
          "Returns whether ONNX logging is enabled or disabled.")
206
      .def(
207
          "_jit_set_onnx_log_enabled",
208
          ::torch::jit::onnx::set_log_enabled,
209
          "Enables or disables ONNX logging.")
210
      .def(
211
          "_jit_set_onnx_log_output_stream",
212
          [](const std::string& stream_name = "stdout") -> void {
213
            std::shared_ptr<std::ostream> out;
214
            if (stream_name == "stdout") {
215
              out = std::shared_ptr<std::ostream>(
216
                  &std::cout, [](std::ostream*) {});
217
            } else if (stream_name == "stderr") {
218
              out = std::shared_ptr<std::ostream>(
219
                  &std::cerr, [](std::ostream*) {});
220
            } else {
221
              std::cerr << "ERROR: only `stdout` and `stderr`"
222
                        << "are supported as `stream_name`" << std::endl;
223
            }
224
            ::torch::jit::onnx::set_log_output_stream(out);
225
          },
226
          "Set specific file stream for ONNX logging.")
227
      .def(
228
          "_jit_onnx_log",
229
          [](const py::args& args) -> void {
230
            if (::torch::jit::onnx::is_log_enabled()) {
231
              auto& out = ::torch::jit::onnx::_get_log_output_stream();
232
              for (auto arg : args) {
233
                out << ::c10::str(arg);
234
              }
235
              out << std::endl;
236
            }
237
          },
238
          "Write `args` to the previously specified ONNX log stream.")
239
      .def(
240
          "_jit_pass_onnx_assign_scoped_names_for_node_and_value",
241
          ::torch::wrap_pybind_function(
242
              ::torch::jit::onnx::AssignScopedNamesForNodeAndValue),
243
          "Assign informative scoped names for nodes and values.")
244
      .def(
245
          "_jit_onnx_create_full_scope_name",
246
          ::torch::wrap_pybind_function(
247
              ::torch::jit::onnx::ONNXScopeName::createFullScopeName),
248
          "Create a full scope name from class name and variable name.");
249

250
  m.def(
251
      "_check_onnx_proto",
252
      ::torch::wrap_pybind_function([](const std::string& proto_string) {
253
        check_onnx_proto(proto_string);
254
      }),
255
      py::arg("proto_string"));
256

257
  auto onnx = m.def_submodule("_onnx");
258
  py::enum_<::ONNX_NAMESPACE::TensorProto_DataType>(onnx, "TensorProtoDataType")
259
      .value("UNDEFINED", ::ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED)
260
      .value("FLOAT", ::ONNX_NAMESPACE::TensorProto_DataType_FLOAT)
261
      .value("UINT8", ::ONNX_NAMESPACE::TensorProto_DataType_UINT8)
262
      .value("INT8", ::ONNX_NAMESPACE::TensorProto_DataType_INT8)
263
      .value("UINT16", ::ONNX_NAMESPACE::TensorProto_DataType_UINT16)
264
      .value("INT16", ::ONNX_NAMESPACE::TensorProto_DataType_INT16)
265
      .value("INT32", ::ONNX_NAMESPACE::TensorProto_DataType_INT32)
266
      .value("INT64", ::ONNX_NAMESPACE::TensorProto_DataType_INT64)
267
      .value("STRING", ::ONNX_NAMESPACE::TensorProto_DataType_STRING)
268
      .value("BOOL", ::ONNX_NAMESPACE::TensorProto_DataType_BOOL)
269
      .value("FLOAT16", ::ONNX_NAMESPACE::TensorProto_DataType_FLOAT16)
270
      .value("DOUBLE", ::ONNX_NAMESPACE::TensorProto_DataType_DOUBLE)
271
      .value("UINT32", ::ONNX_NAMESPACE::TensorProto_DataType_UINT32)
272
      .value("UINT64", ::ONNX_NAMESPACE::TensorProto_DataType_UINT64)
273
      .value("COMPLEX64", ::ONNX_NAMESPACE::TensorProto_DataType_COMPLEX64)
274
      .value("COMPLEX128", ::ONNX_NAMESPACE::TensorProto_DataType_COMPLEX128)
275
      .value("BFLOAT16", ::ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16)
276
      .value("FLOAT8E4M3FN", ::torch::onnx::TensorProto_DataType_FLOAT8E4M3FN)
277
      .value(
278
          "FLOAT8E4M3FNUZ", ::torch::onnx::TensorProto_DataType_FLOAT8E4M3FNUZ)
279
      .value("FLOAT8E5M2", ::torch::onnx::TensorProto_DataType_FLOAT8E5M2)
280
      .value(
281
          "FLOAT8E5M2FNUZ", ::torch::onnx::TensorProto_DataType_FLOAT8E5M2FNUZ);
282

283
  py::enum_<OperatorExportTypes>(onnx, "OperatorExportTypes")
284
      .value("ONNX", OperatorExportTypes::ONNX)
285
      .value("ONNX_ATEN", OperatorExportTypes::ONNX_ATEN)
286
      .value("ONNX_ATEN_FALLBACK", OperatorExportTypes::ONNX_ATEN_FALLBACK)
287
      .value("ONNX_FALLTHROUGH", OperatorExportTypes::ONNX_FALLTHROUGH);
288

289
  py::enum_<TrainingMode>(onnx, "TrainingMode")
290
      .value("EVAL", TrainingMode::EVAL)
291
      .value("PRESERVE", TrainingMode::PRESERVE)
292
      .value("TRAINING", TrainingMode::TRAINING);
293

294
  onnx.attr("PRODUCER_VERSION") = py::str(TORCH_VERSION);
295

296
#ifdef BUILD_CAFFE2
297
  onnx.attr("_CAFFE2_ATEN_FALLBACK") = true;
298
#else
299
  onnx.attr("_CAFFE2_ATEN_FALLBACK") = false;
300
#endif
301
}
302
} // namespace torch::onnx
303

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.