onnxruntime

Форк
0
/
winml_adapter_model.cpp 
1056 строк · 35.5 Кб
1
// Copyright (c) Microsoft Corporation. All rights reserved.
2
// Licensed under the MIT License.
3

4
#pragma once
5
#include "adapter/pch.h"
6

7
#include "winml_adapter_model.h"
8

9
#include "winml_adapter_c_api.h"
10
#include "core/graph/onnx_protobuf.h"
11
#include "core/session/ort_apis.h"
12
#include "winml_adapter_apis.h"
13
#include "core/framework/error_code_helper.h"
14
#include "core/common/common.h"
15

16
#include <io.h>
17
#include <fcntl.h>
18
#include "google/protobuf/io/zero_copy_stream_impl.h"
19
#include "core/framework/onnxruntime_typeinfo.h"
20

21
#include "onnx/defs/schema.h"
22
#include "core/framework/tensor_type_and_shape.h"
23

24
#include "onnx/onnx-ml.pb.h"
25

26
namespace winmla = Windows::AI::MachineLearning::Adapter;
27

28
static std::vector<const char*> GetInitializers(const ONNX_NAMESPACE::ModelProto& model_proto) {
29
  std::vector<const char*> initializers;
30
  auto& graph = model_proto.graph();
31
  auto& graph_initializers = graph.initializer();
32
  for (auto& initializer : graph_initializers) {
33
    initializers.push_back(initializer.name().c_str());
34
  }
35
  return initializers;
36
}
37

38
static std::vector<const ONNX_NAMESPACE::ValueInfoProto*> GetInputsWithoutInitializers(
39
  const ONNX_NAMESPACE::ModelProto& model_proto
40
) {
41
  auto initializers = GetInitializers(model_proto);
42

43
  std::vector<const ONNX_NAMESPACE::ValueInfoProto*> inputs_without_initializers;
44
  auto& graph = model_proto.graph();
45
  auto& inputs = graph.input();
46
  for (auto& input : inputs) {
47
    if (input.has_name() && input.has_type()) {
48
      auto found_it = std::find_if(std::begin(initializers), std::end(initializers), [&](auto& initializer) {
49
        return std::strcmp(initializer, input.name().c_str()) == 0;
50
      });
51

52
      auto is_initializer = found_it != std::end(initializers);
53
      if (!is_initializer) {
54
        inputs_without_initializers.push_back(&input);
55
      }
56
    }
57
  }
58
  return inputs_without_initializers;
59
}
60

61
static std::vector<const ONNX_NAMESPACE::ValueInfoProto*> GetOutputs(const ONNX_NAMESPACE::ModelProto& model_proto) {
62
  std::vector<const ONNX_NAMESPACE::ValueInfoProto*> outputs_with_name;
63
  auto& graph = model_proto.graph();
64
  auto& outputs = graph.output();
65
  for (auto& output : outputs) {
66
    if (output.has_name() && output.has_type()) {
67
      outputs_with_name.push_back(&output);
68
    }
69
  }
70
  return outputs_with_name;
71
}
72

73
class ModelInfo {
74
 public:
75
  ModelInfo(const ONNX_NAMESPACE::ModelProto* model_proto) { Initialize(model_proto); }
76

77
 public:
78
  // model metadata
79
  std::string author_;
80
  std::string name_;
81
  std::string domain_;
82
  std::string description_;
83
  int64_t version_;
84
  std::vector<std::pair<std::string, std::string>> model_metadata_;
85
  std::vector<const ONNX_NAMESPACE::ValueInfoProto*> input_features_;
86
  std::vector<const ONNX_NAMESPACE::ValueInfoProto*> output_features_;
87
  bool requires_float16_support_;
88

89
 private:
90
  void Initialize(const ONNX_NAMESPACE::ModelProto* model_proto) {
91
    for (auto& prop : model_proto->metadata_props()) {
92
      model_metadata_.push_back(std::make_pair(prop.key(), prop.value()));
93
    }
94

95
    input_features_ = GetInputsWithoutInitializers(*model_proto);
96
    output_features_ = ::GetOutputs(*model_proto);
97

98
    auto has_producer_name = model_proto->has_producer_name();
99
    author_ = has_producer_name ? model_proto->producer_name() : "";
100

101
    auto has_domain = model_proto->has_domain();
102
    domain_ = has_domain ? model_proto->domain() : "";
103

104
    auto has_graph = model_proto->has_graph();
105
    auto graph_has_name = model_proto->graph().has_name();
106
    auto is_name_available = has_graph && graph_has_name;
107
    name_ = is_name_available ? model_proto->graph().name() : "";
108

109
    auto has_description = model_proto->has_doc_string();
110
    description_ = has_description ? model_proto->doc_string() : "";
111

112
    auto has_version = model_proto->has_model_version();
113
    version_ = has_version ? model_proto->model_version() : 0;
114
  }
115
};
116

117
OrtModel::OrtModel(std::unique_ptr<ONNX_NAMESPACE::ModelProto> model_proto)
118
  : model_proto_(std::move(model_proto)),
119
    model_info_(std::make_unique<ModelInfo>(model_proto_.get())) {
120
}
121

