onnxruntime
182 строки · 5.9 Кб
1// Copyright (c) Microsoft Corporation. All rights reserved.
2// Licensed under the MIT License.
3
4#include "testPch.h"
5#include "RawApiTestsGpu.h"
6#include "RawApiHelpers.h"
7
8#include <d3d11.h>
9#include <windows.graphics.directx.direct3d11.interop.h>
10#include <dxgi.h>
11#include <dxgi1_6.h>
12#include <d3d11on12.h>
13#include <d3d11_3.h>
14
15namespace ml = Microsoft::AI::MachineLearning;
16
17enum class DeviceType {
18CPU,
19DirectX,
20D3D11Device,
21D3D12CommandQueue,
22DirectXHighPerformance,
23DirectXMinPower,
24Last
25};
26
27ml::learning_model_device CreateDevice(DeviceType deviceType) {
28switch (deviceType) {
29case DeviceType::CPU:
30return ml::learning_model_device();
31case DeviceType::DirectX:
32return ml::gpu::directx_device(ml::gpu::directx_device_kind::directx);
33case DeviceType::DirectXHighPerformance:
34return ml::gpu::directx_device(ml::gpu::directx_device_kind::directx_high_power);
35case DeviceType::DirectXMinPower:
36return ml::gpu::directx_device(ml::gpu::directx_device_kind::directx_min_power);
37case DeviceType::D3D11Device: {
38Microsoft::WRL::ComPtr<ID3D11Device> d3d11Device;
39Microsoft::WRL::ComPtr<ID3D11DeviceContext> d3d11DeviceContext;
40D3D_FEATURE_LEVEL d3dFeatureLevel;
41auto result = D3D11CreateDevice(
42nullptr,
43D3D_DRIVER_TYPE::D3D_DRIVER_TYPE_HARDWARE,
44nullptr,
450,
46nullptr,
470,
48D3D11_SDK_VERSION,
49d3d11Device.GetAddressOf(),
50&d3dFeatureLevel,
51d3d11DeviceContext.GetAddressOf()
52);
53if (FAILED(result)) {
54printf("Failed to create d3d11 device");
55exit(3);
56}
57
58Microsoft::WRL::ComPtr<IDXGIDevice> dxgiDevice;
59d3d11Device.Get()->QueryInterface<IDXGIDevice>(dxgiDevice.GetAddressOf());
60
61Microsoft::WRL::ComPtr<IInspectable> inspectable;
62CreateDirect3D11DeviceFromDXGIDevice(dxgiDevice.Get(), inspectable.GetAddressOf());
63
64Microsoft::WRL::ComPtr<ABI::Windows::Graphics::DirectX::Direct3D11::IDirect3DDevice> direct3dDevice;
65inspectable.As(&direct3dDevice);
66
67return ml::gpu::directx_device(direct3dDevice.Get());
68}
69case DeviceType::D3D12CommandQueue: {
70Microsoft::WRL::ComPtr<ID3D12Device> d3d12Device;
71auto result = D3D12CreateDevice(
72nullptr,
73D3D_FEATURE_LEVEL::D3D_FEATURE_LEVEL_12_0,
74__uuidof(ID3D12Device),
75reinterpret_cast<void**>(d3d12Device.GetAddressOf())
76);
77if (FAILED(result)) {
78printf("Failed to create d3d12 device");
79exit(3);
80}
81Microsoft::WRL::ComPtr<ID3D12CommandQueue> queue;
82D3D12_COMMAND_QUEUE_DESC commandQueueDesc = {};
83commandQueueDesc.Type = D3D12_COMMAND_LIST_TYPE_DIRECT;
84d3d12Device->CreateCommandQueue(
85&commandQueueDesc, __uuidof(ID3D12CommandQueue), reinterpret_cast<void**>(queue.GetAddressOf())
86);
87
88return ml::gpu::directx_device(queue.Get());
89}
90default:
91return ml::learning_model_device();
92}
93}
94
95static void RawApiTestsGpuApiTestsClassSetup() {
96WINML_EXPECT_HRESULT_SUCCEEDED(RoInitialize(RO_INIT_TYPE::RO_INIT_MULTITHREADED));
97}
98
99static void CreateDirectXDevice() {
100WINML_EXPECT_NO_THROW(CreateDevice(DeviceType::DirectX));
101}
102
103static void CreateD3D11DeviceDevice() {
104WINML_EXPECT_NO_THROW(CreateDevice(DeviceType::D3D11Device));
105}
106
107static void CreateD3D12CommandQueueDevice() {
108WINML_EXPECT_NO_THROW(CreateDevice(DeviceType::D3D12CommandQueue));
109}
110
111static void CreateDirectXHighPerformanceDevice() {
112WINML_EXPECT_NO_THROW(CreateDevice(DeviceType::DirectXHighPerformance));
113}
114
115static void CreateDirectXMinPowerDevice() {
116WINML_EXPECT_NO_THROW(CreateDevice(DeviceType::DirectXMinPower));
117}
118
119static void Evaluate() {
120std::wstring model_path = L"model.onnx";
121std::unique_ptr<ml::learning_model> model = nullptr;
122WINML_EXPECT_NO_THROW(model = std::make_unique<ml::learning_model>(model_path.c_str(), model_path.size()));
123
124std::unique_ptr<ml::learning_model_device> device = nullptr;
125WINML_EXPECT_NO_THROW(device = std::make_unique<ml::learning_model_device>(CreateDevice(DeviceType::DirectX)));
126
127RunOnDevice(*model.get(), *device.get(), InputStrategy::CopyInputs);
128
129WINML_EXPECT_NO_THROW(model.reset());
130}
131
132static void EvaluateNoInputCopy() {
133std::wstring model_path = L"model.onnx";
134std::unique_ptr<ml::learning_model> model = nullptr;
135WINML_EXPECT_NO_THROW(model = std::make_unique<ml::learning_model>(model_path.c_str(), model_path.size()));
136
137std::unique_ptr<ml::learning_model_device> device = nullptr;
138WINML_EXPECT_NO_THROW(device = std::make_unique<ml::learning_model_device>(CreateDevice(DeviceType::DirectX)));
139
140RunOnDevice(*model.get(), *device.get(), InputStrategy::BindAsReference);
141
142WINML_EXPECT_NO_THROW(model.reset());
143}
144
145static void EvaluateManyBuffers() {
146std::wstring model_path = L"model.onnx";
147std::unique_ptr<ml::learning_model> model = nullptr;
148WINML_EXPECT_NO_THROW(model = std::make_unique<ml::learning_model>(model_path.c_str(), model_path.size()));
149
150std::unique_ptr<ml::learning_model_device> device = nullptr;
151WINML_EXPECT_NO_THROW(device = std::make_unique<ml::learning_model_device>(CreateDevice(DeviceType::DirectX)));
152
153RunOnDevice(*model.get(), *device.get(), InputStrategy::BindWithMultipleReferences);
154
155WINML_EXPECT_NO_THROW(model.reset());
156}
157
158const RawApiTestsGpuApi& getapi() {
159static RawApiTestsGpuApi api = {
160RawApiTestsGpuApiTestsClassSetup,
161CreateDirectXDevice,
162CreateD3D11DeviceDevice,
163CreateD3D12CommandQueueDevice,
164CreateDirectXHighPerformanceDevice,
165CreateDirectXMinPowerDevice,
166Evaluate,
167EvaluateNoInputCopy,
168EvaluateManyBuffers
169};
170
171if (SkipGpuTests()) {
172api.CreateDirectXDevice = SkipTest;
173api.CreateD3D11DeviceDevice = SkipTest;
174api.CreateD3D12CommandQueueDevice = SkipTest;
175api.CreateDirectXHighPerformanceDevice = SkipTest;
176api.CreateDirectXMinPowerDevice = SkipTest;
177api.Evaluate = SkipTest;
178api.EvaluateNoInputCopy = SkipTest;
179api.EvaluateManyBuffers = SkipTest;
180}
181return api;
182}
183