onnxruntime
1056 строк · 35.5 Кб
1// Copyright (c) Microsoft Corporation. All rights reserved.
2// Licensed under the MIT License.
3
4#pragma once
5#include "adapter/pch.h"
6
7#include "winml_adapter_model.h"
8
9#include "winml_adapter_c_api.h"
10#include "core/graph/onnx_protobuf.h"
11#include "core/session/ort_apis.h"
12#include "winml_adapter_apis.h"
13#include "core/framework/error_code_helper.h"
14#include "core/common/common.h"
15
16#include <io.h>
17#include <fcntl.h>
18#include "google/protobuf/io/zero_copy_stream_impl.h"
19#include "core/framework/onnxruntime_typeinfo.h"
20
21#include "onnx/defs/schema.h"
22#include "core/framework/tensor_type_and_shape.h"
23
24#include "onnx/onnx-ml.pb.h"
25
26namespace winmla = Windows::AI::MachineLearning::Adapter;
27
28static std::vector<const char*> GetInitializers(const ONNX_NAMESPACE::ModelProto& model_proto) {
29std::vector<const char*> initializers;
30auto& graph = model_proto.graph();
31auto& graph_initializers = graph.initializer();
32for (auto& initializer : graph_initializers) {
33initializers.push_back(initializer.name().c_str());
34}
35return initializers;
36}
37
38static std::vector<const ONNX_NAMESPACE::ValueInfoProto*> GetInputsWithoutInitializers(
39const ONNX_NAMESPACE::ModelProto& model_proto
40) {
41auto initializers = GetInitializers(model_proto);
42
43std::vector<const ONNX_NAMESPACE::ValueInfoProto*> inputs_without_initializers;
44auto& graph = model_proto.graph();
45auto& inputs = graph.input();
46for (auto& input : inputs) {
47if (input.has_name() && input.has_type()) {
48auto found_it = std::find_if(std::begin(initializers), std::end(initializers), [&](auto& initializer) {
49return std::strcmp(initializer, input.name().c_str()) == 0;
50});
51
52auto is_initializer = found_it != std::end(initializers);
53if (!is_initializer) {
54inputs_without_initializers.push_back(&input);
55}
56}
57}
58return inputs_without_initializers;
59}
60
61static std::vector<const ONNX_NAMESPACE::ValueInfoProto*> GetOutputs(const ONNX_NAMESPACE::ModelProto& model_proto) {
62std::vector<const ONNX_NAMESPACE::ValueInfoProto*> outputs_with_name;
63auto& graph = model_proto.graph();
64auto& outputs = graph.output();
65for (auto& output : outputs) {
66if (output.has_name() && output.has_type()) {
67outputs_with_name.push_back(&output);
68}
69}
70return outputs_with_name;
71}
72
73class ModelInfo {
74public:
75ModelInfo(const ONNX_NAMESPACE::ModelProto* model_proto) { Initialize(model_proto); }
76
77public:
78// model metadata
79std::string author_;
80std::string name_;
81std::string domain_;
82std::string description_;
83int64_t version_;
84std::vector<std::pair<std::string, std::string>> model_metadata_;
85std::vector<const ONNX_NAMESPACE::ValueInfoProto*> input_features_;
86std::vector<const ONNX_NAMESPACE::ValueInfoProto*> output_features_;
87bool requires_float16_support_;
88
89private:
90void Initialize(const ONNX_NAMESPACE::ModelProto* model_proto) {
91for (auto& prop : model_proto->metadata_props()) {
92model_metadata_.push_back(std::make_pair(prop.key(), prop.value()));
93}
94
95input_features_ = GetInputsWithoutInitializers(*model_proto);
96output_features_ = ::GetOutputs(*model_proto);
97
98auto has_producer_name = model_proto->has_producer_name();
99author_ = has_producer_name ? model_proto->producer_name() : "";
100
101auto has_domain = model_proto->has_domain();
102domain_ = has_domain ? model_proto->domain() : "";
103
104auto has_graph = model_proto->has_graph();
105auto graph_has_name = model_proto->graph().has_name();
106auto is_name_available = has_graph && graph_has_name;
107name_ = is_name_available ? model_proto->graph().name() : "";
108
109auto has_description = model_proto->has_doc_string();
110description_ = has_description ? model_proto->doc_string() : "";
111
112auto has_version = model_proto->has_model_version();
113version_ = has_version ? model_proto->model_version() : 0;
114}
115};
116
117OrtModel::OrtModel(std::unique_ptr<ONNX_NAMESPACE::ModelProto> model_proto)
118: model_proto_(std::move(model_proto)),
119model_info_(std::make_unique<ModelInfo>(model_proto_.get())) {
120}
121
122// factory methods for creating an ort model from a path
123static OrtStatus* CreateModelProto(const char* path, std::unique_ptr<ONNX_NAMESPACE::ModelProto>& out) {
124int file_descriptor;
125
126auto path_str = std::string(path);
127auto wide_path = onnxruntime::ToWideString(path_str);
128
129_set_errno(0); // clear errno
130_wsopen_s(
131&file_descriptor, wide_path.c_str(), O_RDONLY | _O_SEQUENTIAL | _O_BINARY, _SH_DENYWR, _S_IREAD | _S_IWRITE
132);
133
134errno_t err = 0;
135_get_errno(&err);
136if (err == ENOENT) {
137return OrtApis::CreateStatus(ORT_NO_SUCHFILE, "Model file not found!");
138}
139
140if (0 > file_descriptor) {
141return OrtApis::CreateStatus(ORT_NO_SUCHFILE, "Model file not found!");
142}
143
144google::protobuf::io::FileInputStream stream(file_descriptor);
145stream.SetCloseOnDelete(true);
146
147auto model_proto = std::unique_ptr<ONNX_NAMESPACE::ModelProto>(new ONNX_NAMESPACE::ModelProto());
148
149auto parse_succeeded = model_proto->ParseFromZeroCopyStream(&stream);
150if (!parse_succeeded) {
151return OrtApis::CreateStatus(ORT_INVALID_PROTOBUF, "Failed to parse model file!");
152}
153
154out = std::move(model_proto);
155
156return S_OK;
157}
158
159OrtStatus* OrtModel::CreateEmptyModel(int64_t opset, OrtModel** model) {
160auto model_proto = std::unique_ptr<ONNX_NAMESPACE::ModelProto>(new ONNX_NAMESPACE::ModelProto());
161auto opsetimportproto = model_proto->add_opset_import();
162opsetimportproto->set_version(opset);
163model_proto->set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION);
164return OrtModel::CreateOrtModelFromProto(std::move(model_proto), model);
165}
166
167OrtStatus* OrtModel::CreateOrtModelFromPath(const char* path, size_t len, OrtModel** model) {
168ORT_UNUSED_PARAMETER(len);
169
170std::unique_ptr<ONNX_NAMESPACE::ModelProto> model_proto;
171
172if (auto status = CreateModelProto(path, model_proto)) {
173return status;
174}
175
176return OrtModel::CreateOrtModelFromProto(std::move(model_proto), model);
177}
178
179OrtStatus* OrtModel::CreateOrtModelFromData(void* data, size_t len, OrtModel** model) {
180auto model_proto = std::unique_ptr<ONNX_NAMESPACE::ModelProto>(new ONNX_NAMESPACE::ModelProto());
181
182auto parse_succeeded = model_proto->ParseFromArray(data, static_cast<int>(len));
183if (!parse_succeeded) {
184return OrtApis::CreateStatus(ORT_INVALID_PROTOBUF, "Failed to parse model stream!");
185}
186
187return OrtModel::CreateOrtModelFromProto(std::move(model_proto), model);
188}
189
190OrtStatus* OrtModel::CreateOrtModelFromProto(
191std::unique_ptr<ONNX_NAMESPACE::ModelProto>&& model_proto, OrtModel** model
192) {
193*model = new (std::nothrow) OrtModel(std::move(model_proto));
194if (*model == nullptr) {
195return OrtApis::CreateStatus(ORT_ENGINE_ERROR, "Engine failed to create a model!");
196}
197
198return nullptr;
199}
200
201const ModelInfo* OrtModel::UseModelInfo() const {
202return model_info_.get();
203}
204
205ONNX_NAMESPACE::ModelProto* OrtModel::UseModelProto() const {
206return model_proto_.get();
207}
208
209std::unique_ptr<ONNX_NAMESPACE::ModelProto> OrtModel::DetachModelProto() {
210return std::move(model_proto_);
211}
212
213void OrtModel::RefreshModelInfo() {
214auto new_info = std::make_unique<ModelInfo>(model_proto_.get());
215model_info_->author_ = std::move(new_info->author_);
216model_info_->description_ = std::move(new_info->description_);
217model_info_->domain_ = std::move(new_info->domain_);
218model_info_->input_features_ = std::move(new_info->input_features_);
219model_info_->model_metadata_ = std::move(new_info->model_metadata_);
220model_info_->name_ = std::move(new_info->name_);
221model_info_->output_features_ = std::move(new_info->output_features_);
222model_info_->requires_float16_support_ = std::move(new_info->requires_float16_support_);
223model_info_->version_ = std::move(new_info->version_);
224}
225
226ORT_API_STATUS_IMPL(
227winmla::CreateModelFromPath, _In_ const char* model_path, _In_ size_t size, _Outptr_ OrtModel** out
228) {
229API_IMPL_BEGIN
230if (auto status = OrtModel::CreateOrtModelFromPath(model_path, size, out)) {
231return status;
232}
233return nullptr;
234API_IMPL_END
235}
236
237ORT_API_STATUS_IMPL(winmla::CreateModelFromData, _In_opt_ void* data, _In_ size_t size, _Outptr_ OrtModel** out) {
238API_IMPL_BEGIN
239if (auto status = OrtModel::CreateOrtModelFromData(data, size, out)) {
240return status;
241}
242return nullptr;
243API_IMPL_END
244}
245
246ORT_API_STATUS_IMPL(winmla::CloneModel, _In_ const OrtModel* in, _Outptr_ OrtModel** out) {
247API_IMPL_BEGIN
248auto model_proto_copy = std::make_unique<ONNX_NAMESPACE::ModelProto>(*in->UseModelProto());
249if (auto status = OrtModel::CreateOrtModelFromProto(std::move(model_proto_copy), out)) {
250return status;
251}
252return nullptr;
253API_IMPL_END
254}
255
256ORT_API_STATUS_IMPL(winmla::SaveModel, _In_ const OrtModel* in, _In_ const wchar_t* const file_name, _In_ size_t len) {
257API_IMPL_BEGIN
258int fd;
259std::wstring file_path = file_name;
260Status status = onnxruntime::Env::Default().FileOpenWr(file_path, fd);
261if (fd < 0) {
262return OrtApis::CreateStatus(ORT_NO_SUCHFILE, "File not found!");
263}
264
265auto model_proto = in->UseModelProto();
266google::protobuf::io::FileOutputStream output(fd);
267const bool success = model_proto->SerializeToZeroCopyStream(&output) && output.Flush();
268if (!success) {
269return OrtApis::CreateStatus(ORT_RUNTIME_EXCEPTION, "Failed to serialize model!");
270}
271output.Close();
272return nullptr;
273API_IMPL_END
274}
275
276ORT_API_STATUS_IMPL(
277winmla::ModelGetAuthor, _In_ const OrtModel* model, _Out_ const char** const author, _Out_ size_t* len
278) {
279API_IMPL_BEGIN
280*author = model->UseModelInfo()->author_.c_str();
281*len = model->UseModelInfo()->author_.size();
282return nullptr;
283API_IMPL_END
284}
285
286ORT_API_STATUS_IMPL(
287winmla::ModelGetName, _In_ const OrtModel* model, _Out_ const char** const name, _Out_ size_t* len
288) {
289API_IMPL_BEGIN
290*name = model->UseModelInfo()->name_.c_str();
291*len = model->UseModelInfo()->name_.size();
292return nullptr;
293API_IMPL_END
294}
295
296ORT_API_STATUS_IMPL(winmla::ModelSetName, _In_ const OrtModel* model, _In_ const char* const name) {
297API_IMPL_BEGIN
298auto model_proto = model->UseModelProto();
299ONNX_NAMESPACE::GraphProto& graph = *model_proto->mutable_graph();
300graph.set_name(name);
301return nullptr;
302API_IMPL_END
303}
304
305ORT_API_STATUS_IMPL(
306winmla::ModelGetDomain, _In_ const OrtModel* model, _Out_ const char** const domain, _Out_ size_t* len
307) {
308API_IMPL_BEGIN
309*domain = model->UseModelInfo()->domain_.c_str();
310*len = model->UseModelInfo()->domain_.size();
311return nullptr;
312API_IMPL_END
313}
314
315ORT_API_STATUS_IMPL(
316winmla::ModelGetDescription, _In_ const OrtModel* model, _Out_ const char** const description, _Out_ size_t* len
317) {
318API_IMPL_BEGIN
319*description = model->UseModelInfo()->description_.c_str();
320*len = model->UseModelInfo()->description_.size();
321return nullptr;
322API_IMPL_END
323}
324
325ORT_API_STATUS_IMPL(winmla::ModelGetVersion, _In_ const OrtModel* model, _Out_ int64_t* version) {
326API_IMPL_BEGIN
327*version = model->UseModelInfo()->version_;
328return nullptr;
329API_IMPL_END
330}
331
332ORT_API_STATUS_IMPL(winmla::ModelGetMetadataCount, _In_ const OrtModel* model, _Out_ size_t* count) {
333API_IMPL_BEGIN
334*count = model->UseModelInfo()->model_metadata_.size();
335return nullptr;
336API_IMPL_END
337}
338
339ORT_API_STATUS_IMPL(
340winmla::ModelGetMetadata,
341_In_ const OrtModel* model,
342_In_ size_t count,
343_Out_ const char** const key,
344_Out_ size_t* key_len,
345_Out_ const char** const value,
346_Out_ size_t* value_len
347) {
348API_IMPL_BEGIN
349*key = model->UseModelInfo()->model_metadata_[count].first.c_str();
350*key_len = model->UseModelInfo()->model_metadata_[count].first.size();
351*value = model->UseModelInfo()->model_metadata_[count].second.c_str();
352*value_len = model->UseModelInfo()->model_metadata_[count].second.size();
353return nullptr;
354API_IMPL_END
355}
356
357ORT_API_STATUS_IMPL(winmla::ModelGetInputCount, _In_ const OrtModel* model, _Out_ size_t* count) {
358API_IMPL_BEGIN
359*count = model->UseModelInfo()->input_features_.size();
360return nullptr;
361API_IMPL_END
362}
363
364ORT_API_STATUS_IMPL(winmla::ModelGetOutputCount, _In_ const OrtModel* model, _Out_ size_t* count) {
365API_IMPL_BEGIN
366*count = model->UseModelInfo()->output_features_.size();
367return nullptr;
368API_IMPL_END
369}
370
371ORT_API_STATUS_IMPL(
372winmla::ModelGetInputName,
373_In_ const OrtModel* model,
374_In_ size_t index,
375_Out_ const char** input_name,
376_Out_ size_t* count
377) {
378API_IMPL_BEGIN
379*input_name = model->UseModelInfo()->input_features_[index]->name().c_str();
380*count = model->UseModelInfo()->input_features_[index]->name().size();
381return nullptr;
382API_IMPL_END
383}
384
385ORT_API_STATUS_IMPL(
386winmla::ModelGetOutputName,
387_In_ const OrtModel* model,
388_In_ size_t index,
389_Out_ const char** output_name,
390_Out_ size_t* count
391) {
392API_IMPL_BEGIN
393*output_name = model->UseModelInfo()->output_features_[index]->name().c_str();
394*count = model->UseModelInfo()->output_features_[index]->name().size();
395return nullptr;
396API_IMPL_END
397}
398
399ORT_API_STATUS_IMPL(
400winmla::ModelGetInputDescription,
401_In_ const OrtModel* model,
402_In_ size_t index,
403_Out_ const char** input_description,
404_Out_ size_t* count
405) {
406API_IMPL_BEGIN
407*input_description = model->UseModelInfo()->input_features_[index]->doc_string().c_str();
408*count = model->UseModelInfo()->input_features_[index]->doc_string().size();
409return nullptr;
410API_IMPL_END
411}
412
413ORT_API_STATUS_IMPL(
414winmla::ModelGetOutputDescription,
415_In_ const OrtModel* model,
416_In_ size_t index,
417_Out_ const char** output_description,
418_Out_ size_t* count
419) {
420API_IMPL_BEGIN
421*output_description = model->UseModelInfo()->output_features_[index]->doc_string().c_str();
422*count = model->UseModelInfo()->output_features_[index]->doc_string().size();
423return nullptr;
424API_IMPL_END
425}
426
427ORT_API_STATUS_IMPL(
428winmla::ModelGetInputTypeInfo, _In_ const OrtModel* model, _In_ size_t index, _Outptr_ OrtTypeInfo** type_info
429) {
430API_IMPL_BEGIN
431auto info = OrtTypeInfo::FromTypeProto(model->UseModelInfo()->input_features_[index]->type());
432*type_info = info.release();
433return nullptr;
434API_IMPL_END
435}
436
437ORT_API_STATUS_IMPL(
438winmla::ModelGetOutputTypeInfo, _In_ const OrtModel* model, _In_ size_t index, _Outptr_ OrtTypeInfo** type_info
439) {
440API_IMPL_BEGIN
441auto info = OrtTypeInfo::FromTypeProto(model->UseModelInfo()->output_features_[index]->type());
442*type_info = info.release();
443return nullptr;
444API_IMPL_END
445}
446
447ORT_API_STATUS_IMPL(winmla::ModelEnsureNoFloat16, _In_ const OrtModel* model) {
448API_IMPL_BEGIN
449auto model_info = model->UseModelInfo();
450auto model_proto = model->UseModelProto();
451auto& graph = model_proto->graph();
452
453// The model will not contain fp16 operations if:
454// 1. The model has no fp16 inputs
455// 2. The model has no fp16 initializers
456// 3. The model does not create any fp16 intermediary tensors via the Cast (to float16) operator
457// 4. The model does not have any fp16 outputs
458
459// 1. Ensure that The model has no fp16 inputs
460for (auto input : model_info->input_features_) {
461auto& type = input->type();
462if (type.value_case() == ONNX_NAMESPACE::TypeProto::kTensorType) {
463auto& tensor_type = type.tensor_type();
464if (tensor_type.elem_type() == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_BFLOAT16) {
465std::stringstream error_message;
466error_message << "The model contains a 16-bit input (" << input->name()
467<< "), but the current device does not support 16-bit float.";
468return OrtApis::CreateStatus(ORT_INVALID_GRAPH, error_message.str().c_str());
469}
470}
471}
472
473// 2. Ensure that the model has no fp16 initializers
474for (int i = 0; i < graph.node_size(); i++) {
475auto node = graph.node(i);
476if (node.op_type() == "Cast" && node.domain().empty()) {
477for (int attribIndex = 0; attribIndex < node.attribute_size(); attribIndex++) {
478auto attribute = node.attribute(attribIndex);
479if (attribute.name() == "to") {
480if (attribute.i() == ONNX_NAMESPACE::TensorProto::DataType::TensorProto_DataType_FLOAT16) {
481std::stringstream error_message;
482error_message << "The model contains a 16-bit input (" << node.name().c_str()
483<< "), but the current device does not support 16-bit float.";
484return OrtApis::CreateStatus(ORT_INVALID_GRAPH, error_message.str().c_str());
485}
486}
487}
488}
489}
490
491// 3. Ensure that the model does not create any fp16 intermediary
492// tensors via the Cast (to float16) operator
493for (int i = 0; i < graph.initializer_size(); i++) {
494auto initializer = graph.initializer(i);
495if (initializer.data_type() == ONNX_NAMESPACE::TensorProto::DataType::TensorProto_DataType_FLOAT16) {
496std::stringstream error_message;
497error_message << "The model contains a 16-bit input (" << initializer.name().c_str()
498<< "), but the current device does not support 16-bit float.";
499return OrtApis::CreateStatus(ORT_INVALID_GRAPH, error_message.str().c_str());
500}
501}
502
503// 4. Ensure that the model does not have any fp16 outputs
504for (auto output : model_info->output_features_) {
505auto& type = output->type();
506if (type.value_case() == ONNX_NAMESPACE::TypeProto::kTensorType) {
507auto& tensor_type = type.tensor_type();
508if (tensor_type.elem_type() == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_BFLOAT16) {
509std::stringstream error_message;
510error_message << "The model contains a 16-bit input (" << output->name()
511<< "), but the current device does not support 16-bit float.";
512return OrtApis::CreateStatus(ORT_INVALID_GRAPH, error_message.str().c_str());
513}
514}
515}
516return nullptr;
517API_IMPL_END
518}
519
520ORT_API_STATUS_IMPL(winmla::CreateModel, _In_ int64_t opset, _Outptr_ OrtModel** out) {
521API_IMPL_BEGIN
522return OrtModel::CreateEmptyModel(opset, out);
523API_IMPL_END
524}
525
526static ONNX_NAMESPACE::TensorProto_DataType ONNXTensorElementDataTypeToTensorProto_DataType(
527ONNXTensorElementDataType type
528) {
529switch (type) {
530case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT:
531return ONNX_NAMESPACE::TensorProto_DataType_FLOAT;
532case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8:
533return ONNX_NAMESPACE::TensorProto_DataType_UINT8;
534case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8:
535return ONNX_NAMESPACE::TensorProto_DataType_INT8;
536case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16:
537return ONNX_NAMESPACE::TensorProto_DataType_UINT16;
538case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16:
539return ONNX_NAMESPACE::TensorProto_DataType_INT16;
540case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32:
541return ONNX_NAMESPACE::TensorProto_DataType_INT32;
542case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64:
543return ONNX_NAMESPACE::TensorProto_DataType_INT64;
544case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING:
545return ONNX_NAMESPACE::TensorProto_DataType_STRING;
546case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL:
547return ONNX_NAMESPACE::TensorProto_DataType_BOOL;
548case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16:
549return ONNX_NAMESPACE::TensorProto_DataType_FLOAT16;
550case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE:
551return ONNX_NAMESPACE::TensorProto_DataType_DOUBLE;
552case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32:
553return ONNX_NAMESPACE::TensorProto_DataType_UINT32;
554case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64:
555return ONNX_NAMESPACE::TensorProto_DataType_UINT64;
556default:
557return ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED;
558}
559}
560
561static void CreateTypeProto_Tensor(
562ONNX_NAMESPACE::TypeProto_Tensor* mutable_tensor_type,
563const char* const name,
564const int64_t* shape,
565size_t shape_len,
566ONNX_NAMESPACE::TensorProto_DataType data_type
567) {
568mutable_tensor_type->set_elem_type(data_type);
569
570size_t dim_param = 0;
571for (size_t i = 0; i < shape_len; i++) {
572if (shape[i] == -1) {
573std::ostringstream str;
574str << name << dim_param++;
575mutable_tensor_type->mutable_shape()->add_dim()->set_dim_param(str.str().c_str(), 1);
576} else {
577mutable_tensor_type->mutable_shape()->add_dim()->set_dim_value(shape[i]);
578}
579}
580
581if (shape_len > 0) {
582mutable_tensor_type->mutable_shape()->mutable_dim(0)->set_denotation("DATA_BATCH");
583}
584}
585
586ORT_API_STATUS_IMPL(
587winmla::ModelAddInput, _In_ OrtModel* model, _In_ const char* const input_name, _In_ OrtTypeInfo* info
588) {
589API_IMPL_BEGIN
590auto model_proto = model->UseModelProto();
591ONNX_NAMESPACE::GraphProto& graph = *model_proto->mutable_graph();
592ONNX_NAMESPACE::ValueInfoProto& input = *graph.add_input();
593input.set_name(input_name);
594
595if (info->type == ONNXType::ONNX_TYPE_TENSOR) {
596auto num_dims = info->data->shape.NumDimensions();
597CreateTypeProto_Tensor(
598input.mutable_type()->mutable_tensor_type(),
599input_name,
600(num_dims == 0) ? nullptr : &info->data->shape[0],
601num_dims,
602ONNXTensorElementDataTypeToTensorProto_DataType(info->data->type)
603);
604}
605return nullptr;
606API_IMPL_END
607}
608
609ORT_API_STATUS_IMPL(
610winmla::ModelAddConstantInput,
611_In_ OrtModel* model,
612_In_ const char* const input_name,
613_In_ OrtTypeInfo* info,
614_In_ OrtValue* value
615) {
616API_IMPL_BEGIN
617auto model_proto = model->UseModelProto();
618ONNX_NAMESPACE::GraphProto& graph = *model_proto->mutable_graph();
619ONNX_NAMESPACE::TensorProto& input = *graph.add_initializer();
620input.set_name(input_name);
621
622auto num_dims = info->data->shape.NumDimensions();
623for (size_t i = 0; i < num_dims; i++) {
624input.add_dims(info->data->shape[i]);
625}
626
627input.set_data_type(ONNXTensorElementDataTypeToTensorProto_DataType(info->data->type));
628auto tensor = value->GetMutable<onnxruntime::Tensor>();
629input.set_raw_data(tensor->DataRaw(), tensor->SizeInBytes());
630
631return nullptr;
632API_IMPL_END
633}
634
635ORT_API_STATUS_IMPL(
636winmla::ModelAddOutput, _In_ OrtModel* model, _In_ const char* const output_name, _In_ OrtTypeInfo* info
637) {
638API_IMPL_BEGIN
639auto model_proto = model->UseModelProto();
640ONNX_NAMESPACE::GraphProto& graph = *model_proto->mutable_graph();
641ONNX_NAMESPACE::ValueInfoProto& output = *graph.add_output();
642output.set_name(output_name);
643
644if (info->type == ONNXType::ONNX_TYPE_TENSOR) {
645CreateTypeProto_Tensor(
646output.mutable_type()->mutable_tensor_type(),
647output_name,
648&info->data->shape[0],
649info->data->shape.NumDimensions(),
650ONNXTensorElementDataTypeToTensorProto_DataType(info->data->type)
651);
652}
653return nullptr;
654API_IMPL_END
655}
656
657static const onnx::OpSchema* GetSchema(const char* const op_type, int64_t opset, const char* const op_domain) {
658std::string domain = onnx::ONNX_DOMAIN;
659if (op_domain) {
660domain = op_domain;
661}
662
663auto registry = ONNX_NAMESPACE::OpSchemaRegistry::Instance();
664return registry->GetSchema(op_type, static_cast<int>(opset), domain);
665}
666
667ORT_API_STATUS_IMPL(
668winmla::ModelAddOperator,
669_In_ OrtModel* model,
670_In_ const char* const op_type,
671_In_ const char* const op_name,
672_In_ int64_t opset,
673_In_ const char* const op_domain,
674_In_ const char* const* input_names,
675_In_ size_t num_inputs,
676_In_ const char* const* output_names,
677_In_ size_t num_outputs,
678_In_ const char* const* attribute_names,
679_In_ OrtValue** attribute_values,
680_In_ size_t num_attributes
681) {
682API_IMPL_BEGIN
683auto model_proto = model->UseModelProto();
684ONNX_NAMESPACE::GraphProto& graph = *model_proto->mutable_graph();
685onnx::NodeProto& node = *graph.add_node();
686node.set_op_type(op_type);
687node.set_name(op_name);
688node.set_domain(op_domain);
689
690auto schema = GetSchema(op_type, opset, op_domain);
691auto all_attributes = schema->attributes();
692
693for (size_t i = 0; i < num_attributes; i++) {
694auto tensor = attribute_values[i]->GetMutable<onnxruntime::Tensor>();
695
696auto attr = node.add_attribute();
697attr->set_name(attribute_names[i]);
698auto& schema_attribute_definition = all_attributes.at(attribute_names[i]);
699attr->set_type(schema_attribute_definition.type);
700
701switch (schema_attribute_definition.type) {
702case onnx::AttributeProto_AttributeType_INT: {
703if (tensor->Shape().Size() != 1) {
704return OrtApis::CreateStatus(ORT_ENGINE_ERROR, "Expected a single int64 value!");
705}
706auto raw_data = tensor->DataRaw();
707attr->set_i(*reinterpret_cast<const int64_t*>(raw_data));
708break;
709}
710case onnx::AttributeProto_AttributeType_FLOAT: {
711if (tensor->Shape().Size() != 1) {
712return OrtApis::CreateStatus(ORT_ENGINE_ERROR, "Expected a single float value!");
713}
714auto raw_data = tensor->DataRaw();
715attr->set_f(*reinterpret_cast<const float*>(raw_data));
716break;
717}
718case onnx::AttributeProto_AttributeType_STRING: {
719if (tensor->Shape().Size() != 1) {
720return OrtApis::CreateStatus(ORT_ENGINE_ERROR, "Expected a single string value!");
721}
722auto raw_data = tensor->DataRaw();
723attr->set_s(*reinterpret_cast<const std::string*>(raw_data));
724break;
725}
726case onnx::AttributeProto_AttributeType_INTS: {
727auto raw_data = tensor->DataRaw();
728for (int j = 0; j < tensor->Shape().Size(); j++) {
729attr->add_ints(*(reinterpret_cast<const int64_t*>(raw_data) + j));
730}
731break;
732}
733case onnx::AttributeProto_AttributeType_FLOATS: {
734auto raw_data = tensor->DataRaw();
735for (int j = 0; j < tensor->Shape().Size(); j++) {
736attr->add_floats(*(reinterpret_cast<const float*>(raw_data) + j));
737}
738break;
739}
740case onnx::AttributeProto_AttributeType_TENSOR: {
741auto tensor_proto = attr->add_tensors();
742auto prim_type = tensor->DataType()->AsPrimitiveDataType();
743if (prim_type == nullptr) {
744return OrtApis::CreateStatus(ORT_ENGINE_ERROR, "Undefined tensor type!");
745}
746tensor_proto->set_data_type(prim_type->GetDataType());
747tensor_proto->set_raw_data(tensor->DataRaw(), tensor->SizeInBytes());
748break;
749}
750}
751}
752
753for (size_t i = 0; i < num_inputs; i++) {
754auto name = input_names[i];
755if (name != nullptr) {
756node.add_input(name);
757} else {
758node.add_input();
759}
760}
761
762for (size_t i = 0; i < num_outputs; i++) {
763auto name = output_names[i];
764if (name != nullptr) {
765node.add_output(name);
766} else {
767node.add_output("unused");
768}
769}
770return nullptr;
771API_IMPL_END
772}
773
774ORT_API_STATUS_IMPL(
775winmla::ModelGetOpsetVersion, _In_ OrtModel* model, _In_ const char* const domain, _Out_ int32_t* version
776) {
777API_IMPL_BEGIN
778auto model_proto = model->UseModelProto();
779
780*version = -1;
781auto size = static_cast<int>(model_proto->opset_import_size());
782for (int i = 0; i < size; i++) {
783auto& current_opset = model_proto->opset_import(i);
784auto& current_domain = current_opset.domain();
785if (_strnicmp(domain, current_domain.c_str(), current_domain.size()) == 0) {
786*version = static_cast<int32_t>(current_opset.version());
787break;
788}
789}
790
791return nullptr;
792API_IMPL_END
793}
794
795ORT_API(void, winmla::ReleaseModel, OrtModel* ptr) {
796delete ptr;
797}
798
799#include "core/framework/onnxruntime_typeinfo.h"
800#include "core/framework/tensor_type_and_shape.h"
801
802ORT_API_STATUS_IMPL(
803winmla::CreateTensorTypeInfo,
804_In_ const int64_t* dim_values,
805size_t dim_count,
806ONNXTensorElementDataType type,
807_Out_ OrtTypeInfo** ort_type_info
808) {
809API_IMPL_BEGIN
810auto tensor_shape = onnxruntime::TensorShape(dim_values, dim_count);
811auto type_and_shape = OrtTensorTypeAndShapeInfo::GetTensorShapeAndTypeHelper(type, std::move(tensor_shape), nullptr);
812*ort_type_info = OrtTypeInfo::MakePtr(ONNX_TYPE_TENSOR, std::move(type_and_shape)).release();
813return nullptr;
814API_IMPL_END
815}
816
817ORT_API_STATUS_IMPL(winmla::CreateSequenceTypeInfo, _Out_ OrtTypeInfo** type_info) {
818API_IMPL_BEGIN
819return nullptr;
820API_IMPL_END
821}
822
823ORT_API_STATUS_IMPL(winmla::CreateMapTypeInfo, _Out_ OrtTypeInfo** type_info) {
824API_IMPL_BEGIN
825return nullptr;
826API_IMPL_END
827}
828
829ORT_API_STATUS_IMPL(
830winmla::OperatorGetNumInputs,
831_In_ const char* const op_type,
832_In_ int64_t opset,
833_In_ const char* const op_domain,
834_Out_ size_t* num_inputs
835) {
836API_IMPL_BEGIN
837auto schema = GetSchema(op_type, opset, op_domain);
838*num_inputs = schema->inputs().size();
839return nullptr;
840API_IMPL_END
841}
842
843ORT_API_STATUS_IMPL(
844winmla::OperatorGetInputName,
845_In_ const char* const op_type,
846_In_ int64_t opset,
847_In_ const char* const op_domain,
848_In_ size_t index,
849_Out_ const char** const name
850) {
851API_IMPL_BEGIN
852auto schema = GetSchema(op_type, opset, op_domain);
853*name = schema->inputs().at(index).GetName().c_str();
854return nullptr;
855API_IMPL_END
856}
857
858ORT_API_STATUS_IMPL(
859winmla::OperatorGetNumOutputs,
860_In_ const char* const op_type,
861_In_ int64_t opset,
862_In_ const char* const op_domain,
863_Out_ size_t* num_outputs
864) {
865API_IMPL_BEGIN
866auto schema = GetSchema(op_type, opset, op_domain);
867*num_outputs = schema->outputs().size();
868return nullptr;
869API_IMPL_END
870}
871
872ORT_API_STATUS_IMPL(
873winmla::OperatorGetOutputName,
874_In_ const char* const op_type,
875_In_ int64_t opset,
876_In_ const char* const op_domain,
877_In_ size_t index,
878_Out_ const char** const name
879) {
880API_IMPL_BEGIN
881auto schema = GetSchema(op_type, opset, op_domain);
882*name = schema->outputs().at(index).GetName().c_str();
883return nullptr;
884API_IMPL_END
885}
886#include "core/platform/threadpool.h"
887#include "core/platform/env.h"
888
889ORT_API_STATUS_IMPL(
890winmla::CreateThreadPool, _In_ ThreadPoolType type, _In_ OrtThreadPoolOptions* options, _Outptr_ OrtThreadPool** out
891) {
892API_IMPL_BEGIN
893OrtThreadPoolParams params = {};
894params.thread_pool_size = options->thread_pool_size;
895params.auto_set_affinity = options->auto_set_affinity;
896params.allow_spinning = options->allow_spinning;
897params.dynamic_block_base_ = options->dynamic_block_base_;
898params.stack_size = options->stack_size;
899params.name = options->name;
900params.set_denormal_as_zero = options->set_denormal_as_zero;
901
902auto unique_tp = onnxruntime::concurrency::CreateThreadPool(
903&onnxruntime::Env::Default(), params, (onnxruntime::concurrency::ThreadPoolType)type
904);
905*out = reinterpret_cast<OrtThreadPool*>(unique_tp.release());
906return nullptr;
907API_IMPL_END
908}
909
910ORT_API(void, winmla::ReleaseThreadPool, OrtThreadPool* ptr) {
911delete reinterpret_cast<onnxruntime::concurrency::ThreadPool*>(ptr);
912}
913
914ORT_API_STATUS_IMPL(
915winmla::JoinModels,
916_In_ OrtModel* first_model,
917_In_ OrtModel* second_model,
918_In_ const char* const* output_names,
919_In_ const char* const* input_names,
920size_t num_linkages,
921bool promote_unlinked_outputs,
922_In_ const char* const join_node_prefix
923) {
924API_IMPL_BEGIN
925
926std::string second_model_prefix = join_node_prefix;
927auto first_model_proto = first_model->UseModelProto();
928auto second_model_proto = second_model->DetachModelProto();
929
930// Remove old outputs
931if (promote_unlinked_outputs) {
932// Copy the output of the first model
933auto first_outputs = first_model_proto->graph().output();
934
935// Clear all outputs
936first_model_proto->mutable_graph()->mutable_output()->Clear();
937
938// Add back output
939for (int i = first_outputs.size() - 1; i >= 0; i--) {
940auto& output = first_outputs.at(i);
941auto output_name = output.name();
942
943auto found_it = std::find_if(output_names, output_names + num_linkages, [output_name](auto& name) {
944return std::strcmp(name, output_name.c_str()) == 0;
945});
946if (found_it == (output_names + num_linkages)) {
947// if output.name() is not found in the linkages, it is unlinked, and it should be promoted
948auto& promoted_output = *first_model_proto->mutable_graph()->add_output();
949promoted_output = std::move(output);
950}
951}
952} else {
953// remove all first model outputs
954first_model_proto->mutable_graph()->mutable_output()->Clear();
955}
956
957// add all model outputs from the second model
958for (int i = 0; i < second_model_proto->graph().output_size(); i++) {
959auto& other_output = *second_model_proto->mutable_graph()->mutable_output(i);
960*other_output.mutable_name() = second_model_prefix + other_output.name();
961auto& output = *first_model_proto->mutable_graph()->add_output();
962output = std::move(other_output);
963}
964
965// loop through second model inputs and promote the unlinked ones to the main model inputs
966for (int i = 0; i < second_model_proto->graph().input_size(); i++) {
967auto& other_input = *second_model_proto->mutable_graph()->mutable_input(i);
968auto old_name = other_input.name();
969*other_input.mutable_name() = second_model_prefix + old_name;
970
971auto found_it = std::find_if(input_names, input_names + num_linkages, [old_name](auto& name) {
972return std::strcmp(name, old_name.c_str()) == 0;
973});
974bool is_linked =
975found_it != (input_names + num_linkages); // figure out if other_input.name() exists in the output_names mapped
976if (!is_linked) {
977auto& input = *first_model_proto->mutable_graph()->add_input();
978input = std::move(other_input);
979}
980}
981
982// add all initializers
983for (int i = 0; i < second_model_proto->graph().initializer_size(); i++) {
984auto& other_initializer = *second_model_proto->mutable_graph()->mutable_initializer(i);
985*other_initializer.mutable_name() = second_model_prefix + other_initializer.name();
986auto& initializer = *first_model_proto->mutable_graph()->add_initializer();
987initializer = std::move(other_initializer);
988}
989
990// add all nodes
991for (int i = 0; i < second_model_proto->graph().node_size(); i++) {
992auto& other_node = *second_model_proto->mutable_graph()->mutable_node(i);
993if (0 != strcmp(other_node.name().c_str(), "")) {
994*other_node.mutable_name() = second_model_prefix + other_node.name();
995}
996for (int j = 0; j < other_node.input_size(); j++) {
997*other_node.mutable_input(j) = second_model_prefix + other_node.input(j);
998}
999for (int j = 0; j < other_node.output_size(); j++) {
1000*other_node.mutable_output(j) = second_model_prefix + other_node.output(j);
1001}
1002auto& node = *first_model_proto->mutable_graph()->add_node();
1003node = std::move(other_node);
1004}
1005
1006// WinML+RT API only supports opset 7 and above models.
1007// In practice this number is always overwritten by the for loop below which will find the actual opset version.
1008int64_t opset = 7;
1009for (int i = 0; i < second_model_proto->opset_import_size(); i++) {
1010auto mutable_opset_import = second_model_proto->mutable_opset_import(i);
1011auto domain = mutable_opset_import->has_domain() ? mutable_opset_import->domain() : std::string("");
1012auto version = mutable_opset_import->version();
1013
1014// does the domain exist in the first model?
1015auto found_it = std::find_if(
1016first_model_proto->mutable_opset_import()->begin(),
1017first_model_proto->mutable_opset_import()->end(),
1018[&domain](auto& mutable_opset_import) {
1019auto first_model_domain = mutable_opset_import.has_domain() ? mutable_opset_import.domain() : std::string("");
1020return 0 == strcmp(first_model_domain.c_str(), domain.c_str());
1021}
1022);
1023if (found_it != first_model_proto->mutable_opset_import()->end()) {
1024found_it->set_version(std::max(found_it->version(), version));
1025if (0 == strcmp(domain.c_str(), "")) {
1026opset = found_it->version();
1027}
1028}
1029}
1030
1031// add identity ops to rename all of the first model outputs to secondmodel inputs with prefix for each linkage
1032for (size_t i = 0; i < num_linkages; i++) {
1033auto op_output_name = second_model_prefix + *(input_names + i);
1034const char* const op_output_name_const_str = op_output_name.c_str();
1035std::string name = "IdentityTo";
1036name += second_model_prefix + *(input_names + i);
1037ModelAddOperator(
1038first_model,
1039"Identity",
1040name.c_str(),
1041opset,
1042"",
1043(output_names + i),
10441,
1045&op_output_name_const_str,
10461,
1047nullptr,
1048nullptr,
10490
1050);
1051}
1052first_model->RefreshModelInfo();
1053
1054return nullptr;
1055API_IMPL_END
1056}
1057