122
// factory methods for creating an ort model from a path
123
static OrtStatus* CreateModelProto(const char* path, std::unique_ptr<ONNX_NAMESPACE::ModelProto>& out) {
124
  int file_descriptor;
125

126
  auto path_str = std::string(path);
127
  auto wide_path = onnxruntime::ToWideString(path_str);
128

129
  _set_errno(0);  // clear errno
130
  _wsopen_s(
131
    &file_descriptor, wide_path.c_str(), O_RDONLY | _O_SEQUENTIAL | _O_BINARY, _SH_DENYWR, _S_IREAD | _S_IWRITE
132
  );
133

134
  errno_t err = 0;
135
  _get_errno(&err);
136
  if (err == ENOENT) {
137
    return OrtApis::CreateStatus(ORT_NO_SUCHFILE, "Model file not found!");
138
  }
139

140
  if (0 > file_descriptor) {
141
    return OrtApis::CreateStatus(ORT_NO_SUCHFILE, "Model file not found!");
142
  }
143

144
  google::protobuf::io::FileInputStream stream(file_descriptor);
145
  stream.SetCloseOnDelete(true);
146

147
  auto model_proto = std::unique_ptr<ONNX_NAMESPACE::ModelProto>(new ONNX_NAMESPACE::ModelProto());
148

149
  auto parse_succeeded = model_proto->ParseFromZeroCopyStream(&stream);
150
  if (!parse_succeeded) {
151
    return OrtApis::CreateStatus(ORT_INVALID_PROTOBUF, "Failed to parse model file!");
152
  }
153

154
  out = std::move(model_proto);
155

156
  return S_OK;
157
}
158

159
OrtStatus* OrtModel::CreateEmptyModel(int64_t opset, OrtModel** model) {
160
  auto model_proto = std::unique_ptr<ONNX_NAMESPACE::ModelProto>(new ONNX_NAMESPACE::ModelProto());
161
  auto opsetimportproto = model_proto->add_opset_import();
162
  opsetimportproto->set_version(opset);
163
  model_proto->set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION);
164
  return OrtModel::CreateOrtModelFromProto(std::move(model_proto), model);
165
}
166

167
OrtStatus* OrtModel::CreateOrtModelFromPath(const char* path, size_t len, OrtModel** model) {
168
  ORT_UNUSED_PARAMETER(len);
169

170
  std::unique_ptr<ONNX_NAMESPACE::ModelProto> model_proto;
171

172
  if (auto status = CreateModelProto(path, model_proto)) {
173
    return status;
174
  }
175

176
  return OrtModel::CreateOrtModelFromProto(std::move(model_proto), model);
177
}
178

179
OrtStatus* OrtModel::CreateOrtModelFromData(void* data, size_t len, OrtModel** model) {
180
  auto model_proto = std::unique_ptr<ONNX_NAMESPACE::ModelProto>(new ONNX_NAMESPACE::ModelProto());
181

182
  auto parse_succeeded = model_proto->ParseFromArray(data, static_cast<int>(len));
183
  if (!parse_succeeded) {
184
    return OrtApis::CreateStatus(ORT_INVALID_PROTOBUF, "Failed to parse model stream!");
185
  }
186

187
  return OrtModel::CreateOrtModelFromProto(std::move(model_proto), model);
188
}
189

190
OrtStatus* OrtModel::CreateOrtModelFromProto(
191
  std::unique_ptr<ONNX_NAMESPACE::ModelProto>&& model_proto, OrtModel** model
192
) {
193
  *model = new (std::nothrow) OrtModel(std::move(model_proto));
194
  if (*model == nullptr) {
195
    return OrtApis::CreateStatus(ORT_ENGINE_ERROR, "Engine failed to create a model!");
196
  }
197

198
  return nullptr;
199
}
200

201
const ModelInfo* OrtModel::UseModelInfo() const {
202
  return model_info_.get();
203
}
204

205
ONNX_NAMESPACE::ModelProto* OrtModel::UseModelProto() const {
206
  return model_proto_.get();
207
}
208

209
std::unique_ptr<ONNX_NAMESPACE::ModelProto> OrtModel::DetachModelProto() {
210
  return std::move(model_proto_);
211
}
212

213
void OrtModel::RefreshModelInfo() {
214
  auto new_info = std::make_unique<ModelInfo>(model_proto_.get());
215
  model_info_->author_ = std::move(new_info->author_);
216
  model_info_->description_ = std::move(new_info->description_);
217
  model_info_->domain_ = std::move(new_info->domain_);
218
  model_info_->input_features_ = std::move(new_info->input_features_);
219
  model_info_->model_metadata_ = std::move(new_info->model_metadata_);
220
  model_info_->name_ = std::move(new_info->name_);
221
  model_info_->output_features_ = std::move(new_info->output_features_);
222
  model_info_->requires_float16_support_ = std::move(new_info->requires_float16_support_);
223
  model_info_->version_ = std::move(new_info->version_);
224
}
225

226
ORT_API_STATUS_IMPL(
227
  winmla::CreateModelFromPath, _In_ const char* model_path, _In_ size_t size, _Outptr_ OrtModel** out
228
) {
229
  API_IMPL_BEGIN
230
  if (auto status = OrtModel::CreateOrtModelFromPath(model_path, size, out)) {
231
    return status;
232
  }
233
  return nullptr;
234
  API_IMPL_END
235
}
236

237
ORT_API_STATUS_IMPL(winmla::CreateModelFromData, _In_opt_ void* data, _In_ size_t size, _Outptr_ OrtModel** out) {
238
  API_IMPL_BEGIN
239
  if (auto status = OrtModel::CreateOrtModelFromData(data, size, out)) {
240
    return status;
241
  }
242
  return nullptr;
243
  API_IMPL_END
244
}
245

246
ORT_API_STATUS_IMPL(winmla::CloneModel, _In_ const OrtModel* in, _Outptr_ OrtModel** out) {
247
  API_IMPL_BEGIN
248
  auto model_proto_copy = std::make_unique<ONNX_NAMESPACE::ModelProto>(*in->UseModelProto());
249
  if (auto status = OrtModel::CreateOrtModelFromProto(std::move(model_proto_copy), out)) {
250
    return status;
251
  }
252
  return nullptr;
253
  API_IMPL_END
254
}
255

