onnxruntime

Форк
0
709 строк · 27.8 Кб
1
// Copyright (c) Microsoft Corporation. All rights reserved.
2
// Licensed under the MIT License.
3

4
#include "testPch.h"
5
#include <wil/result.h>
6
#include <D3d11_4.h>
7
#include <dxgi1_6.h>
8
#include "filehelpers.h"
9
#include <fstream>
10
#include <MemoryBuffer.h>
11
#include "CustomOperatorProvider.h"
12
#include "CustomOps.h"
13

14
// For custom operator and shape inferencing support
15
#include "core/providers/dml/DmlExecutionProvider/inc/MLOperatorAuthor.h"
16
#include "core/providers/dml/DmlExecutionProvider/src/ErrorHandling.h"
17
#include "core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorHelper.h"
18
#include "core/providers/dml/OperatorAuthorHelper/OperatorHelper.h"
19
#include "core/providers/dml/OperatorAuthorHelper/OperatorVersions.h"
20
#include "core/graph/constants.h"
21
#include "CustomNullOp.h"
22
#include <wil/wrl.h>
23

24
using namespace winml;
25
using namespace wfc;
26
using namespace wm;
27
using namespace wgi;
28
using namespace ws;
29
using namespace wss;
30

31
static void CustomOpsScenarioTestsClassSetup() {
32
  winrt::init_apartment();
33
#ifdef BUILD_INBOX
34
  winrt_activation_handler = WINRT_RoGetActivationFactory;
35
#endif
36
}
37

