From 0f2d6eb87aa21e78901bcdc248361f3fd3dd49eb Mon Sep 17 00:00:00 2001 From: yangxinyu Date: Thu, 20 Jun 2024 19:54:11 +0000 Subject: [PATCH] a lot hack but run --- .../lib/Conversion/GPUToNVVM/GPUToNVVM.cpp | 91 +++++++++++++++++++ compiler/lib/Pipelines/GPU/GPUOpt.cpp | 1 - compiler/lib/Pipelines/GPU/NVVMCodegen.cpp | 2 +- 3 files changed, 92 insertions(+), 2 deletions(-) diff --git a/compiler/lib/Conversion/GPUToNVVM/GPUToNVVM.cpp b/compiler/lib/Conversion/GPUToNVVM/GPUToNVVM.cpp index 61f0ac02e..c3e510cee 100644 --- a/compiler/lib/Conversion/GPUToNVVM/GPUToNVVM.cpp +++ b/compiler/lib/Conversion/GPUToNVVM/GPUToNVVM.cpp @@ -39,6 +39,7 @@ #include "mlir/Conversion/LLVMCommon/TypeConverter.h" #include "mlir/Conversion/MathToLLVM/MathToLLVM.h" #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" +#include "mlir/Conversion/NVGPUToNVVM/NVGPUToNVVM.h" #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" @@ -49,6 +50,9 @@ #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/Math/Transforms/Passes.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" #include "mlir/IR/IRMapping.h" #include "mlir/Interfaces/DataLayoutInterfaces.h" #include "mlir/Transforms/DialectConversion.h" @@ -64,6 +68,63 @@ using namespace mlir::NVVM; namespace { +void ConvertToDynamicSharedMemory(GPUModuleOp moduleOp) { + SymbolTableCollection symbolTableCollection; + // Collect all the adressOfOps to static shared memory globals. + SmallVector addressOfOps; + moduleOp.walk([&](LLVM::AddressOfOp addressOfOp) { + // Check that the global associated with this addressOfOp has shared memory + // space. + if (addressOfOp.getGlobal(symbolTableCollection).getAddrSpace() == 3) + addressOfOps.push_back(addressOfOp); + }); + if (addressOfOps.size() == 0) + return; + OpBuilder builder(moduleOp); + builder.setInsertionPoint(&moduleOp.front()); + auto type = + LLVM::LLVMArrayType::get(IntegerType::get(builder.getContext(), 8), 0); + LLVM::GlobalOp global = builder.create( + moduleOp.getLoc(), type, /*isConstant=*/false, LLVM::Linkage::External, + "__dynamic_shared_memory__", Attribute(), + /*alignment=*/16, /*addr_space=*/3); + uint32_t numberOfBytes = 0; + // Replace the addressOfOps with correctly offseted pointers to dynamic + // shared memory. + llvm::SmallDenseMap globalMemoryOffsetMap; + for (auto addressOfOp : addressOfOps) { + uint32_t offset = 0; + auto globalOp = addressOfOp.getGlobal(symbolTableCollection); + if (globalMemoryOffsetMap.count(globalOp)) { + offset = globalMemoryOffsetMap[globalOp]; + } else { + offset = numberOfBytes; + if (std::optional alignment = globalOp.getAlignment()) { + offset = llvm::alignTo(offset, *alignment); + } + globalMemoryOffsetMap[globalOp] = offset; + auto thisarray = globalOp.getType(); + DataLayout dataLayout = DataLayout::closest(addressOfOp); + numberOfBytes = offset + dataLayout.getTypeSizeInBits(thisarray) / 8; + } + auto loc = addressOfOp.getLoc(); + builder.setInsertionPoint(addressOfOp); + LLVM::AddressOfOp globalPtr = + builder.create(loc, global); + Value zero = builder.create( + loc, IntegerType::get(builder.getContext(), 64), + builder.getI64IntegerAttr(0)); + Value offsetValue = builder.create( + loc, IntegerType::get(builder.getContext(), 64), + builder.getI64IntegerAttr(offset)); + Value shiftedPtr = builder.create( + loc, globalPtr.getType(), global.getGlobalType(), globalPtr, + ValueRange({zero, offsetValue})); + addressOfOp.replaceAllUsesWith(shiftedPtr); + addressOfOp.erase(); + } +} + template struct OpToFuncCallLowering : public ConvertOpToLLVMPattern { public: @@ -253,6 +314,30 @@ struct GPUToNVVMExtPass : public GPUToNVVMExtBase { // Apply in-dialect lowering. In-dialect lowering will replace // ops which need to be lowered further, which is not supported by a // single conversion pass. + // Run Vector -> Vector transformations ahead of conversion to LLVM. + { + RewritePatternSet patterns(&getContext()); + vector::populateVectorToVectorCanonicalizationPatterns(patterns); + vector::populateVectorBroadcastLoweringPatterns(patterns); + vector::populateVectorContractLoweringPatterns( + patterns, + vector::VectorTransformsOptions().setVectorTransformsOptions( + vector::VectorContractLowering::OuterProduct)); + vector::populateVectorMaskOpLoweringPatterns(patterns); + // We currently always use 64 bit indices, thus ensure the bit width of + // the mask compare is consistent. + vector::populateVectorMaskMaterializationPatterns( + patterns, /*force32BitVectorIndices=*/false); + vector::populateVectorShapeCastLoweringPatterns(patterns); + // TODO: doubtful that the "default" does what one want here, it is likely + // better to use something else. + vector::populateVectorTransposeLoweringPatterns( + patterns, vector::VectorTransformsOptions()); + vector::populateVectorTransferLoweringPatterns(patterns); + if (failed(applyPatternsAndFoldGreedily(m, std::move(patterns)))) { + return signalPassFailure(); + } + } { RewritePatternSet patterns(m.getContext()); populateGpuRewritePatterns(patterns); @@ -289,13 +374,19 @@ struct GPUToNVVMExtPass : public GPUToNVVMExtBase { converter.addConversion([&](gpu::MMAMatrixType type) -> Type { return convertMMAToLLVMType(type); }); + // Convert dummy tokens. + converter.addConversion([&](nvgpu::DeviceAsyncTokenType type) -> Type { + return converter.convertType(IntegerType::get(type.getContext(), 32)); + }); RewritePatternSet llvmPatterns(m.getContext()); arith::populateArithToLLVMConversionPatterns(converter, llvmPatterns); + populateVectorToLLVMConversionPatterns(converter, llvmPatterns); cf::populateControlFlowToLLVMConversionPatterns(converter, llvmPatterns); populateFuncToLLVMConversionPatterns(converter, llvmPatterns); populateFinalizeMemRefToLLVMConversionPatterns(converter, llvmPatterns); populateGpuToNVVMConversionPatterns(converter, llvmPatterns); + populateNVGPUToNVVMConversionPatterns(converter, llvmPatterns); populateGpuWMMAToNVVMConversionPatterns(converter, llvmPatterns); #if 0 // FIXME: enable if gpu arch >= sm_75 diff --git a/compiler/lib/Pipelines/GPU/GPUOpt.cpp b/compiler/lib/Pipelines/GPU/GPUOpt.cpp index 70e47bfbe..3426f2350 100644 --- a/compiler/lib/Pipelines/GPU/GPUOpt.cpp +++ b/compiler/lib/Pipelines/GPU/GPUOpt.cpp @@ -127,7 +127,6 @@ void createReductionGPUOptPipelineImpl(OpPassManager &pm) { } void createGemmGPUOptPipelineImpl(OpPassManager &pm) { - // TODO(YangXinyu): Get workgroup size from config! GPUMappingForallOptions options; options.funcAnchor = getByteIRMatmulEpilogueFusionAttrName().str(); options.annotatePrefix = "__byteir_gpu_gemm_tile"; diff --git a/compiler/lib/Pipelines/GPU/NVVMCodegen.cpp b/compiler/lib/Pipelines/GPU/NVVMCodegen.cpp index b791546b8..1b1f904d3 100644 --- a/compiler/lib/Pipelines/GPU/NVVMCodegen.cpp +++ b/compiler/lib/Pipelines/GPU/NVVMCodegen.cpp @@ -53,7 +53,7 @@ void createNVVMCodegenPipelineImpl(OpPassManager &pm, pm.addPass(createSimplifyLinearizedIndexPass()); pm.addPass(createCanonicalizerPass()); pm.addPass(createCSEPass()); - pm.addNestedPass(createConvertVectorToLLVMPass()); + // pm.addNestedPass(createConvertVectorToLLVMPass()); pm.addNestedPass(createGPUToNVVMExtPass( useBarePtrCallConv, mlir::kDeriveIndexBitwidthFromDataLayout, gpuArch)); pm.addPass(createCSEPass());