256
ORT_API_STATUS_IMPL(winmla::SaveModel, _In_ const OrtModel* in, _In_ const wchar_t* const file_name, _In_ size_t len) {
257
  API_IMPL_BEGIN
258
  int fd;
259
  std::wstring file_path = file_name;
260
  Status status = onnxruntime::Env::Default().FileOpenWr(file_path, fd);
261
  if (fd < 0) {
262
    return OrtApis::CreateStatus(ORT_NO_SUCHFILE, "File not found!");
263
  }
264

265
  auto model_proto = in->UseModelProto();
266
  google::protobuf::io::FileOutputStream output(fd);
267
  const bool success = model_proto->SerializeToZeroCopyStream(&output) && output.Flush();
268
  if (!success) {
269
    return OrtApis::CreateStatus(ORT_RUNTIME_EXCEPTION, "Failed to serialize model!");
270
  }
271
  output.Close();
272
  return nullptr;
273
  API_IMPL_END
274
}
275

276
ORT_API_STATUS_IMPL(
277
  winmla::ModelGetAuthor, _In_ const OrtModel* model, _Out_ const char** const author, _Out_ size_t* len
278
) {
279
  API_IMPL_BEGIN
280
  *author = model->UseModelInfo()->author_.c_str();
281
  *len = model->UseModelInfo()->author_.size();
282
  return nullptr;
283
  API_IMPL_END
284
}
285

286
ORT_API_STATUS_IMPL(
287
  winmla::ModelGetName, _In_ const OrtModel* model, _Out_ const char** const name, _Out_ size_t* len
288
) {
289
  API_IMPL_BEGIN
290
  *name = model->UseModelInfo()->name_.c_str();
291
  *len = model->UseModelInfo()->name_.size();
292
  return nullptr;
293
  API_IMPL_END
294
}
295

296
ORT_API_STATUS_IMPL(winmla::ModelSetName, _In_ const OrtModel* model, _In_ const char* const name) {
297
  API_IMPL_BEGIN
298
  auto model_proto = model->UseModelProto();
299
  ONNX_NAMESPACE::GraphProto& graph = *model_proto->mutable_graph();
300
  graph.set_name(name);
301
  return nullptr;
302
  API_IMPL_END
303
}
304

305
ORT_API_STATUS_IMPL(
306
  winmla::ModelGetDomain, _In_ const OrtModel* model, _Out_ const char** const domain, _Out_ size_t* len
307
) {
308
  API_IMPL_BEGIN
309
  *domain = model->UseModelInfo()->domain_.c_str();
310
  *len = model->UseModelInfo()->domain_.size();
311
  return nullptr;
312
  API_IMPL_END
313
}
314

315
ORT_API_STATUS_IMPL(
316
  winmla::ModelGetDescription, _In_ const OrtModel* model, _Out_ const char** const description, _Out_ size_t* len
317
) {
318
  API_IMPL_BEGIN
319
  *description = model->UseModelInfo()->description_.c_str();
320
  *len = model->UseModelInfo()->description_.size();
321
  return nullptr;
322
  API_IMPL_END
323
}
324

325
ORT_API_STATUS_IMPL(winmla::ModelGetVersion, _In_ const OrtModel* model, _Out_ int64_t* version) {
326
  API_IMPL_BEGIN
327
  *version = model->UseModelInfo()->version_;
328
  return nullptr;
329
  API_IMPL_END
330
}
331

332
ORT_API_STATUS_IMPL(winmla::ModelGetMetadataCount, _In_ const OrtModel* model, _Out_ size_t* count) {
333
  API_IMPL_BEGIN
334
  *count = model->UseModelInfo()->model_metadata_.size();
335
  return nullptr;
336
  API_IMPL_END
337
}
338

339
ORT_API_STATUS_IMPL(
340
  winmla::ModelGetMetadata,
341
  _In_ const OrtModel* model,
342
  _In_ size_t count,
343
  _Out_ const char** const key,
344
  _Out_ size_t* key_len,
345
  _Out_ const char** const value,
346
  _Out_ size_t* value_len
347
) {
348
  API_IMPL_BEGIN
349
  *key = model->UseModelInfo()->model_metadata_[count].first.c_str();
350
  *key_len = model->UseModelInfo()->model_metadata_[count].first.size();
351
  *value = model->UseModelInfo()->model_metadata_[count].second.c_str();
352
  *value_len = model->UseModelInfo()->model_metadata_[count].second.size();
353
  return nullptr;
354
  API_IMPL_END
355
}
356

357
ORT_API_STATUS_IMPL(winmla::ModelGetInputCount, _In_ const OrtModel* model, _Out_ size_t* count) {
358
  API_IMPL_BEGIN
359
  *count = model->UseModelInfo()->input_features_.size();
360
  return nullptr;
361
  API_IMPL_END
362
}
363

364
ORT_API_STATUS_IMPL(winmla::ModelGetOutputCount, _In_ const OrtModel* model, _Out_ size_t* count) {
365
  API_IMPL_BEGIN
366
  *count = model->UseModelInfo()->output_features_.size();
367
  return nullptr;
368
  API_IMPL_END
369
}
370

371
ORT_API_STATUS_IMPL(
372
  winmla::ModelGetInputName,
373
  _In_ const OrtModel* model,
374
  _In_ size_t index,
375
  _Out_ const char** input_name,
376
  _Out_ size_t* count
377
) {
378
  API_IMPL_BEGIN
379
  *input_name = model->UseModelInfo()->input_features_[index]->name().c_str();
380
  *count = model->UseModelInfo()->input_features_[index]->name().size();
381
  return nullptr;
382
  API_IMPL_END
383
}
384

385
ORT_API_STATUS_IMPL(
386
  winmla::ModelGetOutputName,
387
  _In_ const OrtModel* model,
388
  _In_ size_t index,
389
  _Out_ const char** output_name,
390
  _Out_ size_t* count
391
) {
392
  API_IMPL_BEGIN
393
  *output_name = model->UseModelInfo()->output_features_[index]->name().c_str();
394
  *count = model->UseModelInfo()->output_features_[index]->name().size();
395
  return nullptr;
396
  API_IMPL_END
397
}
398