38
// Tests that the execution provider correctly fuses operators together when custom ops are involved.
39
static void CustomOperatorFusion() {
40
  constexpr const wchar_t* c_modelFilename = L"squeezenet_tensor_input.onnx";
41

42
  // This particular model has 25 Conv ops and 25 Relu ops, all of which are eligible for fusion so we expect them
43
  // all to be fused (removing them from the graph) and replaced with the appropriate fused op instead. The same
44
  // goes for the single Gemm+Sigmoid in the model too.
45
  constexpr const uint32_t c_expectedConvOps = 0;
46
  constexpr const uint32_t c_expectedReluOps = 0;
47
  constexpr const uint32_t c_expectedFusedConvOps = 25;
48
  constexpr const uint32_t c_expectedGemmOps = 0;
49
  constexpr const uint32_t c_expectedSigmoidOps = 0;
50
  constexpr const uint32_t c_expectedFusedGemmOps = 1;
51

52
  // These ops are also part of the model but shouldn't be fused
53
  constexpr const uint32_t c_expectedBatchNormOps = 1;
54
  constexpr const uint32_t c_expectedMaxPoolOps = 3;
55
  constexpr const uint32_t c_expectedConcatOps = 8;
56

57
  struct CallbackOperatorProvider : winrt::implements<
58
                                      CallbackOperatorProvider,
59
                                      winml::ILearningModelOperatorProvider,
60
                                      ILearningModelOperatorProviderNative> {
61
    struct CallCounts {
62
      std::atomic<uint32_t> conv = 0;
63
      std::atomic<uint32_t> relu = 0;
64
      std::atomic<uint32_t> fusedConv = 0;
65
      std::atomic<uint32_t> gemm = 0;
66
      std::atomic<uint32_t> sigmoid = 0;
67
      std::atomic<uint32_t> fusedGemm = 0;
68
      std::atomic<uint32_t> batchNorm = 0;
69
      std::atomic<uint32_t> maxPool = 0;
70
      std::atomic<uint32_t> concat = 0;
71
    };
72

73
    const CallCounts& GetCallCounts() { return m_callCounts; }
74

75
    CallbackOperatorProvider() {
76
      using namespace OperatorHelper;
77

78
      std::wostringstream dll;
79
      dll << BINARY_NAME;
80
      auto winml_dll_name = dll.str();
81

82
#if WINAPI_FAMILY_PARTITION(WINAPI_PARTITION_DESKTOP)
83
      auto m_library = LoadLibraryExW(winml_dll_name.c_str(), nullptr, 0);
84
#elif WINAPI_FAMILY_PARTITION(WINAPI_PARTITION_PC_APP)
85
      auto m_library = LoadPackagedLibrary(winml_dll_name.c_str(), 0 /*Reserved*/);
86
#endif
87
      using create_registry_delegate = HRESULT WINAPI(_COM_Outptr_ IMLOperatorRegistry * *registry);
88
      auto create_registry =
89
        reinterpret_cast<create_registry_delegate*>(GetProcAddress(m_library, "MLCreateOperatorRegistry"));
90
      WINML_EXPECT_HRESULT_SUCCEEDED(create_registry(m_registry.put()));
91

92
#pragma push_macro("REGISTER_KERNEL")
93
#define REGISTER_KERNEL(_name, _domain, _opSet, _shapeInferrer, _callCount) \
94
  NullOperatorFactory::RegisterKernel(                                      \
95
    #_name,                                                                 \
96
    (_domain),                                                              \
97
    _opSet::sc_sinceVer_##_name,                                            \
98
    m_registry,                                                             \
99
    winrt::make<NullShapeInferrer<_shapeInferrer>>(),                       \
100
    (_callCount)                                                            \
101
  );
102

103
      REGISTER_KERNEL(Conv, onnxruntime::kOnnxDomain, OnnxOperatorSet7, ConvHelper, &m_callCounts.conv);
104
      REGISTER_KERNEL(
105
        Relu, onnxruntime::kOnnxDomain, OnnxOperatorSet7, GetOutputShapeAsInputShapeHelper, &m_callCounts.relu
106
      );
107
      REGISTER_KERNEL(DmlFusedConv, onnxruntime::kMSDmlDomain, MsftOperatorSet1, ConvHelper, &m_callCounts.fusedConv);
108

109
      REGISTER_KERNEL(Gemm, onnxruntime::kOnnxDomain, OnnxOperatorSet7, GemmHelper, &m_callCounts.gemm);
110
      REGISTER_KERNEL(
111
        Sigmoid, onnxruntime::kOnnxDomain, OnnxOperatorSet7, GetOutputShapeAsInputShapeHelper, &m_callCounts.sigmoid
112
      );
113
      REGISTER_KERNEL(DmlFusedGemm, onnxruntime::kMSDmlDomain, MsftOperatorSet1, GemmHelper, &m_callCounts.fusedGemm);
114

115
      REGISTER_KERNEL(
116
        BatchNormalization,
117
        onnxruntime::kOnnxDomain,
118
        OnnxOperatorSet7,
119
        GetOutputShapeAsInputShapeHelper,
120
        &m_callCounts.batchNorm
121
      );
122
      REGISTER_KERNEL(MaxPool, onnxruntime::kOnnxDomain, OnnxOperatorSet7, PoolingHelper, &m_callCounts.maxPool);
123
      REGISTER_KERNEL(Concat, onnxruntime::kOnnxDomain, OnnxOperatorSet7, ConcatHelper, &m_callCounts.concat);
124

125
#pragma pop_macro("REGISTER_KERNEL")
126
    }
127

128
    STDMETHOD(GetRegistry)
129
    (IMLOperatorRegistry** ppOperatorRegistry) {
130
      if (ppOperatorRegistry == nullptr) {
131
        return E_POINTER;
132
      }
133

134
      m_registry.copy_to(ppOperatorRegistry);
135
      return S_OK;
136
    }
137

138
   private:
139
    winrt::com_ptr<IMLOperatorRegistry> m_registry;
140
    CallCounts m_callCounts;
141
  };
142

143
  auto customOperatorProvider = winrt::make<CallbackOperatorProvider>();
144
  auto provider = customOperatorProvider.as<ILearningModelOperatorProvider>();
145

146
  LearningModelDevice device = nullptr;
147
  WINML_EXPECT_NO_THROW(device = LearningModelDevice(LearningModelDeviceKind::DirectX));
148
  std::wstring fullPath = FileHelpers::GetModulePath() + c_modelFilename;
149
  auto model = LearningModel::LoadFromFilePath(fullPath, provider);
150

151
  auto featureValue = FileHelpers::LoadImageFeatureValue(L"227x227.png");
152

153
  LearningModelSession session = nullptr;
154
  WINML_EXPECT_NO_THROW(session = LearningModelSession(model, device));
155
  LearningModelBinding modelBinding(session);
156

157
  modelBinding.Bind(L"data", featureValue);
158
  auto result = session.Evaluate(modelBinding, L"");
159

160
  const auto& callCounts = customOperatorProvider.as<CallbackOperatorProvider>()->GetCallCounts();
161

162
  // Verify that the correct number of each operator was seen (i.e. that none were dropped / incorrectly fused)
163
  WINML_EXPECT_EQUAL(c_expectedConvOps, callCounts.conv);
164
  WINML_EXPECT_EQUAL(c_expectedReluOps, callCounts.relu);
165
  WINML_EXPECT_EQUAL(c_expectedFusedConvOps, callCounts.fusedConv);
166
  WINML_EXPECT_EQUAL(c_expectedGemmOps, callCounts.gemm);
167
  WINML_EXPECT_EQUAL(c_expectedSigmoidOps, callCounts.sigmoid);
168
  WINML_EXPECT_EQUAL(c_expectedFusedGemmOps, callCounts.fusedGemm);
169
  WINML_EXPECT_EQUAL(c_expectedBatchNormOps, callCounts.batchNorm);
170
  WINML_EXPECT_EQUAL(c_expectedMaxPoolOps, callCounts.maxPool);
171
  WINML_EXPECT_EQUAL(c_expectedConcatOps, callCounts.concat);
172
}
173

