onnxruntime

Форк
0
/
D3DDeviceCache.cpp 
726 строк · 28.6 Кб
1
// Copyright (c) Microsoft Corporation. All rights reserved.
2
// Licensed under the MIT License.
3

4
#include "lib/Api.Image/pch.h"
5
#include "inc/D3DDeviceCache.h"
6
#include <directxmath.h>
7
#include <d3d11on12.h>
8
#include "inc/DeviceHelpers.h"
9
#include "CommonDeviceHelpers.h"
10

11
namespace float32 {
12
#include "shaders\SurfaceToTensor-SurfaceToTensorBGR8.h"
13
#include "shaders\SurfaceToTensor-SurfaceToTensorRGB8.h"
14
#include "shaders\SurfaceToTensor-SurfaceToTensorGRAY8.h"
15
#include "shaders\SurfaceToTensor-SurfaceGRAY8ToTensorBGR8.h"
16
#include "shaders\SurfaceToTensor-SurfaceGRAY8ToTensorGRAY8.h"
17
#include "shaders\TensorToSurface-TensorBGR8ToSurface.h"
18
#include "shaders\TensorToSurface-TensorRGB8ToSurface.h"
19
#include "shaders\TensorToSurface-TensorGRAY8ToSurface.h"
20
#include "shaders\TensorToSurface-TensorBGR8ToSurfaceGRAY8.h"
21
#include "shaders\TensorToSurface-TensorRGB8ToSurfaceGRAY8.h"
22
#include "shaders\TensorToSurface-TensorGRAY8ToSurfaceGRAY8.h"
23
}  // namespace float32
24

25
namespace float16 {
26
#include "shaders\SurfaceToTensor16-SurfaceToTensorBGR8.h"
27
#include "shaders\SurfaceToTensor16-SurfaceToTensorRGB8.h"
28
#include "shaders\SurfaceToTensor16-SurfaceToTensorGRAY8.h"
29
#include "shaders\SurfaceToTensor16-SurfaceGRAY8ToTensorBGR8.h"
30
#include "shaders\SurfaceToTensor16-SurfaceGRAY8ToTensorGRAY8.h"
31
#include "shaders\TensorToSurface16-TensorBGR8ToSurface.h"
32
#include "shaders\TensorToSurface16-TensorRGB8ToSurface.h"
33
#include "shaders\TensorToSurface16-TensorGRAY8ToSurface.h"
34
#include "shaders\TensorToSurface16-TensorBGR8ToSurfaceGRAY8.h"
35
#include "shaders\TensorToSurface16-TensorRGB8ToSurfaceGRAY8.h"
36
#include "shaders\TensorToSurface16-TensorGRAY8ToSurfaceGRAY8.h"
37
}  // namespace float16
38

39
using namespace Microsoft::WRL;
40

41
using namespace _winml;
42

43
D3DDeviceCache::D3DDeviceCache(winml::LearningModelDeviceKind const& deviceKind) {
44
  WINML_THROW_IF_FAILED(CoCreateGuid(&fence_guid_));
45

46
  if (deviceKind == winml::LearningModelDeviceKind::Cpu || deviceKind == winml::LearningModelDeviceKind::Default) {
47
    // CPU device don't make any GPU devices
48
    device_luid_.HighPart = device_luid_.LowPart = 0;
49
    return;
50
  }
51

52
  DXGI_GPU_PREFERENCE preference;
53
  WINML_THROW_IF_FAILED(GetGPUPreference(deviceKind, &preference));
54

55
  CommonDeviceHelpers::AdapterEnumerationSupport support;
56
  WINML_THROW_IF_FAILED(CommonDeviceHelpers::GetAdapterEnumerationSupport(&support));
57

58
  const char noHardwareAdaptersAvailableErrStr[] = "No hardware adapters available";
59
  const char failedToObtainHardwareAdaptersErrStr[] = "Failed to obtain hardware adapters.";
60
  HRESULT hardwareAdapterSuccessfullyObtained = S_OK;
61
  if (support.has_dxgi) {
62
    winrt::com_ptr<IDXGIAdapter1> spAdapter;
63
    hardwareAdapterSuccessfullyObtained = GetDXGIHardwareAdapterWithPreference(preference, spAdapter.put());
64
    if (hardwareAdapterSuccessfullyObtained == HRESULT_FROM_WIN32(ERROR_NOT_FOUND)) {
65
      WINML_THROW_HR_MSG_NO_TELEMETRY_SENT(hardwareAdapterSuccessfullyObtained, noHardwareAdaptersAvailableErrStr);
66
    } else {
67
      WINML_THROW_IF_FAILED_MSG(hardwareAdapterSuccessfullyObtained, failedToObtainHardwareAdaptersErrStr);
68
    }
69
    WINML_THROW_IF_FAILED(D3D12CreateDevice(spAdapter.get(), D3D_FEATURE_LEVEL_11_0, IID_PPV_ARGS(device_.put())));
70
  }
71
#ifdef ENABLE_DXCORE
72
  if (support.has_dxgi == false) {
73
    winrt::com_ptr<IDXCoreAdapter> spAdapter;
74
    hardwareAdapterSuccessfullyObtained = GetDXCoreHardwareAdapterWithPreference(preference, spAdapter.put());
75
    if (hardwareAdapterSuccessfullyObtained == HRESULT_FROM_WIN32(ERROR_NOT_FOUND)) {
76
      WINML_THROW_HR_MSG_NO_TELEMETRY_SENT(hardwareAdapterSuccessfullyObtained, noHardwareAdaptersAvailableErrStr);
77
    } else {
78
      WINML_THROW_IF_FAILED_MSG(hardwareAdapterSuccessfullyObtained, failedToObtainHardwareAdaptersErrStr);
79
    }
80
    WINML_THROW_IF_FAILED(D3D12CreateDevice(spAdapter.get(), D3D_FEATURE_LEVEL_11_0, IID_PPV_ARGS(device_.put())));
81
  }
82
#endif
83
  InitializeCommandQueue(device_.get());
84

85
  device_luid_ = device_->GetAdapterLuid();
86
}
87

