onnxruntime

Форк
0
/
RawApiTestsGpu.cpp 
182 строки · 5.9 Кб
1
// Copyright (c) Microsoft Corporation. All rights reserved.
2
// Licensed under the MIT License.
3

4
#include "testPch.h"
5
#include "RawApiTestsGpu.h"
6
#include "RawApiHelpers.h"
7

8
#include <d3d11.h>
9
#include <windows.graphics.directx.direct3d11.interop.h>
10
#include <dxgi.h>
11
#include <dxgi1_6.h>
12
#include <d3d11on12.h>
13
#include <d3d11_3.h>
14

15
namespace ml = Microsoft::AI::MachineLearning;
16

17
enum class DeviceType {
18
  CPU,
19
  DirectX,
20
  D3D11Device,
21
  D3D12CommandQueue,
22
  DirectXHighPerformance,
23
  DirectXMinPower,
24
  Last
25
};
26

27
ml::learning_model_device CreateDevice(DeviceType deviceType) {
28
  switch (deviceType) {
29
    case DeviceType::CPU:
30
      return ml::learning_model_device();
31
    case DeviceType::DirectX:
32
      return ml::gpu::directx_device(ml::gpu::directx_device_kind::directx);
33
    case DeviceType::DirectXHighPerformance:
34
      return ml::gpu::directx_device(ml::gpu::directx_device_kind::directx_high_power);
35
    case DeviceType::DirectXMinPower:
36
      return ml::gpu::directx_device(ml::gpu::directx_device_kind::directx_min_power);
37
    case DeviceType::D3D11Device: {
38
      Microsoft::WRL::ComPtr<ID3D11Device> d3d11Device;
39
      Microsoft::WRL::ComPtr<ID3D11DeviceContext> d3d11DeviceContext;
40
      D3D_FEATURE_LEVEL d3dFeatureLevel;
41
      auto result = D3D11CreateDevice(
42
        nullptr,
43
        D3D_DRIVER_TYPE::D3D_DRIVER_TYPE_HARDWARE,
44
        nullptr,
45
        0,
46
        nullptr,
47
        0,
48
        D3D11_SDK_VERSION,
49
        d3d11Device.GetAddressOf(),
50
        &d3dFeatureLevel,
51
        d3d11DeviceContext.GetAddressOf()
52
      );
53
      if (FAILED(result)) {
54
        printf("Failed to create d3d11 device");
55
        exit(3);
56
      }
57

58
      Microsoft::WRL::ComPtr<IDXGIDevice> dxgiDevice;
59
      d3d11Device.Get()->QueryInterface<IDXGIDevice>(dxgiDevice.GetAddressOf());
60

61
      Microsoft::WRL::ComPtr<IInspectable> inspectable;
62
      CreateDirect3D11DeviceFromDXGIDevice(dxgiDevice.Get(), inspectable.GetAddressOf());
63

64
      Microsoft::WRL::ComPtr<ABI::Windows::Graphics::DirectX::Direct3D11::IDirect3DDevice> direct3dDevice;
65
      inspectable.As(&direct3dDevice);
66

67
      return ml::gpu::directx_device(direct3dDevice.Get());
68
    }
69
    case DeviceType::D3D12CommandQueue: {
70
      Microsoft::WRL::ComPtr<ID3D12Device> d3d12Device;
71
      auto result = D3D12CreateDevice(
72
        nullptr,
73
        D3D_FEATURE_LEVEL::D3D_FEATURE_LEVEL_12_0,
74
        __uuidof(ID3D12Device),
75
        reinterpret_cast<void**>(d3d12Device.GetAddressOf())
76
      );
77
      if (FAILED(result)) {
78
        printf("Failed to create d3d12 device");
79
        exit(3);
80
      }
81
      Microsoft::WRL::ComPtr<ID3D12CommandQueue> queue;
82
      D3D12_COMMAND_QUEUE_DESC commandQueueDesc = {};
83
      commandQueueDesc.Type = D3D12_COMMAND_LIST_TYPE_DIRECT;
84
      d3d12Device->CreateCommandQueue(
85
        &commandQueueDesc, __uuidof(ID3D12CommandQueue), reinterpret_cast<void**>(queue.GetAddressOf())
86
      );
87

88
      return ml::gpu::directx_device(queue.Get());
89
    }
90
    default:
91
      return ml::learning_model_device();
92
  }
93
}
94