174
struct LocalCustomOperatorProvider : winrt::implements<
175
                                       LocalCustomOperatorProvider,
176
                                       winml::ILearningModelOperatorProvider,
177
                                       ILearningModelOperatorProviderNative> {
178
  LocalCustomOperatorProvider() {
179
    std::wostringstream dll;
180
    dll << BINARY_NAME;
181
    auto winml_dll_name = dll.str();
182

183
#if WINAPI_FAMILY_PARTITION(WINAPI_PARTITION_DESKTOP)
184
    auto m_library = LoadLibraryExW(winml_dll_name.c_str(), nullptr, 0);
185
#elif WINAPI_FAMILY_PARTITION(WINAPI_PARTITION_PC_APP)
186
    auto m_library = LoadPackagedLibrary(winml_dll_name.c_str(), 0 /*Reserved*/);
187
#endif
188
    using create_registry_delegate = HRESULT WINAPI(_COM_Outptr_ IMLOperatorRegistry * *registry);
189
    auto create_registry =
190
      reinterpret_cast<create_registry_delegate*>(GetProcAddress(m_library, "MLCreateOperatorRegistry"));
191
    WINML_EXPECT_HRESULT_SUCCEEDED(create_registry(m_registry.put()));
192
  }
193

194
  STDMETHOD(GetRegistry)
195
  (IMLOperatorRegistry** ppOperatorRegistry) {
196
    if (ppOperatorRegistry == nullptr) {
197
      return E_POINTER;
198
    }
199

200
    m_registry.copy_to(ppOperatorRegistry);
201
    return S_OK;
202
  }
203

204
  IMLOperatorRegistry* GetRegistry() { return m_registry.get(); }
205

206
 protected:
207
  winrt::com_ptr<IMLOperatorRegistry> m_registry;
208
};
209

210
// Checks test attributes set on ABI kernels can be queried with correct values
211
void VerifyTestAttributes(const MLOperatorAttributes& attrs) {
212
  std::string strAttr = attrs.GetAttribute("DefaultedNonRequiredString");
213
  WINML_EXPECT_EQUAL(strAttr, "1");
214

215
  std::vector<std::string> strArrayAttr = attrs.GetAttributeVector("DefaultedNonRequiredStringArray");
216
  std::vector<std::string> expected = std::vector<std::string>({"1", "2"});
217
  for (size_t i = 0; i < expected.size(); ++i) {
218
    WINML_EXPECT_EQUAL(strArrayAttr[i], expected[i]);
219
  }
220

221
  WINML_EXPECT_EQUAL(1, attrs.GetAttribute<int64_t>("DefaultedNonRequiredInt"));
222
  WINML_EXPECT_EQUAL(1.0f, attrs.GetAttribute<float>("DefaultedNonRequiredFloat"));
223

224
  WINML_EXPECT_EQUAL(std::vector<int64_t>({1, 2}), attrs.GetAttributeVector<int64_t>("DefaultedNonRequiredIntArray"));
225
  WINML_EXPECT_EQUAL(
226
    std::vector<float>({1.0f, 2.0f}), attrs.GetAttributeVector<float>("DefaultedNonRequiredFloatArray")
227
  );
228
}
229