399
ORT_API_STATUS_IMPL(
400
  winmla::ModelGetInputDescription,
401
  _In_ const OrtModel* model,
402
  _In_ size_t index,
403
  _Out_ const char** input_description,
404
  _Out_ size_t* count
405
) {
406
  API_IMPL_BEGIN
407
  *input_description = model->UseModelInfo()->input_features_[index]->doc_string().c_str();
408
  *count = model->UseModelInfo()->input_features_[index]->doc_string().size();
409
  return nullptr;
410
  API_IMPL_END
411
}
412

413
ORT_API_STATUS_IMPL(
414
  winmla::ModelGetOutputDescription,
415
  _In_ const OrtModel* model,
416
  _In_ size_t index,
417
  _Out_ const char** output_description,
418
  _Out_ size_t* count
419
) {
420
  API_IMPL_BEGIN
421
  *output_description = model->UseModelInfo()->output_features_[index]->doc_string().c_str();
422
  *count = model->UseModelInfo()->output_features_[index]->doc_string().size();
423
  return nullptr;
424
  API_IMPL_END
425
}
426

427
ORT_API_STATUS_IMPL(
428
  winmla::ModelGetInputTypeInfo, _In_ const OrtModel* model, _In_ size_t index, _Outptr_ OrtTypeInfo** type_info
429
) {
430
  API_IMPL_BEGIN
431
  auto info = OrtTypeInfo::FromTypeProto(model->UseModelInfo()->input_features_[index]->type());
432
  *type_info = info.release();
433
  return nullptr;
434
  API_IMPL_END
435
}
436

437
ORT_API_STATUS_IMPL(
438
  winmla::ModelGetOutputTypeInfo, _In_ const OrtModel* model, _In_ size_t index, _Outptr_ OrtTypeInfo** type_info
439
) {
440
  API_IMPL_BEGIN
441
  auto info = OrtTypeInfo::FromTypeProto(model->UseModelInfo()->output_features_[index]->type());
442
  *type_info = info.release();
443
  return nullptr;
444
  API_IMPL_END
445
}
446

447
ORT_API_STATUS_IMPL(winmla::ModelEnsureNoFloat16, _In_ const OrtModel* model) {
448
  API_IMPL_BEGIN
449
  auto model_info = model->UseModelInfo();
450
  auto model_proto = model->UseModelProto();
451
  auto& graph = model_proto->graph();
452

453
  // The model will not contain fp16 operations if:
454
  // 1. The model has no fp16 inputs
455
  // 2. The model has no fp16 initializers
456
  // 3. The model does not create any fp16 intermediary tensors via the Cast (to float16) operator
457
  // 4. The model does not have any fp16 outputs
458

459
  // 1. Ensure that The model has no fp16 inputs
460
  for (auto input : model_info->input_features_) {
461
    auto& type = input->type();
462
    if (type.value_case() == ONNX_NAMESPACE::TypeProto::kTensorType) {
463
      auto& tensor_type = type.tensor_type();
464
      if (tensor_type.elem_type() == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_BFLOAT16) {
465
        std::stringstream error_message;
466
        error_message << "The model contains a 16-bit input (" << input->name()
467
                      << "), but the current device does not support 16-bit float.";
468
        return OrtApis::CreateStatus(ORT_INVALID_GRAPH, error_message.str().c_str());
469
      }
470
    }
471
  }
472

473
  // 2. Ensure that the model has no fp16 initializers
474
  for (int i = 0; i < graph.node_size(); i++) {
475
    auto node = graph.node(i);
476
    if (node.op_type() == "Cast" && node.domain().empty()) {
477
      for (int attribIndex = 0; attribIndex < node.attribute_size(); attribIndex++) {
478
        auto attribute = node.attribute(attribIndex);
479
        if (attribute.name() == "to") {
480
          if (attribute.i() == ONNX_NAMESPACE::TensorProto::DataType::TensorProto_DataType_FLOAT16) {
481
            std::stringstream error_message;
482
            error_message << "The model contains a 16-bit input (" << node.name().c_str()
483
                          << "), but the current device does not support 16-bit float.";
484
            return OrtApis::CreateStatus(ORT_INVALID_GRAPH, error_message.str().c_str());
485
          }
486
        }
487
      }
488
    }
489
  }
490

491
  // 3. Ensure that the model does not create any fp16 intermediary
492
  //    tensors via the Cast (to float16) operator
493
  for (int i = 0; i < graph.initializer_size(); i++) {
494
    auto initializer = graph.initializer(i);
495
    if (initializer.data_type() == ONNX_NAMESPACE::TensorProto::DataType::TensorProto_DataType_FLOAT16) {
496
      std::stringstream error_message;
497
      error_message << "The model contains a 16-bit input (" << initializer.name().c_str()
498
                    << "), but the current device does not support 16-bit float.";
499
      return OrtApis::CreateStatus(ORT_INVALID_GRAPH, error_message.str().c_str());
500
    }
501
  }
502

503
  // 4. Ensure that the model does not have any fp16 outputs
504
  for (auto output : model_info->output_features_) {
505
    auto& type = output->type();
506
    if (type.value_case() == ONNX_NAMESPACE::TypeProto::kTensorType) {
507
      auto& tensor_type = type.tensor_type();
508
      if (tensor_type.elem_type() == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_BFLOAT16) {
509
        std::stringstream error_message;
510
        error_message << "The model contains a 16-bit input (" << output->name()
511
                      << "), but the current device does not support 16-bit float.";
512
        return OrtApis::CreateStatus(ORT_INVALID_GRAPH, error_message.str().c_str());
513
      }
514
    }
515
  }
516
  return nullptr;
517
  API_IMPL_END
518
}
519