95
static void RawApiTestsGpuApiTestsClassSetup() {
96
  WINML_EXPECT_HRESULT_SUCCEEDED(RoInitialize(RO_INIT_TYPE::RO_INIT_MULTITHREADED));
97
}
98

99
static void CreateDirectXDevice() {
100
  WINML_EXPECT_NO_THROW(CreateDevice(DeviceType::DirectX));
101
}
102

103
static void CreateD3D11DeviceDevice() {
104
  WINML_EXPECT_NO_THROW(CreateDevice(DeviceType::D3D11Device));
105
}
106

107
static void CreateD3D12CommandQueueDevice() {
108
  WINML_EXPECT_NO_THROW(CreateDevice(DeviceType::D3D12CommandQueue));
109
}
110

111
static void CreateDirectXHighPerformanceDevice() {
112
  WINML_EXPECT_NO_THROW(CreateDevice(DeviceType::DirectXHighPerformance));
113
}
114

115
static void CreateDirectXMinPowerDevice() {
116
  WINML_EXPECT_NO_THROW(CreateDevice(DeviceType::DirectXMinPower));
117
}
118

119
static void Evaluate() {
120
  std::wstring model_path = L"model.onnx";
121
  std::unique_ptr<ml::learning_model> model = nullptr;
122
  WINML_EXPECT_NO_THROW(model = std::make_unique<ml::learning_model>(model_path.c_str(), model_path.size()));
123

124
  std::unique_ptr<ml::learning_model_device> device = nullptr;
125
  WINML_EXPECT_NO_THROW(device = std::make_unique<ml::learning_model_device>(CreateDevice(DeviceType::DirectX)));
126

127
  RunOnDevice(*model.get(), *device.get(), InputStrategy::CopyInputs);
128

129
  WINML_EXPECT_NO_THROW(model.reset());
130
}
131

132
static void EvaluateNoInputCopy() {
133
  std::wstring model_path = L"model.onnx";
134
  std::unique_ptr<ml::learning_model> model = nullptr;
135
  WINML_EXPECT_NO_THROW(model = std::make_unique<ml::learning_model>(model_path.c_str(), model_path.size()));
136

137
  std::unique_ptr<ml::learning_model_device> device = nullptr;
138
  WINML_EXPECT_NO_THROW(device = std::make_unique<ml::learning_model_device>(CreateDevice(DeviceType::DirectX)));
139

140
  RunOnDevice(*model.get(), *device.get(), InputStrategy::BindAsReference);
141

142
  WINML_EXPECT_NO_THROW(model.reset());
143
}
144

145
static void EvaluateManyBuffers() {
146
  std::wstring model_path = L"model.onnx";
147
  std::unique_ptr<ml::learning_model> model = nullptr;
148
  WINML_EXPECT_NO_THROW(model = std::make_unique<ml::learning_model>(model_path.c_str(), model_path.size()));
149

150
  std::unique_ptr<ml::learning_model_device> device = nullptr;
151
  WINML_EXPECT_NO_THROW(device = std::make_unique<ml::learning_model_device>(CreateDevice(DeviceType::DirectX)));
152

153
  RunOnDevice(*model.get(), *device.get(), InputStrategy::BindWithMultipleReferences);
154

155
  WINML_EXPECT_NO_THROW(model.reset());
156
}
157

158
const RawApiTestsGpuApi& getapi() {
159
  static RawApiTestsGpuApi api = {
160
    RawApiTestsGpuApiTestsClassSetup,
161
    CreateDirectXDevice,
162
    CreateD3D11DeviceDevice,
163
    CreateD3D12CommandQueueDevice,
164
    CreateDirectXHighPerformanceDevice,
165
    CreateDirectXMinPowerDevice,
166
    Evaluate,
167
    EvaluateNoInputCopy,
168
    EvaluateManyBuffers
169
  };
170

171
  if (SkipGpuTests()) {
172
    api.CreateDirectXDevice = SkipTest;
173
    api.CreateD3D11DeviceDevice = SkipTest;
174
    api.CreateD3D12CommandQueueDevice = SkipTest;
175
    api.CreateDirectXHighPerformanceDevice = SkipTest;
176
    api.CreateDirectXMinPowerDevice = SkipTest;
177
    api.Evaluate = SkipTest;
178
    api.EvaluateNoInputCopy = SkipTest;
179
    api.EvaluateManyBuffers = SkipTest;
180
  }
181
  return api;
182
}
183

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.