88
D3DDeviceCache::D3DDeviceCache(wgdx::Direct3D11::IDirect3DDevice const& device) {
89
  WINML_THROW_IF_FAILED(CoCreateGuid(&fence_guid_));
90

91
  // Use the 11 device to initialize 12
92
  winrt_device_ = device;
93

94
  // they told us which device to run on, crack the interop wrapper to get the dxgi device
95
  winrt::com_ptr<::Windows::Graphics::DirectX::Direct3D11::IDirect3DDxgiInterfaceAccess> dxgi;
96
  dxgi = device.as<::Windows::Graphics::DirectX::Direct3D11::IDirect3DDxgiInterfaceAccess>();
97

98
  winrt::com_ptr<IDXGIDevice> dxgiDevice;
99
  WINML_THROW_IF_FAILED(dxgi->GetInterface(IID_PPV_ARGS(dxgiDevice.put())));
100

101
  device_11_ = dxgiDevice.as<ID3D11Device>();
102

103
  winrt::com_ptr<ID3D11DeviceContext> spContext;
104
  device_11_->GetImmediateContext(spContext.put());
105
  spContext.as(device_context11_);
106

107
  winrt::com_ptr<IDXGIDevice> pDXGIDevice;
108
  WINML_THROW_IF_FAILED(dxgi->GetInterface(IID_PPV_ARGS(pDXGIDevice.put())));
109

110
  winrt::com_ptr<IDXGIAdapter> adapter;
111
  WINML_THROW_IF_FAILED(pDXGIDevice->GetAdapter(adapter.put()));
112

113
  WINML_THROW_IF_FAILED(D3D12CreateDevice(adapter.get(), D3D_FEATURE_LEVEL_11_0, IID_PPV_ARGS(device_.put())));
114

115
  InitializeCommandQueue(device_.get());
116

117
  device_luid_ = device_->GetAdapterLuid();
118
}
119

120
D3DDeviceCache::D3DDeviceCache(ID3D12CommandQueue* queue) {
121
  WINML_THROW_IF_FAILED(CoCreateGuid(&fence_guid_));
122

123
  // Use the command queue to initialize all of the needed D3D11 interop
124
  command_queue_.copy_from(queue);
125
  command_queue_->QueryInterface(IID_PPV_ARGS(sharing_contract_.put()));
126

127
  WINML_THROW_IF_FAILED(queue->GetDevice(IID_PPV_ARGS(device_.put())));
128

129
  device_luid_ = device_->GetAdapterLuid();
130
}
131

132
D3DDeviceCache::~D3DDeviceCache() {
133
  // If this is a CPU instance device_ will not have been created.
134
  // Ensure the device is still valid before doing work.
135
  if (device_ != nullptr && (device_->GetDeviceRemovedReason() == S_OK)) {
136
    // dx11 stack is optional, and we lazy load it when available
137
    if (device_context11_ != nullptr) {
138
      // Sync 11 to 12 then Sync 12 to the CPU. This ensures that all inflight work is done before we delete the d3d objects.
139
      GPUSyncD3D11ToD3D12();
140
    }
141
    SyncD3D12ToCPU();
142
  }
143
}
144

145
bool D3DDeviceCache::IsFloat16Supported() {
146
  if (device_ != nullptr) {
147
    return CommonDeviceHelpers::IsFloat16Supported(device_.get());
148
  }
149

150
  return true;
151
}
152

153
ID3D11Device* D3DDeviceCache::GetD3D11Device() {
154
  EnsureD3D11FromD3D12();
155
  return device_11_.get();
156
}
157

158
const GUID& D3DDeviceCache::GetFenceGuid() const {
159
  return fence_guid_;
160
}
161

162
ID3D11DeviceContext4* D3DDeviceCache::GetD3D11DeviceContext() {
163
  EnsureD3D11FromD3D12();
164
  return device_context11_.get();
165
}
166

167
wgdx::Direct3D11::IDirect3DDevice D3DDeviceCache::GetWinrtDevice() {
168
  EnsureD3D11FromD3D12();
169
  return winrt_device_;
170
}
171

