onnxruntime
726 строк · 28.6 Кб
1// Copyright (c) Microsoft Corporation. All rights reserved.
2// Licensed under the MIT License.
3
4#include "lib/Api.Image/pch.h"
5#include "inc/D3DDeviceCache.h"
6#include <directxmath.h>
7#include <d3d11on12.h>
8#include "inc/DeviceHelpers.h"
9#include "CommonDeviceHelpers.h"
10
11namespace float32 {
12#include "shaders\SurfaceToTensor-SurfaceToTensorBGR8.h"
13#include "shaders\SurfaceToTensor-SurfaceToTensorRGB8.h"
14#include "shaders\SurfaceToTensor-SurfaceToTensorGRAY8.h"
15#include "shaders\SurfaceToTensor-SurfaceGRAY8ToTensorBGR8.h"
16#include "shaders\SurfaceToTensor-SurfaceGRAY8ToTensorGRAY8.h"
17#include "shaders\TensorToSurface-TensorBGR8ToSurface.h"
18#include "shaders\TensorToSurface-TensorRGB8ToSurface.h"
19#include "shaders\TensorToSurface-TensorGRAY8ToSurface.h"
20#include "shaders\TensorToSurface-TensorBGR8ToSurfaceGRAY8.h"
21#include "shaders\TensorToSurface-TensorRGB8ToSurfaceGRAY8.h"
22#include "shaders\TensorToSurface-TensorGRAY8ToSurfaceGRAY8.h"
23} // namespace float32
24
25namespace float16 {
26#include "shaders\SurfaceToTensor16-SurfaceToTensorBGR8.h"
27#include "shaders\SurfaceToTensor16-SurfaceToTensorRGB8.h"
28#include "shaders\SurfaceToTensor16-SurfaceToTensorGRAY8.h"
29#include "shaders\SurfaceToTensor16-SurfaceGRAY8ToTensorBGR8.h"
30#include "shaders\SurfaceToTensor16-SurfaceGRAY8ToTensorGRAY8.h"
31#include "shaders\TensorToSurface16-TensorBGR8ToSurface.h"
32#include "shaders\TensorToSurface16-TensorRGB8ToSurface.h"
33#include "shaders\TensorToSurface16-TensorGRAY8ToSurface.h"
34#include "shaders\TensorToSurface16-TensorBGR8ToSurfaceGRAY8.h"
35#include "shaders\TensorToSurface16-TensorRGB8ToSurfaceGRAY8.h"
36#include "shaders\TensorToSurface16-TensorGRAY8ToSurfaceGRAY8.h"
37} // namespace float16
38
39using namespace Microsoft::WRL;
40
41using namespace _winml;
42
43D3DDeviceCache::D3DDeviceCache(winml::LearningModelDeviceKind const& deviceKind) {
44WINML_THROW_IF_FAILED(CoCreateGuid(&fence_guid_));
45
46if (deviceKind == winml::LearningModelDeviceKind::Cpu || deviceKind == winml::LearningModelDeviceKind::Default) {
47// CPU device don't make any GPU devices
48device_luid_.HighPart = device_luid_.LowPart = 0;
49return;
50}
51
52DXGI_GPU_PREFERENCE preference;
53WINML_THROW_IF_FAILED(GetGPUPreference(deviceKind, &preference));
54
55CommonDeviceHelpers::AdapterEnumerationSupport support;
56WINML_THROW_IF_FAILED(CommonDeviceHelpers::GetAdapterEnumerationSupport(&support));
57
58const char noHardwareAdaptersAvailableErrStr[] = "No hardware adapters available";
59const char failedToObtainHardwareAdaptersErrStr[] = "Failed to obtain hardware adapters.";
60HRESULT hardwareAdapterSuccessfullyObtained = S_OK;
61if (support.has_dxgi) {
62winrt::com_ptr<IDXGIAdapter1> spAdapter;
63hardwareAdapterSuccessfullyObtained = GetDXGIHardwareAdapterWithPreference(preference, spAdapter.put());
64if (hardwareAdapterSuccessfullyObtained == HRESULT_FROM_WIN32(ERROR_NOT_FOUND)) {
65WINML_THROW_HR_MSG_NO_TELEMETRY_SENT(hardwareAdapterSuccessfullyObtained, noHardwareAdaptersAvailableErrStr);
66} else {
67WINML_THROW_IF_FAILED_MSG(hardwareAdapterSuccessfullyObtained, failedToObtainHardwareAdaptersErrStr);
68}
69WINML_THROW_IF_FAILED(D3D12CreateDevice(spAdapter.get(), D3D_FEATURE_LEVEL_11_0, IID_PPV_ARGS(device_.put())));
70}
71#ifdef ENABLE_DXCORE
72if (support.has_dxgi == false) {
73winrt::com_ptr<IDXCoreAdapter> spAdapter;
74hardwareAdapterSuccessfullyObtained = GetDXCoreHardwareAdapterWithPreference(preference, spAdapter.put());
75if (hardwareAdapterSuccessfullyObtained == HRESULT_FROM_WIN32(ERROR_NOT_FOUND)) {
76WINML_THROW_HR_MSG_NO_TELEMETRY_SENT(hardwareAdapterSuccessfullyObtained, noHardwareAdaptersAvailableErrStr);
77} else {
78WINML_THROW_IF_FAILED_MSG(hardwareAdapterSuccessfullyObtained, failedToObtainHardwareAdaptersErrStr);
79}
80WINML_THROW_IF_FAILED(D3D12CreateDevice(spAdapter.get(), D3D_FEATURE_LEVEL_11_0, IID_PPV_ARGS(device_.put())));
81}
82#endif
83InitializeCommandQueue(device_.get());
84
85device_luid_ = device_->GetAdapterLuid();
86}
87
88D3DDeviceCache::D3DDeviceCache(wgdx::Direct3D11::IDirect3DDevice const& device) {
89WINML_THROW_IF_FAILED(CoCreateGuid(&fence_guid_));
90
91// Use the 11 device to initialize 12
92winrt_device_ = device;
93
94// they told us which device to run on, crack the interop wrapper to get the dxgi device
95winrt::com_ptr<::Windows::Graphics::DirectX::Direct3D11::IDirect3DDxgiInterfaceAccess> dxgi;
96dxgi = device.as<::Windows::Graphics::DirectX::Direct3D11::IDirect3DDxgiInterfaceAccess>();
97
98winrt::com_ptr<IDXGIDevice> dxgiDevice;
99WINML_THROW_IF_FAILED(dxgi->GetInterface(IID_PPV_ARGS(dxgiDevice.put())));
100
101device_11_ = dxgiDevice.as<ID3D11Device>();
102
103winrt::com_ptr<ID3D11DeviceContext> spContext;
104device_11_->GetImmediateContext(spContext.put());
105spContext.as(device_context11_);
106
107winrt::com_ptr<IDXGIDevice> pDXGIDevice;
108WINML_THROW_IF_FAILED(dxgi->GetInterface(IID_PPV_ARGS(pDXGIDevice.put())));
109
110winrt::com_ptr<IDXGIAdapter> adapter;
111WINML_THROW_IF_FAILED(pDXGIDevice->GetAdapter(adapter.put()));
112
113WINML_THROW_IF_FAILED(D3D12CreateDevice(adapter.get(), D3D_FEATURE_LEVEL_11_0, IID_PPV_ARGS(device_.put())));
114
115InitializeCommandQueue(device_.get());
116
117device_luid_ = device_->GetAdapterLuid();
118}
119
120D3DDeviceCache::D3DDeviceCache(ID3D12CommandQueue* queue) {
121WINML_THROW_IF_FAILED(CoCreateGuid(&fence_guid_));
122
123// Use the command queue to initialize all of the needed D3D11 interop
124command_queue_.copy_from(queue);
125command_queue_->QueryInterface(IID_PPV_ARGS(sharing_contract_.put()));
126
127WINML_THROW_IF_FAILED(queue->GetDevice(IID_PPV_ARGS(device_.put())));
128
129device_luid_ = device_->GetAdapterLuid();
130}
131
132D3DDeviceCache::~D3DDeviceCache() {
133// If this is a CPU instance device_ will not have been created.
134// Ensure the device is still valid before doing work.
135if (device_ != nullptr && (device_->GetDeviceRemovedReason() == S_OK)) {
136// dx11 stack is optional, and we lazy load it when available
137if (device_context11_ != nullptr) {
138// Sync 11 to 12 then Sync 12 to the CPU. This ensures that all inflight work is done before we delete the d3d objects.
139GPUSyncD3D11ToD3D12();
140}
141SyncD3D12ToCPU();
142}
143}
144
145bool D3DDeviceCache::IsFloat16Supported() {
146if (device_ != nullptr) {
147return CommonDeviceHelpers::IsFloat16Supported(device_.get());
148}
149
150return true;
151}
152
153ID3D11Device* D3DDeviceCache::GetD3D11Device() {
154EnsureD3D11FromD3D12();
155return device_11_.get();
156}
157
158const GUID& D3DDeviceCache::GetFenceGuid() const {
159return fence_guid_;
160}
161
162ID3D11DeviceContext4* D3DDeviceCache::GetD3D11DeviceContext() {
163EnsureD3D11FromD3D12();
164return device_context11_.get();
165}
166
167wgdx::Direct3D11::IDirect3DDevice D3DDeviceCache::GetWinrtDevice() {
168EnsureD3D11FromD3D12();
169return winrt_device_;
170}
171
172void D3DDeviceCache::InitializeCommandQueue(ID3D12Device1* device) {
173D3D12_COMMAND_QUEUE_DESC queueDesc = {};
174queueDesc.Type = D3D12_COMMAND_LIST_TYPE_DIRECT;
175queueDesc.Flags = D3D12_COMMAND_QUEUE_FLAG_DISABLE_GPU_TIMEOUT;
176WINML_THROW_IF_FAILED(
177device->CreateCommandQueue(&queueDesc, winrt::guid_of<ID3D12CommandQueue>(), command_queue_.put_void())
178);
179
180// If possible get the sharing context. If not leave nullptr;
181command_queue_->QueryInterface(IID_PPV_ARGS(sharing_contract_.put()));
182}
183
184// this initializes the following variables, making them from the dx12 device
185// device_11_
186// device_context11_
187// winrt_device_
188void D3DDeviceCache::EnsureD3D11FromD3D12() {
189// do we even have a device? (CPU will use the cache but not have a device) .
190if (device_ == nullptr)
191return;
192
193// are we already initialized
194if (winrt_device_ != nullptr)
195return;
196
197CWinMLAutoLock lock(&lock_);
198
199// check with the lock held, are we already initialized
200if (winrt_device_ != nullptr)
201return;
202
203winrt::com_ptr<::IInspectable> spInspectable;
204winrt::com_ptr<IDXGIDevice> spDXGIDevice;
205
206// call our SEH version (for delay loading)
207WINML_THROW_IF_FAILED(CreateD3D11On12Device(device_.get(), device_11_.put()));
208winrt::com_ptr<ID3D11DeviceContext> spContext;
209device_11_->GetImmediateContext(spContext.put());
210spContext.as(device_context11_);
211
212WINML_THROW_IF_FAILED(device_11_->QueryInterface(IID_PPV_ARGS(spDXGIDevice.put())));
213// Convert to Winrt wrapper. This doesn't actually make a new device.
214WINML_THROW_IF_FAILED(CreateDirect3D11DeviceFromDXGIDevice(spDXGIDevice.get(), spInspectable.put()));
215WINML_THROW_IF_FAILED(spInspectable->QueryInterface(
216winrt::guid_of<wgdx::Direct3D11::IDirect3DDevice>(), reinterpret_cast<void**>(winrt::put_abi(winrt_device_))
217));
218}
219
220void D3DDeviceCache::EnsureD3D12Fence() {
221// are we already initialized?
222if (d3d12_fence_ != nullptr)
223return;
224
225CWinMLAutoLock lock(&lock_);
226
227// with the lock held, are we already initialized?
228if (d3d12_fence_ != nullptr)
229return;
230
231WINML_THROW_IF_FAILED(device_->CreateFence(0, D3D12_FENCE_FLAG_SHARED, IID_PPV_ARGS(d3d12_fence_.put())));
232}
233
234// this initializes the following variables, so that we can share dx12 with dx11
235// d3d11_fence_
236// d3d12_fence_
237void D3DDeviceCache::EnsureSharedFences() {
238// are we already initialized?
239if (d3d11_fence_ != nullptr)
240return;
241
242CWinMLAutoLock lock(&lock_);
243
244// with the lock held, are we already initialized?
245if (d3d11_fence_ != nullptr)
246return;
247
248EnsureD3D12Fence();
249
250// ensure the d11 stack is alive, the 11 stack doesn't exist on WCOSHeadless yet, so be resilient
251EnsureD3D11FromD3D12();
252
253winrt::com_ptr<ID3D12DeviceChild> spD3D12DeviceChild;
254d3d12_fence_.as(spD3D12DeviceChild);
255HANDLE hSharedFence;
256WINML_THROW_IF_FAILED(device_->CreateSharedHandle(spD3D12DeviceChild.get(), NULL, GENERIC_ALL, nullptr, &hSharedFence)
257);
258
259winrt::com_ptr<ID3D11Device5> spD3D11Device5;
260device_11_.as(spD3D11Device5);
261wil::unique_handle safe(hSharedFence);
262WINML_THROW_IF_FAILED(spD3D11Device5->OpenSharedFence(safe.get(), IID_PPV_ARGS(d3d11_fence_.put())));
263}
264
265void D3DDeviceCache::GPUSyncD3D11ToD3D12() {
266EnsureSharedFences();
267
268UINT64 currentFence = fence_value_++;
269WINML_THROW_IF_FAILED(device_context11_->Signal(d3d11_fence_.get(), currentFence));
270
271WINML_THROW_IF_FAILED(command_queue_->Wait(d3d12_fence_.get(), currentFence));
272
273if (sharing_contract_ != nullptr) {
274sharing_contract_->SharedFenceSignal(d3d12_fence_.get(), currentFence);
275}
276}
277
278void D3DDeviceCache::GPUSyncD3D12ToD3D11() {
279EnsureSharedFences();
280
281UINT64 currentFence = fence_value_++;
282WINML_THROW_IF_FAILED(command_queue_->Signal(d3d12_fence_.get(), currentFence));
283
284WINML_THROW_IF_FAILED(device_context11_->Wait(d3d11_fence_.get(), currentFence));
285}
286
287void D3DDeviceCache::SyncD3D12ToCPU() {
288UINT64 currentFence = QueueFenceToD3D12();
289WaitForFenceValue(currentFence);
290}
291
292UINT64 D3DDeviceCache::QueueFenceToD3D12() {
293EnsureD3D12Fence();
294
295UINT64 currentFence = fence_value_++;
296WINML_THROW_IF_FAILED(command_queue_->Signal(d3d12_fence_.get(), currentFence));
297
298return currentFence;
299}
300
301void D3DDeviceCache::WaitForFenceValue(UINT64 fenceValue) {
302EnsureD3D12Fence();
303
304wil::unique_handle event(CreateEvent(nullptr, FALSE, FALSE, nullptr));
305THROW_LAST_ERROR_IF(!event);
306
307WINML_THROW_IF_FAILED(d3d12_fence_->SetEventOnCompletion(fenceValue, event.get()));
308
309DWORD retVal = WaitForSingleObject(event.get(), INFINITE);
310if (retVal != WAIT_OBJECT_0) {
311WINML_THROW_IF_FAILED(E_UNEXPECTED);
312}
313}
314
315ID3D12RootSignature* D3DDeviceCache::GetTensorizeRootSignature() {
316if (tensorize_root_signature_ == nullptr) {
317winrt::com_ptr<ID3D12RootSignature> newRootSignature;
318D3D12_FEATURE_DATA_ROOT_SIGNATURE featureData = {};
319
320// This is the highest version the sample supports. If CheckFeatureSupport succeeds, the HighestVersion returned will not be greater than this.
321featureData.HighestVersion = D3D_ROOT_SIGNATURE_VERSION_1_1;
322
323if (FAILED(device_->CheckFeatureSupport(D3D12_FEATURE_ROOT_SIGNATURE, &featureData, sizeof(featureData)))) {
324featureData.HighestVersion = D3D_ROOT_SIGNATURE_VERSION_1_0;
325}
326
327// Compute root signature.
328{
329CD3DX12_DESCRIPTOR_RANGE1 ranges[2] = {};
330ranges[0].Init(D3D12_DESCRIPTOR_RANGE_TYPE_SRV, 1, 0, 0, D3D12_DESCRIPTOR_RANGE_FLAG_DESCRIPTORS_VOLATILE);
331ranges[1].Init(D3D12_DESCRIPTOR_RANGE_TYPE_UAV, 1, 0, 0, D3D12_DESCRIPTOR_RANGE_FLAG_DATA_VOLATILE);
332
333CD3DX12_ROOT_PARAMETER1 rootParameters[3] = {};
334rootParameters[0].InitAsConstants(4, 0);
335rootParameters[1].InitAsDescriptorTable(1, &ranges[0], D3D12_SHADER_VISIBILITY_ALL);
336rootParameters[2].InitAsDescriptorTable(1, &ranges[1], D3D12_SHADER_VISIBILITY_ALL);
337
338CD3DX12_VERSIONED_ROOT_SIGNATURE_DESC computeRootSignatureDesc;
339computeRootSignatureDesc.Init_1_1(_countof(rootParameters), rootParameters, 0, nullptr);
340
341winrt::com_ptr<ID3DBlob> signature;
342winrt::com_ptr<ID3DBlob> error;
343WINML_THROW_IF_FAILED(D3DX12SerializeVersionedRootSignature(
344&computeRootSignatureDesc, featureData.HighestVersion, signature.put(), error.put()
345));
346WINML_THROW_IF_FAILED(device_->CreateRootSignature(
3470, signature->GetBufferPointer(), signature->GetBufferSize(), IID_PPV_ARGS(newRootSignature.put())
348));
349newRootSignature->SetName(L"Tensorize Rootsignature");
350}
351
352if (InterlockedCompareExchangePointer(tensorize_root_signature_.put_void(), newRootSignature.get(), nullptr) ==
353nullptr) {
354// This thread won the race and just cached the PSO
355newRootSignature.detach();
356}
357}
358
359return tensorize_root_signature_.get();
360}
361
362ID3D12RootSignature* D3DDeviceCache::GetDetensorizeRootSignature() {
363if (detensorize_root_signature_ == nullptr) {
364winrt::com_ptr<ID3D12RootSignature> newRootSignature;
365D3D12_FEATURE_DATA_ROOT_SIGNATURE featureData = {};
366
367// This is the highest version the sample supports. If CheckFeatureSupport succeeds, the HighestVersion returned will not be greater than this.
368featureData.HighestVersion = D3D_ROOT_SIGNATURE_VERSION_1_1;
369
370if (FAILED(device_->CheckFeatureSupport(D3D12_FEATURE_ROOT_SIGNATURE, &featureData, sizeof(featureData)))) {
371featureData.HighestVersion = D3D_ROOT_SIGNATURE_VERSION_1_0;
372}
373
374// Compute root signature.
375{
376CD3DX12_DESCRIPTOR_RANGE1 ranges[2] = {};
377ranges[0].Init(D3D12_DESCRIPTOR_RANGE_TYPE_SRV, 1, 0, 0, D3D12_DESCRIPTOR_RANGE_FLAG_DESCRIPTORS_VOLATILE);
378ranges[1].Init(D3D12_DESCRIPTOR_RANGE_TYPE_UAV, 1, 0, 0, D3D12_DESCRIPTOR_RANGE_FLAG_DATA_VOLATILE);
379
380CD3DX12_ROOT_PARAMETER1 rootParameters[3] = {};
381rootParameters[0].InitAsConstants(4, 0);
382rootParameters[1].InitAsDescriptorTable(1, &ranges[0], D3D12_SHADER_VISIBILITY_ALL);
383rootParameters[2].InitAsDescriptorTable(1, &ranges[1], D3D12_SHADER_VISIBILITY_ALL);
384
385CD3DX12_VERSIONED_ROOT_SIGNATURE_DESC rootSignatureDesc;
386rootSignatureDesc.Init_1_1(
387_countof(rootParameters),
388rootParameters,
3890,
390nullptr,
391D3D12_ROOT_SIGNATURE_FLAG_ALLOW_INPUT_ASSEMBLER_INPUT_LAYOUT
392);
393
394winrt::com_ptr<ID3DBlob> signature;
395winrt::com_ptr<ID3DBlob> error;
396WINML_THROW_IF_FAILED(D3DX12SerializeVersionedRootSignature(
397&rootSignatureDesc, featureData.HighestVersion, signature.put(), error.put()
398));
399WINML_THROW_IF_FAILED(device_->CreateRootSignature(
4000, signature->GetBufferPointer(), signature->GetBufferSize(), IID_PPV_ARGS(newRootSignature.put())
401));
402newRootSignature->SetName(L"Detensorize Rootsignature");
403}
404
405if (InterlockedCompareExchangePointer(detensorize_root_signature_.put_void(), newRootSignature.get(), nullptr) ==
406nullptr) {
407// This thread won the race and just cached the PSO
408newRootSignature.detach();
409}
410}
411
412return detensorize_root_signature_.get();
413}
414
415ID3D12PipelineState* D3DDeviceCache::GetCachedPipelineState(
416PipelineStateCacheType type,
417PipelineStateCacheFormat formatFrom,
418PipelineStateCacheFormat formatTo,
419PipelineStateCacheOperation operation
420) {
421if (cached_pipeline_state[static_cast<int>(type)][static_cast<int>(formatFrom)][static_cast<int>(formatTo)]
422[static_cast<int>(operation)] == nullptr) {
423winrt::com_ptr<ID3D12PipelineState> newPSO;
424if (operation == PipelineStateCacheOperation::kTensorize) {
425newPSO.attach(CreateTensorizePipelineState(type, formatFrom, formatTo));
426} else {
427newPSO.attach(CreateDetensorizePipelineState(type, formatFrom, formatTo));
428}
429
430if (InterlockedCompareExchangePointer(
431cached_pipeline_state[static_cast<int>(type)][static_cast<int>(formatFrom)][static_cast<int>(formatTo)]
432[static_cast<int>(operation)]
433.put_void(),
434newPSO.get(),
435nullptr
436) == nullptr) {
437// This thread won the race and just cached the PSO
438newPSO.detach();
439}
440}
441
442return cached_pipeline_state[static_cast<int>(type)][static_cast<int>(formatFrom)][static_cast<int>(formatTo)]
443[static_cast<int>(operation)]
444.get();
445}
446
447ID3D12PipelineState* D3DDeviceCache::CreateTensorizePipelineState(
448PipelineStateCacheType type, PipelineStateCacheFormat formatFrom, PipelineStateCacheFormat formatTo
449) {
450static_assert(
451static_cast<unsigned int>(PipelineStateCacheFormat::kCount) == 3,
452"PipelineStateCacheFormat changed, update D3DDeviceCache::CreateTensorizePipelineState()"
453);
454
455const BYTE* shaderBytecode = nullptr;
456uint64_t shaderBytecodeSize = 0;
457
458switch (formatFrom) {
459case PipelineStateCacheFormat::kBGR8:
460case PipelineStateCacheFormat::kRGB8:
461if (type == PipelineStateCacheType::kFloat32) {
462if (formatTo == PipelineStateCacheFormat::kBGR8) {
463shaderBytecode = float32::g_csSurfaceToTensorBGR8;
464shaderBytecodeSize = sizeof(float32::g_csSurfaceToTensorBGR8);
465} else if (formatTo == PipelineStateCacheFormat::kRGB8) {
466shaderBytecode = float32::g_csSurfaceToTensorRGB8;
467shaderBytecodeSize = sizeof(float32::g_csSurfaceToTensorRGB8);
468} else if (formatTo == PipelineStateCacheFormat::kGRAY8) {
469shaderBytecode = float32::g_csSurfaceToTensorGRAY8;
470shaderBytecodeSize = sizeof(float32::g_csSurfaceToTensorGRAY8);
471} else {
472assert(false);
473}
474} else if (type == PipelineStateCacheType::kFloat16) {
475if (formatTo == PipelineStateCacheFormat::kBGR8) {
476shaderBytecode = float16::g_csSurfaceToTensorBGR8;
477shaderBytecodeSize = sizeof(float16::g_csSurfaceToTensorBGR8);
478} else if (formatTo == PipelineStateCacheFormat::kRGB8) {
479shaderBytecode = float16::g_csSurfaceToTensorRGB8;
480shaderBytecodeSize = sizeof(float16::g_csSurfaceToTensorRGB8);
481} else if (formatTo == PipelineStateCacheFormat::kGRAY8) {
482shaderBytecode = float16::g_csSurfaceToTensorGRAY8;
483shaderBytecodeSize = sizeof(float16::g_csSurfaceToTensorGRAY8);
484} else {
485assert(false);
486}
487}
488break;
489case PipelineStateCacheFormat::kGRAY8:
490if (type == PipelineStateCacheType::kFloat32) {
491if (formatTo == PipelineStateCacheFormat::kBGR8 || formatTo == PipelineStateCacheFormat::kRGB8) {
492// GRAY -> RGB is the same shader as GRAY -> BGR
493shaderBytecode = float32::g_csSurfaceGRAY8ToTensorBGR8;
494shaderBytecodeSize = sizeof(float32::g_csSurfaceGRAY8ToTensorBGR8);
495} else if (formatTo == PipelineStateCacheFormat::kGRAY8) {
496shaderBytecode = float32::g_csSurfaceGRAY8ToTensorGRAY8;
497shaderBytecodeSize = sizeof(float32::g_csSurfaceGRAY8ToTensorGRAY8);
498} else {
499assert(false);
500}
501} else if (type == PipelineStateCacheType::kFloat16) {
502if (formatTo == PipelineStateCacheFormat::kBGR8 || formatTo == PipelineStateCacheFormat::kRGB8) {
503// GRAY -> RGB is the same shader as GRAY -> BGR
504shaderBytecode = float16::g_csSurfaceGRAY8ToTensorBGR8;
505shaderBytecodeSize = sizeof(float16::g_csSurfaceGRAY8ToTensorBGR8);
506} else if (formatTo == PipelineStateCacheFormat::kGRAY8) {
507shaderBytecode = float16::g_csSurfaceGRAY8ToTensorGRAY8;
508shaderBytecodeSize = sizeof(float16::g_csSurfaceGRAY8ToTensorGRAY8);
509} else {
510assert(false);
511}
512}
513break;
514default:
515assert(false);
516break;
517}
518
519D3D12_COMPUTE_PIPELINE_STATE_DESC computePsoDesc = {};
520computePsoDesc.pRootSignature = GetTensorizeRootSignature();
521computePsoDesc.CS = CD3DX12_SHADER_BYTECODE(shaderBytecode, static_cast<size_t>(shaderBytecodeSize));
522
523winrt::com_ptr<ID3D12PipelineState> pipelineState;
524WINML_THROW_IF_FAILED(device_->CreateComputePipelineState(&computePsoDesc, IID_PPV_ARGS(pipelineState.put())));
525
526return pipelineState.detach();
527}
528
529ID3D12PipelineState* D3DDeviceCache::CreateDetensorizePipelineState(
530PipelineStateCacheType type, PipelineStateCacheFormat formatFrom, PipelineStateCacheFormat formatTo
531) {
532static_assert(
533static_cast<unsigned int>(PipelineStateCacheFormat::kCount) == 3,
534"PipelineStateCacheFormat changed, update D3DDeviceCache::CreateDetensorizePipelineState()"
535);
536
537const BYTE* shaderBytecode = nullptr;
538uint64_t shaderBytecodeSize = 0;
539
540switch (formatFrom) {
541case PipelineStateCacheFormat::kBGR8:
542if (type == PipelineStateCacheType::kFloat32) {
543if (formatTo == PipelineStateCacheFormat::kBGR8 || formatTo == PipelineStateCacheFormat::kRGB8) {
544shaderBytecode = float32::g_csTensorBGR8ToSurface;
545shaderBytecodeSize = sizeof(float32::g_csTensorBGR8ToSurface);
546} else if (formatTo == PipelineStateCacheFormat::kGRAY8) {
547shaderBytecode = float32::g_csTensorBGR8ToSurfaceGRAY8;
548shaderBytecodeSize = sizeof(float32::g_csTensorBGR8ToSurfaceGRAY8);
549} else {
550assert(false);
551}
552} else if (type == PipelineStateCacheType::kFloat16) {
553if (formatTo == PipelineStateCacheFormat::kBGR8 || formatTo == PipelineStateCacheFormat::kRGB8) {
554shaderBytecode = float16::g_csTensorBGR8ToSurface;
555shaderBytecodeSize = sizeof(float16::g_csTensorBGR8ToSurface);
556} else if (formatTo == PipelineStateCacheFormat::kGRAY8) {
557shaderBytecode = float16::g_csTensorBGR8ToSurfaceGRAY8;
558shaderBytecodeSize = sizeof(float16::g_csTensorBGR8ToSurfaceGRAY8);
559} else {
560assert(false);
561}
562}
563break;
564case PipelineStateCacheFormat::kRGB8:
565if (type == PipelineStateCacheType::kFloat32) {
566if (formatTo == PipelineStateCacheFormat::kBGR8 || formatTo == PipelineStateCacheFormat::kRGB8) {
567shaderBytecode = float32::g_csTensorRGB8ToSurface;
568shaderBytecodeSize = sizeof(float32::g_csTensorRGB8ToSurface);
569} else if (formatTo == PipelineStateCacheFormat::kGRAY8) {
570shaderBytecode = float32::g_csTensorRGB8ToSurfaceGRAY8;
571shaderBytecodeSize = sizeof(float32::g_csTensorRGB8ToSurfaceGRAY8);
572} else {
573assert(false);
574}
575} else if (type == PipelineStateCacheType::kFloat16) {
576if (formatTo == PipelineStateCacheFormat::kBGR8 || formatTo == PipelineStateCacheFormat::kRGB8) {
577shaderBytecode = float16::g_csTensorRGB8ToSurface;
578shaderBytecodeSize = sizeof(float16::g_csTensorRGB8ToSurface);
579} else if (formatTo == PipelineStateCacheFormat::kGRAY8) {
580shaderBytecode = float16::g_csTensorRGB8ToSurfaceGRAY8;
581shaderBytecodeSize = sizeof(float16::g_csTensorRGB8ToSurfaceGRAY8);
582} else {
583assert(false);
584}
585}
586break;
587case PipelineStateCacheFormat::kGRAY8:
588if (type == PipelineStateCacheType::kFloat32) {
589if (formatTo == PipelineStateCacheFormat::kBGR8 || formatTo == PipelineStateCacheFormat::kRGB8) {
590// GRAY -> RGB is the same shader as GRAY -> BGR
591shaderBytecode = float32::g_csTensorGRAY8ToSurface;
592shaderBytecodeSize = sizeof(float32::g_csTensorGRAY8ToSurface);
593} else if (formatTo == PipelineStateCacheFormat::kGRAY8) {
594shaderBytecode = float32::g_csTensorGRAY8ToSurfaceGRAY8;
595shaderBytecodeSize = sizeof(float32::g_csTensorGRAY8ToSurfaceGRAY8);
596} else {
597assert(false);
598}
599} else if (type == PipelineStateCacheType::kFloat16) {
600if (formatTo == PipelineStateCacheFormat::kBGR8 || formatTo == PipelineStateCacheFormat::kRGB8) {
601// GRAY -> RGB is the same shader as GRAY -> BGR
602shaderBytecode = float16::g_csTensorGRAY8ToSurface;
603shaderBytecodeSize = sizeof(float16::g_csTensorGRAY8ToSurface);
604} else if (formatTo == PipelineStateCacheFormat::kGRAY8) {
605shaderBytecode = float16::g_csTensorGRAY8ToSurfaceGRAY8;
606shaderBytecodeSize = sizeof(float16::g_csTensorGRAY8ToSurfaceGRAY8);
607} else {
608assert(false);
609}
610}
611break;
612default:
613assert(false);
614break;
615}
616
617D3D12_COMPUTE_PIPELINE_STATE_DESC computePsoDesc = {};
618computePsoDesc.pRootSignature = GetDetensorizeRootSignature();
619computePsoDesc.CS = CD3DX12_SHADER_BYTECODE(shaderBytecode, static_cast<size_t>(shaderBytecodeSize));
620
621winrt::com_ptr<ID3D12PipelineState> pipelineState;
622WINML_THROW_IF_FAILED(device_->CreateComputePipelineState(&computePsoDesc, IID_PPV_ARGS(pipelineState.put())));
623
624return pipelineState.detach();
625}
626
627ID3D12Resource* D3DDeviceCache::GetDetensorizeVertexBuffer(_Out_ UINT* vertexBufferSize) {
628if (detensorize_vertex_buffer_ == nullptr) {
629winrt::com_ptr<ID3D12Resource> newResource;
630// Create the vertex buffer.
631// 2 triangles for full screen
632DirectX::XMFLOAT3 triangleVertices[] = {
633{-1.0f, 1.0f, 0.0f},
634{ 1.0f, 1.0f, 0.0f},
635{-1.0f, -1.0f, 0.0f},
636{ 1.0f, -1.0f, 0.0f},
637};
638
639assert(sc_vertexBufferSize == sizeof(triangleVertices));
640
641CD3DX12_HEAP_PROPERTIES heapProp(D3D12_HEAP_TYPE_UPLOAD);
642D3D12_RESOURCE_DESC resourceDiscription = CD3DX12_RESOURCE_DESC::Buffer(sc_vertexBufferSize);
643WINML_THROW_IF_FAILED(device_->CreateCommittedResource(
644&heapProp,
645D3D12_HEAP_FLAG_NONE,
646&resourceDiscription,
647D3D12_RESOURCE_STATE_GENERIC_READ,
648nullptr,
649IID_PPV_ARGS(newResource.put())
650));
651
652// Copy the triangle data to the vertex buffer.
653UINT8* pVertexDataBegin;
654CD3DX12_RANGE readRange(0, 0); // We do not intend to read from this resource on the CPU.
655WINML_THROW_IF_FAILED(newResource->Map(0, &readRange, reinterpret_cast<void**>(&pVertexDataBegin)));
656memcpy(pVertexDataBegin, triangleVertices, sizeof(triangleVertices));
657newResource->Unmap(0, nullptr);
658
659if (InterlockedCompareExchangePointer(detensorize_vertex_buffer_.put_void(), newResource.get(), nullptr) ==
660nullptr) {
661// This thread won the race and just cached the PSO
662newResource.detach();
663}
664}
665
666*vertexBufferSize = sc_vertexBufferSize;
667return detensorize_vertex_buffer_.get();
668}
669
670HANDLE D3DDeviceCache::GetConverterFenceHandle() {
671// Lazily create the fence since we may never need to use it
672if (!converter_fence_) {
673WINML_THROW_IF_FAILED(device_->CreateFence(
6740, D3D12_FENCE_FLAG_SHARED | D3D12_FENCE_FLAG_SHARED_CROSS_ADAPTER, IID_PPV_ARGS(converter_fence_.put())
675));
676
677HANDLE hSharedFence;
678WINML_THROW_IF_FAILED(
679device_->CreateSharedHandle(converter_fence_.get(), nullptr, GENERIC_ALL, nullptr, &hSharedFence)
680);
681
682converter_fence_handle_ = wil::unique_handle(hSharedFence);
683}
684
685return converter_fence_handle_.get();
686}
687
688void D3DDeviceCache::SyncConverterToD3D11Device(_In_ ID3D11Fence* pD3D11Fence) {
689assert(command_queue_ != nullptr);
690assert(pD3D11Fence != nullptr);
691
692ComPtr<ID3D11Device> spD3D11Device;
693pD3D11Fence->GetDevice(&spD3D11Device);
694
695ComPtr<ID3D11DeviceContext> spD3D11DeviceContext;
696spD3D11Device->GetImmediateContext(&spD3D11DeviceContext);
697
698ComPtr<ID3D11DeviceContext4> spD3D11DeviceContext4;
699WINML_THROW_IF_FAILED(spD3D11DeviceContext->QueryInterface(IID_PPV_ARGS(&spD3D11DeviceContext4)));
700
701UINT64 newfenceValue = converter_fence_value_++;
702WINML_THROW_IF_FAILED(command_queue_->Signal(converter_fence_.get(), newfenceValue));
703WINML_THROW_IF_FAILED(spD3D11DeviceContext4->Wait(pD3D11Fence, newfenceValue));
704}
705
706void D3DDeviceCache::SyncD3D11DeviceToConverter(_In_ ID3D11Fence* pD3D11Fence) {
707assert(command_queue_ != nullptr);
708assert(pD3D11Fence != nullptr);
709
710ComPtr<ID3D11Device> spD3D11Device;
711pD3D11Fence->GetDevice(&spD3D11Device);
712
713ComPtr<ID3D11DeviceContext> spD3D11DeviceContext;
714spD3D11Device->GetImmediateContext(&spD3D11DeviceContext);
715
716ComPtr<ID3D11DeviceContext4> spD3D11DeviceContext4;
717WINML_THROW_IF_FAILED(spD3D11DeviceContext->QueryInterface(IID_PPV_ARGS(&spD3D11DeviceContext4)));
718
719UINT64 newfenceValue = converter_fence_value_++;
720WINML_THROW_IF_FAILED(spD3D11DeviceContext4->Signal(pD3D11Fence, newfenceValue));
721WINML_THROW_IF_FAILED(command_queue_->Wait(converter_fence_.get(), newfenceValue));
722}
723
724bool D3DDeviceCache::SharedHandleInitialized() {
725return d3d11_fence_ != nullptr;
726}
727