Skip to content

Commit

Permalink
Handle dynamic shared memory
Browse files Browse the repository at this point in the history
  • Loading branch information
yangxinyu committed Jun 23, 2024
1 parent 0f2d6eb commit 298432c
Show file tree
Hide file tree
Showing 8 changed files with 163 additions and 20 deletions.
1 change: 1 addition & 0 deletions compiler/include/byteir/Dialect/GPU/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#ifndef BYTEIR_DIALECT_GPU_PASSES_H
#define BYTEIR_DIALECT_GPU_PASSES_H

#include "byteir/Dialect/GPU/Transforms/LegalizeGPULaunch.h"
#include "byteir/Dialect/GPU/Transforms/GPUBlockSwizzle.h"
#include "byteir/Dialect/GPU/Transforms/GPUDistributeSharedMemoryCopy.h"
#include "byteir/Dialect/GPU/Transforms/GPUDistributeToWarp.h"
Expand Down
8 changes: 8 additions & 0 deletions compiler/include/byteir/Dialect/GPU/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -127,4 +127,12 @@ def GPUVectorToGPU : Pass<"gpu-vector-to-gpu", "func::FuncOp"> {
"nvgpu::NVGPUDialect",
];
}

//===----------------------------------------------------------------------===//
// LegalizeGPULaunch
//===----------------------------------------------------------------------===//
def LegalizeGPULaunch : Pass<"legalize-gpu-launch", "func::FuncOp"> {
let summary = "Legalize GPU launch ops.";
let constructor = "mlir::createLegalizeGPULaunchPass()";
}
#endif // BYTEIR_DIALECT_GPU_PASSES
34 changes: 34 additions & 0 deletions compiler/include/byteir/Dialect/GPU/Transforms/LegalizeGPULaunch.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
//===- LegalizeGPULaunch.h ---------------------------------*--- C++ -*-===//
//
// Copyright 2024 ByteDance Ltd. and/or its affiliates. All rights reserved.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
//===----------------------------------------------------------------------===//

#ifndef BYTEIR_DIALECT_GPU_TRANSFORMS_LEGALIZEGPULAUNCH_H
#define BYTEIR_DIALECT_GPU_TRANSFORMS_LEGALIZEGPULAUNCH_H

#include "mlir/Pass/Pass.h"
#include "llvm/ADT/StringRef.h"
#include <memory>

namespace mlir {
namespace func {
class FuncOp;
} // namespace func

std::unique_ptr<OperationPass<func::FuncOp>> createLegalizeGPULaunchPass();

} // namespace mlir

#endif // BYTEIR_DIALECT_GPU_TRANSFORMS_LEGALIZEGPULAUNCH_H
8 changes: 8 additions & 0 deletions compiler/lib/Conversion/FuncToByre/FuncToByre.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,14 @@ class ConvertGPULaunchFuncToByrePattern
computeOp->setAttr("BlockSize.y", rewriter.getI32IntegerAttr(by));
computeOp->setAttr("BlockSize.z", rewriter.getI32IntegerAttr(bz));

auto sharedMemorySize = launchOp.getDynamicSharedMemorySize();
if (sharedMemorySize) {
auto sharedMemorySizeValue =
cast<arith::ConstantOp>(sharedMemorySize.getDefiningOp());
IntegerAttr smem = cast<IntegerAttr>(sharedMemorySizeValue.getValue());
computeOp->setAttr("DynamicSharedMemorySize", smem);
}