172
void D3DDeviceCache::InitializeCommandQueue(ID3D12Device1* device) {
173
  D3D12_COMMAND_QUEUE_DESC queueDesc = {};
174
  queueDesc.Type = D3D12_COMMAND_LIST_TYPE_DIRECT;
175
  queueDesc.Flags = D3D12_COMMAND_QUEUE_FLAG_DISABLE_GPU_TIMEOUT;
176
  WINML_THROW_IF_FAILED(
177
    device->CreateCommandQueue(&queueDesc, winrt::guid_of<ID3D12CommandQueue>(), command_queue_.put_void())
178
  );
179

180
  // If possible get the sharing context. If not leave nullptr;
181
  command_queue_->QueryInterface(IID_PPV_ARGS(sharing_contract_.put()));
182
}
183

184
// this initializes the following variables, making them from the dx12 device
185
//      device_11_
186
//      device_context11_
187
//      winrt_device_
188
void D3DDeviceCache::EnsureD3D11FromD3D12() {
189
  // do we even have a device?  (CPU will use the cache but not have a device) .
190
  if (device_ == nullptr)
191
    return;
192

193
  // are we already initialized
194
  if (winrt_device_ != nullptr)
195
    return;
196

197
  CWinMLAutoLock lock(&lock_);
198

199
  // check with the lock held, are we already initialized
200
  if (winrt_device_ != nullptr)
201
    return;
202

203
  winrt::com_ptr<::IInspectable> spInspectable;
204
  winrt::com_ptr<IDXGIDevice> spDXGIDevice;
205

206
  // call our SEH version (for delay loading)
207
  WINML_THROW_IF_FAILED(CreateD3D11On12Device(device_.get(), device_11_.put()));
208
  winrt::com_ptr<ID3D11DeviceContext> spContext;
209
  device_11_->GetImmediateContext(spContext.put());
210
  spContext.as(device_context11_);
211

212
  WINML_THROW_IF_FAILED(device_11_->QueryInterface(IID_PPV_ARGS(spDXGIDevice.put())));
213
  // Convert to Winrt wrapper. This doesn't actually make a new device.
214
  WINML_THROW_IF_FAILED(CreateDirect3D11DeviceFromDXGIDevice(spDXGIDevice.get(), spInspectable.put()));
215
  WINML_THROW_IF_FAILED(spInspectable->QueryInterface(
216
    winrt::guid_of<wgdx::Direct3D11::IDirect3DDevice>(), reinterpret_cast<void**>(winrt::put_abi(winrt_device_))
217
  ));
218
}
219

220
void D3DDeviceCache::EnsureD3D12Fence() {
221
  // are we already initialized?
222
  if (d3d12_fence_ != nullptr)
223
    return;
224

225
  CWinMLAutoLock lock(&lock_);
226

227
  // with the lock held, are we already initialized?
228
  if (d3d12_fence_ != nullptr)
229
    return;
230

231
  WINML_THROW_IF_FAILED(device_->CreateFence(0, D3D12_FENCE_FLAG_SHARED, IID_PPV_ARGS(d3d12_fence_.put())));
232
}
233

234
// this initializes the following variables, so that we can share dx12 with dx11
235
//      d3d11_fence_
236
//      d3d12_fence_
237
void D3DDeviceCache::EnsureSharedFences() {
238
  // are we already initialized?
239
  if (d3d11_fence_ != nullptr)
240
    return;
241

242
  CWinMLAutoLock lock(&lock_);
243

244
  // with the lock held, are we already initialized?
245
  if (d3d11_fence_ != nullptr)
246
    return;
247

248
  EnsureD3D12Fence();
249

250
  // ensure the d11 stack is alive, the 11 stack doesn't exist on WCOSHeadless yet, so be resilient
251
  EnsureD3D11FromD3D12();
252

253
  winrt::com_ptr<ID3D12DeviceChild> spD3D12DeviceChild;
254
  d3d12_fence_.as(spD3D12DeviceChild);
255
  HANDLE hSharedFence;
256
  WINML_THROW_IF_FAILED(device_->CreateSharedHandle(spD3D12DeviceChild.get(), NULL, GENERIC_ALL, nullptr, &hSharedFence)
257
  );
258

259
  winrt::com_ptr<ID3D11Device5> spD3D11Device5;
260
  device_11_.as(spD3D11Device5);
261
  wil::unique_handle safe(hSharedFence);
262
  WINML_THROW_IF_FAILED(spD3D11Device5->OpenSharedFence(safe.get(), IID_PPV_ARGS(d3d11_fence_.put())));
263
}
264

265
void D3DDeviceCache::GPUSyncD3D11ToD3D12() {
266
  EnsureSharedFences();
267

268
  UINT64 currentFence = fence_value_++;
269
  WINML_THROW_IF_FAILED(device_context11_->Signal(d3d11_fence_.get(), currentFence));
270

271
  WINML_THROW_IF_FAILED(command_queue_->Wait(d3d12_fence_.get(), currentFence));
272

273
  if (sharing_contract_ != nullptr) {
274
    sharing_contract_->SharedFenceSignal(d3d12_fence_.get(), currentFence);
275
  }
276
}
277

278
void D3DDeviceCache::GPUSyncD3D12ToD3D11() {
279
  EnsureSharedFences();
280

281
  UINT64 currentFence = fence_value_++;
282
  WINML_THROW_IF_FAILED(command_queue_->Signal(d3d12_fence_.get(), currentFence));
283

284
  WINML_THROW_IF_FAILED(device_context11_->Wait(d3d11_fence_.get(), currentFence));
285
}
286

