DirectStorage
687 строк · 20.9 Кб
1//
2// Copyright (c) Microsoft. All rights reserved.
3// This code is licensed under the MIT License (MIT).
4// THIS CODE IS PROVIDED *AS IS* WITHOUT WARRANTY OF
5// ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING ANY
6// IMPLIED WARRANTIES OF FITNESS FOR A PARTICULAR
7// PURPOSE, MERCHANTABILITY, OR NON-INFRINGEMENT.
8//
9
10#define NOMINMAX11
12#include "CustomDecompression.h"13
14#include <dstorage.h>15#include <dxgi1_4.h>16#include <winrt/base.h>17#include <winrt/windows.applicationmodel.datatransfer.h>18
19#include <chrono>20#include <filesystem>21#include <fstream>22#include <iostream>23#include <sstream>24
25using winrt::check_hresult;26using winrt::com_ptr;27
28void SetClipboardText(std::wstring const& str);29
30struct handle_closer31{
32void operator()(HANDLE h) noexcept33{34assert(h != INVALID_HANDLE_VALUE);35if (h)36{37CloseHandle(h);38}39}40};41using ScopedHandle = std::unique_ptr<void, handle_closer>;42
43void ShowHelpText()44{
45std::cout << "Compresses a file, saves it to disk, and then loads & decompresses using DirectStorage." << std::endl46<< std::endl;47std::cout << "USAGE: GpuDecompressionBenchmark <path> [chunk size in MiB]" << std::endl << std::endl;48std::cout << " Default chunk size is 16." << std::endl;49}
50
51struct ChunkMetadata52{
53uint32_t Offset;54uint32_t CompressedSize;55uint32_t UncompressedSize;56};57
58struct Metadata59{
60uint32_t UncompressedSize;61uint32_t CompressedSize;62uint32_t LargestCompressedChunkSize;63std::vector<ChunkMetadata> Chunks;64};65
66Metadata GenerateUncompressedMetadata(wchar_t const* filename, uint32_t chunkSizeBytes)67{
68ScopedHandle inHandle(69CreateFile(filename, GENERIC_READ, FILE_SHARE_READ, nullptr, OPEN_EXISTING, FILE_ATTRIBUTE_NORMAL, nullptr));70winrt::check_bool(inHandle.get());71
72DWORD size = GetFileSize(inHandle.get(), nullptr);73
74Metadata metadata;75metadata.UncompressedSize = size;76metadata.CompressedSize = size;77metadata.LargestCompressedChunkSize = chunkSizeBytes;78
79uint32_t offset = 0;80
81while (offset < size)82{83uint32_t chunkSize = std::min<uint32_t>(size - offset, chunkSizeBytes);84
85metadata.Chunks.push_back({offset, chunkSize, chunkSize});86offset += chunkSize;87}88
89return metadata;90}
91
92com_ptr<IDStorageCompressionCodec> GetCodec(DSTORAGE_COMPRESSION_FORMAT format)93{
94com_ptr<IDStorageCompressionCodec> codec;95switch (format)96{97case DSTORAGE_COMPRESSION_FORMAT_GDEFLATE:98check_hresult(DStorageCreateCompressionCodec(format, 0, IID_PPV_ARGS(codec.put())));99break;100
101#if USE_ZLIB102case DSTORAGE_CUSTOM_COMPRESSION_0:103codec = winrt::make<ZLibCodec>();104break;105#endif106
107default:108std::terminate();109}110
111return codec;112}
113
114Metadata Compress(115DSTORAGE_COMPRESSION_FORMAT format,116const wchar_t* originalFilename,117const wchar_t* compressedFilename,118uint32_t chunkSizeBytes)119{
120ScopedHandle inHandle(CreateFile(121originalFilename,122GENERIC_READ,123FILE_SHARE_READ,124nullptr,125OPEN_EXISTING,126FILE_ATTRIBUTE_NORMAL,127nullptr));128winrt::check_bool(inHandle.get());129
130DWORD size = GetFileSize(inHandle.get(), nullptr);131
132ScopedHandle inMapping(CreateFileMapping(inHandle.get(), nullptr, PAGE_READONLY, 0, 0, nullptr));133winrt::check_bool(inMapping.get());134
135uint8_t* srcData = reinterpret_cast<uint8_t*>(MapViewOfFile(inMapping.get(), FILE_MAP_READ, 0, 0, size));136winrt::check_bool(srcData);137
138ScopedHandle outHandle(CreateFile(139compressedFilename,140GENERIC_WRITE,141FILE_SHARE_WRITE,142nullptr,143CREATE_ALWAYS,144FILE_ATTRIBUTE_NORMAL,145nullptr));146winrt::check_bool(outHandle.get());147
148uint32_t numChunks = (size + chunkSizeBytes - 1) / chunkSizeBytes;149
150std::wcout << "Compressing " << originalFilename << " to " << compressedFilename << " in " << numChunks << "x"151<< chunkSizeBytes / 1024 / 1024 << " MiB chunks" << std::endl;152
153using Chunk = std::vector<uint8_t>;154
155std::vector<Chunk> chunks;156std::vector<uint32_t> chunkOffsets;157
158chunks.resize(numChunks);159for (uint32_t i = 0; i < numChunks; ++i)160{161uint32_t thisChunkOffset = i * chunkSizeBytes;162chunkOffsets.push_back(thisChunkOffset);163}164
165std::atomic<size_t> nextChunk = 0;166
167std::vector<std::thread> threads;168threads.reserve(std::thread::hardware_concurrency());169
170for (unsigned int i = 0; i < std::thread::hardware_concurrency(); ++i)171{172threads.emplace_back(173[&]()174{175// Each thread needs its own instance of the codec176com_ptr<IDStorageCompressionCodec> codec = GetCodec(format);177
178while (true)179{180size_t chunkIndex = nextChunk.fetch_add(1);181if (chunkIndex >= numChunks)182return;183
184size_t thisChunkOffset = chunkIndex * chunkSizeBytes;185size_t thisChunkSize = std::min<size_t>(size - thisChunkOffset, chunkSizeBytes);186
187Chunk chunk(codec->CompressBufferBound(thisChunkSize));188
189uint8_t* uncompressedStart = srcData + thisChunkOffset;190
191size_t compressedSize = 0;192check_hresult(codec->CompressBuffer(193uncompressedStart,194thisChunkSize,195DSTORAGE_COMPRESSION_BEST_RATIO,196chunk.data(),197chunk.size(),198&compressedSize));199chunk.resize(compressedSize);200
201chunks[chunkIndex] = std::move(chunk);202}203});204}205
206size_t lastNextChunk = std::numeric_limits<size_t>::max();207
208do209{210Sleep(250);211if (nextChunk != lastNextChunk)212{213lastNextChunk = nextChunk;214std::cout << " " << std::min<size_t>(numChunks, lastNextChunk + 1) << " / " << numChunks << " \r";215std::cout.flush();216}217} while (lastNextChunk < numChunks);218
219for (auto& thread : threads)220{221thread.join();222}223
224uint32_t totalCompressedSize = 0;225uint32_t offset = 0;226
227Metadata metadata;228metadata.UncompressedSize = size;229metadata.LargestCompressedChunkSize = 0;230
231for (uint32_t i = 0; i < numChunks; ++i)232{233winrt::check_bool(234WriteFile(outHandle.get(), chunks[i].data(), static_cast<DWORD>(chunks[i].size()), nullptr, nullptr));235
236uint32_t thisChunkOffset = i * chunkSizeBytes;237uint32_t thisChunkSize = std::min<uint32_t>(size - thisChunkOffset, chunkSizeBytes);238
239ChunkMetadata chunkMetadata{};240chunkMetadata.Offset = offset;241chunkMetadata.CompressedSize = static_cast<uint32_t>(chunks[i].size());242chunkMetadata.UncompressedSize = thisChunkSize;243metadata.Chunks.push_back(chunkMetadata);244
245totalCompressedSize += chunkMetadata.CompressedSize;246offset += chunkMetadata.CompressedSize;247
248metadata.LargestCompressedChunkSize =249std::max(metadata.LargestCompressedChunkSize, chunkMetadata.CompressedSize);250}251
252outHandle.reset();253
254metadata.CompressedSize = totalCompressedSize;255
256std::cout << "Total: " << size << " --> " << totalCompressedSize << " bytes (" << totalCompressedSize * 100.0 / size257<< "%) " << std::endl;258
259return metadata;260}
261
262static uint64_t GetProcessCycleTime()263{
264ULONG64 cycleTime;265
266winrt::check_bool(QueryProcessCycleTime(GetCurrentProcess(), &cycleTime));267
268return cycleTime;269}
270
271struct TestResult272{
273double Bandwidth;274uint64_t ProcessCycles;275};276
277TestResult RunTest(278IDStorageFactory* factory,279uint32_t stagingSizeMiB,280wchar_t const* sourceFilename,281DSTORAGE_COMPRESSION_FORMAT compressionFormat,282Metadata const& metadata,283int numRuns)284{
285com_ptr<IDStorageFile> file;286
287HRESULT hr = factory->OpenFile(sourceFilename, IID_PPV_ARGS(file.put()));288if (FAILED(hr))289{290std::wcout << L"The file '" << sourceFilename << L"' could not be opened. HRESULT=0x" << std::hex << hr291<< std::endl;292std::abort();293}294
295// The staging buffer size must be set before any queues are created.296std::cout << " " << stagingSizeMiB << " MiB staging buffer: ";297uint32_t stagingBufferSizeBytes = stagingSizeMiB * 1024 * 1024;298check_hresult(factory->SetStagingBufferSize(stagingBufferSizeBytes));299
300if (metadata.LargestCompressedChunkSize > stagingBufferSizeBytes)301{302std::cout << " SKIPPED! " << std::endl;303return {0, 0};304}305
306com_ptr<ID3D12Device> device;307check_hresult(D3D12CreateDevice(nullptr, D3D_FEATURE_LEVEL_12_1, IID_PPV_ARGS(&device)));308
309// Create a DirectStorage queue which will be used to load data into a310// buffer on the GPU.311DSTORAGE_QUEUE_DESC queueDesc{};312queueDesc.Capacity = DSTORAGE_MAX_QUEUE_CAPACITY;313queueDesc.Priority = DSTORAGE_PRIORITY_NORMAL;314queueDesc.SourceType = DSTORAGE_REQUEST_SOURCE_FILE;315queueDesc.Device = device.get();316
317com_ptr<IDStorageQueue> queue;318check_hresult(factory->CreateQueue(&queueDesc, IID_PPV_ARGS(queue.put())));319
320// Create the ID3D12Resource buffer which will be populated with the file's contents321D3D12_HEAP_PROPERTIES bufferHeapProps = {};322bufferHeapProps.Type = D3D12_HEAP_TYPE_DEFAULT;323
324D3D12_RESOURCE_DESC bufferDesc = {};325bufferDesc.Dimension = D3D12_RESOURCE_DIMENSION_BUFFER;326bufferDesc.Width = metadata.UncompressedSize;327bufferDesc.Height = 1;328bufferDesc.DepthOrArraySize = 1;329bufferDesc.MipLevels = 1;330bufferDesc.Format = DXGI_FORMAT_UNKNOWN;331bufferDesc.Layout = D3D12_TEXTURE_LAYOUT_ROW_MAJOR;332bufferDesc.SampleDesc.Count = 1;333
334com_ptr<ID3D12Resource> bufferResource;335check_hresult(device->CreateCommittedResource(336&bufferHeapProps,337D3D12_HEAP_FLAG_NONE,338&bufferDesc,339D3D12_RESOURCE_STATE_COMMON,340nullptr,341IID_PPV_ARGS(bufferResource.put())));342
343// Configure a fence to be signaled when the request is completed344com_ptr<ID3D12Fence> fence;345check_hresult(device->CreateFence(0, D3D12_FENCE_FLAG_NONE, IID_PPV_ARGS(fence.put())));346
347ScopedHandle fenceEvent(CreateEvent(nullptr, FALSE, FALSE, nullptr));348uint64_t fenceValue = 1;349
350double meanBandwidth = 0;351uint64_t meanCycleTime = 0;352
353for (int i = 0; i < numRuns; ++i)354{355check_hresult(fence->SetEventOnCompletion(fenceValue, fenceEvent.get()));356
357// Enqueue requests to load each compressed chunk.358uint32_t destOffset = 0;359for (auto const& chunk : metadata.Chunks)360{361DSTORAGE_REQUEST request = {};362request.Options.SourceType = DSTORAGE_REQUEST_SOURCE_FILE;363request.Options.DestinationType = DSTORAGE_REQUEST_DESTINATION_BUFFER;364request.Options.CompressionFormat = compressionFormat;365request.Source.File.Source = file.get();366request.Source.File.Offset = chunk.Offset;367request.Source.File.Size = chunk.CompressedSize;368request.UncompressedSize = chunk.UncompressedSize;369request.Destination.Buffer.Resource = bufferResource.get();370request.Destination.Buffer.Offset = destOffset;371request.Destination.Buffer.Size = chunk.UncompressedSize;372queue->EnqueueRequest(&request);373destOffset += request.UncompressedSize;374}375
376// Signal the fence when done377queue->EnqueueSignal(fence.get(), fenceValue);378
379auto startTime = std::chrono::high_resolution_clock::now();380auto startCycleTime = GetProcessCycleTime();381
382// Tell DirectStorage to start executing all queued items.383queue->Submit();384
385// Wait for the submitted work to complete386WaitForSingleObject(fenceEvent.get(), INFINITE);387
388auto endCycleTime = GetProcessCycleTime();389auto endTime = std::chrono::high_resolution_clock::now();390
391if (fence->GetCompletedValue() == (uint64_t)-1)392{393// Device removed! Give DirectStorage a chance to detect the error.394Sleep(5);395}396
397// If an error was detected the first failure record398// can be retrieved to get more details.399DSTORAGE_ERROR_RECORD errorRecord{};400queue->RetrieveErrorRecord(&errorRecord);401if (FAILED(errorRecord.FirstFailure.HResult))402{403//404// errorRecord.FailureCount - The number of failed requests in the queue since the last405// RetrieveErrorRecord call.406// errorRecord.FirstFailure - Detailed record about the first failed command in the enqueue order.407//408std::cout << "The DirectStorage request failed! HRESULT=0x" << std::hex << errorRecord.FirstFailure.HResult409<< std::endl;410
411if (errorRecord.FirstFailure.CommandType == DSTORAGE_COMMAND_TYPE_REQUEST)412{413auto& r = errorRecord.FirstFailure.Request.Request;414
415std::cout << std::dec << " " << r.Source.File.Offset << " " << r.Source.File.Size << std::endl;416}417std::terminate();418}419else420{421auto duration = endTime - startTime;422
423using dseconds = std::chrono::duration<double>;424
425double durationInSeconds = std::chrono::duration_cast<dseconds>(duration).count();426double bandwidth = (metadata.UncompressedSize / durationInSeconds) / 1000.0 / 1000.0 / 1000.0;427meanBandwidth += bandwidth;428
429meanCycleTime += (endCycleTime - startCycleTime);430
431std::cout << ".";432}433++fenceValue;434}435
436meanBandwidth /= numRuns;437meanCycleTime /= numRuns;438
439std::cout << " " << meanBandwidth << " GB/s"440<< " mean cycle time: " << std::dec << meanCycleTime << std::endl;441
442return {meanBandwidth, meanCycleTime};443}
444
445int wmain(int argc, wchar_t* argv[])446{
447enum class TestCase448{449Uncompressed,450#if USE_ZLIB451CpuZLib,452#endif453CpuGDeflate,454GpuGDeflate
455};456
457TestCase testCases[] =458{ TestCase::Uncompressed,459#if USE_ZLIB460TestCase::CpuZLib,461#endif462TestCase::CpuGDeflate,463TestCase::GpuGDeflate };464
465if (argc < 2)466{467ShowHelpText();468return -1;469}470
471const wchar_t* originalFilename = argv[1];472std::wstring gdeflateFilename = std::wstring(originalFilename) + L".gdeflate";473
474#if USE_ZLIB475std::wstring zlibFilename = std::wstring(originalFilename) + L".zlib";476#endif477
478uint32_t chunkSizeMiB = 16;479if (argc > 2)480{481chunkSizeMiB = _wtoi(argv[2]);482if (chunkSizeMiB == 0)483{484ShowHelpText();485std::wcout << std::endl << L"Invalid chunk size: " << argv[2] << std::endl;486return -1;487}488}489uint32_t chunkSizeBytes = chunkSizeMiB * 1024 * 1024;490
491Metadata uncompressedMetadata = GenerateUncompressedMetadata(originalFilename, chunkSizeBytes);492Metadata gdeflateMetadata =493Compress(DSTORAGE_COMPRESSION_FORMAT_GDEFLATE, originalFilename, gdeflateFilename.c_str(), chunkSizeBytes);494
495#if USE_ZLIB496Metadata zlibMetadata =497Compress(DSTORAGE_CUSTOM_COMPRESSION_0, originalFilename, zlibFilename.c_str(), chunkSizeBytes);498#endif499
500constexpr uint32_t MAX_STAGING_BUFFER_SIZE = 1024;501
502struct Result503{504TestCase TestCase;505uint32_t StagingBufferSizeMiB;506TestResult Data;507};508
509std::vector<Result> results;510
511for (TestCase testCase : testCases)512{513DSTORAGE_COMPRESSION_FORMAT compressionFormat;514DSTORAGE_CONFIGURATION config{};515int numRuns = 0;516Metadata* metadata = nullptr;517wchar_t const* filename = nullptr;518
519switch (testCase)520{521case TestCase::Uncompressed:522compressionFormat = DSTORAGE_COMPRESSION_FORMAT_NONE;523numRuns = 10;524metadata = &uncompressedMetadata;525filename = originalFilename;526std::cout << "Uncompressed:" << std::endl;527break;528
529#if USE_ZLIB530case TestCase::CpuZLib:531compressionFormat = DSTORAGE_CUSTOM_COMPRESSION_0;532numRuns = 2;533metadata = &zlibMetadata;534filename = zlibFilename.c_str();535std::cout << "ZLib:" << std::endl;536break;537#endif538
539case TestCase::CpuGDeflate:540compressionFormat = DSTORAGE_COMPRESSION_FORMAT_GDEFLATE;541numRuns = 2;542
543// When forcing the CPU implementation of GDEFLATE we need to go544// through the custom decompression path so we can ensure that545// GDEFLATE doesn't try and decompress directly to an upload heap.546config.NumBuiltInCpuDecompressionThreads = DSTORAGE_DISABLE_BUILTIN_CPU_DECOMPRESSION;547config.DisableGpuDecompression = true;548
549metadata = &gdeflateMetadata;550filename = gdeflateFilename.c_str();551std::cout << "CPU GDEFLATE:" << std::endl;552break;553
554case TestCase::GpuGDeflate:555compressionFormat = DSTORAGE_COMPRESSION_FORMAT_GDEFLATE;556numRuns = 10;557metadata = &gdeflateMetadata;558filename = gdeflateFilename.c_str();559std::cout << "GPU GDEFLATE:" << std::endl;560break;561
562default:563std::terminate();564}565
566check_hresult(DStorageSetConfiguration(&config));567
568com_ptr<IDStorageFactory> factory;569check_hresult(DStorageGetFactory(IID_PPV_ARGS(factory.put())));570
571factory->SetDebugFlags(DSTORAGE_DEBUG_SHOW_ERRORS | DSTORAGE_DEBUG_BREAK_ON_ERROR);572
573CustomDecompression customDecompression(factory.get(), std::thread::hardware_concurrency());574
575for (uint32_t stagingSizeMiB = 1; stagingSizeMiB <= MAX_STAGING_BUFFER_SIZE; stagingSizeMiB *= 2)576{577if (stagingSizeMiB < chunkSizeMiB)578continue;579
580TestResult data = RunTest(factory.get(), stagingSizeMiB, filename, compressionFormat, *metadata, numRuns);581
582results.push_back({testCase, stagingSizeMiB, data});583}584}585
586std::cout << "\n\n";587
588std::wstringstream bandwidth;589std::wstringstream cycles;590
591std::wstring header =592L"\"Staging Buffer Size MiB\"\t\"Uncompressed\"\t\"ZLib\"\t\"CPU GDEFLATE\"\t\"GPU GDEFLATE\"";593bandwidth << header << std::endl;594cycles << header << std::endl;595
596for (uint32_t stagingBufferSize = 1; stagingBufferSize <= MAX_STAGING_BUFFER_SIZE; stagingBufferSize *= 2)597{598std::wstringstream bandwidthRow;599std::wstringstream cyclesRow;600
601bandwidthRow << stagingBufferSize << "\t";602cyclesRow << stagingBufferSize << "\t";603
604constexpr bool showEmptyRows = true;605
606bool foundOne = false;607
608for (auto& testCase : testCases)609{610auto it = std::find_if(611results.begin(),612results.end(),613[&](Result const& r) { return r.TestCase == testCase && r.StagingBufferSizeMiB == stagingBufferSize; });614
615if (it == results.end())616{617bandwidthRow << L"\t";618cyclesRow << L"\t";619}620else621{622bandwidthRow << it->Data.Bandwidth << L"\t";623cyclesRow << it->Data.ProcessCycles << L"\t";624foundOne = true;625}626}627
628if (showEmptyRows || foundOne)629{630bandwidth << bandwidthRow.str() << std::endl;631cycles << cyclesRow.str() << std::endl;632}633}634
635std::wstringstream combined;636combined << "Bandwidth" << std::endl637<< bandwidth.str() << std::endl638<< std::endl639<< "Cycles" << std::endl640<< cycles.str() << std::endl;641
642combined << std::endl << "Compression" << std::endl;643combined << "Case\tSize\tRatio" << std::endl;644
645auto ratioLine = [&](char const* name, Metadata const& metadata)646{647combined << name << "\t" << metadata.CompressedSize << "\t"648<< static_cast<double>(metadata.CompressedSize) / static_cast<double>(metadata.UncompressedSize)649<< std::endl;650};651
652ratioLine("Uncompressed", uncompressedMetadata);653#if USE_ZLIB654ratioLine("ZLib", zlibMetadata);655#else656combined << "ZLib" << "\tn/a\tn/a" << std::endl;657#endif658ratioLine("GDEFLATE", gdeflateMetadata);659
660combined << std::endl;661
662std::wcout << combined.str();663
664try665{666SetClipboardText(combined.str());667std::wcout << "\nThese results have been copied to the clipboard, ready to paste into Excel." << std::endl;668return 0;669}670catch (...)671{672std::wcout << "\nFailed to copy results to clipboard. Sorry." << std::endl;673}674
675return 0;676}
677
678void SetClipboardText(std::wstring const& str)679{
680using namespace winrt::Windows::ApplicationModel::DataTransfer;681
682DataPackage dataPackage;683dataPackage.SetText(str);684
685Clipboard::SetContent(dataPackage);686Clipboard::Flush();687}
688