if (useBarePtrCallConv) {
computeOp->setAttr(byre::getKernelCallConventionAttrName(),
rewriter.getStringAttr("bare_ptr"));
Expand Down
36 changes: 24 additions & 12 deletions compiler/lib/Conversion/GPUToNVVM/GPUToNVVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,9 @@ using namespace mlir::NVVM;

namespace {

void ConvertToDynamicSharedMemory(GPUModuleOp moduleOp) {
static void ConvertToDynamicSharedMemory(GPUModuleOp moduleOp) {
SymbolTableCollection symbolTableCollection;
// Collect all the adressOfOps to static shared memory globals.
// Collect all the addressOfOps to static shared memory globals.
SmallVector<LLVM::AddressOfOp> addressOfOps;
moduleOp.walk([&](LLVM::AddressOfOp addressOfOp) {
// Check that the global associated with this addressOfOp has shared memory
Expand All @@ -80,17 +80,8 @@ void ConvertToDynamicSharedMemory(GPUModuleOp moduleOp) {
});
if (addressOfOps.size() == 0)
return;
OpBuilder builder(moduleOp);
builder.setInsertionPoint(&moduleOp.front());
auto type =
LLVM::LLVMArrayType::get(IntegerType::get(builder.getContext(), 8), 0);
LLVM::GlobalOp global = builder.create<LLVM::GlobalOp>(
moduleOp.getLoc(), type, /*isConstant=*/false, LLVM::Linkage::External,
"__dynamic_shared_memory__", Attribute(),
/*alignment=*/16, /*addr_space=*/3);

uint32_t numberOfBytes = 0;
// Replace the addressOfOps with correctly offseted pointers to dynamic
// shared memory.
llvm::SmallDenseMap<LLVM::GlobalOp, uint32_t> globalMemoryOffsetMap;
for (auto addressOfOp : addressOfOps) {
uint32_t offset = 0;
Expand All @@ -107,6 +98,26 @@ void ConvertToDynamicSharedMemory(GPUModuleOp moduleOp) {
DataLayout dataLayout = DataLayout::closest(addressOfOp);
numberOfBytes = offset + dataLayout.getTypeSizeInBits(thisarray) / 8;
}
}

// Check if numberOfBytes is less than 48 * 1024
if (numberOfBytes < 48 * 1024) {
return;
}

OpBuilder builder(moduleOp);
builder.setInsertionPoint(&moduleOp.front());
auto type =
LLVM::LLVMArrayType::get(IntegerType::get(builder.getContext(), 8), 0);
LLVM::GlobalOp global = builder.create<LLVM::GlobalOp>(
moduleOp.getLoc(), type, /*isConstant=*/false, LLVM::Linkage::External,
"__dynamic_shared_memory__", Attribute(),
/*alignment=*/16, /*addr_space=*/3);

// Replace the addressOfOps with correctly offseted pointers to dynamic
// shared memory.
for (auto addressOfOp : addressOfOps) {
uint32_t offset = globalMemoryOffsetMap[addressOfOp.getGlobal(symbolTableCollection)];
auto loc = addressOfOp.getLoc();
builder.setInsertionPoint(addressOfOp);
LLVM::AddressOfOp globalPtr =
Expand Down Expand Up @@ -416,6 +427,7 @@ struct GPUToNVVMExtPass : public GPUToNVVMExtBase<GPUToNVVMExtPass> {
}
}
});
ConvertToDynamicSharedMemory(m);
}
};

Expand Down
1 change: 1 addition & 0 deletions compiler/lib/Dialect/GPU/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
add_mlir_dialect_library(ByteIRGPUPasses
LegalizeGPULaunch.cpp
GPUBlockSwizzle.cpp
GPUDistributeSharedMemoryCopy.cpp
GPUDistributeToWarp.cpp
Expand Down
77 changes: 77 additions & 0 deletions compiler/lib/Dialect/GPU/Transforms/LegalizeGPULaunch.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
//===- LegalizeGPULaunch.cpp-*-===//
//
// Copyright 2022 ByteDance Ltd. and/or its affiliates. All rights reserved.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
//===----------------------------------------------------------------------===//

#include "byteir/Dialect/GPU/Transforms/LegalizeGPULaunch.h"
#include "byteir/Dialect/GPU/Transforms/Utils.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Visitors.h"
#include <string>

#include "PassDetail.h"

using namespace llvm;
using namespace mlir;

namespace {

static int64_t getSharedMemorySizeInGPULaunch(gpu::LaunchOp op) {
int64_t sharedMemSizeInBytes = 0;
op->walk([&](memref::AllocaOp allocaOp) {
sharedMemSizeInBytes +=
allocaOp.getType().getNumElements() *
allocaOp.getType().getElementType().getIntOrFloatBitWidth() / 8;
});
op->walk([&](memref::AllocOp allocOp) {
sharedMemSizeInBytes +=
allocOp.getType().getNumElements() *
allocOp.getType().getElementType().getIntOrFloatBitWidth() / 8;
});
return sharedMemSizeInBytes;
}

struct LegalizeGPULaunchPass
: public LegalizeGPULaunchBase<LegalizeGPULaunchPass> {
LegalizeGPULaunchPass() : LegalizeGPULaunchBase<LegalizeGPULaunchPass>() {}
void runOnOperation() override {
func::FuncOp funcOp = getOperation();
OpBuilder builder(funcOp.getContext());
auto launchOps = funcOp.getOps<gpu::LaunchOp>();
for (auto launchOp : launchOps) {
int64_t sharedMemSize = getSharedMemorySizeInGPULaunch(launchOp);
if (sharedMemSize < 48 * 1024) // 48kB
continue;
builder.setInsertionPoint(launchOp);
Value sharedMemSizeValue = builder.create<arith::ConstantOp>(
launchOp.getLoc(), builder.getI32IntegerAttr(sharedMemSize));
if (!launchOp.getDynamicSharedMemorySizeMutable().empty()) {
continue;
}
launchOp.getDynamicSharedMemorySizeMutable().append(
ValueRange{sharedMemSizeValue});
}
}
};
} // namespace

std::unique_ptr<OperationPass<func::FuncOp>>
mlir::createLegalizeGPULaunchPass() {
return std::make_unique<LegalizeGPULaunchPass>();
}
18 changes: 10 additions & 8 deletions compiler/lib/Pipelines/GPU/GPUOpt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,13 +108,13 @@ void createReductionGPUOptPipelineImpl(OpPassManager &pm) {

createGPUMappingForallTransform(pm, options);
pm.addPass(createTransformDialectInterpreter(true));
pm.addPass(createCSEPass());
pm.addPass(createCanonicalizerPass());
pm.addPass(createGpuLauchSinkIndexComputationsPass());

{
OpPassManager anchoredPM(func::FuncOp::getOperationName());

anchoredPM.addPass(createCSEPass());
anchoredPM.addPass(createCanonicalizerPass());
anchoredPM.addPass(createGpuLauchSinkIndexComputationsPass());
anchoredPM.addPass(createPromoteBuffersToStackPass(
/*isSmallAlloc =*/[](Value value) {
return value.getParentRegion()->getParentOfType<gpu::LaunchOp>();
Expand All @@ -132,13 +132,13 @@ void createGemmGPUOptPipelineImpl(OpPassManager &pm) {
options.annotatePrefix = "__byteir_gpu_gemm_tile";
createGPUMappingForallTransform(pm, options);
pm.addPass(createTransformDialectInterpreter(true));
pm.addPass(createCSEPass());
pm.addPass(createCanonicalizerPass());
pm.addPass(createGpuLauchSinkIndexComputationsPass());

{
OpPassManager anchoredPM(func::FuncOp::getOperationName());

anchoredPM.addPass(createCSEPass());
anchoredPM.addPass(createCanonicalizerPass());
anchoredPM.addPass(createGpuLauchSinkIndexComputationsPass());

anchoredPM.addPass(createPromoteBuffersToStackPass(
/*isSmallAlloc =*/[](Value value) {
return value.getParentRegion()->getParentOfType<gpu::LaunchOp>();
Expand All @@ -147,14 +147,16 @@ void createGemmGPUOptPipelineImpl(OpPassManager &pm) {
pm.addNestedPass<func::FuncOp>(createAnchoredPipelinePass(
getByteIRMatmulEpilogueFusionAttrName(), anchoredPM));
}
pm.addPass(createGpuKernelOutliningPass());
{
OpPassManager anchoredPM(func::FuncOp::getOperationName());

anchoredPM.addPass(createLegalizeGPULaunchPass());
// anchoredPM.addPass(createSetSharedMemorySizePass());

pm.addNestedPass<func::FuncOp>(createAnchoredPipelinePass(
getByteIRMatmulEpilogueFusionAttrName(), anchoredPM));
}
pm.addPass(createGpuKernelOutliningPass());
}

void createGPUOptPipelineImpl(OpPassManager &pm, const bool &useBarePtrCallConv,
Expand Down

0 comments on commit 298432c

Please sign in to comment.