287
void D3DDeviceCache::SyncD3D12ToCPU() {
288
  UINT64 currentFence = QueueFenceToD3D12();
289
  WaitForFenceValue(currentFence);
290
}
291

292
UINT64 D3DDeviceCache::QueueFenceToD3D12() {
293
  EnsureD3D12Fence();
294

295
  UINT64 currentFence = fence_value_++;
296
  WINML_THROW_IF_FAILED(command_queue_->Signal(d3d12_fence_.get(), currentFence));
297

298
  return currentFence;
299
}
300

301
void D3DDeviceCache::WaitForFenceValue(UINT64 fenceValue) {
302
  EnsureD3D12Fence();
303

304
  wil::unique_handle event(CreateEvent(nullptr, FALSE, FALSE, nullptr));
305
  THROW_LAST_ERROR_IF(!event);
306

307
  WINML_THROW_IF_FAILED(d3d12_fence_->SetEventOnCompletion(fenceValue, event.get()));
308

309
  DWORD retVal = WaitForSingleObject(event.get(), INFINITE);
310
  if (retVal != WAIT_OBJECT_0) {
311
    WINML_THROW_IF_FAILED(E_UNEXPECTED);
312
  }
313
}
314

315
ID3D12RootSignature* D3DDeviceCache::GetTensorizeRootSignature() {
316
  if (tensorize_root_signature_ == nullptr) {
317
    winrt::com_ptr<ID3D12RootSignature> newRootSignature;
318
    D3D12_FEATURE_DATA_ROOT_SIGNATURE featureData = {};
319

320
    // This is the highest version the sample supports. If CheckFeatureSupport succeeds, the HighestVersion returned will not be greater than this.
321
    featureData.HighestVersion = D3D_ROOT_SIGNATURE_VERSION_1_1;
322

323
    if (FAILED(device_->CheckFeatureSupport(D3D12_FEATURE_ROOT_SIGNATURE, &featureData, sizeof(featureData)))) {
324
      featureData.HighestVersion = D3D_ROOT_SIGNATURE_VERSION_1_0;
325
    }
326

327
    // Compute root signature.
328
    {
329
      CD3DX12_DESCRIPTOR_RANGE1 ranges[2] = {};
330
      ranges[0].Init(D3D12_DESCRIPTOR_RANGE_TYPE_SRV, 1, 0, 0, D3D12_DESCRIPTOR_RANGE_FLAG_DESCRIPTORS_VOLATILE);
331
      ranges[1].Init(D3D12_DESCRIPTOR_RANGE_TYPE_UAV, 1, 0, 0, D3D12_DESCRIPTOR_RANGE_FLAG_DATA_VOLATILE);
332

333
      CD3DX12_ROOT_PARAMETER1 rootParameters[3] = {};
334
      rootParameters[0].InitAsConstants(4, 0);
335
      rootParameters[1].InitAsDescriptorTable(1, &ranges[0], D3D12_SHADER_VISIBILITY_ALL);
336
      rootParameters[2].InitAsDescriptorTable(1, &ranges[1], D3D12_SHADER_VISIBILITY_ALL);
337

338
      CD3DX12_VERSIONED_ROOT_SIGNATURE_DESC computeRootSignatureDesc;
339
      computeRootSignatureDesc.Init_1_1(_countof(rootParameters), rootParameters, 0, nullptr);
340

341
      winrt::com_ptr<ID3DBlob> signature;
342
      winrt::com_ptr<ID3DBlob> error;
343
      WINML_THROW_IF_FAILED(D3DX12SerializeVersionedRootSignature(
344
        &computeRootSignatureDesc, featureData.HighestVersion, signature.put(), error.put()
345
      ));
346
      WINML_THROW_IF_FAILED(device_->CreateRootSignature(
347
        0, signature->GetBufferPointer(), signature->GetBufferSize(), IID_PPV_ARGS(newRootSignature.put())
348
      ));
349
      newRootSignature->SetName(L"Tensorize Rootsignature");
350
    }
351

352
    if (InterlockedCompareExchangePointer(tensorize_root_signature_.put_void(), newRootSignature.get(), nullptr) ==
353
        nullptr) {
354
      // This thread won the race and just cached the PSO
355
      newRootSignature.detach();
356
    }
357
  }
358

359
  return tensorize_root_signature_.get();
360
}
361