230
// Foo kernel which is doing Add and optionally truncates its output
231
template <typename T, bool VerifyAttributes = false, bool Truncate = false>
232
class FooKernel {
233
 public:
234
  FooKernel(const MLOperatorKernelCreationContext& info) {
235
    if (VerifyAttributes) {
236
      VerifyTestAttributes(info);
237
    }
238

239
    VerifyShapeInfo(info);
240
  }
241

242
  void VerifyShapeInfo(const MLOperatorKernelCreationContext& info) {
243
    if (!Truncate) {
244
      winrt::com_ptr<IMLOperatorTensorShapeDescription> shapeInfo;
245
      WINML_EXPECT_EQUAL(info.GetInterface()->HasTensorShapeDescription(), false);
246
      WINML_EXPECT_HRESULT_FAILED(info.GetInterface()->GetTensorShapeDescription(shapeInfo.put()));
247
    } else {
248
      winrt::com_ptr<IMLOperatorTensorShapeDescription> shapeInfo;
249
      WINML_EXPECT_EQUAL(info.GetInterface()->HasTensorShapeDescription(), true);
250
      WINML_EXPECT_EQUAL(info.GetInterface()->GetTensorShapeDescription(shapeInfo.put()), S_OK);
251
    }
252
  }
253

254
  void Compute(const MLOperatorKernelContext& context) const {
255
    const auto X = context.GetInputTensor(0);
256
    const auto W = context.GetInputTensor(1);
257

258
    auto xData = X.GetData<T>();
259
    auto wData = W.GetData<T>();
260

261
    auto shape = X.GetShape();
262

263
    // This is used to test shape inference
264
    if (Truncate) {
265
      shape[0] -= 1;
266
    }
267

268
    if (!Truncate) {
269
      winrt::com_ptr<IMLOperatorTensor> tensor;
270
      WINML_EXPECT_HRESULT_FAILED(context.GetInterface()->GetOutputTensor(0, tensor.put()));
271
    } else {
272
      MLOperatorTensor tensor = context.GetOutputTensor(0);
273
    }
274

275
    auto Y = context.GetOutputTensor(0, shape);
276
    auto yData = Y.GetData<T>();
277

278
    size_t size = 1;
279
    for (size_t i = 0; i < shape.size(); i++) {
280
      size *= shape[i];
281
    }
282

283
    for (size_t i = 0; i < size; i++) {
284
      yData[i] = xData[i] + wData[i];
285
    }
286
  }
287
};
288

289
template <bool VerifyTestAttributes = false>
290
void CALLBACK CreateABIFooKernel(IMLOperatorKernelCreationContext* kernelInfo, IMLOperatorKernel** opKernel) {
291
  HRESULT hr = MLOperatorKernel<FooKernel<float, VerifyTestAttributes>>::CreateInstance(*kernelInfo, opKernel);
292
  THROW_IF_FAILED(hr);
293
}
294

295
void CALLBACK CreateTruncatedABIFooKernel(IMLOperatorKernelCreationContext* kernelInfo, IMLOperatorKernel** opKernel) {
296
  HRESULT hr = MLOperatorKernel<FooKernel<float, true, true>>::CreateInstance(*kernelInfo, opKernel);
297
  THROW_IF_FAILED(hr);
298
}
299

300
// Test using a foo kernel which is doing Add, but register it as "Mul".
301
static void CustomKernelWithBuiltInSchema() {
302
  // Create the registry
303
  auto operatorProvider = winrt::make<LocalCustomOperatorProvider>();
304
  IMLOperatorRegistry* registry = operatorProvider.as<LocalCustomOperatorProvider>()->GetRegistry();
305

306
  // Register the kernel
307
  MLOperatorEdgeDescription floatTensorType = {
308
    MLOperatorEdgeType::Tensor, static_cast<uint64_t>(MLOperatorTensorDataType::Float)
309
  };
310

311
  MLOperatorEdgeTypeConstrant constraint = {"T", &floatTensorType, 1};
312

313
  MLOperatorKernelDescription kernelDesc = {
314
    "",
315
    "Mul",
316
    7,
317
    MLOperatorExecutionType::Cpu,
318
    &constraint,
319
    1,
320
    nullptr,
321
    0,
322
    MLOperatorKernelOptions::AllowDynamicInputShapes
323
  };
324

325
  Microsoft::WRL::ComPtr<MLOperatorKernelFactory> factory =
326
    wil::MakeOrThrow<MLOperatorKernelFactory>(CreateABIFooKernel<false>);
327
  WINML_EXPECT_HRESULT_SUCCEEDED(registry->RegisterOperatorKernel(&kernelDesc, factory.Get(), nullptr));
328

329
  // Prepare inputs
330
  std::vector<int64_t> dimsX = {3, 2};
331
  std::vector<float> valuesX = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};
332

333
  // Prepare expected inputs and outputs
334
  std::vector<int64_t> expectedDimsY = {3, 2};
335

336
  // The expected value should be Add's result.
337
  std::vector<float> expectedValuesY = {2.0f, 4.0f, 6.0f, 8.0f, 10.0f, 12.0f};
338

339
  // Create the model and sessions
340
  std::wstring fullPath = FileHelpers::GetModulePath() + L"mul.onnx";
341
  LearningModel model = LearningModel::LoadFromFilePath(fullPath, operatorProvider);
342

343
  LearningModelSession session(model);
344
  LearningModelBinding bindings(session);
345

346
  // Bind inputs and outputs
347
  TensorFloat inputTensor = TensorFloat::CreateFromArray(dimsX, winrt::array_view<const float>(std::move(valuesX)));
348
  bindings.Bind(winrt::hstring(L"X"), inputTensor);
349

350
  auto outputValue = TensorFloat::Create();
