Skip to content

Commit

Permalink
update tf patches
Browse files Browse the repository at this point in the history
  • Loading branch information
Vremold committed Jul 5, 2024
1 parent 81bb52e commit 844e17f
Show file tree
Hide file tree
Showing 7 changed files with 35 additions and 94 deletions.
Original file line number Diff line number Diff line change
@@ -1,17 +1,8 @@
From 3a14562a78e712a1ea495ab0a6ba55b3afbc6ef0 Mon Sep 17 00:00:00 2001
From: "quanbo.liu" <[email protected]>
Date: Mon, 22 Jan 2024 10:04:42 +0800
Subject: [PATCH] fix bug of create f16 const for HoistCwiseBinaryOutOfConcat

---
.../compiler/mlir/tensorflow/ir/tf_ops_a_m.cc | 22 +++++++++++++------
1 file changed, 15 insertions(+), 7 deletions(-)

diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc
index f66c996f32a..2b61a57f488 100644
index 36fb36a3d45..9417d8c2f5a 100644
--- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc
+++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc
@@ -1317,12 +1317,12 @@ LogicalResult HoistCwiseBinaryOutOfConcat::matchAndRewrite(
@@ -1339,12 +1339,12 @@ LogicalResult HoistCwiseBinaryOutOfConcat::matchAndRewrite(
// Process `exceptions`: For each value there, synthesize a binary op of the
// above kind, so that the concat hoisting optimization can still apply.
if (!exceptions.empty()) {
Expand All @@ -27,30 +18,22 @@ index f66c996f32a..2b61a57f488 100644
else
return failure();
DenseElementsAttr const_attr;
@@ -1331,11 +1331,19 @@ LogicalResult HoistCwiseBinaryOutOfConcat::matchAndRewrite(
.getType()
.dyn_cast<ShapedType>();
Type scalar_dtype = scalar_tensor_type.getElementType();
- if (scalar_dtype.isa<FloatType>())
- const_attr = DenseElementsAttr::get(scalar_tensor_type,
- static_cast<float>(identity_val));
@@ -1354,7 +1354,17 @@ LogicalResult HoistCwiseBinaryOutOfConcat::matchAndRewrite(
if (mlir::isa<FloatType>(scalar_dtype))
const_attr = DenseElementsAttr::get(scalar_tensor_type,
static_cast<float>(identity_val));
- else
+ if (scalar_dtype.isa<FloatType>()) {
+ //const_attr = DenseFPElementsAttr::get(scalar_tensor_type, APFloat(identity_val));
+ if (mlir::isa<FloatType>(scalar_dtype)) {
+ // const_attr = DenseFPElementsAttr::get(scalar_tensor_type, APFloat(identity_val));
+ APFloat epsilonFloat = APFloat(identity_val);
+ bool losesInfo = false;
+ auto status = epsilonFloat.convert(
+ scalar_dtype.cast<FloatType>().getFloatSemantics(), APFloat::rmNearestTiesToEven, &losesInfo);
+ mlir::cast<FloatType>(scalar_dtype).getFloatSemantics(), APFloat::rmNearestTiesToEven, &losesInfo);
+ if(losesInfo || status != llvm::APFloatBase::opStatus::opOK) {
+ return op.emitError("float type conversion failed");
+ }
+ const_attr = DenseElementsAttr::get(scalar_tensor_type, epsilonFloat);
+ } else {
return failure();
+ }

// All checks are passes, and we now prepare for rewrite.
auto identity_const = rewriter.create<TF::ConstOp>(loc, const_attr);
--
2.30.2

Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize.cc b/tensorflow/compiler/mlir/lite/transforms/optimize.cc
index 2ed3c0519d8..885e64dea93 100644
index 606be04a0f7..a3936717de4 100644
--- a/tensorflow/compiler/mlir/lite/transforms/optimize.cc
+++ b/tensorflow/compiler/mlir/lite/transforms/optimize.cc
@@ -2073,9 +2073,10 @@ struct FuseReshapeAndTransposeAroundBatchMatmul
@@ -2236,9 +2236,10 @@ struct FuseReshapeAndTransposeAroundBatchMatmul
SmallVector<int, 3> new_shape = {
static_cast<int>(transpose_input.getType().getDimSize(0)),
static_cast<int>(transpose_input.getType().getDimSize(1)),
Expand All @@ -14,5 +14,5 @@ index 2ed3c0519d8..885e64dea93 100644
+ transpose_input.getType().getShape().end(), 1,
+ std::multiplies<int64_t>()))};
auto shape_constant = rewriter.create<ConstOp>(
batch_matmul.getLoc(),
DenseIntElementsAttr::get(
batch_matmul.getLoc(), GetI32ElementsAttr(new_shape, &rewriter));
auto reshaped_input = rewriter.create<ReshapeOp>(
Original file line number Diff line number Diff line change
@@ -1,17 +1,8 @@
From 1bf4d5368ea4e97ac57fbdc1b772904969aff97e Mon Sep 17 00:00:00 2001
From: "quanbo.liu" <[email protected]>
Date: Sun, 31 Dec 2023 11:38:26 +0800
Subject: [PATCH] [Fix] support tf shape inference

---
.../tensorflow/transforms/shape_inference.cc | 107 +++++++++++-------
1 file changed, 67 insertions(+), 40 deletions(-)

diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc
index 92bbc1f5a99..1fe63ef49fe 100644
index 6a9527aea26..4295e2b8624 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc
@@ -1121,6 +1121,40 @@ bool ShapeInference::InferShapeForCast(Operation* op) {
@@ -1149,6 +1149,40 @@ bool ShapeInference::InferShapeForCast(Operation* op) {
return UpdateTypeAndInsertIncompatibleUseCasts(new_type, result);
}

Expand All @@ -20,10 +11,10 @@ index 92bbc1f5a99..1fe63ef49fe 100644
+Type GetElementTypeFromOperand(TensorType operand_type,
+ TensorType result_type) {
+ auto operand_handle_type =
+ operand_type.getElementType().dyn_cast<TensorFlowTypeWithSubtype>();
+ mlir::dyn_cast<TensorFlowTypeWithSubtype>(operand_type.getElementType());
+ if (!operand_handle_type) return result_type.getElementType();
+ auto result_handle_type =
+ result_type.getElementType().cast<TensorFlowTypeWithSubtype>();
+ mlir::cast<TensorFlowTypeWithSubtype>(result_type.getElementType());
+ if (operand_handle_type.GetSubtypes().empty() ||
+ !result_handle_type.GetSubtypes().empty())
+ return result_type.getElementType();
Expand Down Expand Up @@ -52,17 +43,17 @@ index 92bbc1f5a99..1fe63ef49fe 100644
bool ShapeInference::InferShapeForIf(IfOp op) {
DCOMMENT_OP(op.getOperation(), "Infer shape for if ");
bool changed = false;
@@ -1130,8 +1164,22 @@ bool ShapeInference::InferShapeForIf(IfOp op) {
@@ -1158,8 +1192,22 @@ bool ShapeInference::InferShapeForIf(IfOp op) {
op.ResolveElseFunction(&symbol_table_).getFunctionType().getResults();
for (auto it : llvm::zip(op.getResults(), then_results, else_results)) {
// If then and else types do not match, skip refinement for that result.
- if (std::get<1>(it) != std::get<2>(it)) continue;
- changed = RefineResultType(op, std::get<0>(it), std::get<1>(it)) || changed;
+ //if (std::get<1>(it) != std::get<2>(it)) continue;
+ // if (std::get<1>(it) != std::get<2>(it)) continue;
+ auto lhs_type = std::get<1>(it);
+ auto rhs_type = std::get<2>(it);
+ auto lhs_rank_type = lhs_type.dyn_cast<RankedTensorType>();
+ auto rhs_rank_type = rhs_type.dyn_cast<RankedTensorType>();
+ auto lhs_rank_type = mlir::dyn_cast<RankedTensorType>(lhs_type);
+ auto rhs_rank_type = mlir::dyn_cast<RankedTensorType>(rhs_type);
+ auto expected_type = lhs_type;
+ if (lhs_type != rhs_type) {
+ if(lhs_rank_type && rhs_rank_type &&
Expand All @@ -77,7 +68,7 @@ index 92bbc1f5a99..1fe63ef49fe 100644
}
return changed;
}
@@ -1141,12 +1189,25 @@ bool ShapeInference::InferShapeForIfRegion(IfRegionOp op) {
@@ -1169,12 +1217,25 @@ bool ShapeInference::InferShapeForIfRegion(IfRegionOp op) {

Operation* then_yield = op.getThenBranch().front().getTerminator();
Operation* else_yield = op.getElseBranch().front().getTerminator();
Expand All @@ -88,7 +79,7 @@ index 92bbc1f5a99..1fe63ef49fe 100644
- if (std::get<1>(result) != std::get<2>(result)) continue;
- changed = RefineResultType(op, std::get<0>(result), std::get<1>(result)) ||
- changed;
+ //if (std::get<1>(it) != std::get<2>(it)) continue;
+ // if (std::get<1>(it) != std::get<2>(it)) continue;
+ auto lhs_type = std::get<1>(it);
+ auto rhs_type = std::get<2>(it);
+ auto lhs_rank_type = lhs_type.dyn_cast<RankedTensorType>();
Expand All @@ -107,7 +98,7 @@ index 92bbc1f5a99..1fe63ef49fe 100644
}
return changed;
}
@@ -2318,21 +2379,6 @@ bool ShapeInference::InferShapeForNonTFDialectOperation(Operation* op) {
@@ -2468,21 +2529,6 @@ bool ShapeInference::InferShapeForNonTFDialectOperation(Operation* op) {
return false;
}

Expand All @@ -116,10 +107,10 @@ index 92bbc1f5a99..1fe63ef49fe 100644
-Type GetElementTypeFromOperand(TensorType operand_type,
- TensorType result_type) {
- auto operand_handle_type =
- operand_type.getElementType().dyn_cast<TensorFlowTypeWithSubtype>();
- mlir::dyn_cast<TensorFlowTypeWithSubtype>(operand_type.getElementType());
- if (!operand_handle_type) return result_type.getElementType();
- auto result_handle_type =
- result_type.getElementType().cast<TensorFlowTypeWithSubtype>();
- mlir::cast<TensorFlowTypeWithSubtype>(result_type.getElementType());
- if (operand_handle_type.GetSubtypes().empty() ||
- !result_handle_type.GetSubtypes().empty())
- return result_type.getElementType();
Expand All @@ -129,7 +120,7 @@ index 92bbc1f5a99..1fe63ef49fe 100644
// Checks if one tensor type can refine another type for tf.While/
// tf.WhileRegion. If rank differs or static dimensions can be lost, the other
// type cannot be used for refinement.
@@ -2692,25 +2738,6 @@ bool RankedAndSameRank(TensorType lhs, TensorType rhs) {
@@ -2855,25 +2901,6 @@ bool RankedAndSameRank(TensorType lhs, TensorType rhs) {
return lhs.hasRank() && rhs.hasRank() && lhs.getRank() == rhs.getRank();
}

Expand All @@ -155,6 +146,3 @@ index 92bbc1f5a99..1fe63ef49fe 100644
// Finds compatible types to propagate into functions/regions of a shape
// invariant tf.While/tf.WhileRegion. If operand and result types are the same,
// that type is returned. If operand and result types are of the same rank, a
--
2.30.2

Original file line number Diff line number Diff line change
@@ -1,27 +1,14 @@
From 3f9ae8f63872a748974652833ae8dc6c47f18267 Mon Sep 17 00:00:00 2001
From: "quanbo.liu" <[email protected]>
Date: Thu, 7 Sep 2023 19:02:24 +0800
Subject: [PATCH] Do not convert tf.Select to mhlo.select when type of input is
tf_type.string

---
tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf.cc | 3 +++
1 file changed, 3 insertions(+)

diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf.cc
index 5853c59664c..98d725215ab 100644
index 13c9c3f9306..fc78a4420a5 100644
--- a/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf.cc
+++ b/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf.cc
@@ -2838,6 +2838,9 @@ class ConvertSelectOp : public OpRewritePattern<TF::SelectOp> {
@@ -2868,6 +2868,9 @@ class ConvertSelectOp : public OpRewritePattern<TF::SelectOp> {

LogicalResult matchAndRewrite(TF::SelectOp op,
PatternRewriter &rewriter) const override {
+ if(op.getOutput().getType().getElementType().isa<mlir::TF::StringType>()) {
+ return failure();
+ }
// This lowering only works on ranked types.
auto cond_type = op.getCondition().getType().dyn_cast<RankedTensorType>();
auto then_type = op.getThenValue().getType().dyn_cast<RankedTensorType>();
--
2.20.1

auto cond_type =
mlir::dyn_cast<RankedTensorType>(op.getCondition().getType());
17 changes: 0 additions & 17 deletions frontends/tf-frontend/external/patches/tensorflow/tf_build.patch
Original file line number Diff line number Diff line change
Expand Up @@ -68,23 +68,6 @@ index 02c9f486e8e..ff76e59f788 100644
deps = [
"@llvm-project//mlir:AffineDialect",
"@llvm-project//mlir:QuantOps",
diff --git a/tensorflow/compiler/mlir/lite/BUILD b/tensorflow/compiler/mlir/lite/BUILD
index b6d406f040a..e657dcdabaf 100644
--- a/tensorflow/compiler/mlir/lite/BUILD
+++ b/tensorflow/compiler/mlir/lite/BUILD
@@ -8,11 +8,7 @@ package(
# copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
default_visibility = [
# TODO(jpienaar): Make the visibility more restrictive.
- ":friends",
- "//learning/brain/mobile/programmability:__subpackages__",
- "//tensorflow/lite/experimental/tf_runtime:__subpackages__",
- "//tensorflow/lite/testing:__subpackages__",
- "//third_party/odml/infra/genai/conversion/per_layer:__subpackages__",
+ "//visibility:public",
],
licenses = ["notice"],
)
diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD
index cb24a361353..b2771263628 100644
--- a/tensorflow/compiler/mlir/tensorflow/BUILD
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
diff --git a/tensorflow/compiler/mlir/lite/transforms/dilated_conv.h b/tensorflow/compiler/mlir/lite/transforms/dilated_conv.h
index 51068fcf4ac..01e59529f84 100644
index fe8bb7d2ca1..38caa663968 100644
--- a/tensorflow/compiler/mlir/lite/transforms/dilated_conv.h
+++ b/tensorflow/compiler/mlir/lite/transforms/dilated_conv.h
@@ -112,9 +112,9 @@ LogicalResult ConvertTFDilatedConvOp<Conv2dOpTy>::matchAndRewrite(
@@ -113,9 +113,9 @@ LogicalResult ConvertTFDilatedConvOp<Conv2dOpTy>::matchAndRewrite(
// Allow dynamic width and height dimensions only.
auto result_ty = op.getResult().getType().template cast<TensorType>();
auto result_ty = mlir::cast<TensorType>(op.getResult().getType());
if (!result_ty.hasRank() || result_ty.getRank() != 4 ||
- result_ty.isDynamicDim(0) || result_ty.isDynamicDim(3)) {
+ result_ty.isDynamicDim(3)) {
Expand Down
2 changes: 1 addition & 1 deletion frontends/tf-frontend/scripts/build_and_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ PROJ_DIR="$CUR_DIR/.."
bash $PROJ_DIR/scripts/prepare.sh

pushd $PROJ_DIR
python3 -m pip install /data00/mhlo_libraries/mhlo_tools-1.3.0-cp39-cp39-linux_x86_64.whl
python3 -m pip install /data00/mhlo_libraries/mhlo_tools-1.4.0-cp39-cp39-linux_x86_64.whl
$PROJ_DIR/bazel --output_user_root=./build build //tools:tf-frontend //tools:tf-ext-opt
$PROJ_DIR/bazel --output_user_root=./build test --test_output=errors //tf_mlir_ext/tests:all --java_runtime_version=remotejdk_11
$PROJ_DIR/bazel --output_user_root=./build test --test_output=errors //tf_mlir_ext/numerical:all --java_runtime_version=remotejdk_11
Expand Down

0 comments on commit 844e17f

Please sign in to comment.