362
ID3D12RootSignature* D3DDeviceCache::GetDetensorizeRootSignature() {
363
  if (detensorize_root_signature_ == nullptr) {
364
    winrt::com_ptr<ID3D12RootSignature> newRootSignature;
365
    D3D12_FEATURE_DATA_ROOT_SIGNATURE featureData = {};
366

367
    // This is the highest version the sample supports. If CheckFeatureSupport succeeds, the HighestVersion returned will not be greater than this.
368
    featureData.HighestVersion = D3D_ROOT_SIGNATURE_VERSION_1_1;
369

370
    if (FAILED(device_->CheckFeatureSupport(D3D12_FEATURE_ROOT_SIGNATURE, &featureData, sizeof(featureData)))) {
371
      featureData.HighestVersion = D3D_ROOT_SIGNATURE_VERSION_1_0;
372
    }
373

374
    // Compute root signature.
375
    {
376
      CD3DX12_DESCRIPTOR_RANGE1 ranges[2] = {};
377
      ranges[0].Init(D3D12_DESCRIPTOR_RANGE_TYPE_SRV, 1, 0, 0, D3D12_DESCRIPTOR_RANGE_FLAG_DESCRIPTORS_VOLATILE);
378
      ranges[1].Init(D3D12_DESCRIPTOR_RANGE_TYPE_UAV, 1, 0, 0, D3D12_DESCRIPTOR_RANGE_FLAG_DATA_VOLATILE);
379

380
      CD3DX12_ROOT_PARAMETER1 rootParameters[3] = {};
381
      rootParameters[0].InitAsConstants(4, 0);
382
      rootParameters[1].InitAsDescriptorTable(1, &ranges[0], D3D12_SHADER_VISIBILITY_ALL);
383
      rootParameters[2].InitAsDescriptorTable(1, &ranges[1], D3D12_SHADER_VISIBILITY_ALL);
384

385
      CD3DX12_VERSIONED_ROOT_SIGNATURE_DESC rootSignatureDesc;
386
      rootSignatureDesc.Init_1_1(
387
        _countof(rootParameters),
388
        rootParameters,
389
        0,
390
        nullptr,
391
        D3D12_ROOT_SIGNATURE_FLAG_ALLOW_INPUT_ASSEMBLER_INPUT_LAYOUT
392
      );
393

394
      winrt::com_ptr<ID3DBlob> signature;
395
      winrt::com_ptr<ID3DBlob> error;
396
      WINML_THROW_IF_FAILED(D3DX12SerializeVersionedRootSignature(
397
        &rootSignatureDesc, featureData.HighestVersion, signature.put(), error.put()
398
      ));
399
      WINML_THROW_IF_FAILED(device_->CreateRootSignature(
400
        0, signature->GetBufferPointer(), signature->GetBufferSize(), IID_PPV_ARGS(newRootSignature.put())
401
      ));
402
      newRootSignature->SetName(L"Detensorize Rootsignature");
403
    }
404

405
    if (InterlockedCompareExchangePointer(detensorize_root_signature_.put_void(), newRootSignature.get(), nullptr) ==
406
        nullptr) {
407
      // This thread won the race and just cached the PSO
408
      newRootSignature.detach();
409
    }
410
  }
411

412
  return detensorize_root_signature_.get();
413
}
414

415
ID3D12PipelineState* D3DDeviceCache::GetCachedPipelineState(
416
  PipelineStateCacheType type,
417
  PipelineStateCacheFormat formatFrom,
418
  PipelineStateCacheFormat formatTo,
419
  PipelineStateCacheOperation operation
420
) {
421
  if (cached_pipeline_state[static_cast<int>(type)][static_cast<int>(formatFrom)][static_cast<int>(formatTo)]
422
                           [static_cast<int>(operation)] == nullptr) {
423
    winrt::com_ptr<ID3D12PipelineState> newPSO;
424
    if (operation == PipelineStateCacheOperation::kTensorize) {
425
      newPSO.attach(CreateTensorizePipelineState(type, formatFrom, formatTo));
426
    } else {
427
      newPSO.attach(CreateDetensorizePipelineState(type, formatFrom, formatTo));
428
    }
429

430
    if (InterlockedCompareExchangePointer(
431
          cached_pipeline_state[static_cast<int>(type)][static_cast<int>(formatFrom)][static_cast<int>(formatTo)]
432
                               [static_cast<int>(operation)]
433
                                 .put_void(),
434
          newPSO.get(),
435
          nullptr
436
        ) == nullptr) {
437
      // This thread won the race and just cached the PSO
438
      newPSO.detach();
439
    }
440
  }
441

442
  return cached_pipeline_state[static_cast<int>(type)][static_cast<int>(formatFrom)][static_cast<int>(formatTo)]
443
                              [static_cast<int>(operation)]
444
                                .get();
445
}
446

