onnxruntime
709 строк · 27.8 Кб
1// Copyright (c) Microsoft Corporation. All rights reserved.
2// Licensed under the MIT License.
3
4#include "testPch.h"
5#include <wil/result.h>
6#include <D3d11_4.h>
7#include <dxgi1_6.h>
8#include "filehelpers.h"
9#include <fstream>
10#include <MemoryBuffer.h>
11#include "CustomOperatorProvider.h"
12#include "CustomOps.h"
13
14// For custom operator and shape inferencing support
15#include "core/providers/dml/DmlExecutionProvider/inc/MLOperatorAuthor.h"
16#include "core/providers/dml/DmlExecutionProvider/src/ErrorHandling.h"
17#include "core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorHelper.h"
18#include "core/providers/dml/OperatorAuthorHelper/OperatorHelper.h"
19#include "core/providers/dml/OperatorAuthorHelper/OperatorVersions.h"
20#include "core/graph/constants.h"
21#include "CustomNullOp.h"
22#include <wil/wrl.h>
23
24using namespace winml;
25using namespace wfc;
26using namespace wm;
27using namespace wgi;
28using namespace ws;
29using namespace wss;
30
31static void CustomOpsScenarioTestsClassSetup() {
32winrt::init_apartment();
33#ifdef BUILD_INBOX
34winrt_activation_handler = WINRT_RoGetActivationFactory;
35#endif
36}
37
38// Tests that the execution provider correctly fuses operators together when custom ops are involved.
39static void CustomOperatorFusion() {
40constexpr const wchar_t* c_modelFilename = L"squeezenet_tensor_input.onnx";
41
42// This particular model has 25 Conv ops and 25 Relu ops, all of which are eligible for fusion so we expect them
43// all to be fused (removing them from the graph) and replaced with the appropriate fused op instead. The same
44// goes for the single Gemm+Sigmoid in the model too.
45constexpr const uint32_t c_expectedConvOps = 0;
46constexpr const uint32_t c_expectedReluOps = 0;
47constexpr const uint32_t c_expectedFusedConvOps = 25;
48constexpr const uint32_t c_expectedGemmOps = 0;
49constexpr const uint32_t c_expectedSigmoidOps = 0;
50constexpr const uint32_t c_expectedFusedGemmOps = 1;
51
52// These ops are also part of the model but shouldn't be fused
53constexpr const uint32_t c_expectedBatchNormOps = 1;
54constexpr const uint32_t c_expectedMaxPoolOps = 3;
55constexpr const uint32_t c_expectedConcatOps = 8;
56
57struct CallbackOperatorProvider : winrt::implements<
58CallbackOperatorProvider,
59winml::ILearningModelOperatorProvider,
60ILearningModelOperatorProviderNative> {
61struct CallCounts {
62std::atomic<uint32_t> conv = 0;
63std::atomic<uint32_t> relu = 0;
64std::atomic<uint32_t> fusedConv = 0;
65std::atomic<uint32_t> gemm = 0;
66std::atomic<uint32_t> sigmoid = 0;
67std::atomic<uint32_t> fusedGemm = 0;
68std::atomic<uint32_t> batchNorm = 0;
69std::atomic<uint32_t> maxPool = 0;
70std::atomic<uint32_t> concat = 0;
71};
72
73const CallCounts& GetCallCounts() { return m_callCounts; }
74
75CallbackOperatorProvider() {
76using namespace OperatorHelper;
77
78std::wostringstream dll;
79dll << BINARY_NAME;
80auto winml_dll_name = dll.str();
81
82#if WINAPI_FAMILY_PARTITION(WINAPI_PARTITION_DESKTOP)
83auto m_library = LoadLibraryExW(winml_dll_name.c_str(), nullptr, 0);
84#elif WINAPI_FAMILY_PARTITION(WINAPI_PARTITION_PC_APP)
85auto m_library = LoadPackagedLibrary(winml_dll_name.c_str(), 0 /*Reserved*/);
86#endif
87using create_registry_delegate = HRESULT WINAPI(_COM_Outptr_ IMLOperatorRegistry * *registry);
88auto create_registry =
89reinterpret_cast<create_registry_delegate*>(GetProcAddress(m_library, "MLCreateOperatorRegistry"));
90WINML_EXPECT_HRESULT_SUCCEEDED(create_registry(m_registry.put()));
91
92#pragma push_macro("REGISTER_KERNEL")
93#define REGISTER_KERNEL(_name, _domain, _opSet, _shapeInferrer, _callCount) \
94NullOperatorFactory::RegisterKernel( \
95#_name, \
96(_domain), \
97_opSet::sc_sinceVer_##_name, \
98m_registry, \
99winrt::make<NullShapeInferrer<_shapeInferrer>>(), \
100(_callCount) \
101);
102
103REGISTER_KERNEL(Conv, onnxruntime::kOnnxDomain, OnnxOperatorSet7, ConvHelper, &m_callCounts.conv);
104REGISTER_KERNEL(
105Relu, onnxruntime::kOnnxDomain, OnnxOperatorSet7, GetOutputShapeAsInputShapeHelper, &m_callCounts.relu
106);
107REGISTER_KERNEL(DmlFusedConv, onnxruntime::kMSDmlDomain, MsftOperatorSet1, ConvHelper, &m_callCounts.fusedConv);
108
109REGISTER_KERNEL(Gemm, onnxruntime::kOnnxDomain, OnnxOperatorSet7, GemmHelper, &m_callCounts.gemm);
110REGISTER_KERNEL(
111Sigmoid, onnxruntime::kOnnxDomain, OnnxOperatorSet7, GetOutputShapeAsInputShapeHelper, &m_callCounts.sigmoid
112);
113REGISTER_KERNEL(DmlFusedGemm, onnxruntime::kMSDmlDomain, MsftOperatorSet1, GemmHelper, &m_callCounts.fusedGemm);
114
115REGISTER_KERNEL(
116BatchNormalization,
117onnxruntime::kOnnxDomain,
118OnnxOperatorSet7,
119GetOutputShapeAsInputShapeHelper,
120&m_callCounts.batchNorm
121);
122REGISTER_KERNEL(MaxPool, onnxruntime::kOnnxDomain, OnnxOperatorSet7, PoolingHelper, &m_callCounts.maxPool);
123REGISTER_KERNEL(Concat, onnxruntime::kOnnxDomain, OnnxOperatorSet7, ConcatHelper, &m_callCounts.concat);
124
125#pragma pop_macro("REGISTER_KERNEL")
126}
127
128STDMETHOD(GetRegistry)
129(IMLOperatorRegistry** ppOperatorRegistry) {
130if (ppOperatorRegistry == nullptr) {
131return E_POINTER;
132}
133
134m_registry.copy_to(ppOperatorRegistry);
135return S_OK;
136}
137
138private:
139winrt::com_ptr<IMLOperatorRegistry> m_registry;
140CallCounts m_callCounts;
141};
142
143auto customOperatorProvider = winrt::make<CallbackOperatorProvider>();
144auto provider = customOperatorProvider.as<ILearningModelOperatorProvider>();
145
146LearningModelDevice device = nullptr;
147WINML_EXPECT_NO_THROW(device = LearningModelDevice(LearningModelDeviceKind::DirectX));
148std::wstring fullPath = FileHelpers::GetModulePath() + c_modelFilename;
149auto model = LearningModel::LoadFromFilePath(fullPath, provider);
150
151auto featureValue = FileHelpers::LoadImageFeatureValue(L"227x227.png");
152
153LearningModelSession session = nullptr;
154WINML_EXPECT_NO_THROW(session = LearningModelSession(model, device));
155LearningModelBinding modelBinding(session);
156
157modelBinding.Bind(L"data", featureValue);
158auto result = session.Evaluate(modelBinding, L"");
159
160const auto& callCounts = customOperatorProvider.as<CallbackOperatorProvider>()->GetCallCounts();
161
162// Verify that the correct number of each operator was seen (i.e. that none were dropped / incorrectly fused)
163WINML_EXPECT_EQUAL(c_expectedConvOps, callCounts.conv);
164WINML_EXPECT_EQUAL(c_expectedReluOps, callCounts.relu);
165WINML_EXPECT_EQUAL(c_expectedFusedConvOps, callCounts.fusedConv);
166WINML_EXPECT_EQUAL(c_expectedGemmOps, callCounts.gemm);
167WINML_EXPECT_EQUAL(c_expectedSigmoidOps, callCounts.sigmoid);
168WINML_EXPECT_EQUAL(c_expectedFusedGemmOps, callCounts.fusedGemm);
169WINML_EXPECT_EQUAL(c_expectedBatchNormOps, callCounts.batchNorm);
170WINML_EXPECT_EQUAL(c_expectedMaxPoolOps, callCounts.maxPool);
171WINML_EXPECT_EQUAL(c_expectedConcatOps, callCounts.concat);
172}
173
174struct LocalCustomOperatorProvider : winrt::implements<
175LocalCustomOperatorProvider,
176winml::ILearningModelOperatorProvider,
177ILearningModelOperatorProviderNative> {
178LocalCustomOperatorProvider() {
179std::wostringstream dll;
180dll << BINARY_NAME;
181auto winml_dll_name = dll.str();
182
183#if WINAPI_FAMILY_PARTITION(WINAPI_PARTITION_DESKTOP)
184auto m_library = LoadLibraryExW(winml_dll_name.c_str(), nullptr, 0);
185#elif WINAPI_FAMILY_PARTITION(WINAPI_PARTITION_PC_APP)
186auto m_library = LoadPackagedLibrary(winml_dll_name.c_str(), 0 /*Reserved*/);
187#endif
188using create_registry_delegate = HRESULT WINAPI(_COM_Outptr_ IMLOperatorRegistry * *registry);
189auto create_registry =
190reinterpret_cast<create_registry_delegate*>(GetProcAddress(m_library, "MLCreateOperatorRegistry"));
191WINML_EXPECT_HRESULT_SUCCEEDED(create_registry(m_registry.put()));
192}
193
194STDMETHOD(GetRegistry)
195(IMLOperatorRegistry** ppOperatorRegistry) {
196if (ppOperatorRegistry == nullptr) {
197return E_POINTER;
198}
199
200m_registry.copy_to(ppOperatorRegistry);
201return S_OK;
202}
203
204IMLOperatorRegistry* GetRegistry() { return m_registry.get(); }
205
206protected:
207winrt::com_ptr<IMLOperatorRegistry> m_registry;
208};
209
210// Checks test attributes set on ABI kernels can be queried with correct values
211void VerifyTestAttributes(const MLOperatorAttributes& attrs) {
212std::string strAttr = attrs.GetAttribute("DefaultedNonRequiredString");
213WINML_EXPECT_EQUAL(strAttr, "1");
214
215std::vector<std::string> strArrayAttr = attrs.GetAttributeVector("DefaultedNonRequiredStringArray");
216std::vector<std::string> expected = std::vector<std::string>({"1", "2"});
217for (size_t i = 0; i < expected.size(); ++i) {
218WINML_EXPECT_EQUAL(strArrayAttr[i], expected[i]);
219}
220
221WINML_EXPECT_EQUAL(1, attrs.GetAttribute<int64_t>("DefaultedNonRequiredInt"));
222WINML_EXPECT_EQUAL(1.0f, attrs.GetAttribute<float>("DefaultedNonRequiredFloat"));
223
224WINML_EXPECT_EQUAL(std::vector<int64_t>({1, 2}), attrs.GetAttributeVector<int64_t>("DefaultedNonRequiredIntArray"));
225WINML_EXPECT_EQUAL(
226std::vector<float>({1.0f, 2.0f}), attrs.GetAttributeVector<float>("DefaultedNonRequiredFloatArray")
227);
228}
229
230// Foo kernel which is doing Add and optionally truncates its output
231template <typename T, bool VerifyAttributes = false, bool Truncate = false>
232class FooKernel {
233public:
234FooKernel(const MLOperatorKernelCreationContext& info) {
235if (VerifyAttributes) {
236VerifyTestAttributes(info);
237}
238
239VerifyShapeInfo(info);
240}
241
242void VerifyShapeInfo(const MLOperatorKernelCreationContext& info) {
243if (!Truncate) {
244winrt::com_ptr<IMLOperatorTensorShapeDescription> shapeInfo;
245WINML_EXPECT_EQUAL(info.GetInterface()->HasTensorShapeDescription(), false);
246WINML_EXPECT_HRESULT_FAILED(info.GetInterface()->GetTensorShapeDescription(shapeInfo.put()));
247} else {
248winrt::com_ptr<IMLOperatorTensorShapeDescription> shapeInfo;
249WINML_EXPECT_EQUAL(info.GetInterface()->HasTensorShapeDescription(), true);
250WINML_EXPECT_EQUAL(info.GetInterface()->GetTensorShapeDescription(shapeInfo.put()), S_OK);
251}
252}
253
254void Compute(const MLOperatorKernelContext& context) const {
255const auto X = context.GetInputTensor(0);
256const auto W = context.GetInputTensor(1);
257
258auto xData = X.GetData<T>();
259auto wData = W.GetData<T>();
260
261auto shape = X.GetShape();
262
263// This is used to test shape inference
264if (Truncate) {
265shape[0] -= 1;
266}
267
268if (!Truncate) {
269winrt::com_ptr<IMLOperatorTensor> tensor;
270WINML_EXPECT_HRESULT_FAILED(context.GetInterface()->GetOutputTensor(0, tensor.put()));
271} else {
272MLOperatorTensor tensor = context.GetOutputTensor(0);
273}
274
275auto Y = context.GetOutputTensor(0, shape);
276auto yData = Y.GetData<T>();
277
278size_t size = 1;
279for (size_t i = 0; i < shape.size(); i++) {
280size *= shape[i];
281}
282
283for (size_t i = 0; i < size; i++) {
284yData[i] = xData[i] + wData[i];
285}
286}
287};
288
289template <bool VerifyTestAttributes = false>
290void CALLBACK CreateABIFooKernel(IMLOperatorKernelCreationContext* kernelInfo, IMLOperatorKernel** opKernel) {
291HRESULT hr = MLOperatorKernel<FooKernel<float, VerifyTestAttributes>>::CreateInstance(*kernelInfo, opKernel);
292THROW_IF_FAILED(hr);
293}
294
295void CALLBACK CreateTruncatedABIFooKernel(IMLOperatorKernelCreationContext* kernelInfo, IMLOperatorKernel** opKernel) {
296HRESULT hr = MLOperatorKernel<FooKernel<float, true, true>>::CreateInstance(*kernelInfo, opKernel);
297THROW_IF_FAILED(hr);
298}
299
300// Test using a foo kernel which is doing Add, but register it as "Mul".
301static void CustomKernelWithBuiltInSchema() {
302// Create the registry
303auto operatorProvider = winrt::make<LocalCustomOperatorProvider>();
304IMLOperatorRegistry* registry = operatorProvider.as<LocalCustomOperatorProvider>()->GetRegistry();
305
306// Register the kernel
307MLOperatorEdgeDescription floatTensorType = {
308MLOperatorEdgeType::Tensor, static_cast<uint64_t>(MLOperatorTensorDataType::Float)
309};
310
311MLOperatorEdgeTypeConstrant constraint = {"T", &floatTensorType, 1};
312
313MLOperatorKernelDescription kernelDesc = {
314"",
315"Mul",
3167,
317MLOperatorExecutionType::Cpu,
318&constraint,
3191,
320nullptr,
3210,
322MLOperatorKernelOptions::AllowDynamicInputShapes
323};
324
325Microsoft::WRL::ComPtr<MLOperatorKernelFactory> factory =
326wil::MakeOrThrow<MLOperatorKernelFactory>(CreateABIFooKernel<false>);
327WINML_EXPECT_HRESULT_SUCCEEDED(registry->RegisterOperatorKernel(&kernelDesc, factory.Get(), nullptr));
328
329// Prepare inputs
330std::vector<int64_t> dimsX = {3, 2};
331std::vector<float> valuesX = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};
332
333// Prepare expected inputs and outputs
334std::vector<int64_t> expectedDimsY = {3, 2};
335
336// The expected value should be Add's result.
337std::vector<float> expectedValuesY = {2.0f, 4.0f, 6.0f, 8.0f, 10.0f, 12.0f};
338
339// Create the model and sessions
340std::wstring fullPath = FileHelpers::GetModulePath() + L"mul.onnx";
341LearningModel model = LearningModel::LoadFromFilePath(fullPath, operatorProvider);
342
343LearningModelSession session(model);
344LearningModelBinding bindings(session);
345
346// Bind inputs and outputs
347TensorFloat inputTensor = TensorFloat::CreateFromArray(dimsX, winrt::array_view<const float>(std::move(valuesX)));
348bindings.Bind(winrt::hstring(L"X"), inputTensor);
349
350auto outputValue = TensorFloat::Create();
351WINML_EXPECT_NO_THROW(bindings.Bind(L"Y", outputValue));
352
353// Evaluate the model
354winrt::hstring correlationId;
355WINML_EXPECT_NO_THROW(session.Evaluate(bindings, correlationId));
356
357// Check the result shape
358WINML_EXPECT_EQUAL(expectedDimsY.size(), outputValue.Shape().Size());
359for (uint32_t j = 0; j < outputValue.Shape().Size(); j++) {
360WINML_EXPECT_EQUAL(expectedDimsY.at(j), outputValue.Shape().GetAt(j));
361}
362
363// Check the results
364auto buffer = outputValue.GetAsVectorView();
365WINML_EXPECT_TRUE(buffer != nullptr);
366WINML_EXPECT_TRUE(std::equal(expectedValuesY.cbegin(), expectedValuesY.cend(), begin(buffer)));
367
368// Release the model before operatorProvider goes out of scope
369model = nullptr;
370}
371
372// Similar to MLOperatorShapeInferrer, but using an std::function
373class MLOperatorShapeInferrerFromFunc
374: public Microsoft::WRL::
375RuntimeClass<Microsoft::WRL::RuntimeClassFlags<Microsoft::WRL::ClassicCom>, IMLOperatorShapeInferrer> {
376public:
377MLOperatorShapeInferrerFromFunc(std::function<void(IMLOperatorShapeInferenceContext*)> shapeInferenceFn)
378: m_func(shapeInferenceFn) {}
379
380HRESULT STDMETHODCALLTYPE InferOutputShapes(IMLOperatorShapeInferenceContext* context) noexcept override try {
381m_func(context);
382return S_OK;
383}
384CATCH_RETURN();
385
386private:
387std::function<void(IMLOperatorShapeInferenceContext*)> m_func;
388};
389
390// Test using a custom kernel and schema, while verifying attribute defaults, type mapping, and inference methods
391static void CustomKernelWithCustomSchema() {
392// Test cases
393struct {
394// Whether the Foo kernel should truncate its output
395bool truncateOutput;
396
397// Whether a type label is used in the schema, versus a type description
398bool useTypeLabel;
399
400// Whether the schema provides a type inference function, and uses an output type
401// of Int32 instead of Float32
402bool useTypeInference;
403
404// Whether a shape inference method is provided in the schema
405bool useShapeInferenceInSchema;
406
407// Whether a shape inference method is provided in the kernel
408bool useShapeInferenceInKernel;
409
410// Whether attribute defaults are provided in the schema, instead of the kernel
411bool attributeDefaultsInSchema;
412} testCases[] = {
413{false, true, false, false, false, false},
414{false, false, false, false, false, false},
415{false, true, true, false, false, true},
416{ true, false, false, false, true, false},
417{ true, true, true, true, true, true},
418};
419
420for (size_t caseIndex = 0; caseIndex < std::size(testCases); ++caseIndex) {
421// Create the registry
422auto operatorProvider = winrt::make<LocalCustomOperatorProvider>();
423IMLOperatorRegistry* registry = operatorProvider.as<LocalCustomOperatorProvider>()->GetRegistry();
424
425// Create input and output parameters
426MLOperatorSchemaEdgeDescription inputParam = {};
427inputParam.options = MLOperatorParameterOptions::Single;
428
429if (!testCases[caseIndex].useTypeLabel) {
430assert(!testCases[caseIndex].useTypeInference);
431
432MLOperatorEdgeDescription edgeDesc = {};
433edgeDesc.edgeType = MLOperatorEdgeType::Tensor;
434edgeDesc.tensorDataType = MLOperatorTensorDataType::Float;
435
436inputParam.typeFormat = MLOperatorSchemaEdgeTypeFormat::EdgeDescription;
437inputParam.edgeDescription = edgeDesc;
438} else {
439inputParam.typeFormat = MLOperatorSchemaEdgeTypeFormat::Label;
440inputParam.typeLabel = "T1";
441}
442
443MLOperatorSchemaEdgeDescription outputParam = inputParam;
444
445// Type inference should set this to tensor(float) even though T2 is not matched
446// on an input label
447if (testCases[caseIndex].useTypeInference) {
448if (inputParam.typeFormat == MLOperatorSchemaEdgeTypeFormat::Label) {
449outputParam.typeLabel = "T2";
450} else {
451outputParam.edgeDescription.tensorDataType = MLOperatorTensorDataType::Int32;
452}
453}
454
455MLOperatorSchemaEdgeDescription inputs[] = {inputParam, inputParam};
456
457MLOperatorEdgeDescription edgeTypes[6] = {
458{MLOperatorEdgeType::Tensor, static_cast<uint64_t>(MLOperatorTensorDataType::UInt32)},
459{MLOperatorEdgeType::Tensor, static_cast<uint64_t>(MLOperatorTensorDataType::UInt64)},
460{MLOperatorEdgeType::Tensor, static_cast<uint64_t>(MLOperatorTensorDataType::Int32)},
461{MLOperatorEdgeType::Tensor, static_cast<uint64_t>(MLOperatorTensorDataType::Int64)},
462{MLOperatorEdgeType::Tensor, static_cast<uint64_t>(MLOperatorTensorDataType::Float)},
463{MLOperatorEdgeType::Tensor, static_cast<uint64_t>(MLOperatorTensorDataType::Double)}
464};
465
466// Type constraints. Only the first is used unless type inference is provided and
467// the kernel emits a different output type as "T2"
468MLOperatorEdgeTypeConstrant constraints[] = {
469{"T1", edgeTypes, static_cast<uint32_t>(std::size(edgeTypes))},
470{"T2", edgeTypes, static_cast<uint32_t>(std::size(edgeTypes))}
471};
472
473// Test attributes
474MLOperatorAttribute attributes[] = {
475{ "DefaultedNonRequiredInt", MLOperatorAttributeType::Int, false},
476{ "DefaultedNonRequiredFloat", MLOperatorAttributeType::Float, false},
477{ "DefaultedNonRequiredString", MLOperatorAttributeType::String, false},
478{ "DefaultedNonRequiredIntArray", MLOperatorAttributeType::IntArray, false},
479{ "DefaultedNonRequiredFloatArray", MLOperatorAttributeType::FloatArray, false},
480{ "DefaultedNonRequiredStringArray", MLOperatorAttributeType::StringArray, false},
481
482{"NonDefaultedNonRequiredStringArray", MLOperatorAttributeType::StringArray, false},
483};
484
485// Defaults. These are queried back during kernel creation, type and shape inference
486// and tested against the same values
487MLOperatorAttributeNameValue defaultAttributes[] = {
488{ "DefaultedNonRequiredInt", MLOperatorAttributeType::Int, 1},
489{ "DefaultedNonRequiredFloat", MLOperatorAttributeType::Float, 1},
490{ "DefaultedNonRequiredString", MLOperatorAttributeType::String, 1},
491{ "DefaultedNonRequiredIntArray", MLOperatorAttributeType::IntArray, 2},
492{ "DefaultedNonRequiredFloatArray", MLOperatorAttributeType::FloatArray, 2},
493{"DefaultedNonRequiredStringArray", MLOperatorAttributeType::StringArray, 2},
494};
495
496int64_t defaultInts[] = {1, 2};
497float defaultFloats[] = {1.0f, 2.0f};
498const char* defaultStrings[] = {"1", "2"};
499defaultAttributes[0].ints = defaultInts;
500defaultAttributes[1].floats = defaultFloats;
501defaultAttributes[2].strings = defaultStrings;
502defaultAttributes[3].ints = defaultInts;
503defaultAttributes[4].floats = defaultFloats;
504defaultAttributes[5].strings = defaultStrings;
505
506// Schema definition
507MLOperatorSchemaDescription schemaDesc = {};
508schemaDesc.name = "Foo";
509schemaDesc.operatorSetVersionAtLastChange = 7;
510schemaDesc.inputs = inputs;
511schemaDesc.inputCount = 2;
512schemaDesc.outputs = &outputParam;
513schemaDesc.outputCount = 1;
514schemaDesc.typeConstraints = constraints;
515schemaDesc.typeConstraintCount = testCases[caseIndex].useTypeLabel ? 2 : 0;
516schemaDesc.attributes = attributes;
517schemaDesc.attributeCount = static_cast<uint32_t>(std::size(attributes));
518
519if (testCases[caseIndex].attributeDefaultsInSchema) {
520schemaDesc.defaultAttributes = defaultAttributes;
521schemaDesc.defaultAttributeCount = static_cast<uint32_t>(std::size(defaultAttributes));
522}
523
524Microsoft::WRL::ComPtr<MLOperatorTypeInferrer> typeInferrer;
525Microsoft::WRL::ComPtr<MLOperatorShapeInferrerFromFunc> shapeInferrer;
526
527// Type inference function
528if (testCases[caseIndex].useTypeInference) {
529typeInferrer = wil::MakeOrThrow<MLOperatorTypeInferrer>([](IMLOperatorTypeInferenceContext* ctx) -> void {
530VerifyTestAttributes(MLOperatorTypeInferenceContext(ctx));
531
532MLOperatorEdgeDescription edgeDesc = {};
533edgeDesc.edgeType = MLOperatorEdgeType::Tensor;
534edgeDesc.tensorDataType = MLOperatorTensorDataType::Float;
535
536MLOperatorTypeInferenceContext(ctx).SetOutputEdgeDescription(0, &edgeDesc);
537});
538}
539
540// Store the shape inference context with a reference following the call to InferOutputShapes.
541// This will be called after loading the model as an isolated test for how ABI context objects
542// are "closed."
543Microsoft::WRL::ComPtr<IMLOperatorShapeInferenceContext> shapeInferenceContext;
544
545// Shape inference is tested by truncating the output size
546bool truncateOutput = testCases[caseIndex].truncateOutput;
547if (truncateOutput) {
548shapeInferrer = wil::MakeOrThrow<MLOperatorShapeInferrerFromFunc>(
549[&shapeInferenceContext](IMLOperatorShapeInferenceContext* ctx) -> void {
550VerifyTestAttributes(MLShapeInferenceContext(ctx));
551MLShapeInferenceContext(ctx).SetOutputTensorShape(0, {2, 2});
552shapeInferenceContext = ctx;
553}
554);
555}
556
557// Register the schema
558MLOperatorSetId opsetId = {"", 7};
559MLOperatorSchemaDescription* opSchemaDescs = &schemaDesc;
560WINML_EXPECT_EQUAL(
561S_OK,
562registry->RegisterOperatorSetSchema(
563&opsetId,
5641,
565&opSchemaDescs,
5661,
567typeInferrer.Get(),
568testCases[caseIndex].useShapeInferenceInSchema ? shapeInferrer.Get() : nullptr
569)
570);
571
572{
573// Register a future version of the schema in the same domain, while setting its
574// input count to zero to ensure it is not being used.
575auto futureSchemaDesc = schemaDesc;
576futureSchemaDesc.inputCount = 0;
577
578MLOperatorSetId id = {"", 9};
579MLOperatorSchemaDescription* schemaDescs = &futureSchemaDesc;
580WINML_EXPECT_EQUAL(
581S_OK,
582registry->RegisterOperatorSetSchema(
583&id,
5847,
585&schemaDescs,
5861,
587typeInferrer.Get(),
588testCases[caseIndex].useShapeInferenceInSchema ? shapeInferrer.Get() : nullptr
589)
590);
591}
592{
593// Register in another (unused) domain to the custom registry
594auto otherSchemaDesc = schemaDesc;
595otherSchemaDesc.inputCount = 0;
596
597MLOperatorSetId id = {"otherDomain", 7};
598MLOperatorSchemaDescription* schemaDescs = &otherSchemaDesc;
599WINML_EXPECT_EQUAL(
600S_OK,
601registry->RegisterOperatorSetSchema(
602&id,
6031,
604&schemaDescs,
6051,
606typeInferrer.Get(),
607testCases[caseIndex].useShapeInferenceInSchema ? shapeInferrer.Get() : nullptr
608)
609);
610}
611// Register the Foo kernel
612MLOperatorEdgeDescription floatTensorEdgeDesc = {};
613floatTensorEdgeDesc.edgeType = MLOperatorEdgeType::Tensor;
614floatTensorEdgeDesc.tensorDataType = MLOperatorTensorDataType::Float;
615
616MLOperatorEdgeTypeConstrant kernelConstraint = {"T1", &floatTensorEdgeDesc, 1};
617
618MLOperatorKernelDescription kernelDesc = {
619"", "Foo", 7, MLOperatorExecutionType::Cpu, &kernelConstraint, testCases[caseIndex].useTypeLabel ? 1u : 0u
620};
621
622if (!testCases[caseIndex].attributeDefaultsInSchema) {
623kernelDesc.defaultAttributes = defaultAttributes;
624kernelDesc.defaultAttributeCount = static_cast<uint32_t>(std::size(defaultAttributes));
625}
626
627if (!truncateOutput) {
628kernelDesc.options = MLOperatorKernelOptions::AllowDynamicInputShapes;
629Microsoft::WRL::ComPtr<MLOperatorKernelFactory> factory =
630wil::MakeOrThrow<MLOperatorKernelFactory>(CreateABIFooKernel<true>);
631
632WINML_EXPECT_EQUAL(S_OK, registry->RegisterOperatorKernel(&kernelDesc, factory.Get(), nullptr));
633} else {
634Microsoft::WRL::ComPtr<MLOperatorKernelFactory> factory =
635wil::MakeOrThrow<MLOperatorKernelFactory>(CreateTruncatedABIFooKernel);
636WINML_EXPECT_EQUAL(
637S_OK,
638registry->RegisterOperatorKernel(
639&kernelDesc, factory.Get(), testCases[caseIndex].useShapeInferenceInKernel ? shapeInferrer.Get() : nullptr
640)
641);
642}
643
644// Prepare inputs
645std::vector<int64_t> dimsX = {3, 2};
646std::vector<float> valuesX = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};
647
648// Prepare expected inputs and outputs
649std::vector<int64_t> expectedDimsY = {truncateOutput ? 2 : 3, 2};
650// now the expected value should be Add's result.
651std::vector<float> expectedValuesY = {2.0f, 4.0f, 6.0f, 8.0f, 10.0f, 12.0f};
652if (truncateOutput) {
653// The leading dimension is truncated, and the second dimension has two elements over that dim
654expectedValuesY.resize(expectedValuesY.size() - 2);
655}
656
657// Load the model and sessions
658std::wstring fullPath = FileHelpers::GetModulePath() + (truncateOutput ? L"foo_truncated.onnx" : L"foo.onnx");
659LearningModel model = LearningModel::LoadFromFilePath(fullPath, operatorProvider);
660LearningModelSession session(model);
661
662// Bind input and outputs
663LearningModelBinding bindings(session);
664
665TensorFloat inputTensor = TensorFloat::CreateFromArray(dimsX, winrt::array_view<const float>(std::move(valuesX)));
666bindings.Bind(winrt::hstring(L"X"), inputTensor);
667
668auto outputValue = TensorFloat::Create();
669WINML_EXPECT_NO_THROW(bindings.Bind(L"Y", outputValue));
670
671// Evaluate the model
672winrt::hstring correlationId;
673WINML_EXPECT_NO_THROW(session.Evaluate(bindings, correlationId));
674
675// Verify the result shape
676WINML_EXPECT_EQUAL(expectedDimsY.size(), outputValue.Shape().Size());
677for (uint32_t j = 0; j < outputValue.Shape().Size(); j++) {
678WINML_EXPECT_EQUAL(expectedDimsY.at(j), outputValue.Shape().GetAt(j));
679}
680
681// Verify the result values
682auto buffer = outputValue.GetAsVectorView();
683WINML_EXPECT_TRUE(buffer != nullptr);
684WINML_EXPECT_TRUE(std::equal(expectedValuesY.cbegin(), expectedValuesY.cend(), begin(buffer)));
685
686// Release the model before operatorProvider goes out of scope
687model = nullptr;
688
689if (shapeInferenceContext) {
690// Check that the shape inference context is closed and safely fails
691MLOperatorEdgeDescription edgeDesc;
692WINML_EXPECT_EQUAL(E_INVALIDARG, shapeInferenceContext->GetInputEdgeDescription(0, &edgeDesc));
693}
694}
695}
696
697const CustomOpsTestsApi& getapi() {
698static CustomOpsTestsApi api = {
699CustomOpsScenarioTestsClassSetup, CustomOperatorFusion, CustomKernelWithBuiltInSchema, CustomKernelWithCustomSchema
700};
701
702if (SkipGpuTests()) {
703api.CustomOperatorFusion = SkipTest;
704}
705if (RuntimeParameterExists(L"noVideoFrameTests")) {
706api.CustomOperatorFusion = SkipTest;
707}
708return api;
709}
710