351
  WINML_EXPECT_NO_THROW(bindings.Bind(L"Y", outputValue));
352

353
  // Evaluate the model
354
  winrt::hstring correlationId;
355
  WINML_EXPECT_NO_THROW(session.Evaluate(bindings, correlationId));
356

357
  // Check the result shape
358
  WINML_EXPECT_EQUAL(expectedDimsY.size(), outputValue.Shape().Size());
359
  for (uint32_t j = 0; j < outputValue.Shape().Size(); j++) {
360
    WINML_EXPECT_EQUAL(expectedDimsY.at(j), outputValue.Shape().GetAt(j));
361
  }
362

363
  // Check the results
364
  auto buffer = outputValue.GetAsVectorView();
365
  WINML_EXPECT_TRUE(buffer != nullptr);
366
  WINML_EXPECT_TRUE(std::equal(expectedValuesY.cbegin(), expectedValuesY.cend(), begin(buffer)));
367

368
  // Release the model before operatorProvider goes out of scope
369
  model = nullptr;
370
}
371

372
// Similar to MLOperatorShapeInferrer, but using an std::function
373
class MLOperatorShapeInferrerFromFunc
374
  : public Microsoft::WRL::
375
      RuntimeClass<Microsoft::WRL::RuntimeClassFlags<Microsoft::WRL::ClassicCom>, IMLOperatorShapeInferrer> {
376
 public:
377
  MLOperatorShapeInferrerFromFunc(std::function<void(IMLOperatorShapeInferenceContext*)> shapeInferenceFn)
378
    : m_func(shapeInferenceFn) {}
379

380
  HRESULT STDMETHODCALLTYPE InferOutputShapes(IMLOperatorShapeInferenceContext* context) noexcept override try {
381
    m_func(context);
382
    return S_OK;
383
  }
384
  CATCH_RETURN();
385

386
 private:
387
  std::function<void(IMLOperatorShapeInferenceContext*)> m_func;
388
};
389