447
ID3D12PipelineState* D3DDeviceCache::CreateTensorizePipelineState(
448
  PipelineStateCacheType type, PipelineStateCacheFormat formatFrom, PipelineStateCacheFormat formatTo
449
) {
450
  static_assert(
451
    static_cast<unsigned int>(PipelineStateCacheFormat::kCount) == 3,
452
    "PipelineStateCacheFormat changed, update D3DDeviceCache::CreateTensorizePipelineState()"
453
  );
454

455
  const BYTE* shaderBytecode = nullptr;
456
  uint64_t shaderBytecodeSize = 0;
457

458
  switch (formatFrom) {
459
    case PipelineStateCacheFormat::kBGR8:
460
    case PipelineStateCacheFormat::kRGB8:
461
      if (type == PipelineStateCacheType::kFloat32) {
462
        if (formatTo == PipelineStateCacheFormat::kBGR8) {
463
          shaderBytecode = float32::g_csSurfaceToTensorBGR8;
464
          shaderBytecodeSize = sizeof(float32::g_csSurfaceToTensorBGR8);
465
        } else if (formatTo == PipelineStateCacheFormat::kRGB8) {
466
          shaderBytecode = float32::g_csSurfaceToTensorRGB8;
467
          shaderBytecodeSize = sizeof(float32::g_csSurfaceToTensorRGB8);
468
        } else if (formatTo == PipelineStateCacheFormat::kGRAY8) {
469
          shaderBytecode = float32::g_csSurfaceToTensorGRAY8;
470
          shaderBytecodeSize = sizeof(float32::g_csSurfaceToTensorGRAY8);
471
        } else {
472
          assert(false);
473
        }
474
      } else if (type == PipelineStateCacheType::kFloat16) {
475
        if (formatTo == PipelineStateCacheFormat::kBGR8) {
476
          shaderBytecode = float16::g_csSurfaceToTensorBGR8;
477
          shaderBytecodeSize = sizeof(float16::g_csSurfaceToTensorBGR8);
478
        } else if (formatTo == PipelineStateCacheFormat::kRGB8) {
479
          shaderBytecode = float16::g_csSurfaceToTensorRGB8;
480
          shaderBytecodeSize = sizeof(float16::g_csSurfaceToTensorRGB8);
481
        } else if (formatTo == PipelineStateCacheFormat::kGRAY8) {
482
          shaderBytecode = float16::g_csSurfaceToTensorGRAY8;
483
          shaderBytecodeSize = sizeof(float16::g_csSurfaceToTensorGRAY8);
484
        } else {
485
          assert(false);
486
        }
487
      }
488
      break;
489
    case PipelineStateCacheFormat::kGRAY8:
490
      if (type == PipelineStateCacheType::kFloat32) {
491
        if (formatTo == PipelineStateCacheFormat::kBGR8 || formatTo == PipelineStateCacheFormat::kRGB8) {
492
          // GRAY -> RGB is the same shader as GRAY -> BGR
493
          shaderBytecode = float32::g_csSurfaceGRAY8ToTensorBGR8;
494
          shaderBytecodeSize = sizeof(float32::g_csSurfaceGRAY8ToTensorBGR8);
495
        } else if (formatTo == PipelineStateCacheFormat::kGRAY8) {
496
          shaderBytecode = float32::g_csSurfaceGRAY8ToTensorGRAY8;
497
          shaderBytecodeSize = sizeof(float32::g_csSurfaceGRAY8ToTensorGRAY8);
498
        } else {
499
          assert(false);
500
        }
501
      } else if (type == PipelineStateCacheType::kFloat16) {
502
        if (formatTo == PipelineStateCacheFormat::kBGR8 || formatTo == PipelineStateCacheFormat::kRGB8) {
503
          // GRAY -> RGB is the same shader as GRAY -> BGR
504
          shaderBytecode = float16::g_csSurfaceGRAY8ToTensorBGR8;
505
          shaderBytecodeSize = sizeof(float16::g_csSurfaceGRAY8ToTensorBGR8);
506
        } else if (formatTo == PipelineStateCacheFormat::kGRAY8) {
507
          shaderBytecode = float16::g_csSurfaceGRAY8ToTensorGRAY8;
508
          shaderBytecodeSize = sizeof(float16::g_csSurfaceGRAY8ToTensorGRAY8);
509
        } else {
510
          assert(false);
511
        }
512
      }
513
      break;
514
    default:
515
      assert(false);
516
      break;
517
  }
518

519
  D3D12_COMPUTE_PIPELINE_STATE_DESC computePsoDesc = {};
520
  computePsoDesc.pRootSignature = GetTensorizeRootSignature();
521
  computePsoDesc.CS = CD3DX12_SHADER_BYTECODE(shaderBytecode, static_cast<size_t>(shaderBytecodeSize));
522

523
  winrt::com_ptr<ID3D12PipelineState> pipelineState;
524
  WINML_THROW_IF_FAILED(device_->CreateComputePipelineState(&computePsoDesc, IID_PPV_ARGS(pipelineState.put())));
525

526
  return pipelineState.detach();
527
}
528