520
ORT_API_STATUS_IMPL(winmla::CreateModel, _In_ int64_t opset, _Outptr_ OrtModel** out) {
521
  API_IMPL_BEGIN
522
  return OrtModel::CreateEmptyModel(opset, out);
523
  API_IMPL_END
524
}
525

526
static ONNX_NAMESPACE::TensorProto_DataType ONNXTensorElementDataTypeToTensorProto_DataType(
527
  ONNXTensorElementDataType type
528
) {
529
  switch (type) {
530
    case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT:
531
      return ONNX_NAMESPACE::TensorProto_DataType_FLOAT;
532
    case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8:
533
      return ONNX_NAMESPACE::TensorProto_DataType_UINT8;
534
    case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8:
535
      return ONNX_NAMESPACE::TensorProto_DataType_INT8;
536
    case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16:
537
      return ONNX_NAMESPACE::TensorProto_DataType_UINT16;
538
    case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16:
539
      return ONNX_NAMESPACE::TensorProto_DataType_INT16;
540
    case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32:
541
      return ONNX_NAMESPACE::TensorProto_DataType_INT32;
542
    case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64:
543
      return ONNX_NAMESPACE::TensorProto_DataType_INT64;
544
    case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING:
545
      return ONNX_NAMESPACE::TensorProto_DataType_STRING;
546
    case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL:
547
      return ONNX_NAMESPACE::TensorProto_DataType_BOOL;
548
    case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16:
549
      return ONNX_NAMESPACE::TensorProto_DataType_FLOAT16;
550
    case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE:
551
      return ONNX_NAMESPACE::TensorProto_DataType_DOUBLE;
552
    case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32:
553
      return ONNX_NAMESPACE::TensorProto_DataType_UINT32;
554
    case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64:
555
      return ONNX_NAMESPACE::TensorProto_DataType_UINT64;
556
    default:
557
      return ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED;
558
  }
559
}
560

561
static void CreateTypeProto_Tensor(
562
  ONNX_NAMESPACE::TypeProto_Tensor* mutable_tensor_type,
563
  const char* const name,
564
  const int64_t* shape,
565
  size_t shape_len,
566
  ONNX_NAMESPACE::TensorProto_DataType data_type
567
) {
568
  mutable_tensor_type->set_elem_type(data_type);
569

570
  size_t dim_param = 0;
571
  for (size_t i = 0; i < shape_len; i++) {
572
    if (shape[i] == -1) {
573
      std::ostringstream str;
574
      str << name << dim_param++;
575
      mutable_tensor_type->mutable_shape()->add_dim()->set_dim_param(str.str().c_str(), 1);
576
    } else {
577
      mutable_tensor_type->mutable_shape()->add_dim()->set_dim_value(shape[i]);
578
    }
579
  }
580

581
  if (shape_len > 0) {
582
    mutable_tensor_type->mutable_shape()->mutable_dim(0)->set_denotation("DATA_BATCH");
583
  }
584
}
585

586
ORT_API_STATUS_IMPL(
587
  winmla::ModelAddInput, _In_ OrtModel* model, _In_ const char* const input_name, _In_ OrtTypeInfo* info
588
) {
589
  API_IMPL_BEGIN
590
  auto model_proto = model->UseModelProto();
591
  ONNX_NAMESPACE::GraphProto& graph = *model_proto->mutable_graph();
592
  ONNX_NAMESPACE::ValueInfoProto& input = *graph.add_input();
593
  input.set_name(input_name);
594

595
  if (info->type == ONNXType::ONNX_TYPE_TENSOR) {
596
    auto num_dims = info->data->shape.NumDimensions();
597
    CreateTypeProto_Tensor(
598
      input.mutable_type()->mutable_tensor_type(),
599
      input_name,
600
      (num_dims == 0) ? nullptr : &info->data->shape[0],
601
      num_dims,
602
      ONNXTensorElementDataTypeToTensorProto_DataType(info->data->type)
603
    );
604
  }
605
  return nullptr;
606
  API_IMPL_END
607
}
608

609
ORT_API_STATUS_IMPL(
610
  winmla::ModelAddConstantInput,
611
  _In_ OrtModel* model,
612
  _In_ const char* const input_name,
613
  _In_ OrtTypeInfo* info,
614
  _In_ OrtValue* value
615
) {
616
  API_IMPL_BEGIN
617
  auto model_proto = model->UseModelProto();
618
  ONNX_NAMESPACE::GraphProto& graph = *model_proto->mutable_graph();
619
  ONNX_NAMESPACE::TensorProto& input = *graph.add_initializer();
620
  input.set_name(input_name);
621

622
  auto num_dims = info->data->shape.NumDimensions();
623
  for (size_t i = 0; i < num_dims; i++) {
624
    input.add_dims(info->data->shape[i]);
625
  }
626

627
  input.set_data_type(ONNXTensorElementDataTypeToTensorProto_DataType(info->data->type));
628
  auto tensor = value->GetMutable<onnxruntime::Tensor>();
629
  input.set_raw_data(tensor->DataRaw(), tensor->SizeInBytes());
630

631
  return nullptr;
632
  API_IMPL_END
633
}
634

635
ORT_API_STATUS_IMPL(
636
  winmla::ModelAddOutput, _In_ OrtModel* model, _In_ const char* const output_name, _In_ OrtTypeInfo* info
637
) {
638
  API_IMPL_BEGIN
639
  auto model_proto = model->UseModelProto();
640
  ONNX_NAMESPACE::GraphProto& graph = *model_proto->mutable_graph();
641
  ONNX_NAMESPACE::ValueInfoProto& output = *graph.add_output();
642
  output.set_name(output_name);
643

644
  if (info->type == ONNXType::ONNX_TYPE_TENSOR) {
645
    CreateTypeProto_Tensor(
646
      output.mutable_type()->mutable_tensor_type(),
647
      output_name,
648
      &info->data->shape[0],
649
      info->data->shape.NumDimensions(),
650
      ONNXTensorElementDataTypeToTensorProto_DataType(info->data->type)
651
    );
652
  }
653
  return nullptr;
654
  API_IMPL_END
655
}
656