390
// Test using a custom kernel and schema, while verifying attribute defaults, type mapping, and inference methods
391
static void CustomKernelWithCustomSchema() {
392
  // Test cases
393
  struct {
394
    // Whether the Foo kernel should truncate its output
395
    bool truncateOutput;
396

397
    // Whether a type label is used in the schema, versus a type description
398
    bool useTypeLabel;
399

400
    // Whether the schema provides a type inference function, and uses an output type
401
    // of Int32 instead of Float32
402
    bool useTypeInference;
403

404
    // Whether a shape inference method is provided in the schema
405
    bool useShapeInferenceInSchema;
406

407
    // Whether a shape inference method is provided in the kernel
408
    bool useShapeInferenceInKernel;
409

410
    // Whether attribute defaults are provided in the schema, instead of the kernel
411
    bool attributeDefaultsInSchema;
412
  } testCases[] = {
413
    {false,  true, false, false, false, false},
414
    {false, false, false, false, false, false},
415
    {false,  true,  true, false, false,  true},
416
    { true, false, false, false,  true, false},
417
    { true,  true,  true,  true,  true,  true},
418
  };
419

420
  for (size_t caseIndex = 0; caseIndex < std::size(testCases); ++caseIndex) {
421
    // Create the registry
422
    auto operatorProvider = winrt::make<LocalCustomOperatorProvider>();
423
    IMLOperatorRegistry* registry = operatorProvider.as<LocalCustomOperatorProvider>()->GetRegistry();
424

425
    // Create input and output parameters
426
    MLOperatorSchemaEdgeDescription inputParam = {};
427
    inputParam.options = MLOperatorParameterOptions::Single;
428

429
    if (!testCases[caseIndex].useTypeLabel) {
430
      assert(!testCases[caseIndex].useTypeInference);
431

432
      MLOperatorEdgeDescription edgeDesc = {};
433
      edgeDesc.edgeType = MLOperatorEdgeType::Tensor;
434
      edgeDesc.tensorDataType = MLOperatorTensorDataType::Float;
435

436
      inputParam.typeFormat = MLOperatorSchemaEdgeTypeFormat::EdgeDescription;
437
      inputParam.edgeDescription = edgeDesc;
438
    } else {
439
      inputParam.typeFormat = MLOperatorSchemaEdgeTypeFormat::Label;
440
      inputParam.typeLabel = "T1";
441
    }
442

443
    MLOperatorSchemaEdgeDescription outputParam = inputParam;
444

445
    // Type inference should set this to tensor(float) even though T2 is not matched
446
    // on an input label
447
    if (testCases[caseIndex].useTypeInference) {
448
      if (inputParam.typeFormat == MLOperatorSchemaEdgeTypeFormat::Label) {
449
        outputParam.typeLabel = "T2";
450
      } else {
451
        outputParam.edgeDescription.tensorDataType = MLOperatorTensorDataType::Int32;
452
      }
453
    }
454

455
    MLOperatorSchemaEdgeDescription inputs[] = {inputParam, inputParam};
456

457
    MLOperatorEdgeDescription edgeTypes[6] = {
458
      {MLOperatorEdgeType::Tensor, static_cast<uint64_t>(MLOperatorTensorDataType::UInt32)},
459
      {MLOperatorEdgeType::Tensor, static_cast<uint64_t>(MLOperatorTensorDataType::UInt64)},
460
      {MLOperatorEdgeType::Tensor,  static_cast<uint64_t>(MLOperatorTensorDataType::Int32)},
461
      {MLOperatorEdgeType::Tensor,  static_cast<uint64_t>(MLOperatorTensorDataType::Int64)},
462
      {MLOperatorEdgeType::Tensor,  static_cast<uint64_t>(MLOperatorTensorDataType::Float)},
463
      {MLOperatorEdgeType::Tensor, static_cast<uint64_t>(MLOperatorTensorDataType::Double)}
464
    };
465

466
    // Type constraints.  Only the first is used unless type inference is provided and
467
    // the kernel emits a different output type as "T2"
468
    MLOperatorEdgeTypeConstrant constraints[] = {
469
      {"T1", edgeTypes, static_cast<uint32_t>(std::size(edgeTypes))},
470
      {"T2", edgeTypes, static_cast<uint32_t>(std::size(edgeTypes))}
471
    };
472

473
    // Test attributes
474
    MLOperatorAttribute attributes[] = {
475
      {           "DefaultedNonRequiredInt",         MLOperatorAttributeType::Int, false},
476
      {         "DefaultedNonRequiredFloat",       MLOperatorAttributeType::Float, false},
477
      {        "DefaultedNonRequiredString",      MLOperatorAttributeType::String, false},
478
      {      "DefaultedNonRequiredIntArray",    MLOperatorAttributeType::IntArray, false},
479
      {    "DefaultedNonRequiredFloatArray",  MLOperatorAttributeType::FloatArray, false},
480
      {   "DefaultedNonRequiredStringArray", MLOperatorAttributeType::StringArray, false},
481

482
      {"NonDefaultedNonRequiredStringArray", MLOperatorAttributeType::StringArray, false},
483
    };
484

485
    // Defaults.  These are queried back during kernel creation, type and shape inference
486
    // and tested against the same values
487
    MLOperatorAttributeNameValue defaultAttributes[] = {
488
      {        "DefaultedNonRequiredInt",         MLOperatorAttributeType::Int, 1},
489
      {      "DefaultedNonRequiredFloat",       MLOperatorAttributeType::Float, 1},
490
      {     "DefaultedNonRequiredString",      MLOperatorAttributeType::String, 1},
491
      {   "DefaultedNonRequiredIntArray",    MLOperatorAttributeType::IntArray, 2},
492
      { "DefaultedNonRequiredFloatArray",  MLOperatorAttributeType::FloatArray, 2},
493
      {"DefaultedNonRequiredStringArray", MLOperatorAttributeType::StringArray, 2},
494
    };
495

496
    int64_t defaultInts[] = {1, 2};
497
    float defaultFloats[] = {1.0f, 2.0f};
498
    const char* defaultStrings[] = {"1", "2"};
499
    defaultAttributes[0].ints = defaultInts;
500
    defaultAttributes[1].floats = defaultFloats;
501
    defaultAttributes[2].strings = defaultStrings;
502
    defaultAttributes[3].ints = defaultInts;
503
    defaultAttributes[4].floats = defaultFloats;
504
    defaultAttributes[5].strings = defaultStrings;
505

506
    // Schema definition
507
    MLOperatorSchemaDescription schemaDesc = {};
508
    schemaDesc.name = "Foo";
509
    schemaDesc.operatorSetVersionAtLastChange = 7;
510
    schemaDesc.inputs = inputs;
511
    schemaDesc.inputCount = 2;
512
    schemaDesc.outputs = &outputParam;
513
    schemaDesc.outputCount = 1;
514
    schemaDesc.typeConstraints = constraints;
515
    schemaDesc.typeConstraintCount = testCases[caseIndex].useTypeLabel ? 2 : 0;
516
    schemaDesc.attributes = attributes;
517
    schemaDesc.attributeCount = static_cast<uint32_t>(std::size(attributes));
518

519
    if (testCases[caseIndex].attributeDefaultsInSchema) {
520
      schemaDesc.defaultAttributes = defaultAttributes;
521
      schemaDesc.defaultAttributeCount = static_cast<uint32_t>(std::size(defaultAttributes));
522
    }
523

524
    Microsoft::WRL::ComPtr<MLOperatorTypeInferrer> typeInferrer;
525
    Microsoft::WRL::ComPtr<MLOperatorShapeInferrerFromFunc> shapeInferrer;
526

527
    // Type inference function
528
    if (testCases[caseIndex].useTypeInference) {
529
      typeInferrer = wil::MakeOrThrow<MLOperatorTypeInferrer>([](IMLOperatorTypeInferenceContext* ctx) -> void {
530
        VerifyTestAttributes(MLOperatorTypeInferenceContext(ctx));
531

532
        MLOperatorEdgeDescription edgeDesc = {};
533
        edgeDesc.edgeType = MLOperatorEdgeType::Tensor;
534
        edgeDesc.tensorDataType = MLOperatorTensorDataType::Float;
535

536
        MLOperatorTypeInferenceContext(ctx).SetOutputEdgeDescription(0, &edgeDesc);
537
      });
538
    }
539

540
    // Store the shape inference context with a reference following the call to InferOutputShapes.
541
    // This will be called after loading the model as an isolated test for how ABI context objects
542
    // are "closed."
543
    Microsoft::WRL::ComPtr<IMLOperatorShapeInferenceContext> shapeInferenceContext;
544

545
    // Shape inference is tested by truncating the output size
546
    bool truncateOutput = testCases[caseIndex].truncateOutput;
547
    if (truncateOutput) {
548
      shapeInferrer = wil::MakeOrThrow<MLOperatorShapeInferrerFromFunc>(
549
        [&shapeInferenceContext](IMLOperatorShapeInferenceContext* ctx) -> void {
550
          VerifyTestAttributes(MLShapeInferenceContext(ctx));
551
          MLShapeInferenceContext(ctx).SetOutputTensorShape(0, {2, 2});
552
          shapeInferenceContext = ctx;
553
        }
554
      );
555
    }
556

557
    // Register the schema
558
    MLOperatorSetId opsetId = {"", 7};
559
    MLOperatorSchemaDescription* opSchemaDescs = &schemaDesc;
560
    WINML_EXPECT_EQUAL(
561
      S_OK,
562
      registry->RegisterOperatorSetSchema(
563
        &opsetId,
564
        1,
565
        &opSchemaDescs,
566
        1,
567
        typeInferrer.Get(),
568
        testCases[caseIndex].useShapeInferenceInSchema ? shapeInferrer.Get() : nullptr
569
      )
570
    );
571

572
    {
573
      // Register a future version of the schema in the same domain, while setting its
574
      // input count to zero to ensure it is not being used.
575
      auto futureSchemaDesc = schemaDesc;
576
      futureSchemaDesc.inputCount = 0;
577

578
      MLOperatorSetId id = {"", 9};
579
      MLOperatorSchemaDescription* schemaDescs = &futureSchemaDesc;
580
      WINML_EXPECT_EQUAL(
581
        S_OK,
582
        registry->RegisterOperatorSetSchema(
583
          &id,
584
          7,
585
          &schemaDescs,
586
          1,
587
          typeInferrer.Get(),
588
          testCases[caseIndex].useShapeInferenceInSchema ? shapeInferrer.Get() : nullptr
589
        )
590
      );
591
    }
592
    {
593
      // Register in another (unused) domain to the custom registry
594
      auto otherSchemaDesc = schemaDesc;
595
      otherSchemaDesc.inputCount = 0;
596

597
      MLOperatorSetId id = {"otherDomain", 7};
598
      MLOperatorSchemaDescription* schemaDescs = &otherSchemaDesc;
599
      WINML_EXPECT_EQUAL(
600
        S_OK,
601
        registry->RegisterOperatorSetSchema(
602
          &id,
603
          1,
604
          &schemaDescs,
605
          1,
606
          typeInferrer.Get(),
607
          testCases[caseIndex].useShapeInferenceInSchema ? shapeInferrer.Get() : nullptr
608
        )
609
      );
610
    }
611
    // Register the Foo kernel
612
    MLOperatorEdgeDescription floatTensorEdgeDesc = {};
613
    floatTensorEdgeDesc.edgeType = MLOperatorEdgeType::Tensor;
614
    floatTensorEdgeDesc.tensorDataType = MLOperatorTensorDataType::Float;
615

616
    MLOperatorEdgeTypeConstrant kernelConstraint = {"T1", &floatTensorEdgeDesc, 1};
617

618
    MLOperatorKernelDescription kernelDesc = {
619
      "", "Foo", 7, MLOperatorExecutionType::Cpu, &kernelConstraint, testCases[caseIndex].useTypeLabel ? 1u : 0u
620
    };
621

622
    if (!testCases[caseIndex].attributeDefaultsInSchema) {
623
      kernelDesc.defaultAttributes = defaultAttributes;
624
      kernelDesc.defaultAttributeCount = static_cast<uint32_t>(std::size(defaultAttributes));
625
    }
626

627
    if (!truncateOutput) {
628
      kernelDesc.options = MLOperatorKernelOptions::AllowDynamicInputShapes;
629
      Microsoft::WRL::ComPtr<MLOperatorKernelFactory> factory =
630
        wil::MakeOrThrow<MLOperatorKernelFactory>(CreateABIFooKernel<true>);
631

632
      WINML_EXPECT_EQUAL(S_OK, registry->RegisterOperatorKernel(&kernelDesc, factory.Get(), nullptr));
633
    } else {
634
      Microsoft::WRL::ComPtr<MLOperatorKernelFactory> factory =
635
        wil::MakeOrThrow<MLOperatorKernelFactory>(CreateTruncatedABIFooKernel);
636
      WINML_EXPECT_EQUAL(
637
        S_OK,
638
        registry->RegisterOperatorKernel(
639
          &kernelDesc, factory.Get(), testCases[caseIndex].useShapeInferenceInKernel ? shapeInferrer.Get() : nullptr
640
        )
641
      );
642
    }
643

644
    // Prepare inputs
645
    std::vector<int64_t> dimsX = {3, 2};
646
    std::vector<float> valuesX = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};
647

648
    // Prepare expected inputs and outputs
649
    std::vector<int64_t> expectedDimsY = {truncateOutput ? 2 : 3, 2};
650
    // now the expected value should be Add's result.
651
    std::vector<float> expectedValuesY = {2.0f, 4.0f, 6.0f, 8.0f, 10.0f, 12.0f};
652
    if (truncateOutput) {
653
      // The leading dimension is truncated, and the second dimension has two elements over that dim
654
      expectedValuesY.resize(expectedValuesY.size() - 2);
655
    }
656

657
    // Load the model and sessions
658
    std::wstring fullPath = FileHelpers::GetModulePath() + (truncateOutput ? L"foo_truncated.onnx" : L"foo.onnx");
659
    LearningModel model = LearningModel::LoadFromFilePath(fullPath, operatorProvider);
660
    LearningModelSession session(model);
661

662
    // Bind input and outputs
663
    LearningModelBinding bindings(session);
664

665
    TensorFloat inputTensor = TensorFloat::CreateFromArray(dimsX, winrt::array_view<const float>(std::move(valuesX)));
666
    bindings.Bind(winrt::hstring(L"X"), inputTensor);
667

668
    auto outputValue = TensorFloat::Create();
669
    WINML_EXPECT_NO_THROW(bindings.Bind(L"Y", outputValue));
670

671
    // Evaluate the model
672
    winrt::hstring correlationId;
673
    WINML_EXPECT_NO_THROW(session.Evaluate(bindings, correlationId));
674

675
    // Verify the result shape
676
    WINML_EXPECT_EQUAL(expectedDimsY.size(), outputValue.Shape().Size());
677
    for (uint32_t j = 0; j < outputValue.Shape().Size(); j++) {
678
      WINML_EXPECT_EQUAL(expectedDimsY.at(j), outputValue.Shape().GetAt(j));
679
    }
680

681
    // Verify the result values
682
    auto buffer = outputValue.GetAsVectorView();
683
    WINML_EXPECT_TRUE(buffer != nullptr);
684
    WINML_EXPECT_TRUE(std::equal(expectedValuesY.cbegin(), expectedValuesY.cend(), begin(buffer)));
685

686
    // Release the model before operatorProvider goes out of scope
687
    model = nullptr;
688

689
    if (shapeInferenceContext) {
690
      // Check that the shape inference context is closed and safely fails
691
      MLOperatorEdgeDescription edgeDesc;
692
      WINML_EXPECT_EQUAL(E_INVALIDARG, shapeInferenceContext->GetInputEdgeDescription(0, &edgeDesc));
693
    }
694
  }
695
}
696

697
const CustomOpsTestsApi& getapi() {
698
  static CustomOpsTestsApi api = {
699
    CustomOpsScenarioTestsClassSetup, CustomOperatorFusion, CustomKernelWithBuiltInSchema, CustomKernelWithCustomSchema
700
  };
701

702
  if (SkipGpuTests()) {
703
    api.CustomOperatorFusion = SkipTest;
704
  }
705
  if (RuntimeParameterExists(L"noVideoFrameTests")) {
706
    api.CustomOperatorFusion = SkipTest;
707
  }
708
  return api;
709
}
710

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.