529
ID3D12PipelineState* D3DDeviceCache::CreateDetensorizePipelineState(
530
  PipelineStateCacheType type, PipelineStateCacheFormat formatFrom, PipelineStateCacheFormat formatTo
531
) {
532
  static_assert(
533
    static_cast<unsigned int>(PipelineStateCacheFormat::kCount) == 3,
534
    "PipelineStateCacheFormat changed, update D3DDeviceCache::CreateDetensorizePipelineState()"
535
  );
536

537
  const BYTE* shaderBytecode = nullptr;
538
  uint64_t shaderBytecodeSize = 0;
539

540
  switch (formatFrom) {
541
    case PipelineStateCacheFormat::kBGR8:
542
      if (type == PipelineStateCacheType::kFloat32) {
543
        if (formatTo == PipelineStateCacheFormat::kBGR8 || formatTo == PipelineStateCacheFormat::kRGB8) {
544
          shaderBytecode = float32::g_csTensorBGR8ToSurface;
545
          shaderBytecodeSize = sizeof(float32::g_csTensorBGR8ToSurface);
546
        } else if (formatTo == PipelineStateCacheFormat::kGRAY8) {
547
          shaderBytecode = float32::g_csTensorBGR8ToSurfaceGRAY8;
548
          shaderBytecodeSize = sizeof(float32::g_csTensorBGR8ToSurfaceGRAY8);
549
        } else {
550
          assert(false);
551
        }
552
      } else if (type == PipelineStateCacheType::kFloat16) {
553
        if (formatTo == PipelineStateCacheFormat::kBGR8 || formatTo == PipelineStateCacheFormat::kRGB8) {
554
          shaderBytecode = float16::g_csTensorBGR8ToSurface;
555
          shaderBytecodeSize = sizeof(float16::g_csTensorBGR8ToSurface);
556
        } else if (formatTo == PipelineStateCacheFormat::kGRAY8) {
557
          shaderBytecode = float16::g_csTensorBGR8ToSurfaceGRAY8;
558
          shaderBytecodeSize = sizeof(float16::g_csTensorBGR8ToSurfaceGRAY8);
559
        } else {
560
          assert(false);
561
        }
562
      }
563
      break;
564
    case PipelineStateCacheFormat::kRGB8:
565
      if (type == PipelineStateCacheType::kFloat32) {
566
        if (formatTo == PipelineStateCacheFormat::kBGR8 || formatTo == PipelineStateCacheFormat::kRGB8) {
567
          shaderBytecode = float32::g_csTensorRGB8ToSurface;
568
          shaderBytecodeSize = sizeof(float32::g_csTensorRGB8ToSurface);
569
        } else if (formatTo == PipelineStateCacheFormat::kGRAY8) {
570
          shaderBytecode = float32::g_csTensorRGB8ToSurfaceGRAY8;
571
          shaderBytecodeSize = sizeof(float32::g_csTensorRGB8ToSurfaceGRAY8);
572
        } else {
573
          assert(false);
574
        }
575
      } else if (type == PipelineStateCacheType::kFloat16) {
576
        if (formatTo == PipelineStateCacheFormat::kBGR8 || formatTo == PipelineStateCacheFormat::kRGB8) {
577
          shaderBytecode = float16::g_csTensorRGB8ToSurface;
578
          shaderBytecodeSize = sizeof(float16::g_csTensorRGB8ToSurface);
579
        } else if (formatTo == PipelineStateCacheFormat::kGRAY8) {
580
          shaderBytecode = float16::g_csTensorRGB8ToSurfaceGRAY8;
581
          shaderBytecodeSize = sizeof(float16::g_csTensorRGB8ToSurfaceGRAY8);
582
        } else {
583
          assert(false);
584
        }
585
      }
586
      break;
587
    case PipelineStateCacheFormat::kGRAY8:
588
      if (type == PipelineStateCacheType::kFloat32) {
589
        if (formatTo == PipelineStateCacheFormat::kBGR8 || formatTo == PipelineStateCacheFormat::kRGB8) {
590
          // GRAY -> RGB is the same shader as GRAY -> BGR
591
          shaderBytecode = float32::g_csTensorGRAY8ToSurface;
592
          shaderBytecodeSize = sizeof(float32::g_csTensorGRAY8ToSurface);
593
        } else if (formatTo == PipelineStateCacheFormat::kGRAY8) {
594
          shaderBytecode = float32::g_csTensorGRAY8ToSurfaceGRAY8;
595
          shaderBytecodeSize = sizeof(float32::g_csTensorGRAY8ToSurfaceGRAY8);
596
        } else {
597
          assert(false);
598
        }
599
      } else if (type == PipelineStateCacheType::kFloat16) {
600
        if (formatTo == PipelineStateCacheFormat::kBGR8 || formatTo == PipelineStateCacheFormat::kRGB8) {
601
          // GRAY -> RGB is the same shader as GRAY -> BGR
602
          shaderBytecode = float16::g_csTensorGRAY8ToSurface;
603
          shaderBytecodeSize = sizeof(float16::g_csTensorGRAY8ToSurface);
604
        } else if (formatTo == PipelineStateCacheFormat::kGRAY8) {
605
          shaderBytecode = float16::g_csTensorGRAY8ToSurfaceGRAY8;
606
          shaderBytecodeSize = sizeof(float16::g_csTensorGRAY8ToSurfaceGRAY8);
607
        } else {
608
          assert(false);
609
        }
610
      }
611
      break;
612
    default:
613
      assert(false);
614
      break;
615
  }
616

617
  D3D12_COMPUTE_PIPELINE_STATE_DESC computePsoDesc = {};
618
  computePsoDesc.pRootSignature = GetDetensorizeRootSignature();
619
  computePsoDesc.CS = CD3DX12_SHADER_BYTECODE(shaderBytecode, static_cast<size_t>(shaderBytecodeSize));
620

621
  winrt::com_ptr<ID3D12PipelineState> pipelineState;
622
  WINML_THROW_IF_FAILED(device_->CreateComputePipelineState(&computePsoDesc, IID_PPV_ARGS(pipelineState.put())));
623

624
  return pipelineState.detach();
625
}
626