657
static const onnx::OpSchema* GetSchema(const char* const op_type, int64_t opset, const char* const op_domain) {
658
  std::string domain = onnx::ONNX_DOMAIN;
659
  if (op_domain) {
660
    domain = op_domain;
661
  }
662

663
  auto registry = ONNX_NAMESPACE::OpSchemaRegistry::Instance();
664
  return registry->GetSchema(op_type, static_cast<int>(opset), domain);
665
}
666

667
ORT_API_STATUS_IMPL(
668
  winmla::ModelAddOperator,
669
  _In_ OrtModel* model,
670
  _In_ const char* const op_type,
671
  _In_ const char* const op_name,
672
  _In_ int64_t opset,
673
  _In_ const char* const op_domain,
674
  _In_ const char* const* input_names,
675
  _In_ size_t num_inputs,
676
  _In_ const char* const* output_names,
677
  _In_ size_t num_outputs,
678
  _In_ const char* const* attribute_names,
679
  _In_ OrtValue** attribute_values,
680
  _In_ size_t num_attributes
681
) {
682
  API_IMPL_BEGIN
683
  auto model_proto = model->UseModelProto();
684
  ONNX_NAMESPACE::GraphProto& graph = *model_proto->mutable_graph();
685
  onnx::NodeProto& node = *graph.add_node();
686
  node.set_op_type(op_type);
687
  node.set_name(op_name);
688
  node.set_domain(op_domain);
689

690
  auto schema = GetSchema(op_type, opset, op_domain);
691
  auto all_attributes = schema->attributes();
692

693
  for (size_t i = 0; i < num_attributes; i++) {
694
    auto tensor = attribute_values[i]->GetMutable<onnxruntime::Tensor>();
695

696
    auto attr = node.add_attribute();
697
    attr->set_name(attribute_names[i]);
698
    auto& schema_attribute_definition = all_attributes.at(attribute_names[i]);
699
    attr->set_type(schema_attribute_definition.type);
700

701
    switch (schema_attribute_definition.type) {
702
      case onnx::AttributeProto_AttributeType_INT: {
703
        if (tensor->Shape().Size() != 1) {
704
          return OrtApis::CreateStatus(ORT_ENGINE_ERROR, "Expected a single int64 value!");
705
        }
706
        auto raw_data = tensor->DataRaw();
707
        attr->set_i(*reinterpret_cast<const int64_t*>(raw_data));
708
        break;
709
      }
710
      case onnx::AttributeProto_AttributeType_FLOAT: {
711
        if (tensor->Shape().Size() != 1) {
712
          return OrtApis::CreateStatus(ORT_ENGINE_ERROR, "Expected a single float value!");
713
        }
714
        auto raw_data = tensor->DataRaw();
715
        attr->set_f(*reinterpret_cast<const float*>(raw_data));
716
        break;
717
      }
718
      case onnx::AttributeProto_AttributeType_STRING: {
719
        if (tensor->Shape().Size() != 1) {
720
          return OrtApis::CreateStatus(ORT_ENGINE_ERROR, "Expected a single string value!");
721
        }
722
        auto raw_data = tensor->DataRaw();
723
        attr->set_s(*reinterpret_cast<const std::string*>(raw_data));
724
        break;
725
      }
726
      case onnx::AttributeProto_AttributeType_INTS: {
727
        auto raw_data = tensor->DataRaw();
728
        for (int j = 0; j < tensor->Shape().Size(); j++) {
729
          attr->add_ints(*(reinterpret_cast<const int64_t*>(raw_data) + j));
730
        }
731
        break;
732
      }
733
      case onnx::AttributeProto_AttributeType_FLOATS: {
734
        auto raw_data = tensor->DataRaw();
735
        for (int j = 0; j < tensor->Shape().Size(); j++) {
736
          attr->add_floats(*(reinterpret_cast<const float*>(raw_data) + j));
737
        }
738
        break;
739
      }
740
      case onnx::AttributeProto_AttributeType_TENSOR: {
741
        auto tensor_proto = attr->add_tensors();
742
        auto prim_type = tensor->DataType()->AsPrimitiveDataType();
743
        if (prim_type == nullptr) {
744
          return OrtApis::CreateStatus(ORT_ENGINE_ERROR, "Undefined tensor type!");
745
        }
746
        tensor_proto->set_data_type(prim_type->GetDataType());
747
        tensor_proto->set_raw_data(tensor->DataRaw(), tensor->SizeInBytes());
748
        break;
749
      }
750
    }
751
  }
752

753
  for (size_t i = 0; i < num_inputs; i++) {
754
    auto name = input_names[i];
755
    if (name != nullptr) {
756
      node.add_input(name);
757
    } else {
758
      node.add_input();
759
    }
760
  }
761

762
  for (size_t i = 0; i < num_outputs; i++) {
763
    auto name = output_names[i];
764
    if (name != nullptr) {
765
      node.add_output(name);
766
    } else {
767
      node.add_output("unused");
768
    }
769
  }
770
  return nullptr;
771
  API_IMPL_END
772
}
773

774
ORT_API_STATUS_IMPL(
775
  winmla::ModelGetOpsetVersion, _In_ OrtModel* model, _In_ const char* const domain, _Out_ int32_t* version
776
) {
777
  API_IMPL_BEGIN
778
  auto model_proto = model->UseModelProto();
779

780
  *version = -1;
781
  auto size = static_cast<int>(model_proto->opset_import_size());
782
  for (int i = 0; i < size; i++) {
783
    auto& current_opset = model_proto->opset_import(i);
784
    auto& current_domain = current_opset.domain();
785
    if (_strnicmp(domain, current_domain.c_str(), current_domain.size()) == 0) {
786
      *version = static_cast<int32_t>(current_opset.version());
787
      break;
788
    }
789
  }
790

791
  return nullptr;
792
  API_IMPL_END
793
}
794

