Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rewrite Index Tree Dialect #79

Draft
wants to merge 77 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
77 commits
Select commit Hold shift + click to select a range
4dac7bc
Begin working on IndexTree transformations
AK2000 Nov 1, 2023
20f916a
Fixing some of the problems introduced on merge
AK2000 Nov 1, 2023
864c9f8
Resolved including of device mapping attribute
AK2000 Nov 1, 2023
0447561
Fixed type inclusion, parsing and printing
AK2000 Nov 1, 2023
cf505bd
V1 - Lower TA to new IndexTree ops, but removed everything else
AK2000 Nov 6, 2023
3f4e849
Fixes to TA to change how file is included
AK2000 Nov 6, 2023
d6bc7fd
Creating new block for index tree
AK2000 Nov 8, 2023
48fe5fd
Implement domain inference pass, fix to index ordering
AK2000 Nov 10, 2023
2d6ba90
[WIP] Fragile version of index tree to SCF lowering
AK2000 Nov 30, 2023
972ee7e
Fix carrying tensors inside loop, refactor domain concretization
AK2000 Dec 12, 2023
c29ee1f
Adding TA to index tree patterns for elementwise operations
AK2000 Dec 13, 2023
8d44653
[WIP] Trying to implement intersection op lowering
AK2000 Dec 19, 2023
f25fb52
[WIP] Got domain intersection working, but only with dense output
AK2000 Jan 3, 2024
35cf1c5
[WIP] Minor fix to ordering of reduce args
AK2000 Jan 3, 2024
04b31ba
[WIP] Beginning support for sparse output tensors with new index tree
AK2000 Jan 10, 2024
1c06a26
[WIP] Inlined itree op, got hacky version of removing set op working
AK2000 Jan 11, 2024
00801b1
[WIP] Included lowering to LLVM, lowering print op does not work
AK2000 Jan 12, 2024
352da6a
[WIP] Almost got print op lowering working
AK2000 Jan 15, 2024
9436939
[WIP] Fixed bufferization
AK2000 Jan 18, 2024
ac98adf
[WIP] Generate symbolic pass for sparse tensor declarations
AK2000 Jan 24, 2024
935079d
[WIP] Lots of changes for first try at symbolic domain pass and works…
AK2000 Feb 13, 2024
0c2d98d
[WIP] Broke everythink trying to redo tensor conversion infrastructure
AK2000 Feb 21, 2024
bb32d70
Changing alot to create new sparse tensor types, and appropriate lowe…
AK2000 May 15, 2024
ac27cfb
Fixing some problems with tests, ad pure ops
AK2000 Jun 5, 2024
52495d4
Fixed inconsistencies in test suite
AK2000 Jun 5, 2024
27dd373
Fixing more of the test cases
AK2000 Jun 13, 2024
8bc6187
Fixed dense transpose and print elapsed time
AK2000 Jun 17, 2024
3f777dd
Fixing errors in typing
AK2000 Jun 19, 2024
8cb0d2a
Fixing errors in typing and set op
AK2000 Jun 19, 2024
dc0e445
Adding back ttgt pass
AK2000 Jun 20, 2024
ffe69f2
Fixed delete before use errors
AK2000 Oct 14, 2024
1a7eb01
Another bug found with asan
AK2000 Oct 14, 2024
c32a1cd
Rebased pull request with master
pthomadakis Oct 15, 2024
2153358
1) Fixed minor bugs in IndexTreeToSCF conversion
pthomadakis Oct 16, 2024
e0c95d7
Fixed another use-after-erase bug
pthomadakis Oct 17, 2024
ebc2efc
Fixed some more bugs coming from merging
pthomadakis Oct 17, 2024
5ac3316
Enforce tensor semantics on scalar op
AK2000 Oct 17, 2024
cdfa968
Added semiring operations to work correct assumption of multiplicativ…
AK2000 Oct 29, 2024
83660c7
Merge pull request #67 from AK2000/fix-scalar-op
AK2000 Oct 29, 2024
2705351
Merge pull request #68 from AK2000/fix-semiring-tests
AK2000 Oct 29, 2024
aad9afc
[WIP] Enabling function support
pthomadakis Oct 25, 2024
870b0e6
[WIP] Enabling function support. Dense transpose compound expressions…
pthomadakis Oct 29, 2024
8dd185f
[WIP] Enabling function support. Made transpose operation a littel mo…
pthomadakis Oct 29, 2024
27eacee
Refactored SpConstructOp
pthomadakis Oct 30, 2024
cc73ecd
Fixed func.return when returning SparseTensorType
pthomadakis Oct 30, 2024
aaca5f9
Changed ta.sum to return a scalar instead of a memref.
pthomadakis Nov 1, 2024
e565571
Decouple semiring attribute from domain inference
AK2000 Nov 5, 2024
5b66b30
"fixed" triangle counting test :)
AK2000 Nov 5, 2024
6efa6e3
Merge pull request #72 from AK2000/fix-semiring-tests
gkestor Nov 6, 2024
e4a987a
Merge pull request #73 from AK2000/fix-triangle-counting
gkestor Nov 6, 2024
880c9cc
[COMETPY] Fixed issue with COO formats
pthomadakis Nov 7, 2024
de6469d
[COMETPY] Update tests to make use of pytest and removed custom test …
pthomadakis Nov 7, 2024
5cfd397
[COMETPY] Added data directory to tests
pthomadakis Nov 7, 2024
bfd5993
[COMETPY] Added missing file
pthomadakis Nov 7, 2024
f4fe43a
Added support for f32 data #69
pthomadakis Nov 1, 2024
836f8ce
Updated numpy-scipy side to support f32 #69
pthomadakis Nov 2, 2024
664a7df
Updated SparseTensor type to use the respective enum instead of raw n…
pthomadakis Nov 7, 2024
878c711
Sparse tensor indices crd(_tile),pos(_tile) now have an explicit inte…
pthomadakis Nov 9, 2024
89ee8c4
[COMETPY][WIP] Enabling i32 indices for sparse matrices
pthomadakis Nov 10, 2024
001cf29
Fixed bugs in Index to Integer casts and vice-versa
pthomadakis Nov 10, 2024
6aacaa5
[COMETPY] Finished support for i32,i64 sparse matrix indices.
pthomadakis Nov 10, 2024
978b77c
[WIP] reviewing new index tree dialect and related test cases (added …
gkestor Oct 25, 2024
ac744a9
blis_interface is updated - debug prints removed
gkestor Oct 29, 2024
304f20e
Bug is fixed in the dense transpose optimization. Testcases are organ…
gkestor Nov 13, 2024
4433b75
Fixed path in CMakeLists.txt
pthomadakis Nov 15, 2024
33621c7
Changing how index tree to SCF conversion pass is structured
AK2000 Aug 19, 2024
5557cb3
[WIP] adding semiring zero
AK2000 Nov 7, 2024
4176e61
[WIP] Symbolic pass and workspace transform not working
AK2000 Nov 11, 2024
814240f
Fixed issues with workspace transformation and symbolic pass
AK2000 Nov 11, 2024
5112041
Resolved conflicts with restructure SCF conversion
pthomadakis Nov 15, 2024
a25befd
[COMETPY] Updated sparse tensor representation
pthomadakis Nov 17, 2024
4253d33
[WIP] Got masking partially working but broke other things
AK2000 Nov 18, 2024
b93b044
[WIP] Fixed other problems, but still issues with Sandia_LL with masking
AK2000 Nov 18, 2024
d3551f4
Helped compilation for SandiaLL w Masking
AK2000 Nov 18, 2024
28f5cf7
Merge pull request #76 from AK2000/rewrite_masking
AK2000 Nov 20, 2024
b0fa0ae
Rearranging comet.cpp to emit it at different stage
AK2000 Dec 3, 2024
3ed0cf5
Initial implementation of parallel loops (#77)
AK2000 Dec 10, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
10 changes: 9 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,8 @@ add_custom_target(comet-headers)
set_target_properties(comet-headers PROPERTIES FOLDER "Misc")
add_custom_target(comet-doc)

set(CMAKE_INCLUDE_CURRENT_DIR ON)

# Add MLIR, LLVM and BLIS headers to the include path
include_directories(${LLVM_INCLUDE_DIRS})
include_directories(${MLIR_INCLUDE_DIRS})
Expand Down Expand Up @@ -141,7 +143,7 @@ endif()
add_subdirectory(include/comet)
add_subdirectory(lib)
add_subdirectory(frontends/comet_dsl)
add_subdirectory(integration_test)
add_subdirectory(test/integration)


option(COMET_INCLUDE_DOCS "Generate build targets for the COMET docs.")
Expand Down Expand Up @@ -202,3 +204,9 @@ if (STANDALONE_INSTALL)
message(STATUS "Setting an $ORIGIN-based RPATH on all executables")
set_rpath_all_targets(${CMAKE_CURRENT_SOURCE_DIR})
endif()

option(DEBUG_MODE "Create a installation with debug information" off)
if (DEBUG_MODE)
message(STATUS "Building comet in debug mode")
add_compile_options(-DCOMET_DEBUG_MODE)
endif()
18 changes: 9 additions & 9 deletions frontends/comet_dsl/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,15 @@ get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS)
llvm_update_compile_flags(comet-opt)

set(LIBS
MLIRAnalysis
MLIRIR
MLIRParser
MLIRPass
MLIRTransforms
COMETUtils
COMETTensorAlgebraDialect
COMETIndexTreeDialect
COMETIndexTreeToSCF
MLIRAnalysis
MLIRIR
MLIRParser
MLIRPass
MLIRTransforms
COMETUtils
COMETTensorAlgebraDialect
COMETIndexTreeDialect
# COMETIndexTreeToSCF
)

if(ENABLE_GPU_TARGET)
Expand Down
112 changes: 65 additions & 47 deletions frontends/comet_dsl/comet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
#include "mlir/Dialect/Func/Transforms/Passes.h"
#include "mlir/Dialect/Tensor/Transforms/Passes.h"
#include "mlir/Dialect/SCF/Transforms/Passes.h"
#include "mlir/Dialect/Bufferization/Transforms/Passes.h"


#include "mlir/Conversion/Passes.h"
#include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
Expand Down Expand Up @@ -339,6 +342,14 @@ int loadAndProcessMLIR(mlir::MLIRContext &context,

mlir::OpPassManager &optPM = pm.nest<mlir::func::FuncOp>();

/// Check to see if we are dumping to TA dialect.
if (emitTA)
{
if (mlir::failed(pm.run(*module)))
return 4;
return 0;
}

/// =============================================================================
/// High-level optimization at the TA dialect
/// Such as finding the optimal ordering of dense tensor contractions, or reformulating tensor contractions
Expand Down Expand Up @@ -375,25 +386,14 @@ int loadAndProcessMLIR(mlir::MLIRContext &context,
/// Generate the index tree IR
optPM.addPass(mlir::comet::createLowerTensorAlgebraToIndexTreePass(CodegenTarget));

if (OptKernelFusion)
{
/// Apply partial fusion on index tree dialect for some compound expressions.
optPM.addPass(mlir::comet::createIndexTreeKernelFusionPass());
}
// Create new pass manager to optimize the index tree dialect
optPM.addPass(mlir::comet::createIndexTreeDomainInferencePass());

if (OptWorkspace)
{
/// Optimized workspace transformations, reduce iteration space for nonzero elements
optPM.addPass(mlir::comet::createIndexTreeWorkspaceTransformationsPass());
}

/// Dump index tree dialect.
if (emitIT)
{
if (mlir::failed(pm.run(*module)))
return 4;
return 0;
}
// if (OptKernelFusion)
// {
// /// Apply partial fusion on index tree dialect for some compound expressions.
// optPM.addPass(mlir::comet::createIndexTreeKernelFusionPass());
// }
}

/// =============================================================================
Expand All @@ -408,7 +408,10 @@ int loadAndProcessMLIR(mlir::MLIRContext &context,
/// input and output sparse tensor declaration lowering are distant and need different information
optPM.addPass(mlir::comet::createSparseTensorDeclLoweringPass());
optPM.addPass(mlir::comet::createDenseTensorDeclLoweringPass());
optPM.addPass(mlir::comet::createSparseTempOutputTensorDeclLoweringPass());
optPM.addPass(mlir::comet::createSparseOutputTensorDeclLoweringPass());
optPM.addPass(mlir::comet::createTensorFillLoweringPass());
optPM.addPass(mlir::comet::createDimOpLoweringPass());

/// =============================================================================

Expand All @@ -419,9 +422,9 @@ int loadAndProcessMLIR(mlir::MLIRContext &context,
optPM.addPass(mlir::comet::createLoweringTTGTPass(IsSelectBestPermTTGT, selectedPermNum, IsPrintFlops));
}

/// =============================================================================
/// Operation based optimizations
/// =============================================================================
// /// =============================================================================
// /// Operation based optimizations
// /// =============================================================================
if (OptMatmulTiling)
{
optPM.addPass(mlir::comet::createLinAlgMatmulTilingPass());
Expand All @@ -435,34 +438,39 @@ int loadAndProcessMLIR(mlir::MLIRContext &context,
/// =============================================================================
/// Lowering all the operations to loops
/// =============================================================================
if (IsLoweringtoSCF || emitLoops || emitTriton_ || emitLLVM )
{

/// Workspace transformations will create new dense tensor declarations, so we need to call createDenseTensorDeclLoweringPass
optPM.addPass(mlir::comet::createDenseTensorDeclLoweringPass()); /// lowers dense input/output tensor declaration
optPM.addPass(mlir::comet::createSparseTempOutputTensorDeclLoweringPass()); /// Temporary sparse output tensor declarations introduced by compound expressions
/// should be lowered before sparse output tensor declarations
optPM.addPass(mlir::comet::createSparseOutputTensorDeclLoweringPass()); /// lowering for sparse output tensor declarations
//(sparse_output_tensor_decl and temp_sparse_output_tensor_decl)

optPM.addPass(mlir::comet::createDimOpLoweringPass());

/// The partial Fusion pass might add new tensor.fill operations
optPM.addPass(mlir::comet::createTensorFillLoweringPass());
optPM.addPass(mlir::comet::createPCToLoopsLoweringPass());

if (IsLoweringtoSCF || emitLoops || emitLLVM)
{
/// =============================================================================
/// Lowering of other operations such as transpose, sum, etc. to SCF dialect
/// =============================================================================
/// If it is a transpose of dense tensor, the rewrites rules replaces ta.transpose with linalg.copy.
/// If it is a transpose of sparse tensor, it lowers the code to make a runtime call to specific sorting algorithm
optPM.addPass(mlir::comet::createLowerTensorAlgebraToSCFPass());

/// Concretize the domains of all the index variables
optPM.addPass(mlir::comet::createIndexTreeDomainConcretizationPass());

if (OptWorkspace) {
/// Optimized workspace transformations, reduce iteration space for nonzero elements
optPM.addPass(mlir::comet::createIndexTreeWorkspaceTransformationsPass());
}

optPM.addPass(mlir::comet::createIndexTreeSymbolicComputePass());

/// Dump index tree dialect.
if (emitIT)
{
if (mlir::failed(pm.run(*module)))
return 4;
return 0;
}

/// Finally lowering index tree to SCF dialect
optPM.addPass(mlir::comet::createLowerIndexTreeToSCFPass());
optPM.addPass(mlir::tensor::createTensorBufferizePass());
pm.addPass(mlir::func::createFuncBufferizePass()); /// Needed for func
pm.addPass(mlir::createConvertLinalgToLoopsPass());
optPM.addPass(mlir::comet::createConvertSymbolicDomainsPass());
optPM.addPass(mlir::comet::createSparseTensorConversionPass());
optPM.addPass(mlir::comet::createIndexTreeInliningPass());
optPM.addPass(mlir::createCanonicalizerPass());

if (OptDenseTransposeOp) /// Optimize Dense Transpose operation
{
Expand All @@ -487,14 +495,23 @@ int loadAndProcessMLIR(mlir::MLIRContext &context,



/// =============================================================================
/// Late lowering passes
/// =============================================================================
// /// =============================================================================
// /// Late lowering passes
// /// =============================================================================
// pm.addPass(mlir::bufferization::createEmptyTensorToAllocTensorPass());
pm.addPass(mlir::comet::createTABufferizeFunc());
pm.addPass(mlir::createCanonicalizerPass());

mlir::bufferization::OneShotBufferizationOptions opts;
opts.allowUnknownOps = true;
pm.addPass(mlir::bufferization::createOneShotBufferizePass(opts));

optPM.addPass(mlir::comet::createSTCRemoveDeadOpsPass());
optPM.addPass(mlir::comet::createLateLoweringPass());
// pm.addPass(mlir::createCanonicalizerPass());
optPM.addPass(mlir::createCSEPass());
mlir::OpPassManager &late_lowering_pm = pm.nest<mlir::func::FuncOp>();
late_lowering_pm.addPass(mlir::comet::createSTCRemoveDeadOpsPass());
late_lowering_pm.addPass(mlir::comet::createLateLoweringPass());

pm.addPass(mlir::createCanonicalizerPass());
pm.addPass(mlir::createCSEPass());

#ifdef ENABLE_GPU_TARGET
if (CodegenTarget == TargetDevice::GPU && (emitTriton_ || emitLLVM || IsLoweringtoTriton))
Expand Down Expand Up @@ -616,6 +633,7 @@ int main(int argc, char **argv)
context.loadDialect<mlir::linalg::LinalgDialect>();
context.loadDialect<mlir::scf::SCFDialect>();
context.loadDialect<mlir::bufferization::BufferizationDialect>();
context.loadDialect<mlir::index::IndexDialect>();

mlir::OwningOpRef<mlir::ModuleOp> module;

Expand Down
Loading
Loading