627
ID3D12Resource* D3DDeviceCache::GetDetensorizeVertexBuffer(_Out_ UINT* vertexBufferSize) {
628
  if (detensorize_vertex_buffer_ == nullptr) {
629
    winrt::com_ptr<ID3D12Resource> newResource;
630
    // Create the vertex buffer.
631
    // 2 triangles for full screen
632
    DirectX::XMFLOAT3 triangleVertices[] = {
633
      {-1.0f,  1.0f, 0.0f},
634
      { 1.0f,  1.0f, 0.0f},
635
      {-1.0f, -1.0f, 0.0f},
636
      { 1.0f, -1.0f, 0.0f},
637
    };
638

639
    assert(sc_vertexBufferSize == sizeof(triangleVertices));
640

641
    CD3DX12_HEAP_PROPERTIES heapProp(D3D12_HEAP_TYPE_UPLOAD);
642
    D3D12_RESOURCE_DESC resourceDiscription = CD3DX12_RESOURCE_DESC::Buffer(sc_vertexBufferSize);
643
    WINML_THROW_IF_FAILED(device_->CreateCommittedResource(
644
      &heapProp,
645
      D3D12_HEAP_FLAG_NONE,
646
      &resourceDiscription,
647
      D3D12_RESOURCE_STATE_GENERIC_READ,
648
      nullptr,
649
      IID_PPV_ARGS(newResource.put())
650
    ));
651

652
    // Copy the triangle data to the vertex buffer.
653
    UINT8* pVertexDataBegin;
654
    CD3DX12_RANGE readRange(0, 0);  // We do not intend to read from this resource on the CPU.
655
    WINML_THROW_IF_FAILED(newResource->Map(0, &readRange, reinterpret_cast<void**>(&pVertexDataBegin)));
656
    memcpy(pVertexDataBegin, triangleVertices, sizeof(triangleVertices));
657
    newResource->Unmap(0, nullptr);
658

659
    if (InterlockedCompareExchangePointer(detensorize_vertex_buffer_.put_void(), newResource.get(), nullptr) ==
660
        nullptr) {
661
      // This thread won the race and just cached the PSO
662
      newResource.detach();
663
    }
664
  }
665

666
  *vertexBufferSize = sc_vertexBufferSize;
667
  return detensorize_vertex_buffer_.get();
668
}
669

670
HANDLE D3DDeviceCache::GetConverterFenceHandle() {
671
  // Lazily create the fence since we may never need to use it
672
  if (!converter_fence_) {
673
    WINML_THROW_IF_FAILED(device_->CreateFence(
674
      0, D3D12_FENCE_FLAG_SHARED | D3D12_FENCE_FLAG_SHARED_CROSS_ADAPTER, IID_PPV_ARGS(converter_fence_.put())
675
    ));
676

677
    HANDLE hSharedFence;
678
    WINML_THROW_IF_FAILED(
679
      device_->CreateSharedHandle(converter_fence_.get(), nullptr, GENERIC_ALL, nullptr, &hSharedFence)
680
    );
681

682
    converter_fence_handle_ = wil::unique_handle(hSharedFence);
683
  }
684

685
  return converter_fence_handle_.get();
686
}
687

688
void D3DDeviceCache::SyncConverterToD3D11Device(_In_ ID3D11Fence* pD3D11Fence) {
689
  assert(command_queue_ != nullptr);
690
  assert(pD3D11Fence != nullptr);
691

692
  ComPtr<ID3D11Device> spD3D11Device;
693
  pD3D11Fence->GetDevice(&spD3D11Device);
694

695
  ComPtr<ID3D11DeviceContext> spD3D11DeviceContext;
696
  spD3D11Device->GetImmediateContext(&spD3D11DeviceContext);
697

698
  ComPtr<ID3D11DeviceContext4> spD3D11DeviceContext4;
699
  WINML_THROW_IF_FAILED(spD3D11DeviceContext->QueryInterface(IID_PPV_ARGS(&spD3D11DeviceContext4)));
700

701
  UINT64 newfenceValue = converter_fence_value_++;
702
  WINML_THROW_IF_FAILED(command_queue_->Signal(converter_fence_.get(), newfenceValue));
703
  WINML_THROW_IF_FAILED(spD3D11DeviceContext4->Wait(pD3D11Fence, newfenceValue));
704
}
705

706
void D3DDeviceCache::SyncD3D11DeviceToConverter(_In_ ID3D11Fence* pD3D11Fence) {
707
  assert(command_queue_ != nullptr);
708
  assert(pD3D11Fence != nullptr);
709

710
  ComPtr<ID3D11Device> spD3D11Device;
711
  pD3D11Fence->GetDevice(&spD3D11Device);
712

713
  ComPtr<ID3D11DeviceContext> spD3D11DeviceContext;
714
  spD3D11Device->GetImmediateContext(&spD3D11DeviceContext);
715

716
  ComPtr<ID3D11DeviceContext4> spD3D11DeviceContext4;
717
  WINML_THROW_IF_FAILED(spD3D11DeviceContext->QueryInterface(IID_PPV_ARGS(&spD3D11DeviceContext4)));
718

719
  UINT64 newfenceValue = converter_fence_value_++;
720
  WINML_THROW_IF_FAILED(spD3D11DeviceContext4->Signal(pD3D11Fence, newfenceValue));
721
  WINML_THROW_IF_FAILED(command_queue_->Wait(converter_fence_.get(), newfenceValue));
722
}
723

724
bool D3DDeviceCache::SharedHandleInitialized() {
725
  return d3d11_fence_ != nullptr;
726
}
727

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.