795
ORT_API(void, winmla::ReleaseModel, OrtModel* ptr) {
796
  delete ptr;
797
}
798

799
#include "core/framework/onnxruntime_typeinfo.h"
800
#include "core/framework/tensor_type_and_shape.h"
801

802
ORT_API_STATUS_IMPL(
803
  winmla::CreateTensorTypeInfo,
804
  _In_ const int64_t* dim_values,
805
  size_t dim_count,
806
  ONNXTensorElementDataType type,
807
  _Out_ OrtTypeInfo** ort_type_info
808
) {
809
  API_IMPL_BEGIN
810
  auto tensor_shape = onnxruntime::TensorShape(dim_values, dim_count);
811
  auto type_and_shape = OrtTensorTypeAndShapeInfo::GetTensorShapeAndTypeHelper(type, std::move(tensor_shape), nullptr);
812
  *ort_type_info = OrtTypeInfo::MakePtr(ONNX_TYPE_TENSOR, std::move(type_and_shape)).release();
813
  return nullptr;
814
  API_IMPL_END
815
}
816

817
ORT_API_STATUS_IMPL(winmla::CreateSequenceTypeInfo, _Out_ OrtTypeInfo** type_info) {
818
  API_IMPL_BEGIN
819
  return nullptr;
820
  API_IMPL_END
821
}
822

823
ORT_API_STATUS_IMPL(winmla::CreateMapTypeInfo, _Out_ OrtTypeInfo** type_info) {
824
  API_IMPL_BEGIN
825
  return nullptr;
826
  API_IMPL_END
827
}
828

829
ORT_API_STATUS_IMPL(
830
  winmla::OperatorGetNumInputs,
831
  _In_ const char* const op_type,
832
  _In_ int64_t opset,
833
  _In_ const char* const op_domain,
834
  _Out_ size_t* num_inputs
835
) {
836
  API_IMPL_BEGIN
837
  auto schema = GetSchema(op_type, opset, op_domain);
838
  *num_inputs = schema->inputs().size();
839
  return nullptr;
840
  API_IMPL_END
841
}
842

843
ORT_API_STATUS_IMPL(
844
  winmla::OperatorGetInputName,
845
  _In_ const char* const op_type,
846
  _In_ int64_t opset,
847
  _In_ const char* const op_domain,
848
  _In_ size_t index,
849
  _Out_ const char** const name
850
) {
851
  API_IMPL_BEGIN
852
  auto schema = GetSchema(op_type, opset, op_domain);
853
  *name = schema->inputs().at(index).GetName().c_str();
854
  return nullptr;
855
  API_IMPL_END
856
}
857

858
ORT_API_STATUS_IMPL(
859
  winmla::OperatorGetNumOutputs,
860
  _In_ const char* const op_type,
861
  _In_ int64_t opset,
862
  _In_ const char* const op_domain,
863
  _Out_ size_t* num_outputs
864
) {
865
  API_IMPL_BEGIN
866
  auto schema = GetSchema(op_type, opset, op_domain);
867
  *num_outputs = schema->outputs().size();
868
  return nullptr;
869
  API_IMPL_END
870
}
871

872
ORT_API_STATUS_IMPL(
873
  winmla::OperatorGetOutputName,
874
  _In_ const char* const op_type,
875
  _In_ int64_t opset,
876
  _In_ const char* const op_domain,
877
  _In_ size_t index,
878
  _Out_ const char** const name
879
) {
880
  API_IMPL_BEGIN
881
  auto schema = GetSchema(op_type, opset, op_domain);
882
  *name = schema->outputs().at(index).GetName().c_str();
883
  return nullptr;
884
  API_IMPL_END
885
}
886
#include "core/platform/threadpool.h"
887
#include "core/platform/env.h"
888

889
ORT_API_STATUS_IMPL(
890
  winmla::CreateThreadPool, _In_ ThreadPoolType type, _In_ OrtThreadPoolOptions* options, _Outptr_ OrtThreadPool** out
891
) {
892
  API_IMPL_BEGIN
893
  OrtThreadPoolParams params = {};
894
  params.thread_pool_size = options->thread_pool_size;
895
  params.auto_set_affinity = options->auto_set_affinity;
896
  params.allow_spinning = options->allow_spinning;
897
  params.dynamic_block_base_ = options->dynamic_block_base_;
898
  params.stack_size = options->stack_size;
899
  params.name = options->name;
900
  params.set_denormal_as_zero = options->set_denormal_as_zero;
901

902
  auto unique_tp = onnxruntime::concurrency::CreateThreadPool(
903
    &onnxruntime::Env::Default(), params, (onnxruntime::concurrency::ThreadPoolType)type
904
  );
905
  *out = reinterpret_cast<OrtThreadPool*>(unique_tp.release());
906
  return nullptr;
907
  API_IMPL_END
908
}
909

910
ORT_API(void, winmla::ReleaseThreadPool, OrtThreadPool* ptr) {
911
  delete reinterpret_cast<onnxruntime::concurrency::ThreadPool*>(ptr);
912
}
913

