llvm-project
1764 строки · 73.8 Кб
1//===- ConvertLaunchFuncToGpuRuntimeCalls.cpp - MLIR GPU lowering passes --===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements a pass to convert gpu.launch_func op into a sequence of
10// GPU runtime calls. As most of GPU runtimes does not have a stable published
11// ABI, this pass uses a slim runtime layer that builds on top of the public
12// API from GPU runtime headers.
13//
14//===----------------------------------------------------------------------===//
15
16#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
17
18#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
19#include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h"
20#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
21#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
22#include "mlir/Conversion/ConvertToLLVM/ToLLVMPass.h"
23#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
24#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h"
25#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
26#include "mlir/Conversion/LLVMCommon/Pattern.h"
27#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
28#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
29#include "mlir/Dialect/Async/IR/Async.h"
30#include "mlir/Dialect/GPU/IR/GPUDialect.h"
31#include "mlir/Dialect/GPU/Transforms/Passes.h"
32#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
33#include "mlir/Dialect/MemRef/IR/MemRef.h"
34#include "mlir/IR/Attributes.h"
35#include "mlir/IR/Builders.h"
36#include "mlir/IR/BuiltinOps.h"
37#include "mlir/IR/BuiltinTypes.h"
38
39#include "llvm/ADT/STLExtras.h"
40#include "llvm/Support/Error.h"
41#include "llvm/Support/FormatVariadic.h"
42
43#define DEBUG_TYPE "gpu-to-llvm"
44
45namespace mlir {
46#define GEN_PASS_DEF_GPUTOLLVMCONVERSIONPASS
47#include "mlir/Conversion/Passes.h.inc"
48} // namespace mlir
49
50using namespace mlir;
51
52namespace {
53class GpuToLLVMConversionPass
54: public impl::GpuToLLVMConversionPassBase<GpuToLLVMConversionPass> {
55public:
56using Base::Base;
57void getDependentDialects(DialectRegistry ®istry) const final {
58Base::getDependentDialects(registry);
59registerConvertToLLVMDependentDialectLoading(registry);
60}
61// Run the dialect converter on the module.
62void runOnOperation() override;
63};
64
65template <typename OpTy>
66class ConvertOpToGpuRuntimeCallPattern : public ConvertOpToLLVMPattern<OpTy> {
67public:
68explicit ConvertOpToGpuRuntimeCallPattern(
69const LLVMTypeConverter &typeConverter)
70: ConvertOpToLLVMPattern<OpTy>(typeConverter) {}
71
72protected:
73Value getNumElements(ConversionPatternRewriter &rewriter, Location loc,
74MemRefType type, MemRefDescriptor desc) const {
75Type indexType = ConvertToLLVMPattern::getIndexType();
76return type.hasStaticShape()
77? ConvertToLLVMPattern::createIndexAttrConstant(
78rewriter, loc, indexType, type.getNumElements())
79// For identity maps (verified by caller), the number of
80// elements is stride[0] * size[0].
81: rewriter.create<LLVM::MulOp>(loc,
82desc.stride(rewriter, loc, 0),
83desc.size(rewriter, loc, 0));
84}
85
86MLIRContext *context = &this->getTypeConverter()->getContext();
87
88Type llvmVoidType = LLVM::LLVMVoidType::get(context);
89LLVM::LLVMPointerType llvmPointerType = LLVM::LLVMPointerType::get(context);
90Type llvmInt8Type = IntegerType::get(context, 8);
91Type llvmInt16Type = IntegerType::get(context, 16);
92Type llvmInt32Type = IntegerType::get(context, 32);
93Type llvmInt64Type = IntegerType::get(context, 64);
94Type llvmFloat32Type = Float32Type::get(context);
95Type llvmIntPtrType = IntegerType::get(
96context, this->getTypeConverter()->getPointerBitwidth(0));
97
98FunctionCallBuilder streamCreateCallBuilder = {
99"mgpuStreamCreate", llvmPointerType /* void *stream */, {}};
100FunctionCallBuilder streamDestroyCallBuilder = {
101"mgpuStreamDestroy", llvmVoidType, {llvmPointerType /* void *stream */}};
102FunctionCallBuilder streamSynchronizeCallBuilder = {
103"mgpuStreamSynchronize",
104llvmVoidType,
105{llvmPointerType /* void *stream */}};
106FunctionCallBuilder streamWaitEventCallBuilder = {
107"mgpuStreamWaitEvent",
108llvmVoidType,
109{llvmPointerType /* void *stream */, llvmPointerType /* void *event */}};
110FunctionCallBuilder eventCreateCallBuilder = {
111"mgpuEventCreate", llvmPointerType /* void *event */, {}};
112FunctionCallBuilder eventDestroyCallBuilder = {
113"mgpuEventDestroy", llvmVoidType, {llvmPointerType /* void *event */}};
114FunctionCallBuilder eventSynchronizeCallBuilder = {
115"mgpuEventSynchronize",
116llvmVoidType,
117{llvmPointerType /* void *event */}};
118FunctionCallBuilder eventRecordCallBuilder = {
119"mgpuEventRecord",
120llvmVoidType,
121{llvmPointerType /* void *event */, llvmPointerType /* void *stream */}};
122FunctionCallBuilder hostRegisterCallBuilder = {
123"mgpuMemHostRegisterMemRef",
124llvmVoidType,
125{llvmIntPtrType /* intptr_t rank */,
126llvmPointerType /* void *memrefDesc */,
127llvmIntPtrType /* intptr_t elementSizeBytes */}};
128FunctionCallBuilder hostUnregisterCallBuilder = {
129"mgpuMemHostUnregisterMemRef",
130llvmVoidType,
131{llvmIntPtrType /* intptr_t rank */,
132llvmPointerType /* void *memrefDesc */,
133llvmIntPtrType /* intptr_t elementSizeBytes */}};
134FunctionCallBuilder allocCallBuilder = {
135"mgpuMemAlloc",
136llvmPointerType /* void * */,
137{llvmIntPtrType /* intptr_t sizeBytes */,
138llvmPointerType /* void *stream */,
139llvmInt8Type /* bool isHostShared */}};
140FunctionCallBuilder deallocCallBuilder = {
141"mgpuMemFree",
142llvmVoidType,
143{llvmPointerType /* void *ptr */, llvmPointerType /* void *stream */}};
144FunctionCallBuilder memcpyCallBuilder = {
145"mgpuMemcpy",
146llvmVoidType,
147{llvmPointerType /* void *dst */, llvmPointerType /* void *src */,
148llvmIntPtrType /* intptr_t sizeBytes */,
149llvmPointerType /* void *stream */}};
150FunctionCallBuilder memset16CallBuilder = {
151"mgpuMemset16",
152llvmVoidType,
153{llvmPointerType /* void *dst */,
154llvmInt16Type /* unsigned short value */,
155llvmIntPtrType /* intptr_t sizeBytes */,
156llvmPointerType /* void *stream */}};
157FunctionCallBuilder memset32CallBuilder = {
158"mgpuMemset32",
159llvmVoidType,
160{llvmPointerType /* void *dst */, llvmInt32Type /* unsigned int value */,
161llvmIntPtrType /* intptr_t sizeBytes */,
162llvmPointerType /* void *stream */}};
163FunctionCallBuilder setDefaultDeviceCallBuilder = {
164"mgpuSetDefaultDevice",
165llvmVoidType,
166{llvmInt32Type /* uint32_t devIndex */}};
167FunctionCallBuilder createDnVecCallBuilder = {
168"mgpuCreateDnVec",
169llvmPointerType,
170{llvmIntPtrType, llvmPointerType, llvmInt32Type,
171llvmPointerType /* void *stream */}};
172FunctionCallBuilder destroyDnVecCallBuilder = {
173"mgpuDestroyDnVec",
174llvmVoidType,
175{llvmPointerType, llvmPointerType /* void *stream */}};
176FunctionCallBuilder createDnMatCallBuilder = {
177"mgpuCreateDnMat",
178llvmPointerType,
179{llvmIntPtrType, llvmIntPtrType, llvmPointerType, llvmInt32Type,
180llvmPointerType /* void *stream */}};
181FunctionCallBuilder destroyDnMatCallBuilder = {
182"mgpuDestroyDnMat",
183llvmVoidType,
184{llvmPointerType, llvmPointerType /* void *stream */}};
185FunctionCallBuilder createCooCallBuilder = {
186"mgpuCreateCoo",
187llvmPointerType,
188{llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
189llvmPointerType, llvmPointerType, llvmInt32Type, llvmInt32Type,
190llvmPointerType /* void *stream */}};
191FunctionCallBuilder createCooAoSCallBuilder = {
192"mgpuCreateCooAoS", // deprecated in cuSPARSE 11.2
193llvmPointerType,
194{llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
195llvmPointerType, llvmInt32Type, llvmInt32Type,
196llvmPointerType /* void *stream */}};
197FunctionCallBuilder createCsrCallBuilder = {
198"mgpuCreateCsr",
199llvmPointerType,
200{llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
201llvmPointerType, llvmPointerType, llvmInt32Type, llvmInt32Type,
202llvmInt32Type, llvmPointerType /* void *stream */}};
203FunctionCallBuilder createCscCallBuilder = {
204"mgpuCreateCsc",
205llvmPointerType,
206{llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
207llvmPointerType, llvmPointerType, llvmInt32Type, llvmInt32Type,
208llvmInt32Type, llvmPointerType /* void *stream */}};
209FunctionCallBuilder createBsrCallBuilder = {
210"mgpuCreateBsr",
211llvmPointerType,
212{llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmIntPtrType,
213llvmIntPtrType, llvmPointerType, llvmPointerType, llvmPointerType,
214llvmInt32Type, llvmInt32Type, llvmInt32Type,
215llvmPointerType /* void *stream */}};
216FunctionCallBuilder destroySpMatCallBuilder = {
217"mgpuDestroySpMat",
218llvmVoidType,
219{llvmPointerType, llvmPointerType /* void *stream */}};
220FunctionCallBuilder spMVBufferSizeCallBuilder = {
221"mgpuSpMVBufferSize",
222llvmIntPtrType,
223{llvmInt32Type, llvmPointerType, llvmPointerType, llvmPointerType,
224llvmInt32Type, llvmPointerType /* void *stream */}};
225FunctionCallBuilder spMVCallBuilder = {
226"mgpuSpMV",
227llvmVoidType,
228{llvmInt32Type, llvmPointerType, llvmPointerType, llvmPointerType,
229llvmInt32Type, llvmPointerType, llvmPointerType /* void *stream */}};
230FunctionCallBuilder createSpMMBufferSizeCallBuilder = {
231"mgpuSpMMBufferSize",
232llvmIntPtrType,
233{llvmInt32Type, llvmInt32Type, llvmPointerType, llvmPointerType,
234llvmPointerType, llvmInt32Type, llvmPointerType /* void *stream */}};
235FunctionCallBuilder createSpMMCallBuilder = {
236"mgpuSpMM",
237llvmVoidType,
238{llvmInt32Type, llvmInt32Type, llvmPointerType, llvmPointerType,
239llvmPointerType, llvmInt32Type, llvmPointerType,
240llvmPointerType /* void *stream */}};
241FunctionCallBuilder createSDDMMBufferSizeCallBuilder = {
242"mgpuSDDMMBufferSize",
243llvmIntPtrType,
244{llvmInt32Type, llvmInt32Type, llvmPointerType, llvmPointerType,
245llvmPointerType, llvmInt32Type, llvmPointerType /* void *stream */}};
246FunctionCallBuilder createSDDMMCallBuilder = {
247"mgpuSDDMM",
248llvmVoidType,
249{llvmInt32Type, llvmInt32Type, llvmPointerType, llvmPointerType,
250llvmPointerType, llvmInt32Type, llvmPointerType,
251llvmPointerType /* void *stream */}};
252FunctionCallBuilder createLtDnMatCallBuilder = {
253"mgpuCreateCuSparseLtDnMat",
254llvmVoidType,
255{llvmPointerType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
256llvmInt32Type, llvmPointerType /* void *stream */}};
257FunctionCallBuilder destroyCuSparseLtSpMatBuilder = {
258"mgpuDestroyCuSparseLtSpMat",
259llvmVoidType,
260{llvmPointerType, llvmPointerType /* void *stream */}};
261FunctionCallBuilder destroyCuSparseLtDnMatBuilder = {
262"mgpuDestroyCuSparseLtDnMat",
263llvmVoidType,
264{llvmPointerType, llvmPointerType /* void *stream */}};
265FunctionCallBuilder create2To4SpMatCallBuilder = {
266"mgpuCusparseLtCreate2To4SpMat",
267llvmVoidType,
268{llvmPointerType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
269llvmInt32Type, llvmPointerType /* void *stream */}};
270FunctionCallBuilder createCuSparseLtSpMMBufferSizeBuilder = {
271"mgpuCuSparseLtSpMMBufferSize",
272llvmVoidType,
273{llvmPointerType, llvmInt32Type, llvmInt32Type, llvmPointerType,
274llvmPointerType, llvmPointerType, llvmInt32Type, llvmInt32Type,
275llvmPointerType /*void *stream*/}};
276FunctionCallBuilder createCuSparseLtSpMMBuilder = {
277"mgpuCuSparseLtSpMM",
278llvmVoidType,
279{llvmPointerType, llvmPointerType, llvmPointerType, llvmPointerType,
280llvmPointerType, llvmPointerType, llvmPointerType /*void *stream*/}};
281FunctionCallBuilder createSpGEMMCreateDescrBuilder = {
282"mgpuSpGEMMCreateDescr",
283llvmPointerType,
284{llvmPointerType /*void *stream*/}};
285FunctionCallBuilder createSpGEMMDestroyDescrBuilder = {
286"mgpuSpGEMMDestroyDescr",
287llvmVoidType,
288{llvmPointerType /*s*/, llvmPointerType /*void *stream*/}};
289FunctionCallBuilder createSpGEMMWorkEstimationBuilder = {
290"mgpuSpGEMMWorkEstimation",
291llvmIntPtrType,
292{llvmPointerType /*s*/, llvmInt32Type /*ma*/, llvmInt32Type /*mb*/,
293llvmPointerType /*a*/, llvmPointerType /*b*/, llvmPointerType /*c*/,
294llvmInt32Type /*ctp*/, llvmIntPtrType /*bs*/, llvmPointerType /*buf*/,
295llvmPointerType /*void *stream*/}};
296FunctionCallBuilder createSpGEMMComputeBuilder = {
297"mgpuSpGEMMCompute",
298llvmIntPtrType,
299{llvmPointerType /*s*/, llvmInt32Type /*ma*/, llvmInt32Type /*mb*/,
300llvmPointerType /*a*/, llvmPointerType /*b*/, llvmPointerType /*c*/,
301llvmInt32Type /*ctp*/, llvmIntPtrType /*bs*/, llvmPointerType /*buf*/,
302llvmPointerType /*void *stream*/}};
303FunctionCallBuilder createSpGEMMCopyBuilder = {
304"mgpuSpGEMMCopy",
305llvmVoidType,
306{llvmPointerType /*s*/, llvmInt32Type /*ma*/, llvmInt32Type /*mb*/,
307llvmPointerType /*a*/, llvmPointerType /*b*/, llvmPointerType /*c*/,
308llvmInt32Type /*ctp*/, llvmPointerType /*void *stream*/}};
309FunctionCallBuilder createSpMatGetSizeBuilder = {
310"mgpuSpMatGetSize",
311llvmVoidType,
312{llvmPointerType /*mc*/, llvmPointerType /*rc*/, llvmPointerType /*cc*/,
313llvmPointerType /*nc*/, llvmPointerType /*void *stream*/}};
314FunctionCallBuilder createSetCsrPointersBuilder = {
315"mgpuSetCsrPointers",
316llvmVoidType,
317{llvmPointerType /*spmat*/, llvmPointerType /*pos*/,
318llvmPointerType /*crd*/, llvmPointerType /*val*/,
319llvmPointerType /*void *stream*/}};
320};
321
322/// A rewrite pattern to convert gpu.host_register operations into a GPU runtime
323/// call. Currently it supports CUDA and ROCm (HIP).
324class ConvertHostRegisterOpToGpuRuntimeCallPattern
325: public ConvertOpToGpuRuntimeCallPattern<gpu::HostRegisterOp> {
326public:
327ConvertHostRegisterOpToGpuRuntimeCallPattern(
328const LLVMTypeConverter &typeConverter)
329: ConvertOpToGpuRuntimeCallPattern<gpu::HostRegisterOp>(typeConverter) {}
330
331private:
332LogicalResult
333matchAndRewrite(gpu::HostRegisterOp hostRegisterOp, OpAdaptor adaptor,
334ConversionPatternRewriter &rewriter) const override;
335};
336
337class ConvertHostUnregisterOpToGpuRuntimeCallPattern
338: public ConvertOpToGpuRuntimeCallPattern<gpu::HostUnregisterOp> {
339public:
340ConvertHostUnregisterOpToGpuRuntimeCallPattern(
341const LLVMTypeConverter &typeConverter)
342: ConvertOpToGpuRuntimeCallPattern<gpu::HostUnregisterOp>(typeConverter) {
343}
344
345private:
346LogicalResult
347matchAndRewrite(gpu::HostUnregisterOp hostUnregisterOp, OpAdaptor adaptor,
348ConversionPatternRewriter &rewriter) const override;
349};
350
351/// A rewrite pattern to convert gpu.alloc operations into a GPU runtime
352/// call. Currently it supports CUDA and ROCm (HIP).
353class ConvertAllocOpToGpuRuntimeCallPattern
354: public ConvertOpToGpuRuntimeCallPattern<gpu::AllocOp> {
355public:
356ConvertAllocOpToGpuRuntimeCallPattern(const LLVMTypeConverter &typeConverter)
357: ConvertOpToGpuRuntimeCallPattern<gpu::AllocOp>(typeConverter) {}
358
359private:
360LogicalResult
361matchAndRewrite(gpu::AllocOp allocOp, OpAdaptor adaptor,
362ConversionPatternRewriter &rewriter) const override;
363};
364
365/// A rewrite pattern to convert gpu.dealloc operations into a GPU runtime
366/// call. Currently it supports CUDA and ROCm (HIP).
367class ConvertDeallocOpToGpuRuntimeCallPattern
368: public ConvertOpToGpuRuntimeCallPattern<gpu::DeallocOp> {
369public:
370ConvertDeallocOpToGpuRuntimeCallPattern(
371const LLVMTypeConverter &typeConverter)
372: ConvertOpToGpuRuntimeCallPattern<gpu::DeallocOp>(typeConverter) {}
373
374private:
375LogicalResult
376matchAndRewrite(gpu::DeallocOp deallocOp, OpAdaptor adaptor,
377ConversionPatternRewriter &rewriter) const override;
378};
379
380class ConvertAsyncYieldToGpuRuntimeCallPattern
381: public ConvertOpToGpuRuntimeCallPattern<async::YieldOp> {
382public:
383ConvertAsyncYieldToGpuRuntimeCallPattern(
384const LLVMTypeConverter &typeConverter)
385: ConvertOpToGpuRuntimeCallPattern<async::YieldOp>(typeConverter) {}
386
387private:
388LogicalResult
389matchAndRewrite(async::YieldOp yieldOp, OpAdaptor adaptor,
390ConversionPatternRewriter &rewriter) const override;
391};
392
393/// A rewrite pattern to convert gpu.wait operations into a GPU runtime
394/// call. Currently it supports CUDA and ROCm (HIP).
395class ConvertWaitOpToGpuRuntimeCallPattern
396: public ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp> {
397public:
398ConvertWaitOpToGpuRuntimeCallPattern(const LLVMTypeConverter &typeConverter)
399: ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp>(typeConverter) {}
400
401private:
402LogicalResult
403matchAndRewrite(gpu::WaitOp waitOp, OpAdaptor adaptor,
404ConversionPatternRewriter &rewriter) const override;
405};
406
407/// A rewrite pattern to convert gpu.wait async operations into a GPU runtime
408/// call. Currently it supports CUDA and ROCm (HIP).
409class ConvertWaitAsyncOpToGpuRuntimeCallPattern
410: public ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp> {
411public:
412ConvertWaitAsyncOpToGpuRuntimeCallPattern(
413const LLVMTypeConverter &typeConverter)
414: ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp>(typeConverter) {}
415
416private:
417LogicalResult
418matchAndRewrite(gpu::WaitOp waitOp, OpAdaptor adaptor,
419ConversionPatternRewriter &rewriter) const override;
420};
421
422/// A rewrite patter to legalize gpu.launch_func with LLVM types.
423class LegalizeLaunchFuncOpPattern
424: public ConvertOpToGpuRuntimeCallPattern<gpu::LaunchFuncOp> {
425public:
426LegalizeLaunchFuncOpPattern(const LLVMTypeConverter &typeConverter,
427bool kernelBarePtrCallConv)
428: ConvertOpToGpuRuntimeCallPattern<gpu::LaunchFuncOp>(typeConverter),
429kernelBarePtrCallConv(kernelBarePtrCallConv) {}
430
431private:
432LogicalResult
433matchAndRewrite(gpu::LaunchFuncOp launchOp, OpAdaptor adaptor,
434ConversionPatternRewriter &rewriter) const override;
435
436bool kernelBarePtrCallConv;
437};
438
439/// A rewrite pattern to convert gpu.memcpy operations into a GPU runtime
440/// call. Currently it supports CUDA and ROCm (HIP).
441class ConvertMemcpyOpToGpuRuntimeCallPattern
442: public ConvertOpToGpuRuntimeCallPattern<gpu::MemcpyOp> {
443public:
444ConvertMemcpyOpToGpuRuntimeCallPattern(const LLVMTypeConverter &typeConverter)
445: ConvertOpToGpuRuntimeCallPattern<gpu::MemcpyOp>(typeConverter) {}
446
447private:
448LogicalResult
449matchAndRewrite(gpu::MemcpyOp memcpyOp, OpAdaptor adaptor,
450ConversionPatternRewriter &rewriter) const override;
451};
452
453/// A rewrite pattern to convert gpu.memset operations into a GPU runtime
454/// call. Currently it supports CUDA and ROCm (HIP).
455class ConvertMemsetOpToGpuRuntimeCallPattern
456: public ConvertOpToGpuRuntimeCallPattern<gpu::MemsetOp> {
457public:
458ConvertMemsetOpToGpuRuntimeCallPattern(const LLVMTypeConverter &typeConverter)
459: ConvertOpToGpuRuntimeCallPattern<gpu::MemsetOp>(typeConverter) {}
460
461private:
462LogicalResult
463matchAndRewrite(gpu::MemsetOp memsetOp, OpAdaptor adaptor,
464ConversionPatternRewriter &rewriter) const override;
465};
466
467/// A rewrite pattern to convert gpu.set_default_device to a GPU runtime call.
468/// Currently supports CUDA and ROCm (HIP)
469class ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern
470: public ConvertOpToGpuRuntimeCallPattern<gpu::SetDefaultDeviceOp> {
471public:
472ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern(
473const LLVMTypeConverter &typeConverter)
474: ConvertOpToGpuRuntimeCallPattern<gpu::SetDefaultDeviceOp>(
475typeConverter) {}
476
477LogicalResult
478matchAndRewrite(gpu::SetDefaultDeviceOp op, OpAdaptor adaptor,
479ConversionPatternRewriter &rewriter) const override;
480};
481
482/// Generic rewriting rule for operation on sparse matrices.
483/// Currently supports CUDA (by means of cuSparse and cuSparseLt).
484#define DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(op_name) \
485class Convert##op_name##ToGpuRuntimeCallPattern \
486: public ConvertOpToGpuRuntimeCallPattern<gpu::op_name> { \
487public: \
488Convert##op_name##ToGpuRuntimeCallPattern( \
489const LLVMTypeConverter &typeConverter) \
490: ConvertOpToGpuRuntimeCallPattern<gpu::op_name>(typeConverter) {} \
491\
492private: \
493LogicalResult \
494matchAndRewrite(gpu::op_name op, OpAdaptor adaptor, \
495ConversionPatternRewriter &rewriter) const override; \
496};
497
498DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(CreateDnTensorOp)
499DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(DestroyDnTensorOp)
500DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(CreateCooOp)
501DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(CreateCooAoSOp)
502DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(CreateCsrOp)
503DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(CreateCscOp)
504DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(CreateBsrOp)
505DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(Create2To4SpMatOp)
506DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(DestroySpMatOp)
507DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SpMVBufferSizeOp)
508DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SpMVOp)
509DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SpMMBufferSizeOp)
510DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SDDMMBufferSizeOp)
511DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SpMMOp)
512DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SDDMMOp)
513DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SpGEMMCreateDescrOp)
514DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SpGEMMDestroyDescrOp)
515DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SpGEMMWorkEstimationOrComputeOp)
516DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SpGEMMCopyOp)
517DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SpMatGetSizeOp)
518DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SetCsrPointersOp)
519
520} // namespace
521
522void GpuToLLVMConversionPass::runOnOperation() {
523MLIRContext *context = &getContext();
524LowerToLLVMOptions options(context);
525options.useBarePtrCallConv = hostBarePtrCallConv;
526RewritePatternSet patterns(context);
527ConversionTarget target(*context);
528target.addLegalDialect<LLVM::LLVMDialect>();
529LLVMTypeConverter converter(context, options);
530
531// Populate all patterns from all dialects that implement the
532// `ConvertToLLVMPatternInterface` interface.
533for (Dialect *dialect : context->getLoadedDialects()) {
534auto iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
535if (!iface)
536continue;
537iface->populateConvertToLLVMConversionPatterns(target, converter, patterns);
538}
539
540// Preserve GPU modules and binaries. Modules are preserved as they can be
541// converted later by `gpu-module-to-binary`.
542target.addLegalOp<gpu::GPUModuleOp, gpu::BinaryOp>();
543// Accept as legal LaunchFuncOps if the operands have been lowered.
544target.addDynamicallyLegalOp<gpu::LaunchFuncOp>(
545[&](gpu::LaunchFuncOp op) -> bool { return converter.isLegal(op); });
546
547// These aren't covered by the ConvertToLLVMPatternInterface right now.
548populateVectorToLLVMConversionPatterns(converter, patterns);
549populateFinalizeMemRefToLLVMConversionPatterns(converter, patterns);
550populateAsyncStructuralTypeConversionsAndLegality(converter, patterns,
551target);
552populateGpuToLLVMConversionPatterns(converter, patterns,
553kernelBarePtrCallConv);
554
555if (failed(
556applyPartialConversion(getOperation(), target, std::move(patterns))))
557signalPassFailure();
558}
559
560LLVM::CallOp FunctionCallBuilder::create(Location loc, OpBuilder &builder,
561ArrayRef<Value> arguments) const {
562auto module = builder.getBlock()->getParent()->getParentOfType<ModuleOp>();
563auto function = [&] {
564if (auto function = module.lookupSymbol<LLVM::LLVMFuncOp>(functionName))
565return function;
566return OpBuilder::atBlockEnd(module.getBody())
567.create<LLVM::LLVMFuncOp>(loc, functionName, functionType);
568}();
569return builder.create<LLVM::CallOp>(loc, function, arguments);
570}
571
572// Corresponding to cusparseIndexType_t defined in cusparse.h.
573static int32_t getCuSparseIndexTypeFrom(Type type) {
574if (type.isInteger(16))
575return 1; // CUSPARSE_INDEX_16U
576if (type.isInteger(32))
577return 2; // CUSPARSE_INDEX_32I
578return 3; // CUSPARSE_INDEX_64I
579}
580
581static int32_t getCuSparseLtDataTypeFrom(Type type) {
582if (type.isF16())
583return 0; // CUSPARSE_COMPUTE_16F,
584if (type.isInteger(32))
585return 1; // CUSPARSE_COMPUTE_32I
586llvm_unreachable("unsupported type");
587// TODO: add support to TF32
588}
589
590// Corresponding to cudaDataType_t defined in CUDA library_types.h.
591static int32_t getCuSparseDataTypeFrom(Type type) {
592if (llvm::isa<ComplexType>(type)) {
593// get the element type
594auto elementType = cast<ComplexType>(type).getElementType();
595if (elementType.isBF16())
596return 15; // CUDA_C_16BF
597if (elementType.isF16())
598return 6; // CUDA_C_16F
599if (elementType.isF32())
600return 4; // CUDA_C_32F
601if (elementType.isF64())
602return 5; // CUDA_C_64F
603if (elementType.isInteger(8))
604return 7; // CUDA_C_8I
605if (elementType.isInteger(16))
606return 21; // CUDA_C_16I
607if (elementType.isInteger(32))
608return 11; // CUDA_C_32I
609}
610if (type.isBF16())
611return 14; // CUDA_R_16BF
612if (type.isF16())
613return 2; // CUDA_R_16F
614if (type.isF32())
615return 0; // CUDA_R_32F
616if (type.isF64())
617return 1; // CUDA_R_64F
618if (type.isInteger(8))
619return 3; // CUDA_R_8I
620if (type.isInteger(16))
621return 20; // CUDA_R_16I
622if (type.isInteger(32))
623return 10; // CUDA_R_32I
624
625llvm_unreachable("unsupported element type");
626}
627
628static gpu::Prune2To4SpMatFlag get2To4PruneFlag(Value spMat) {
629return spMat.getDefiningOp<gpu::Create2To4SpMatOp>().getPruneFlag();
630}
631
632// TODO: We may want a run-time (of the mlir compiler) disablement/warning:
633// cusparseLt currently won't work for cuda architecture <8.0 and will trigger a
634// runtime (of the CUDA program) error , but it might be great if we could at
635// least output a warning when we found the target architecture is <8.0 and the
636// user still wants to use cusparseLt. to make sure when lowering gpu sparse
637// dialect to llvm calls, the cusparselt calls are disabled for cuda
638// architecture <8.0
639static bool is2To4Sparsity(Value spMat) {
640if (auto op = spMat.getDefiningOp<gpu::Create2To4SpMatOp>())
641return true;
642if (auto op = spMat.getDefiningOp<gpu::CreateCooOp>())
643return false;
644if (auto op = spMat.getDefiningOp<gpu::CreateCooAoSOp>())
645return false;
646if (auto op = spMat.getDefiningOp<gpu::CreateCsrOp>())
647return false;
648if (auto op = spMat.getDefiningOp<gpu::CreateCscOp>())
649return false;
650if (auto op = spMat.getDefiningOp<gpu::CreateBsrOp>())
651return false;
652// Print the spMat defining op
653spMat.getDefiningOp()->print(llvm::errs());
654llvm_unreachable("cannot find spmat def");
655}
656
657static bool isSpMMCusparseLtOp(Value op) {
658for (Operation *user : op.getUsers()) {
659auto spmmOp = dyn_cast<gpu::SpMMOp>(user);
660// If the other operator is 50% sparsity then we should use cusparseLt
661if (!spmmOp)
662continue;
663if (is2To4Sparsity(spmmOp.getSpmatA()))
664return true;
665}
666return false;
667}
668
669// Returns whether all operands are of LLVM type.
670static LogicalResult areAllLLVMTypes(Operation *op, ValueRange operands,
671ConversionPatternRewriter &rewriter) {
672if (!llvm::all_of(operands, [](Value value) {
673return LLVM::isCompatibleType(value.getType());
674}))
675return rewriter.notifyMatchFailure(
676op, "Cannot convert if operands aren't of LLVM type.");
677return success();
678}
679
680static LogicalResult
681isAsyncWithOneDependency(ConversionPatternRewriter &rewriter,
682gpu::AsyncOpInterface op) {
683if (op.getAsyncDependencies().size() != 1)
684return rewriter.notifyMatchFailure(
685op, "Can only convert with exactly one async dependency.");
686
687if (!op.getAsyncToken())
688return rewriter.notifyMatchFailure(op, "Can convert only async version.");
689
690return success();
691}
692
693LogicalResult ConvertHostRegisterOpToGpuRuntimeCallPattern::matchAndRewrite(
694gpu::HostRegisterOp hostRegisterOp, OpAdaptor adaptor,
695ConversionPatternRewriter &rewriter) const {
696auto *op = hostRegisterOp.getOperation();
697if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)))
698return failure();
699
700Location loc = op->getLoc();
701
702auto memRefType = hostRegisterOp.getValue().getType();
703auto elementType = cast<UnrankedMemRefType>(memRefType).getElementType();
704auto elementSize = getSizeInBytes(loc, elementType, rewriter);
705
706auto arguments = getTypeConverter()->promoteOperands(
707loc, op->getOperands(), adaptor.getOperands(), rewriter);
708arguments.push_back(elementSize);
709hostRegisterCallBuilder.create(loc, rewriter, arguments);
710
711rewriter.eraseOp(op);
712return success();
713}
714
715LogicalResult ConvertHostUnregisterOpToGpuRuntimeCallPattern::matchAndRewrite(
716gpu::HostUnregisterOp hostUnregisterOp, OpAdaptor adaptor,
717ConversionPatternRewriter &rewriter) const {
718Operation *op = hostUnregisterOp.getOperation();
719if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)))
720return failure();
721
722Location loc = op->getLoc();
723
724auto memRefType = hostUnregisterOp.getValue().getType();
725auto elementType = cast<UnrankedMemRefType>(memRefType).getElementType();
726auto elementSize = getSizeInBytes(loc, elementType, rewriter);
727
728auto arguments = getTypeConverter()->promoteOperands(
729loc, op->getOperands(), adaptor.getOperands(), rewriter);
730arguments.push_back(elementSize);
731hostUnregisterCallBuilder.create(loc, rewriter, arguments);
732
733rewriter.eraseOp(op);
734return success();
735}
736
737LogicalResult ConvertAllocOpToGpuRuntimeCallPattern::matchAndRewrite(
738gpu::AllocOp allocOp, OpAdaptor adaptor,
739ConversionPatternRewriter &rewriter) const {
740
741MemRefType memRefType = allocOp.getType();
742
743if (failed(areAllLLVMTypes(allocOp, adaptor.getOperands(), rewriter)) ||
744!isConvertibleAndHasIdentityMaps(memRefType))
745return failure();
746
747auto loc = allocOp.getLoc();
748
749bool isShared = allocOp.getHostShared();
750
751if (isShared && allocOp.getAsyncToken())
752return rewriter.notifyMatchFailure(
753allocOp, "Host Shared allocation cannot be done async");
754if (!isShared && failed(isAsyncWithOneDependency(rewriter, allocOp)))
755return failure();
756
757// Get shape of the memref as values: static sizes are constant
758// values and dynamic sizes are passed to 'alloc' as operands.
759SmallVector<Value, 4> shape;
760SmallVector<Value, 4> strides;
761Value sizeBytes;
762getMemRefDescriptorSizes(loc, memRefType, adaptor.getDynamicSizes(), rewriter,
763shape, strides, sizeBytes);
764
765// Allocate the underlying buffer and store a pointer to it in the MemRef
766// descriptor.
767auto nullPtr = rewriter.create<mlir::LLVM::ZeroOp>(loc, llvmPointerType);
768Value stream = adaptor.getAsyncDependencies().empty()
769? nullPtr
770: adaptor.getAsyncDependencies().front();
771
772auto isHostShared = rewriter.create<mlir::LLVM::ConstantOp>(
773loc, llvmInt8Type, rewriter.getI8IntegerAttr(isShared));
774
775Value allocatedPtr =
776allocCallBuilder.create(loc, rewriter, {sizeBytes, stream, isHostShared})
777.getResult();
778
779// No alignment.
780Value alignedPtr = allocatedPtr;
781
782// Create the MemRef descriptor.
783auto memRefDescriptor = this->createMemRefDescriptor(
784loc, memRefType, allocatedPtr, alignedPtr, shape, strides, rewriter);
785
786if (allocOp.getAsyncToken()) {
787// Async alloc: make dependent ops use the same stream.
788rewriter.replaceOp(allocOp, {memRefDescriptor, stream});
789} else {
790rewriter.replaceOp(allocOp, {memRefDescriptor});
791}
792
793return success();
794}
795
796LogicalResult ConvertDeallocOpToGpuRuntimeCallPattern::matchAndRewrite(
797gpu::DeallocOp deallocOp, OpAdaptor adaptor,
798ConversionPatternRewriter &rewriter) const {
799if (failed(areAllLLVMTypes(deallocOp, adaptor.getOperands(), rewriter)) ||
800failed(isAsyncWithOneDependency(rewriter, deallocOp)))
801return failure();
802
803Location loc = deallocOp.getLoc();
804
805Value pointer =
806MemRefDescriptor(adaptor.getMemref()).allocatedPtr(rewriter, loc);
807Value stream = adaptor.getAsyncDependencies().front();
808deallocCallBuilder.create(loc, rewriter, {pointer, stream});
809
810rewriter.replaceOp(deallocOp, {stream});
811return success();
812}
813
814static bool isGpuAsyncTokenType(Value value) {
815return isa<gpu::AsyncTokenType>(value.getType());
816}
817
818// Converts !gpu.async.token operands of `async.yield` to runtime calls. The
819// !gpu.async.token are lowered to stream within the async.execute region, but
820// are passed as events between them. For each !gpu.async.token operand, we
821// create an event and record it on the stream.
822LogicalResult ConvertAsyncYieldToGpuRuntimeCallPattern::matchAndRewrite(
823async::YieldOp yieldOp, OpAdaptor adaptor,
824ConversionPatternRewriter &rewriter) const {
825if (llvm::none_of(yieldOp.getOperands(), isGpuAsyncTokenType))
826return rewriter.notifyMatchFailure(yieldOp, "no gpu async token operand");
827
828Location loc = yieldOp.getLoc();
829SmallVector<Value, 4> newOperands(adaptor.getOperands());
830llvm::SmallDenseSet<Value> streams;
831for (auto &operand : yieldOp->getOpOperands()) {
832if (!isGpuAsyncTokenType(operand.get()))
833continue;
834auto idx = operand.getOperandNumber();
835auto stream = adaptor.getOperands()[idx];
836auto event = eventCreateCallBuilder.create(loc, rewriter, {}).getResult();
837eventRecordCallBuilder.create(loc, rewriter, {event, stream});
838newOperands[idx] = event;
839streams.insert(stream);
840}
841for (auto stream : streams)
842streamDestroyCallBuilder.create(loc, rewriter, {stream});
843
844rewriter.modifyOpInPlace(yieldOp, [&] { yieldOp->setOperands(newOperands); });
845return success();
846}
847
848// Returns whether `value` is the result of an LLVM::CallOp to `functionName`.
849static bool isDefinedByCallTo(Value value, StringRef functionName) {
850assert(isa<LLVM::LLVMPointerType>(value.getType()));
851if (auto defOp = value.getDefiningOp<LLVM::CallOp>())
852return *defOp.getCallee() == functionName;
853return false;
854}
855
856// Converts `gpu.wait` to runtime calls. The converted op synchronizes the host
857// with the stream/event operands. The operands are destroyed. That is, it
858// assumes that it is not used afterwards or elsewhere. Otherwise we will get a
859// runtime error. Eventually, we should guarantee this property.
860LogicalResult ConvertWaitOpToGpuRuntimeCallPattern::matchAndRewrite(
861gpu::WaitOp waitOp, OpAdaptor adaptor,
862ConversionPatternRewriter &rewriter) const {
863if (waitOp.getAsyncToken())
864return rewriter.notifyMatchFailure(waitOp, "Cannot convert async op.");
865
866Location loc = waitOp.getLoc();
867
868for (auto operand : adaptor.getOperands()) {
869if (isDefinedByCallTo(operand, streamCreateCallBuilder.functionName)) {
870// The converted operand's definition created a stream.
871streamSynchronizeCallBuilder.create(loc, rewriter, {operand});
872streamDestroyCallBuilder.create(loc, rewriter, {operand});
873} else {
874// Otherwise the converted operand is an event. This assumes that we use
875// events in control flow code as well.
876eventSynchronizeCallBuilder.create(loc, rewriter, {operand});
877eventDestroyCallBuilder.create(loc, rewriter, {operand});
878}
879}
880
881rewriter.eraseOp(waitOp);
882return success();
883}
884
885// Converts `gpu.wait async` to runtime calls. The converted op creates a new
886// stream that is synchronized with stream/event operands. The operands are
887// destroyed. That is, it assumes that it is not used afterwards or elsewhere.
888// Otherwise we will get a runtime error. Eventually, we should guarantee this
889// property.
890LogicalResult ConvertWaitAsyncOpToGpuRuntimeCallPattern::matchAndRewrite(
891gpu::WaitOp waitOp, OpAdaptor adaptor,
892ConversionPatternRewriter &rewriter) const {
893if (!waitOp.getAsyncToken())
894return rewriter.notifyMatchFailure(waitOp, "Can only convert async op.");
895
896Location loc = waitOp.getLoc();
897
898auto insertionPoint = rewriter.saveInsertionPoint();
899SmallVector<Value, 1> events;
900for (auto pair :
901llvm::zip(waitOp.getAsyncDependencies(), adaptor.getOperands())) {
902auto operand = std::get<1>(pair);
903if (isDefinedByCallTo(operand, streamCreateCallBuilder.functionName)) {
904// The converted operand's definition created a stream. Insert an event
905// into the stream just after the last use of the original token operand.
906auto *defOp = std::get<0>(pair).getDefiningOp();
907rewriter.setInsertionPointAfter(defOp);
908auto event = eventCreateCallBuilder.create(loc, rewriter, {}).getResult();
909eventRecordCallBuilder.create(loc, rewriter, {event, operand});
910events.push_back(event);
911} else {
912// Otherwise the converted operand is an event. This assumes that we use
913// events in control flow code as well.
914events.push_back(operand);
915}
916}
917rewriter.restoreInsertionPoint(insertionPoint);
918auto stream = streamCreateCallBuilder.create(loc, rewriter, {}).getResult();
919for (auto event : events)
920streamWaitEventCallBuilder.create(loc, rewriter, {stream, event});
921for (auto event : events)
922eventDestroyCallBuilder.create(loc, rewriter, {event});
923rewriter.replaceOp(waitOp, {stream});
924
925return success();
926}
927
928// Legalize the op's operands.
929LogicalResult LegalizeLaunchFuncOpPattern::matchAndRewrite(
930gpu::LaunchFuncOp launchOp, OpAdaptor adaptor,
931ConversionPatternRewriter &rewriter) const {
932if (failed(areAllLLVMTypes(launchOp, adaptor.getOperands(), rewriter)))
933return failure();
934
935if (launchOp.getAsyncDependencies().size() > 1)
936return rewriter.notifyMatchFailure(
937launchOp, "Cannot convert with more than one async dependency.");
938
939// Fail when the synchronous version of the op has async dependencies. The
940// lowering destroys the stream, and we do not want to check that there is no
941// use of the stream after this op.
942if (!launchOp.getAsyncToken() && !launchOp.getAsyncDependencies().empty())
943return rewriter.notifyMatchFailure(
944launchOp, "Cannot convert non-async op with async dependencies.");
945
946Location loc = launchOp.getLoc();
947
948Value stream = Value();
949if (!adaptor.getAsyncDependencies().empty())
950stream = adaptor.getAsyncDependencies().front();
951// If the async keyword is present and there are no dependencies, then a
952// stream must be created to pass to subsequent operations.
953else if (launchOp.getAsyncToken())
954stream = streamCreateCallBuilder.create(loc, rewriter, {}).getResult();
955// Lower the kernel operands to match kernel parameters.
956// Note: If `useBarePtrCallConv` is set in the type converter's options,
957// the value of `kernelBarePtrCallConv` will be ignored.
958SmallVector<Value, 4> arguments = getTypeConverter()->promoteOperands(
959loc, launchOp.getKernelOperands(), adaptor.getKernelOperands(), rewriter,
960/*useBarePtrCallConv=*/kernelBarePtrCallConv);
961
962std::optional<gpu::KernelDim3> clusterSize = std::nullopt;
963if (launchOp.hasClusterSize()) {
964clusterSize =
965gpu::KernelDim3{adaptor.getClusterSizeX(), adaptor.getClusterSizeY(),
966adaptor.getClusterSizeZ()};
967}
968rewriter.create<gpu::LaunchFuncOp>(
969launchOp.getLoc(), launchOp.getKernelAttr(),
970gpu::KernelDim3{adaptor.getGridSizeX(), adaptor.getGridSizeY(),
971adaptor.getGridSizeZ()},
972gpu::KernelDim3{adaptor.getBlockSizeX(), adaptor.getBlockSizeY(),
973adaptor.getBlockSizeZ()},
974adaptor.getDynamicSharedMemorySize(), arguments, stream, clusterSize);
975if (launchOp.getAsyncToken())
976rewriter.replaceOp(launchOp, {stream});
977else
978rewriter.eraseOp(launchOp);
979return success();
980}
981
982static Value bitAndAddrspaceCast(Location loc,
983ConversionPatternRewriter &rewriter,
984LLVM::LLVMPointerType destinationType,
985Value sourcePtr,
986const LLVMTypeConverter &typeConverter) {
987auto sourceTy = cast<LLVM::LLVMPointerType>(sourcePtr.getType());
988if (destinationType.getAddressSpace() != sourceTy.getAddressSpace())
989sourcePtr = rewriter.create<LLVM::AddrSpaceCastOp>(
990loc,
991LLVM::LLVMPointerType::get(rewriter.getContext(),
992destinationType.getAddressSpace()),
993sourcePtr);
994return sourcePtr;
995}
996
997LogicalResult ConvertMemcpyOpToGpuRuntimeCallPattern::matchAndRewrite(
998gpu::MemcpyOp memcpyOp, OpAdaptor adaptor,
999ConversionPatternRewriter &rewriter) const {
1000auto memRefType = cast<MemRefType>(memcpyOp.getSrc().getType());
1001
1002if (failed(areAllLLVMTypes(memcpyOp, adaptor.getOperands(), rewriter)) ||
1003!isConvertibleAndHasIdentityMaps(memRefType) ||
1004failed(isAsyncWithOneDependency(rewriter, memcpyOp)))
1005return failure();
1006
1007auto loc = memcpyOp.getLoc();
1008
1009MemRefDescriptor srcDesc(adaptor.getSrc());
1010Value numElements = getNumElements(rewriter, loc, memRefType, srcDesc);
1011
1012Type elementPtrType = getElementPtrType(memRefType);
1013Value nullPtr = rewriter.create<LLVM::ZeroOp>(loc, elementPtrType);
1014Value gepPtr = rewriter.create<LLVM::GEPOp>(
1015loc, elementPtrType,
1016typeConverter->convertType(memRefType.getElementType()), nullPtr,
1017numElements);
1018auto sizeBytes =
1019rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), gepPtr);
1020
1021auto src = bitAndAddrspaceCast(loc, rewriter, llvmPointerType,
1022srcDesc.alignedPtr(rewriter, loc),
1023*getTypeConverter());
1024auto dst = bitAndAddrspaceCast(
1025loc, rewriter, llvmPointerType,
1026MemRefDescriptor(adaptor.getDst()).alignedPtr(rewriter, loc),
1027*getTypeConverter());
1028
1029auto stream = adaptor.getAsyncDependencies().front();
1030memcpyCallBuilder.create(loc, rewriter, {dst, src, sizeBytes, stream});
1031
1032rewriter.replaceOp(memcpyOp, {stream});
1033
1034return success();
1035}
1036
1037LogicalResult ConvertMemsetOpToGpuRuntimeCallPattern::matchAndRewrite(
1038gpu::MemsetOp memsetOp, OpAdaptor adaptor,
1039ConversionPatternRewriter &rewriter) const {
1040auto memRefType = cast<MemRefType>(memsetOp.getDst().getType());
1041
1042if (failed(areAllLLVMTypes(memsetOp, adaptor.getOperands(), rewriter)) ||
1043!isConvertibleAndHasIdentityMaps(memRefType) ||
1044failed(isAsyncWithOneDependency(rewriter, memsetOp)))
1045return failure();
1046
1047auto loc = memsetOp.getLoc();
1048
1049Type valueType = adaptor.getValue().getType();
1050unsigned bitWidth = valueType.getIntOrFloatBitWidth();
1051// Ints and floats of 16 or 32 bit width are allowed.
1052if (!valueType.isIntOrFloat() || (bitWidth != 16 && bitWidth != 32)) {
1053return rewriter.notifyMatchFailure(
1054memsetOp, "value must be a 16 or 32 bit int or float");
1055}
1056
1057unsigned valueTypeWidth = valueType.getIntOrFloatBitWidth();
1058Type bitCastType = valueTypeWidth == 32 ? llvmInt32Type : llvmInt16Type;
1059
1060MemRefDescriptor dstDesc(adaptor.getDst());
1061Value numElements = getNumElements(rewriter, loc, memRefType, dstDesc);
1062
1063auto value =
1064rewriter.create<LLVM::BitcastOp>(loc, bitCastType, adaptor.getValue());
1065auto dst = bitAndAddrspaceCast(loc, rewriter, llvmPointerType,
1066dstDesc.alignedPtr(rewriter, loc),
1067*getTypeConverter());
1068
1069auto stream = adaptor.getAsyncDependencies().front();
1070FunctionCallBuilder builder =
1071valueTypeWidth == 32 ? memset32CallBuilder : memset16CallBuilder;
1072builder.create(loc, rewriter, {dst, value, numElements, stream});
1073
1074rewriter.replaceOp(memsetOp, {stream});
1075return success();
1076}
1077
1078LogicalResult ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern::matchAndRewrite(
1079gpu::SetDefaultDeviceOp op, OpAdaptor adaptor,
1080ConversionPatternRewriter &rewriter) const {
1081Location loc = op.getLoc();
1082auto call = setDefaultDeviceCallBuilder.create(loc, rewriter,
1083{adaptor.getDevIndex()});
1084rewriter.replaceOp(op, call);
1085return success();
1086}
1087
1088template <typename T>
1089static Value genConstInt32From(OpBuilder &builder, Location loc, T tValue) {
1090Type llvmInt32Type = builder.getIntegerType(32);
1091return builder.create<LLVM::ConstantOp>(loc, llvmInt32Type,
1092static_cast<int32_t>(tValue));
1093}
1094
1095template <typename T>
1096static Value genConstFloat32From(OpBuilder &builder, Location loc, T tValue) {
1097Type llvmFloat32Type = builder.getF32Type();
1098return builder.create<LLVM::ConstantOp>(
1099loc, llvmFloat32Type,
1100builder.getF32FloatAttr(static_cast<float>(tValue)));
1101}
1102
1103LogicalResult ConvertCreateDnTensorOpToGpuRuntimeCallPattern::matchAndRewrite(
1104gpu::CreateDnTensorOp op, OpAdaptor adaptor,
1105ConversionPatternRewriter &rewriter) const {
1106if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1107failed(isAsyncWithOneDependency(rewriter, op)))
1108return failure();
1109Location loc = op.getLoc();
1110auto stream = adaptor.getAsyncDependencies().front();
1111Value pTensor =
1112MemRefDescriptor(adaptor.getMemref()).allocatedPtr(rewriter, loc);
1113Type dType = op.getMemref().getType().getElementType();
1114auto dtp = genConstInt32From(rewriter, loc, getCuSparseDataTypeFrom(dType));
1115
1116SmallVector<Value, 4> dims;
1117for (Value dim : adaptor.getDims()) {
1118dims.push_back(dim);
1119}
1120
1121Value handle;
1122// TODO: For now, we track the use of the handle and lower it to cusparse /
1123// cusparseLt accordingly. If in a block, both cusparse and cusparseLt are
1124// used, we require two separate Creation ops to be the correct logic. In
1125// future, we may add support to using one handle in sparse tensor / GPU
1126// dialect in both cusparse and cusparseLt. use the cusparseLt create call if
1127// the dnmat is used with spmat with 2:4 sparsity
1128if (dims.size() == 2) {
1129if (isSpMMCusparseLtOp(op.getDnTensor())) {
1130auto handleSz = rewriter.create<LLVM::ConstantOp>(
1131loc, getIndexType(), rewriter.getIndexAttr(11032));
1132handle = rewriter.create<LLVM::AllocaOp>(
1133loc, llvmPointerType, llvmInt8Type, handleSz, /*alignment=*/16);
1134handle = rewriter.create<LLVM::BitcastOp>(loc, llvmPointerType, handle);
1135
1136createLtDnMatCallBuilder
1137.create(loc, rewriter,
1138{handle, dims[0], dims[1], pTensor, dtp, stream})
1139.getResult();
1140} else {
1141handle =
1142createDnMatCallBuilder
1143.create(loc, rewriter, {dims[0], dims[1], pTensor, dtp, stream})
1144.getResult();
1145}
1146} else {
1147assert(dims.size() == 1 && "Only 1D and 2D tensors are supported");
1148handle = createDnVecCallBuilder
1149.create(loc, rewriter, {dims[0], pTensor, dtp, stream})
1150.getResult();
1151}
1152rewriter.replaceOp(op, {handle, stream});
1153return success();
1154}
1155
1156LogicalResult ConvertDestroyDnTensorOpToGpuRuntimeCallPattern::matchAndRewrite(
1157gpu::DestroyDnTensorOp op, OpAdaptor adaptor,
1158ConversionPatternRewriter &rewriter) const {
1159if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1160failed(isAsyncWithOneDependency(rewriter, op)))
1161return failure();
1162Location loc = op.getLoc();
1163auto stream = adaptor.getAsyncDependencies().front();
1164auto definingOp = op.getDnTensor().getDefiningOp<gpu::CreateDnTensorOp>();
1165SmallVector<Value, 4> dims;
1166for (Value dim : definingOp.getDims()) {
1167dims.push_back(dim);
1168}
1169if (dims.size() == 2) {
1170// Use the cusparseLt destroy call if the dnmat is used with spmat with
1171// 2:4 sparsity
1172if (isSpMMCusparseLtOp(op.getDnTensor())) {
1173destroyCuSparseLtDnMatBuilder.create(loc, rewriter,
1174{adaptor.getDnTensor(), stream});
1175} else {
1176destroyDnMatCallBuilder.create(loc, rewriter,
1177{adaptor.getDnTensor(), stream});
1178}
1179} else {
1180assert(dims.size() == 1 && "Only 1D and 2D tensors are supported");
1181destroyDnVecCallBuilder.create(loc, rewriter,
1182{adaptor.getDnTensor(), stream});
1183}
1184rewriter.replaceOp(op, {stream});
1185return success();
1186}
1187
1188LogicalResult ConvertCreateCooOpToGpuRuntimeCallPattern::matchAndRewrite(
1189gpu::CreateCooOp op, OpAdaptor adaptor,
1190ConversionPatternRewriter &rewriter) const {
1191if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1192failed(isAsyncWithOneDependency(rewriter, op)))
1193return failure();
1194Location loc = op.getLoc();
1195auto stream = adaptor.getAsyncDependencies().front();
1196Value pRowIdxs =
1197MemRefDescriptor(adaptor.getRowIdxs()).allocatedPtr(rewriter, loc);
1198Value pColIdxs =
1199MemRefDescriptor(adaptor.getColIdxs()).allocatedPtr(rewriter, loc);
1200Value pValues =
1201MemRefDescriptor(adaptor.getValues()).allocatedPtr(rewriter, loc);
1202Type iType =
1203llvm::cast<MemRefType>(op.getColIdxs().getType()).getElementType();
1204Type dType =
1205llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
1206auto itp = genConstInt32From(rewriter, loc, getCuSparseIndexTypeFrom(iType));
1207auto dtp = genConstInt32From(rewriter, loc, getCuSparseDataTypeFrom(dType));
1208auto handle =
1209createCooCallBuilder
1210.create(loc, rewriter,
1211{adaptor.getRows(), adaptor.getCols(), adaptor.getNnz(),
1212pRowIdxs, pColIdxs, pValues, itp, dtp, stream})
1213.getResult();
1214rewriter.replaceOp(op, {handle, stream});
1215return success();
1216}
1217
1218LogicalResult ConvertCreateCooAoSOpToGpuRuntimeCallPattern::matchAndRewrite(
1219gpu::CreateCooAoSOp op, OpAdaptor adaptor,
1220ConversionPatternRewriter &rewriter) const {
1221if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1222failed(isAsyncWithOneDependency(rewriter, op)))
1223return failure();
1224Location loc = op.getLoc();
1225auto stream = adaptor.getAsyncDependencies().front();
1226Value pIdxs = MemRefDescriptor(adaptor.getIdxs()).allocatedPtr(rewriter, loc);
1227Value pValues =
1228MemRefDescriptor(adaptor.getValues()).allocatedPtr(rewriter, loc);
1229Type iType = llvm::cast<MemRefType>(op.getIdxs().getType()).getElementType();
1230Type dType =
1231llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
1232auto itp = genConstInt32From(rewriter, loc, getCuSparseIndexTypeFrom(iType));
1233auto dtp = genConstInt32From(rewriter, loc, getCuSparseDataTypeFrom(dType));
1234auto handle =
1235createCooAoSCallBuilder
1236.create(loc, rewriter,
1237{adaptor.getRows(), adaptor.getCols(), adaptor.getNnz(),
1238pIdxs, pValues, itp, dtp, stream})
1239.getResult();
1240rewriter.replaceOp(op, {handle, stream});
1241return success();
1242}
1243
1244LogicalResult ConvertCreateCsrOpToGpuRuntimeCallPattern::matchAndRewrite(
1245gpu::CreateCsrOp op, OpAdaptor adaptor,
1246ConversionPatternRewriter &rewriter) const {
1247if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1248failed(isAsyncWithOneDependency(rewriter, op)))
1249return failure();
1250Location loc = op.getLoc();
1251auto stream = adaptor.getAsyncDependencies().front();
1252Value pRowPos =
1253MemRefDescriptor(adaptor.getRowPos()).allocatedPtr(rewriter, loc);
1254Value pColIdxs =
1255MemRefDescriptor(adaptor.getColIdxs()).allocatedPtr(rewriter, loc);
1256Value pValues =
1257MemRefDescriptor(adaptor.getValues()).allocatedPtr(rewriter, loc);
1258Type pType =
1259llvm::cast<MemRefType>(op.getRowPos().getType()).getElementType();
1260Type iType =
1261llvm::cast<MemRefType>(op.getColIdxs().getType()).getElementType();
1262Type dType =
1263llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
1264auto ptp = genConstInt32From(rewriter, loc, getCuSparseIndexTypeFrom(pType));
1265auto itp = genConstInt32From(rewriter, loc, getCuSparseIndexTypeFrom(iType));
1266auto dtp = genConstInt32From(rewriter, loc, getCuSparseDataTypeFrom(dType));
1267auto handle =
1268createCsrCallBuilder
1269.create(loc, rewriter,
1270{adaptor.getRows(), adaptor.getCols(), adaptor.getNnz(),
1271pRowPos, pColIdxs, pValues, ptp, itp, dtp, stream})
1272.getResult();
1273rewriter.replaceOp(op, {handle, stream});
1274return success();
1275}
1276
1277LogicalResult ConvertCreate2To4SpMatOpToGpuRuntimeCallPattern::matchAndRewrite(
1278gpu::Create2To4SpMatOp op, OpAdaptor adaptor,
1279ConversionPatternRewriter &rewriter) const {
1280if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1281failed(isAsyncWithOneDependency(rewriter, op)))
1282return failure();
1283Location loc = op.getLoc();
1284auto stream = adaptor.getAsyncDependencies().front();
1285Value pMat =
1286MemRefDescriptor(adaptor.getMemref()).allocatedPtr(rewriter, loc);
1287Type dType =
1288llvm::cast<MemRefType>(op.getMemref().getType()).getElementType();
1289auto dtp = genConstInt32From(rewriter, loc, getCuSparseDataTypeFrom(dType));
1290
1291// CUDA runner asserts the size is 44104 bytes.
1292auto handleSz = rewriter.create<LLVM::ConstantOp>(
1293loc, getIndexType(), rewriter.getIndexAttr(44104));
1294Value handle = rewriter.create<LLVM::AllocaOp>(
1295loc, llvmPointerType, llvmInt8Type, handleSz, /*alignment=*/16);
1296handle = rewriter.create<LLVM::BitcastOp>(loc, llvmPointerType, handle);
1297
1298create2To4SpMatCallBuilder
1299.create(loc, rewriter,
1300{handle, adaptor.getRows(), adaptor.getCols(), pMat, dtp, stream})
1301.getResult();
1302rewriter.replaceOp(op, {handle, stream});
1303return success();
1304}
1305
1306LogicalResult ConvertDestroySpMatOpToGpuRuntimeCallPattern::matchAndRewrite(
1307gpu::DestroySpMatOp op, OpAdaptor adaptor,
1308ConversionPatternRewriter &rewriter) const {
1309if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1310failed(isAsyncWithOneDependency(rewriter, op)))
1311return failure();
1312Location loc = op.getLoc();
1313auto stream = adaptor.getAsyncDependencies().front();
1314// Use the cusparseLt destroy call if the spmat is 2:4 sparsity
1315if (is2To4Sparsity(op.getSpmat())) {
1316destroyCuSparseLtSpMatBuilder.create(loc, rewriter,
1317{adaptor.getSpmat(), stream});
1318
1319} else {
1320destroySpMatCallBuilder.create(loc, rewriter, {adaptor.getSpmat(), stream});
1321}
1322rewriter.replaceOp(op, {stream});
1323return success();
1324}
1325
1326LogicalResult ConvertSpMVBufferSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
1327gpu::SpMVBufferSizeOp op, OpAdaptor adaptor,
1328ConversionPatternRewriter &rewriter) const {
1329if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1330failed(isAsyncWithOneDependency(rewriter, op)))
1331return failure();
1332Location loc = op.getLoc();
1333auto modeA = genConstInt32From(rewriter, loc, op.getModeA());
1334auto computeType = genConstInt32From(
1335rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType()));
1336auto stream = adaptor.getAsyncDependencies().front();
1337auto bufferSize = spMVBufferSizeCallBuilder
1338.create(loc, rewriter,
1339{modeA, adaptor.getSpmatA(), adaptor.getDnX(),
1340adaptor.getDnY(), computeType, stream})
1341.getResult();
1342rewriter.replaceOp(op, {bufferSize, stream});
1343return success();
1344}
1345
1346LogicalResult ConvertSpMVOpToGpuRuntimeCallPattern::matchAndRewrite(
1347gpu::SpMVOp op, OpAdaptor adaptor,
1348ConversionPatternRewriter &rewriter) const {
1349if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1350failed(isAsyncWithOneDependency(rewriter, op)))
1351return failure();
1352Location loc = op.getLoc();
1353auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
1354auto computeType = genConstInt32From(
1355rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType()));
1356auto stream = adaptor.getAsyncDependencies().front();
1357Value pBuf =
1358MemRefDescriptor(adaptor.getBuffer()).allocatedPtr(rewriter, loc);
1359spMVCallBuilder.create(loc, rewriter,
1360{modeA, adaptor.getSpmatA(), adaptor.getDnX(),
1361adaptor.getDnY(), computeType, pBuf, stream});
1362rewriter.replaceOp(op, {stream});
1363return success();
1364}
1365
1366LogicalResult ConvertSpMMBufferSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
1367gpu::SpMMBufferSizeOp op, OpAdaptor adaptor,
1368ConversionPatternRewriter &rewriter) const {
1369if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1370failed(isAsyncWithOneDependency(rewriter, op)))
1371return failure();
1372Location loc = op.getLoc();
1373auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
1374auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB());
1375auto stream = adaptor.getAsyncDependencies().front();
1376Value bufferSize;
1377if (is2To4Sparsity(op.getSpmatA())) {
1378auto pruneFlag =
1379genConstInt32From(rewriter, loc, get2To4PruneFlag(op.getSpmatA()));
1380auto computeType = genConstInt32From(
1381rewriter, loc, getCuSparseLtDataTypeFrom(adaptor.getComputeType()));
1382auto three = rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
1383rewriter.getIndexAttr(3));
1384auto bufferSize = rewriter.create<LLVM::AllocaOp>(
1385loc, llvmPointerType, llvmPointerType, three, /*alignment=*/16);
1386createCuSparseLtSpMMBufferSizeBuilder
1387.create(loc, rewriter,
1388{bufferSize, modeA, modeB, adaptor.getSpmatA(),
1389adaptor.getDnmatB(), adaptor.getDnmatC(), computeType,
1390pruneFlag, stream})
1391.getResult();
1392
1393auto bufferSizePtr1 = rewriter.create<LLVM::GEPOp>(
1394loc, llvmPointerType, llvmPointerType, bufferSize,
1395ValueRange{rewriter.create<LLVM::ConstantOp>(
1396loc, getIndexType(), rewriter.getIndexAttr(1))});
1397auto bufferSizePtr2 = rewriter.create<LLVM::GEPOp>(
1398loc, llvmPointerType, llvmPointerType, bufferSize,
1399ValueRange{rewriter.create<LLVM::ConstantOp>(
1400loc, getIndexType(), rewriter.getIndexAttr(2))});
1401auto bufferSize0 =
1402rewriter.create<LLVM::LoadOp>(loc, llvmInt64Type, bufferSize);
1403auto bufferSize1 =
1404rewriter.create<LLVM::LoadOp>(loc, llvmInt64Type, bufferSizePtr1);
1405auto bufferSize2 =
1406rewriter.create<LLVM::LoadOp>(loc, llvmInt64Type, bufferSizePtr2);
1407
1408rewriter.replaceOp(op, {bufferSize0, bufferSize1, bufferSize2, stream});
1409} else {
1410auto computeType = genConstInt32From(
1411rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType()));
1412bufferSize =
1413createSpMMBufferSizeCallBuilder
1414.create(loc, rewriter,
1415{modeA, modeB, adaptor.getSpmatA(), adaptor.getDnmatB(),
1416adaptor.getDnmatC(), computeType, stream})
1417.getResult();
1418rewriter.replaceOp(op, {bufferSize, stream});
1419}
1420return success();
1421}
1422
1423LogicalResult ConvertSDDMMBufferSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
1424gpu::SDDMMBufferSizeOp op, OpAdaptor adaptor,
1425ConversionPatternRewriter &rewriter) const {
1426if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1427failed(isAsyncWithOneDependency(rewriter, op)))
1428return failure();
1429Location loc = op.getLoc();
1430auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
1431auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB());
1432auto computeType = genConstInt32From(
1433rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType()));
1434auto stream = adaptor.getAsyncDependencies().front();
1435auto bufferSize =
1436createSDDMMBufferSizeCallBuilder
1437.create(loc, rewriter,
1438{modeA, modeB, adaptor.getDnmatA(), adaptor.getDnmatB(),
1439adaptor.getSpmatC(), computeType, stream})
1440.getResult();
1441rewriter.replaceOp(op, {bufferSize, stream});
1442return success();
1443}
1444
1445LogicalResult ConvertSpMMOpToGpuRuntimeCallPattern::matchAndRewrite(
1446gpu::SpMMOp op, OpAdaptor adaptor,
1447ConversionPatternRewriter &rewriter) const {
1448if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1449failed(isAsyncWithOneDependency(rewriter, op)))
1450return failure();
1451Location loc = op.getLoc();
1452auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
1453auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB());
1454auto computeType = genConstInt32From(
1455rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType()));
1456
1457auto stream = adaptor.getAsyncDependencies().front();
1458
1459// Lower to cusparseLt if applicable
1460if (is2To4Sparsity(op.getSpmatA())) {
1461SmallVector<Value> pBufs;
1462for (Value buffer : adaptor.getBuffers()) {
1463Value pBuf = MemRefDescriptor(buffer).allocatedPtr(rewriter, loc);
1464pBufs.push_back(pBuf);
1465}
1466createCuSparseLtSpMMBuilder.create(
1467loc, rewriter,
1468{adaptor.getSpmatA(), adaptor.getDnmatB(), adaptor.getDnmatC(),
1469pBufs[0], pBufs[1], pBufs[2], stream});
1470} else {
1471Value pBuf = MemRefDescriptor(adaptor.getBuffers().front())
1472.allocatedPtr(rewriter, loc);
1473createSpMMCallBuilder.create(loc, rewriter,
1474{modeA, modeB, adaptor.getSpmatA(),
1475adaptor.getDnmatB(), adaptor.getDnmatC(),
1476computeType, pBuf, stream});
1477}
1478rewriter.replaceOp(op, {stream});
1479return success();
1480}
1481
1482template <typename T>
1483static void addOpaquePointerConversion(LLVMTypeConverter &converter) {
1484converter.addConversion([&converter](T) -> Type {
1485return LLVM::LLVMPointerType::get(&converter.getContext());
1486});
1487}
1488
1489LogicalResult ConvertSDDMMOpToGpuRuntimeCallPattern::matchAndRewrite(
1490gpu::SDDMMOp op, OpAdaptor adaptor,
1491ConversionPatternRewriter &rewriter) const {
1492if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1493failed(isAsyncWithOneDependency(rewriter, op)))
1494return failure();
1495Location loc = op.getLoc();
1496auto computeType = genConstInt32From(
1497rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType()));
1498auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
1499auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB());
1500auto stream = adaptor.getAsyncDependencies().front();
1501Value pBuf =
1502MemRefDescriptor(adaptor.getBuffer()).allocatedPtr(rewriter, loc);
1503createSDDMMCallBuilder.create(loc, rewriter,
1504{modeA, modeB, adaptor.getDnmatA(),
1505adaptor.getDnmatB(), adaptor.getSpmatC(),
1506computeType, pBuf, stream});
1507rewriter.replaceOp(op, {stream});
1508return success();
1509}
1510
1511LogicalResult
1512ConvertSpGEMMCreateDescrOpToGpuRuntimeCallPattern::matchAndRewrite(
1513gpu::SpGEMMCreateDescrOp op, OpAdaptor adaptor,
1514ConversionPatternRewriter &rewriter) const {
1515if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1516failed(isAsyncWithOneDependency(rewriter, op)))
1517return failure();
1518Location loc = op.getLoc();
1519auto stream = adaptor.getAsyncDependencies().front();
1520Value descr = createSpGEMMCreateDescrBuilder.create(loc, rewriter, {stream})
1521.getResult();
1522rewriter.replaceOp(op, {descr, stream});
1523return success();
1524}
1525
1526LogicalResult
1527ConvertSpGEMMDestroyDescrOpToGpuRuntimeCallPattern::matchAndRewrite(
1528gpu::SpGEMMDestroyDescrOp op, OpAdaptor adaptor,
1529ConversionPatternRewriter &rewriter) const {
1530if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1531failed(isAsyncWithOneDependency(rewriter, op)))
1532return failure();
1533Location loc = op.getLoc();
1534auto stream = adaptor.getAsyncDependencies().front();
1535createSpGEMMDestroyDescrBuilder.create(loc, rewriter,
1536{adaptor.getDesc(), stream});
1537rewriter.replaceOp(op, {stream});
1538return success();
1539}
1540
1541LogicalResult
1542ConvertSpGEMMWorkEstimationOrComputeOpToGpuRuntimeCallPattern::matchAndRewrite(
1543gpu::SpGEMMWorkEstimationOrComputeOp op, OpAdaptor adaptor,
1544ConversionPatternRewriter &rewriter) const {
1545if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1546failed(isAsyncWithOneDependency(rewriter, op)))
1547return failure();
1548Location loc = op.getLoc();
1549auto computeType = genConstInt32From(
1550rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType()));
1551auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
1552auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB());
1553auto stream = adaptor.getAsyncDependencies().front();
1554
1555Value pBuf =
1556MemRefDescriptor(adaptor.getBuffer()).allocatedPtr(rewriter, loc);
1557Value bufferSizeNew;
1558
1559if (adaptor.getKind() ==
1560gpu::SpGEMMWorkEstimationOrComputeKind::WORK_ESTIMATION) {
1561bufferSizeNew =
1562createSpGEMMWorkEstimationBuilder
1563.create(loc, rewriter,
1564{adaptor.getDesc(), modeA, modeB, adaptor.getSpmatA(),
1565adaptor.getSpmatB(), adaptor.getSpmatC(), computeType,
1566adaptor.getBufferSz(), pBuf, stream})
1567.getResult();
1568} else {
1569bufferSizeNew =
1570createSpGEMMComputeBuilder
1571.create(loc, rewriter,
1572{adaptor.getDesc(), modeA, modeB, adaptor.getSpmatA(),
1573adaptor.getSpmatB(), adaptor.getSpmatC(), computeType,
1574adaptor.getBufferSz(), pBuf, stream})
1575.getResult();
1576}
1577rewriter.replaceOp(op, {bufferSizeNew, stream});
1578return success();
1579}
1580
1581LogicalResult ConvertSpGEMMCopyOpToGpuRuntimeCallPattern::matchAndRewrite(
1582gpu::SpGEMMCopyOp op, OpAdaptor adaptor,
1583ConversionPatternRewriter &rewriter) const {
1584if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1585failed(isAsyncWithOneDependency(rewriter, op)))
1586return failure();
1587Location loc = op.getLoc();
1588auto computeType = genConstInt32From(
1589rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType()));
1590auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
1591auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB());
1592auto stream = adaptor.getAsyncDependencies().front();
1593createSpGEMMCopyBuilder.create(loc, rewriter,
1594{adaptor.getDesc(), modeA, modeB,
1595adaptor.getSpmatA(), adaptor.getSpmatB(),
1596adaptor.getSpmatC(), computeType, stream});
1597rewriter.replaceOp(op, {stream});
1598return success();
1599}
1600
1601LogicalResult ConvertSpMatGetSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
1602gpu::SpMatGetSizeOp op, OpAdaptor adaptor,
1603ConversionPatternRewriter &rewriter) const {
1604if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1605failed(isAsyncWithOneDependency(rewriter, op)))
1606return failure();
1607Location loc = op.getLoc();
1608auto stream = adaptor.getAsyncDependencies().front();
1609
1610auto three = rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
1611rewriter.getIndexAttr(3));
1612auto buffer = rewriter.create<LLVM::AllocaOp>(
1613loc, llvmPointerType, llvmInt64Type, three, /*alignment=*/16);
1614
1615auto rowsPtr = rewriter.create<LLVM::GEPOp>(
1616loc, llvmPointerType, llvmPointerType, buffer,
1617ValueRange{rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
1618rewriter.getIndexAttr(0))});
1619auto colsPtr = rewriter.create<LLVM::GEPOp>(
1620loc, llvmPointerType, llvmPointerType, buffer,
1621ValueRange{rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
1622rewriter.getIndexAttr(1))});
1623auto nnzsPtr = rewriter.create<LLVM::GEPOp>(
1624loc, llvmPointerType, llvmPointerType, buffer,
1625ValueRange{rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
1626rewriter.getIndexAttr(2))});
1627createSpMatGetSizeBuilder.create(
1628loc, rewriter, {adaptor.getSpmat(), rowsPtr, colsPtr, nnzsPtr, stream});
1629auto rows = rewriter.create<LLVM::LoadOp>(loc, llvmInt64Type, rowsPtr);
1630auto cols = rewriter.create<LLVM::LoadOp>(loc, llvmInt64Type, colsPtr);
1631auto nnzs = rewriter.create<LLVM::LoadOp>(loc, llvmInt64Type, nnzsPtr);
1632
1633rewriter.replaceOp(op, {rows, cols, nnzs, stream});
1634return success();
1635}
1636
1637LogicalResult ConvertSetCsrPointersOpToGpuRuntimeCallPattern::matchAndRewrite(
1638gpu::SetCsrPointersOp op, OpAdaptor adaptor,
1639ConversionPatternRewriter &rewriter) const {
1640if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1641failed(isAsyncWithOneDependency(rewriter, op)))
1642return failure();
1643Location loc = op.getLoc();
1644auto stream = adaptor.getAsyncDependencies().front();
1645Value pPos =
1646MemRefDescriptor(adaptor.getPositions()).allocatedPtr(rewriter, loc);
1647Value pCrd =
1648MemRefDescriptor(adaptor.getCoordinates()).allocatedPtr(rewriter, loc);
1649Value pVal =
1650MemRefDescriptor(adaptor.getValues()).allocatedPtr(rewriter, loc);
1651createSetCsrPointersBuilder.create(
1652loc, rewriter, {adaptor.getSpmat(), pPos, pCrd, pVal, stream});
1653rewriter.replaceOp(op, {stream});
1654return success();
1655}
1656
1657LogicalResult ConvertCreateCscOpToGpuRuntimeCallPattern::matchAndRewrite(
1658gpu::CreateCscOp op, OpAdaptor adaptor,
1659ConversionPatternRewriter &rewriter) const {
1660if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1661failed(isAsyncWithOneDependency(rewriter, op)))
1662return failure();
1663Location loc = op.getLoc();
1664auto stream = adaptor.getAsyncDependencies().front();
1665Value pColPos =
1666MemRefDescriptor(adaptor.getColPos()).allocatedPtr(rewriter, loc);
1667Value pRowIdxs =
1668MemRefDescriptor(adaptor.getRowIdxs()).allocatedPtr(rewriter, loc);
1669Value pValues =
1670MemRefDescriptor(adaptor.getValues()).allocatedPtr(rewriter, loc);
1671Type pType =
1672llvm::cast<MemRefType>(op.getColPos().getType()).getElementType();
1673Type iType =
1674llvm::cast<MemRefType>(op.getRowIdxs().getType()).getElementType();
1675Type dType =
1676llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
1677auto ptp = genConstInt32From(rewriter, loc, getCuSparseIndexTypeFrom(pType));
1678auto itp = genConstInt32From(rewriter, loc, getCuSparseIndexTypeFrom(iType));
1679auto dtp = genConstInt32From(rewriter, loc, getCuSparseDataTypeFrom(dType));
1680auto handle =
1681createCscCallBuilder
1682.create(loc, rewriter,
1683{adaptor.getRows(), adaptor.getCols(), adaptor.getNnz(),
1684pColPos, pRowIdxs, pValues, ptp, itp, dtp, stream})
1685.getResult();
1686rewriter.replaceOp(op, {handle, stream});
1687return success();
1688}
1689
1690LogicalResult ConvertCreateBsrOpToGpuRuntimeCallPattern::matchAndRewrite(
1691gpu::CreateBsrOp op, OpAdaptor adaptor,
1692ConversionPatternRewriter &rewriter) const {
1693if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1694failed(isAsyncWithOneDependency(rewriter, op)))
1695return failure();
1696Location loc = op.getLoc();
1697auto stream = adaptor.getAsyncDependencies().front();
1698Value pRowPos =
1699MemRefDescriptor(adaptor.getBRowPos()).allocatedPtr(rewriter, loc);
1700Value pColIdxs =
1701MemRefDescriptor(adaptor.getBColIdxs()).allocatedPtr(rewriter, loc);
1702Value pValues =
1703MemRefDescriptor(adaptor.getValues()).allocatedPtr(rewriter, loc);
1704Type pType =
1705llvm::cast<MemRefType>(op.getBRowPos().getType()).getElementType();
1706Type iType =
1707llvm::cast<MemRefType>(op.getBColIdxs().getType()).getElementType();
1708Type dType =
1709llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
1710auto ptp = genConstInt32From(rewriter, loc, getCuSparseIndexTypeFrom(pType));
1711auto itp = genConstInt32From(rewriter, loc, getCuSparseIndexTypeFrom(iType));
1712auto dtp = genConstInt32From(rewriter, loc, getCuSparseDataTypeFrom(dType));
1713auto handle =
1714createBsrCallBuilder
1715.create(loc, rewriter,
1716{adaptor.getBrows(), adaptor.getBcols(), adaptor.getBnnz(),
1717adaptor.getRBlockSize(), adaptor.getCBlockSize(), pRowPos,
1718pColIdxs, pValues, ptp, itp, dtp, stream})
1719.getResult();
1720rewriter.replaceOp(op, {handle, stream});
1721return success();
1722}
1723
1724void mlir::populateGpuToLLVMConversionPatterns(LLVMTypeConverter &converter,
1725RewritePatternSet &patterns,
1726bool kernelBarePtrCallConv) {
1727addOpaquePointerConversion<gpu::AsyncTokenType>(converter);
1728addOpaquePointerConversion<gpu::SparseDnTensorHandleType>(converter);
1729addOpaquePointerConversion<gpu::SparseSpMatHandleType>(converter);
1730addOpaquePointerConversion<gpu::SparseSpGEMMOpHandleType>(converter);
1731
1732patterns.add<ConvertAllocOpToGpuRuntimeCallPattern,
1733ConvertDeallocOpToGpuRuntimeCallPattern,
1734ConvertHostRegisterOpToGpuRuntimeCallPattern,
1735ConvertHostUnregisterOpToGpuRuntimeCallPattern,
1736ConvertMemcpyOpToGpuRuntimeCallPattern,
1737ConvertMemsetOpToGpuRuntimeCallPattern,
1738ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern,
1739ConvertWaitAsyncOpToGpuRuntimeCallPattern,
1740ConvertWaitOpToGpuRuntimeCallPattern,
1741ConvertAsyncYieldToGpuRuntimeCallPattern,
1742ConvertCreateDnTensorOpToGpuRuntimeCallPattern,
1743ConvertDestroyDnTensorOpToGpuRuntimeCallPattern,
1744ConvertCreateCooOpToGpuRuntimeCallPattern,
1745ConvertCreateCooAoSOpToGpuRuntimeCallPattern,
1746ConvertCreateCsrOpToGpuRuntimeCallPattern,
1747ConvertCreateCscOpToGpuRuntimeCallPattern,
1748ConvertCreateBsrOpToGpuRuntimeCallPattern,
1749ConvertCreate2To4SpMatOpToGpuRuntimeCallPattern,
1750ConvertDestroySpMatOpToGpuRuntimeCallPattern,
1751ConvertSpMVBufferSizeOpToGpuRuntimeCallPattern,
1752ConvertSpMVOpToGpuRuntimeCallPattern,
1753ConvertSpMMBufferSizeOpToGpuRuntimeCallPattern,
1754ConvertSDDMMBufferSizeOpToGpuRuntimeCallPattern,
1755ConvertSpMMOpToGpuRuntimeCallPattern,
1756ConvertSDDMMOpToGpuRuntimeCallPattern,
1757ConvertSpGEMMCreateDescrOpToGpuRuntimeCallPattern,
1758ConvertSpGEMMDestroyDescrOpToGpuRuntimeCallPattern,
1759ConvertSpGEMMWorkEstimationOrComputeOpToGpuRuntimeCallPattern,
1760ConvertSpGEMMCopyOpToGpuRuntimeCallPattern,
1761ConvertSpMatGetSizeOpToGpuRuntimeCallPattern,
1762ConvertSetCsrPointersOpToGpuRuntimeCallPattern>(converter);
1763patterns.add<LegalizeLaunchFuncOpPattern>(converter, kernelBarePtrCallConv);
1764}
1765