Skip to content

Commit

Permalink
a lot hack but run
Browse files Browse the repository at this point in the history
  • Loading branch information
yangxinyu committed Jun 20, 2024
1 parent 753e771 commit 0f2d6eb
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 2 deletions.
91 changes: 91 additions & 0 deletions compiler/lib/Conversion/GPUToNVVM/GPUToNVVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
#include "mlir/Conversion/NVGPUToNVVM/NVGPUToNVVM.h"
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
Expand All @@ -49,6 +50,9 @@
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/Math/Transforms/Passes.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/Interfaces/DataLayoutInterfaces.h"
#include "mlir/Transforms/DialectConversion.h"
Expand All @@ -64,6 +68,63 @@ using namespace mlir::NVVM;

namespace {

void ConvertToDynamicSharedMemory(GPUModuleOp moduleOp) {
SymbolTableCollection symbolTableCollection;
// Collect all the adressOfOps to static shared memory globals.
SmallVector<LLVM::AddressOfOp> addressOfOps;
moduleOp.walk([&](LLVM::AddressOfOp addressOfOp) {
// Check that the global associated with this addressOfOp has shared memory
// space.
if (addressOfOp.getGlobal(symbolTableCollection).getAddrSpace() == 3)
addressOfOps.push_back(addressOfOp);
});
if (addressOfOps.size() == 0)
return;
OpBuilder builder(moduleOp);
builder.setInsertionPoint(&moduleOp.front());
auto type =
LLVM::LLVMArrayType::get(IntegerType::get(builder.getContext(), 8), 0);
LLVM::GlobalOp global = builder.create<LLVM::GlobalOp>(
moduleOp.getLoc(), type, /*isConstant=*/false, LLVM::Linkage::External,
"__dynamic_shared_memory__", Attribute(),
/*alignment=*/16, /*addr_space=*/3);
uint32_t numberOfBytes = 0;
// Replace the addressOfOps with correctly offseted pointers to dynamic
// shared memory.
llvm::SmallDenseMap<LLVM::GlobalOp, uint32_t> globalMemoryOffsetMap;
for (auto addressOfOp : addressOfOps) {
uint32_t offset = 0;
auto globalOp = addressOfOp.getGlobal(symbolTableCollection);
if (globalMemoryOffsetMap.count(globalOp)) {
offset = globalMemoryOffsetMap[globalOp];
} else {
offset = numberOfBytes;
if (std::optional<uint64_t> alignment = globalOp.getAlignment()) {
offset = llvm::alignTo(offset, *alignment);
}
globalMemoryOffsetMap[globalOp] = offset;
auto thisarray = globalOp.getType();
DataLayout dataLayout = DataLayout::closest(addressOfOp);
numberOfBytes = offset + dataLayout.getTypeSizeInBits(thisarray) / 8;
}
auto loc = addressOfOp.getLoc();
builder.setInsertionPoint(addressOfOp);
LLVM::AddressOfOp globalPtr =
builder.create<LLVM::AddressOfOp>(loc, global);
Value zero = builder.create<LLVM::ConstantOp>(
loc, IntegerType::get(builder.getContext(), 64),
builder.getI64IntegerAttr(0));
Value offsetValue = builder.create<LLVM::ConstantOp>(
loc, IntegerType::get(builder.getContext(), 64),
builder.getI64IntegerAttr(offset));
Value shiftedPtr = builder.create<LLVM::GEPOp>(
loc, globalPtr.getType(), global.getGlobalType(), globalPtr,
ValueRange({zero, offsetValue}));
addressOfOp.replaceAllUsesWith(shiftedPtr);
addressOfOp.erase();
}
}

template <typename SourceOp>
struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
public:
Expand Down Expand Up @@ -253,6 +314,30 @@ struct GPUToNVVMExtPass : public GPUToNVVMExtBase<GPUToNVVMExtPass> {
// Apply in-dialect lowering. In-dialect lowering will replace
// ops which need to be lowered further, which is not supported by a
// single conversion pass.
// Run Vector -> Vector transformations ahead of conversion to LLVM.
{
RewritePatternSet patterns(&getContext());
vector::populateVectorToVectorCanonicalizationPatterns(patterns);
vector::populateVectorBroadcastLoweringPatterns(patterns);
vector::populateVectorContractLoweringPatterns(
patterns,
vector::VectorTransformsOptions().setVectorTransformsOptions(
vector::VectorContractLowering::OuterProduct));
vector::populateVectorMaskOpLoweringPatterns(patterns);
// We currently always use 64 bit indices, thus ensure the bit width of
// the mask compare is consistent.
vector::populateVectorMaskMaterializationPatterns(
patterns, /*force32BitVectorIndices=*/false);
vector::populateVectorShapeCastLoweringPatterns(patterns);
// TODO: doubtful that the "default" does what one want here, it is likely
// better to use something else.
vector::populateVectorTransposeLoweringPatterns(
patterns, vector::VectorTransformsOptions());
vector::populateVectorTransferLoweringPatterns(patterns);
if (failed(applyPatternsAndFoldGreedily(m, std::move(patterns)))) {
return signalPassFailure();
}
}
{
RewritePatternSet patterns(m.getContext());
populateGpuRewritePatterns(patterns);
Expand Down Expand Up @@ -289,13 +374,19 @@ struct GPUToNVVMExtPass : public GPUToNVVMExtBase<GPUToNVVMExtPass> {
converter.addConversion([&](gpu::MMAMatrixType type) -> Type {
return convertMMAToLLVMType(type);
});
// Convert dummy tokens.
converter.addConversion([&](nvgpu::DeviceAsyncTokenType type) -> Type {
return converter.convertType(IntegerType::get(type.getContext(), 32));
});
RewritePatternSet llvmPatterns(m.getContext());

arith::populateArithToLLVMConversionPatterns(converter, llvmPatterns);
populateVectorToLLVMConversionPatterns(converter, llvmPatterns);
cf::populateControlFlowToLLVMConversionPatterns(converter, llvmPatterns);
populateFuncToLLVMConversionPatterns(converter, llvmPatterns);
populateFinalizeMemRefToLLVMConversionPatterns(converter, llvmPatterns);
populateGpuToNVVMConversionPatterns(converter, llvmPatterns);
populateNVGPUToNVVMConversionPatterns(converter, llvmPatterns);
populateGpuWMMAToNVVMConversionPatterns(converter, llvmPatterns);
#if 0
// FIXME: enable if gpu arch >= sm_75
Expand Down
1 change: 0 additions & 1 deletion compiler/lib/Pipelines/GPU/GPUOpt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,6 @@ void createReductionGPUOptPipelineImpl(OpPassManager &pm) {
}

void createGemmGPUOptPipelineImpl(OpPassManager &pm) {
// TODO(YangXinyu): Get workgroup size from config!
GPUMappingForallOptions options;
options.funcAnchor = getByteIRMatmulEpilogueFusionAttrName().str();
options.annotatePrefix = "__byteir_gpu_gemm_tile";
Expand Down
2 changes: 1 addition & 1 deletion compiler/lib/Pipelines/GPU/NVVMCodegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ void createNVVMCodegenPipelineImpl(OpPassManager &pm,
pm.addPass(createSimplifyLinearizedIndexPass());
pm.addPass(createCanonicalizerPass());
pm.addPass(createCSEPass());
pm.addNestedPass<gpu::GPUModuleOp>(createConvertVectorToLLVMPass());
// pm.addNestedPass<gpu::GPUModuleOp>(createConvertVectorToLLVMPass());
pm.addNestedPass<gpu::GPUModuleOp>(createGPUToNVVMExtPass(
useBarePtrCallConv, mlir::kDeriveIndexBitwidthFromDataLayout, gpuArch));
pm.addPass(createCSEPass());
Expand Down

0 comments on commit 0f2d6eb

Please sign in to comment.