914
ORT_API_STATUS_IMPL(
915
  winmla::JoinModels,
916
  _In_ OrtModel* first_model,
917
  _In_ OrtModel* second_model,
918
  _In_ const char* const* output_names,
919
  _In_ const char* const* input_names,
920
  size_t num_linkages,
921
  bool promote_unlinked_outputs,
922
  _In_ const char* const join_node_prefix
923
) {
924
  API_IMPL_BEGIN
925

926
  std::string second_model_prefix = join_node_prefix;
927
  auto first_model_proto = first_model->UseModelProto();
928
  auto second_model_proto = second_model->DetachModelProto();
929

930
  // Remove old outputs
931
  if (promote_unlinked_outputs) {
932
    // Copy the output of the first model
933
    auto first_outputs = first_model_proto->graph().output();
934

935
    // Clear all outputs
936
    first_model_proto->mutable_graph()->mutable_output()->Clear();
937

938
    // Add back output
939
    for (int i = first_outputs.size() - 1; i >= 0; i--) {
940
      auto& output = first_outputs.at(i);
941
      auto output_name = output.name();
942

943
      auto found_it = std::find_if(output_names, output_names + num_linkages, [output_name](auto& name) {
944
        return std::strcmp(name, output_name.c_str()) == 0;
945
      });
946
      if (found_it == (output_names + num_linkages)) {
947
        // if output.name() is not found in the linkages, it is unlinked, and it should be promoted
948
        auto& promoted_output = *first_model_proto->mutable_graph()->add_output();
949
        promoted_output = std::move(output);
950
      }
951
    }
952
  } else {
953
    // remove all first model outputs
954
    first_model_proto->mutable_graph()->mutable_output()->Clear();
955
  }
956

957
  // add all model outputs from the second model
958
  for (int i = 0; i < second_model_proto->graph().output_size(); i++) {
959
    auto& other_output = *second_model_proto->mutable_graph()->mutable_output(i);
960
    *other_output.mutable_name() = second_model_prefix + other_output.name();
961
    auto& output = *first_model_proto->mutable_graph()->add_output();
962
    output = std::move(other_output);
963
  }
964

965
  // loop through second model inputs and promote the unlinked ones to the main model inputs
966
  for (int i = 0; i < second_model_proto->graph().input_size(); i++) {
967
    auto& other_input = *second_model_proto->mutable_graph()->mutable_input(i);
968
    auto old_name = other_input.name();
969
    *other_input.mutable_name() = second_model_prefix + old_name;
970

971
    auto found_it = std::find_if(input_names, input_names + num_linkages, [old_name](auto& name) {
972
      return std::strcmp(name, old_name.c_str()) == 0;
973
    });
974
    bool is_linked =
975
      found_it != (input_names + num_linkages);  // figure out if other_input.name() exists in the output_names mapped
976
    if (!is_linked) {
977
      auto& input = *first_model_proto->mutable_graph()->add_input();
978
      input = std::move(other_input);
979
    }
980
  }
981

982
  // add all initializers
983
  for (int i = 0; i < second_model_proto->graph().initializer_size(); i++) {
984
    auto& other_initializer = *second_model_proto->mutable_graph()->mutable_initializer(i);
985
    *other_initializer.mutable_name() = second_model_prefix + other_initializer.name();
986
    auto& initializer = *first_model_proto->mutable_graph()->add_initializer();
987
    initializer = std::move(other_initializer);
988
  }
989

990
  // add all nodes
991
  for (int i = 0; i < second_model_proto->graph().node_size(); i++) {
992
    auto& other_node = *second_model_proto->mutable_graph()->mutable_node(i);
993
    if (0 != strcmp(other_node.name().c_str(), "")) {
994
      *other_node.mutable_name() = second_model_prefix + other_node.name();
995
    }
996
    for (int j = 0; j < other_node.input_size(); j++) {
997
      *other_node.mutable_input(j) = second_model_prefix + other_node.input(j);
998
    }
999
    for (int j = 0; j < other_node.output_size(); j++) {
1000
      *other_node.mutable_output(j) = second_model_prefix + other_node.output(j);
1001
    }
1002
    auto& node = *first_model_proto->mutable_graph()->add_node();
1003
    node = std::move(other_node);
1004
  }
1005

1006
  // WinML+RT API only supports opset 7 and above models.
1007
  // In practice this number is always overwritten by the for loop below which will find the actual opset version.
1008
  int64_t opset = 7;
1009
  for (int i = 0; i < second_model_proto->opset_import_size(); i++) {
1010
    auto mutable_opset_import = second_model_proto->mutable_opset_import(i);
1011
    auto domain = mutable_opset_import->has_domain() ? mutable_opset_import->domain() : std::string("");
1012
    auto version = mutable_opset_import->version();
1013

1014
    // does the domain exist in the first model?
1015
    auto found_it = std::find_if(
1016
      first_model_proto->mutable_opset_import()->begin(),
1017
      first_model_proto->mutable_opset_import()->end(),
1018
      [&domain](auto& mutable_opset_import) {
1019
        auto first_model_domain = mutable_opset_import.has_domain() ? mutable_opset_import.domain() : std::string("");
1020
        return 0 == strcmp(first_model_domain.c_str(), domain.c_str());
1021
      }
1022
    );
1023
    if (found_it != first_model_proto->mutable_opset_import()->end()) {
1024
      found_it->set_version(std::max(found_it->version(), version));
1025
      if (0 == strcmp(domain.c_str(), "")) {
1026
        opset = found_it->version();
1027
      }
1028
    }
1029
  }
1030

1031
  // add identity ops to rename all of the first model outputs to secondmodel inputs with prefix for each linkage
1032
  for (size_t i = 0; i < num_linkages; i++) {
1033
    auto op_output_name = second_model_prefix + *(input_names + i);
1034
    const char* const op_output_name_const_str = op_output_name.c_str();
1035
    std::string name = "IdentityTo";
1036
    name += second_model_prefix + *(input_names + i);
1037
    ModelAddOperator(
1038
      first_model,
1039
      "Identity",
1040
      name.c_str(),
1041
      opset,
1042
      "",
1043
      (output_names + i),
1044
      1,
1045
      &op_output_name_const_str,
1046
      1,
1047
      nullptr,
1048
      nullptr,
1049
      0
1050
    );
1051
  }
1052
  first_model->RefreshModelInfo();
1053

1054
  return nullptr;
1055
  API_IMPL_END
1056
}
1057

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.