-
Notifications
You must be signed in to change notification settings - Fork 44
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
7 changed files
with
35 additions
and
94 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,17 +1,8 @@ | ||
From 3a14562a78e712a1ea495ab0a6ba55b3afbc6ef0 Mon Sep 17 00:00:00 2001 | ||
From: "quanbo.liu" <[email protected]> | ||
Date: Mon, 22 Jan 2024 10:04:42 +0800 | ||
Subject: [PATCH] fix bug of create f16 const for HoistCwiseBinaryOutOfConcat | ||
|
||
--- | ||
.../compiler/mlir/tensorflow/ir/tf_ops_a_m.cc | 22 +++++++++++++------ | ||
1 file changed, 15 insertions(+), 7 deletions(-) | ||
|
||
diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc | ||
index f66c996f32a..2b61a57f488 100644 | ||
index 36fb36a3d45..9417d8c2f5a 100644 | ||
--- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc | ||
+++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc | ||
@@ -1317,12 +1317,12 @@ LogicalResult HoistCwiseBinaryOutOfConcat::matchAndRewrite( | ||
@@ -1339,12 +1339,12 @@ LogicalResult HoistCwiseBinaryOutOfConcat::matchAndRewrite( | ||
// Process `exceptions`: For each value there, synthesize a binary op of the | ||
// above kind, so that the concat hoisting optimization can still apply. | ||
if (!exceptions.empty()) { | ||
|
@@ -27,30 +18,22 @@ index f66c996f32a..2b61a57f488 100644 | |
else | ||
return failure(); | ||
DenseElementsAttr const_attr; | ||
@@ -1331,11 +1331,19 @@ LogicalResult HoistCwiseBinaryOutOfConcat::matchAndRewrite( | ||
.getType() | ||
.dyn_cast<ShapedType>(); | ||
Type scalar_dtype = scalar_tensor_type.getElementType(); | ||
- if (scalar_dtype.isa<FloatType>()) | ||
- const_attr = DenseElementsAttr::get(scalar_tensor_type, | ||
- static_cast<float>(identity_val)); | ||
@@ -1354,7 +1354,17 @@ LogicalResult HoistCwiseBinaryOutOfConcat::matchAndRewrite( | ||
if (mlir::isa<FloatType>(scalar_dtype)) | ||
const_attr = DenseElementsAttr::get(scalar_tensor_type, | ||
static_cast<float>(identity_val)); | ||
- else | ||
+ if (scalar_dtype.isa<FloatType>()) { | ||
+ //const_attr = DenseFPElementsAttr::get(scalar_tensor_type, APFloat(identity_val)); | ||
+ if (mlir::isa<FloatType>(scalar_dtype)) { | ||
+ // const_attr = DenseFPElementsAttr::get(scalar_tensor_type, APFloat(identity_val)); | ||
+ APFloat epsilonFloat = APFloat(identity_val); | ||
+ bool losesInfo = false; | ||
+ auto status = epsilonFloat.convert( | ||
+ scalar_dtype.cast<FloatType>().getFloatSemantics(), APFloat::rmNearestTiesToEven, &losesInfo); | ||
+ mlir::cast<FloatType>(scalar_dtype).getFloatSemantics(), APFloat::rmNearestTiesToEven, &losesInfo); | ||
+ if(losesInfo || status != llvm::APFloatBase::opStatus::opOK) { | ||
+ return op.emitError("float type conversion failed"); | ||
+ } | ||
+ const_attr = DenseElementsAttr::get(scalar_tensor_type, epsilonFloat); | ||
+ } else { | ||
return failure(); | ||
+ } | ||
|
||
// All checks are passes, and we now prepare for rewrite. | ||
auto identity_const = rewriter.create<TF::ConstOp>(loc, const_attr); | ||
-- | ||
2.30.2 | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,17 +1,8 @@ | ||
From 1bf4d5368ea4e97ac57fbdc1b772904969aff97e Mon Sep 17 00:00:00 2001 | ||
From: "quanbo.liu" <[email protected]> | ||
Date: Sun, 31 Dec 2023 11:38:26 +0800 | ||
Subject: [PATCH] [Fix] support tf shape inference | ||
|
||
--- | ||
.../tensorflow/transforms/shape_inference.cc | 107 +++++++++++------- | ||
1 file changed, 67 insertions(+), 40 deletions(-) | ||
|
||
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc | ||
index 92bbc1f5a99..1fe63ef49fe 100644 | ||
index 6a9527aea26..4295e2b8624 100644 | ||
--- a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc | ||
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc | ||
@@ -1121,6 +1121,40 @@ bool ShapeInference::InferShapeForCast(Operation* op) { | ||
@@ -1149,6 +1149,40 @@ bool ShapeInference::InferShapeForCast(Operation* op) { | ||
return UpdateTypeAndInsertIncompatibleUseCasts(new_type, result); | ||
} | ||
|
||
|
@@ -20,10 +11,10 @@ index 92bbc1f5a99..1fe63ef49fe 100644 | |
+Type GetElementTypeFromOperand(TensorType operand_type, | ||
+ TensorType result_type) { | ||
+ auto operand_handle_type = | ||
+ operand_type.getElementType().dyn_cast<TensorFlowTypeWithSubtype>(); | ||
+ mlir::dyn_cast<TensorFlowTypeWithSubtype>(operand_type.getElementType()); | ||
+ if (!operand_handle_type) return result_type.getElementType(); | ||
+ auto result_handle_type = | ||
+ result_type.getElementType().cast<TensorFlowTypeWithSubtype>(); | ||
+ mlir::cast<TensorFlowTypeWithSubtype>(result_type.getElementType()); | ||
+ if (operand_handle_type.GetSubtypes().empty() || | ||
+ !result_handle_type.GetSubtypes().empty()) | ||
+ return result_type.getElementType(); | ||
|
@@ -52,17 +43,17 @@ index 92bbc1f5a99..1fe63ef49fe 100644 | |
bool ShapeInference::InferShapeForIf(IfOp op) { | ||
DCOMMENT_OP(op.getOperation(), "Infer shape for if "); | ||
bool changed = false; | ||
@@ -1130,8 +1164,22 @@ bool ShapeInference::InferShapeForIf(IfOp op) { | ||
@@ -1158,8 +1192,22 @@ bool ShapeInference::InferShapeForIf(IfOp op) { | ||
op.ResolveElseFunction(&symbol_table_).getFunctionType().getResults(); | ||
for (auto it : llvm::zip(op.getResults(), then_results, else_results)) { | ||
// If then and else types do not match, skip refinement for that result. | ||
- if (std::get<1>(it) != std::get<2>(it)) continue; | ||
- changed = RefineResultType(op, std::get<0>(it), std::get<1>(it)) || changed; | ||
+ //if (std::get<1>(it) != std::get<2>(it)) continue; | ||
+ // if (std::get<1>(it) != std::get<2>(it)) continue; | ||
+ auto lhs_type = std::get<1>(it); | ||
+ auto rhs_type = std::get<2>(it); | ||
+ auto lhs_rank_type = lhs_type.dyn_cast<RankedTensorType>(); | ||
+ auto rhs_rank_type = rhs_type.dyn_cast<RankedTensorType>(); | ||
+ auto lhs_rank_type = mlir::dyn_cast<RankedTensorType>(lhs_type); | ||
+ auto rhs_rank_type = mlir::dyn_cast<RankedTensorType>(rhs_type); | ||
+ auto expected_type = lhs_type; | ||
+ if (lhs_type != rhs_type) { | ||
+ if(lhs_rank_type && rhs_rank_type && | ||
|
@@ -77,7 +68,7 @@ index 92bbc1f5a99..1fe63ef49fe 100644 | |
} | ||
return changed; | ||
} | ||
@@ -1141,12 +1189,25 @@ bool ShapeInference::InferShapeForIfRegion(IfRegionOp op) { | ||
@@ -1169,12 +1217,25 @@ bool ShapeInference::InferShapeForIfRegion(IfRegionOp op) { | ||
|
||
Operation* then_yield = op.getThenBranch().front().getTerminator(); | ||
Operation* else_yield = op.getElseBranch().front().getTerminator(); | ||
|
@@ -88,7 +79,7 @@ index 92bbc1f5a99..1fe63ef49fe 100644 | |
- if (std::get<1>(result) != std::get<2>(result)) continue; | ||
- changed = RefineResultType(op, std::get<0>(result), std::get<1>(result)) || | ||
- changed; | ||
+ //if (std::get<1>(it) != std::get<2>(it)) continue; | ||
+ // if (std::get<1>(it) != std::get<2>(it)) continue; | ||
+ auto lhs_type = std::get<1>(it); | ||
+ auto rhs_type = std::get<2>(it); | ||
+ auto lhs_rank_type = lhs_type.dyn_cast<RankedTensorType>(); | ||
|
@@ -107,7 +98,7 @@ index 92bbc1f5a99..1fe63ef49fe 100644 | |
} | ||
return changed; | ||
} | ||
@@ -2318,21 +2379,6 @@ bool ShapeInference::InferShapeForNonTFDialectOperation(Operation* op) { | ||
@@ -2468,21 +2529,6 @@ bool ShapeInference::InferShapeForNonTFDialectOperation(Operation* op) { | ||
return false; | ||
} | ||
|
||
|
@@ -116,10 +107,10 @@ index 92bbc1f5a99..1fe63ef49fe 100644 | |
-Type GetElementTypeFromOperand(TensorType operand_type, | ||
- TensorType result_type) { | ||
- auto operand_handle_type = | ||
- operand_type.getElementType().dyn_cast<TensorFlowTypeWithSubtype>(); | ||
- mlir::dyn_cast<TensorFlowTypeWithSubtype>(operand_type.getElementType()); | ||
- if (!operand_handle_type) return result_type.getElementType(); | ||
- auto result_handle_type = | ||
- result_type.getElementType().cast<TensorFlowTypeWithSubtype>(); | ||
- mlir::cast<TensorFlowTypeWithSubtype>(result_type.getElementType()); | ||
- if (operand_handle_type.GetSubtypes().empty() || | ||
- !result_handle_type.GetSubtypes().empty()) | ||
- return result_type.getElementType(); | ||
|
@@ -129,7 +120,7 @@ index 92bbc1f5a99..1fe63ef49fe 100644 | |
// Checks if one tensor type can refine another type for tf.While/ | ||
// tf.WhileRegion. If rank differs or static dimensions can be lost, the other | ||
// type cannot be used for refinement. | ||
@@ -2692,25 +2738,6 @@ bool RankedAndSameRank(TensorType lhs, TensorType rhs) { | ||
@@ -2855,25 +2901,6 @@ bool RankedAndSameRank(TensorType lhs, TensorType rhs) { | ||
return lhs.hasRank() && rhs.hasRank() && lhs.getRank() == rhs.getRank(); | ||
} | ||
|
||
|
@@ -155,6 +146,3 @@ index 92bbc1f5a99..1fe63ef49fe 100644 | |
// Finds compatible types to propagate into functions/regions of a shape | ||
// invariant tf.While/tf.WhileRegion. If operand and result types are the same, | ||
// that type is returned. If operand and result types are of the same rank, a | ||
-- | ||
2.30.2 | ||
|
21 changes: 4 additions & 17 deletions
21
frontends/tf-frontend/external/patches/tensorflow/tf.Select_to_mhlo.select.patch
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,27 +1,14 @@ | ||
From 3f9ae8f63872a748974652833ae8dc6c47f18267 Mon Sep 17 00:00:00 2001 | ||
From: "quanbo.liu" <[email protected]> | ||
Date: Thu, 7 Sep 2023 19:02:24 +0800 | ||
Subject: [PATCH] Do not convert tf.Select to mhlo.select when type of input is | ||
tf_type.string | ||
|
||
--- | ||
tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf.cc | 3 +++ | ||
1 file changed, 3 insertions(+) | ||
|
||
diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf.cc | ||
index 5853c59664c..98d725215ab 100644 | ||
index 13c9c3f9306..fc78a4420a5 100644 | ||
--- a/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf.cc | ||
+++ b/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf.cc | ||
@@ -2838,6 +2838,9 @@ class ConvertSelectOp : public OpRewritePattern<TF::SelectOp> { | ||
@@ -2868,6 +2868,9 @@ class ConvertSelectOp : public OpRewritePattern<TF::SelectOp> { | ||
|
||
LogicalResult matchAndRewrite(TF::SelectOp op, | ||
PatternRewriter &rewriter) const override { | ||
+ if(op.getOutput().getType().getElementType().isa<mlir::TF::StringType>()) { | ||
+ return failure(); | ||
+ } | ||
// This lowering only works on ranked types. | ||
auto cond_type = op.getCondition().getType().dyn_cast<RankedTensorType>(); | ||
auto then_type = op.getThenValue().getType().dyn_cast<RankedTensorType>(); | ||
-- | ||
2.20.1 | ||
|
||
auto cond_type = | ||
mlir::dyn_cast<RankedTensorType>(op.getCondition().getType()); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
6 changes: 3 additions & 3 deletions
6
frontends/tf-frontend/external/patches/tensorflow/tf_dilated_conv.patch
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters