diff --git a/compiler/include/byteir/Transforms/Passes.td b/compiler/include/byteir/Transforms/Passes.td index 17e88c1e6..b3d3e9853 100644 --- a/compiler/include/byteir/Transforms/Passes.td +++ b/compiler/include/byteir/Transforms/Passes.td @@ -440,7 +440,8 @@ def ShapeReification : Pass<"byteir-shape-reification", "func::FuncOp"> { let constructor = "mlir::createByteIRShapeReificationPass()"; let dependentDialects = [ "mlir::shape::ShapeDialect", - "mlir::tensor::TensorDialect" + "mlir::tensor::TensorDialect", + "mlir::arith::ArithDialect", ]; } diff --git a/compiler/lib/Transforms/PassDetail.h b/compiler/lib/Transforms/PassDetail.h index 05ade22f2..0a63eda78 100644 --- a/compiler/lib/Transforms/PassDetail.h +++ b/compiler/lib/Transforms/PassDetail.h @@ -43,6 +43,10 @@ namespace memref { class MemRefDialect; } // namespace memref +namespace arith { +class ArithDialect; +} // namespace arith + namespace mhlo { class MhloDialect; } // namespace mhlo diff --git a/compiler/lib/Transforms/ShapeReification.cpp b/compiler/lib/Transforms/ShapeReification.cpp index 382e2ed10..f196d93a6 100644 --- a/compiler/lib/Transforms/ShapeReification.cpp +++ b/compiler/lib/Transforms/ShapeReification.cpp @@ -20,6 +20,7 @@ #include "byteir/Dialect/mhlo/DynamicShapeOpRegister/Register.h" #include "byteir/Dialect/mhlo/Util/ShapeInferUtil.h" #include "mhlo/IR/hlo_ops.h" +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Shape/IR/Shape.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" @@ -59,8 +60,8 @@ struct ShapeReificationOnTensorDimPattern // Insert cast, if needed. if (dimOfShape.getType() != op.getType()) { - dimOfShape = rewriter.create(op.getLoc(), op.getType(), - dimOfShape); + dimOfShape = rewriter.create( + op.getLoc(), op.getType(), dimOfShape); } rewriter.replaceOp(op, dimOfShape);