Skip to content

Commit

Permalink
add gpu opt
Browse files Browse the repository at this point in the history
  • Loading branch information
yangxinyu committed Jun 20, 2024
1 parent 8343894 commit 753e771
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,6 @@ void packSharedMemoryAlloc(scf::ForallOp forallOp) {
allocs.push_back(allocOp);
}
});
llvm::errs() << "Found " << allocs.size() << " shared memory allocations\n";
// First sink the alloc as low as possible in the CFG.
sinkOpsInCFG(allocs, dominators);
SmallVector<AliasGroup> aliasGroups;
Expand Down
33 changes: 33 additions & 0 deletions compiler/lib/Pipelines/GPU/GPUOpt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -126,10 +126,43 @@ void createReductionGPUOptPipelineImpl(OpPassManager &pm) {
pm.addPass(createGpuKernelOutliningPass());
}

void createGemmGPUOptPipelineImpl(OpPassManager &pm) {
// TODO(YangXinyu): Get workgroup size from config!
GPUMappingForallOptions options;
options.funcAnchor = getByteIRMatmulEpilogueFusionAttrName().str();
options.annotatePrefix = "__byteir_gpu_gemm_tile";
createGPUMappingForallTransform(pm, options);
pm.addPass(createTransformDialectInterpreter(true));
pm.addPass(createCSEPass());
pm.addPass(createCanonicalizerPass());
pm.addPass(createGpuLauchSinkIndexComputationsPass());

{
OpPassManager anchoredPM(func::FuncOp::getOperationName());

anchoredPM.addPass(createPromoteBuffersToStackPass(
/*isSmallAlloc =*/[](Value value) {
return value.getParentRegion()->getParentOfType<gpu::LaunchOp>();
}));

pm.addNestedPass<func::FuncOp>(createAnchoredPipelinePass(
getByteIRMatmulEpilogueFusionAttrName(), anchoredPM));
}
pm.addPass(createGpuKernelOutliningPass());
{
OpPassManager anchoredPM(func::FuncOp::getOperationName());
// anchoredPM.addPass(createSetSharedMemorySizePass());

pm.addNestedPass<func::FuncOp>(createAnchoredPipelinePass(
getByteIRMatmulEpilogueFusionAttrName(), anchoredPM));
}
}

void createGPUOptPipelineImpl(OpPassManager &pm, const bool &useBarePtrCallConv,
const std::string &target) {
createElementwiseGPUOptPipelineImpl(pm, useBarePtrCallConv, target);
createReductionGPUOptPipelineImpl(pm);
createGemmGPUOptPipelineImpl(pm);
pm.addPass(createCollectGPUKernelPass("unified", false));
}

Expand Down

0 comments on commit 753e771

Please sign in to comment.