Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[compiler] add interReturnType and shape constrain function for byteir.one_hot #463

Merged
merged 8 commits into from
Oct 16, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ void registerRealDynamicSliceInferReturnTypeComponents();
void registerReduceInferReturnTypeComponents();
void registerSoftmaxInferReturnTypeComponents();
void registerAddNInferReturnTypeComponents();
void registerOneHotInferReturnTypeComponents();
void registerTorchIndexSelectInferReturnTypeComponents();
void registerGeLUInferReturnTypeComponents();
void registerLayerNormInferReturnTypeComponents();
Expand All @@ -49,6 +50,7 @@ inline void registerAllMhloInferReturnTypeComponents() {
registerReduceInferReturnTypeComponents();
registerSoftmaxInferReturnTypeComponents();
registerAddNInferReturnTypeComponents();
registerOneHotInferReturnTypeComponents();
registerTorchIndexSelectInferReturnTypeComponents();
registerGeLUInferReturnTypeComponents();
registerLayerNormInferReturnTypeComponents();
Expand Down Expand Up @@ -102,6 +104,7 @@ void registerDynamicPartitionShapeConstraints();
void registerDynamicReshapeShapeConstraints();
void registerEinsumShapeConstraints();
void registerReshapeShapeConstraints();
void registerOneHotShapeConstraints();

inline void registerAllMhloShapeConstraints() {
registerConcatenateShapeConstraints();
Expand All @@ -110,6 +113,7 @@ inline void registerAllMhloShapeConstraints() {
registerDynamicReshapeShapeConstraints();
registerEinsumShapeConstraints();
registerReshapeShapeConstraints();
registerOneHotShapeConstraints();
}

} // namespace mlir
Expand Down
1 change: 0 additions & 1 deletion compiler/include/byteir/Dialect/mhlo/Util/CustomCallUtil.h
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,6 @@ constexpr llvm::StringRef getStridedSliceName() {
constexpr llvm::StringRef getBatchMatMulName() {
return TF_NAME_PREFIX "BatchMatMul";
}

} // namespace mlir

#undef TF_NAME_PREFIX
Expand Down
1 change: 1 addition & 0 deletions compiler/lib/Dialect/mhlo/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ add_byteir_dialect_library(ByteIRMhloDynamicShapeOpRegister
DynamicShapeOpRegister/ReshapeLike.cpp
DynamicShapeOpRegister/Softmax.cpp
DynamicShapeOpRegister/AddN.cpp
DynamicShapeOpRegister/OneHot.cpp
DynamicShapeOpRegister/TorchIndexSelect.cpp
DynamicShapeOpRegister/ScatterNd.cpp
DynamicShapeOpRegister/StridedSlice.cpp
Expand Down
116 changes: 116 additions & 0 deletions compiler/lib/Dialect/mhlo/DynamicShapeOpRegister/OneHot.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
//===- OneHot.cpp ---------------------------------------------*--- C++ -*-===//
//
// Copyright 2024 ByteDance Ltd. and/or its affiliates. All rights reserved.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
//===----------------------------------------------------------------------===//

#include "byteir/Dialect/Shape/IR/ShapeExtOps.h"
#include "byteir/Dialect/mhlo/DynamicShapeOpRegister/Register.h"
#include "byteir/Dialect/mhlo/Util/CustomCallUtil.h"
#include "byteir/Dialect/mhlo/Util/ShapeInferUtil.h"
#include "mhlo/IR/hlo_ops.h"
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "llvm/Support/Debug.h"

#define DEBUG_TYPE "dynamic-shape-op-register"

using namespace mlir;

LogicalResult InsertOneHotShapeConstraints(Operation *op, OpBuilder &builder) {
builder.setInsertionPointAfter(op);
auto operand = op->getOperand(0);
auto result = op->getResult(0);
auto operandType = dyn_cast<RankedTensorType>(operand.getType());
auto resType = dyn_cast<RankedTensorType>(result.getType());
if (!operandType || !resType)
return failure();
auto inputShape = operandType.getShape();
auto outputShape = resType.getShape();
if (inputShape.size() == 0)
return failure();

DictionaryAttr attr = op->getAttrDictionary();
int64_t axis = attr.getAs<DictionaryAttr>(getCustomCallAttrName())
.getAs<IntegerAttr>("axis")
.getInt();
axis = (axis >= 0) ? axis : (axis + outputShape.size());
for (int64_t inputDim = 0, outputDim = 0; inputDim < inputShape.size();
++inputDim, outputDim++) {
if (inputDim == axis) {
outputDim++;
}
Value oprSize =
builder.create<tensor::DimOp>(op->getLoc(), operand, inputDim);
Value resSize =
builder.create<tensor::DimOp>(op->getLoc(), result, outputDim);
builder.create<shape_ext::MeetOp>(op->getLoc(), oprSize, resSize);
}

return success();
}

void mlir::registerOneHotShapeConstraints() {
static InsertShapeConstraintRegistration shapeRegister(
getOneHotName(), InsertOneHotShapeConstraints);
}

void mlir::registerOneHotInferReturnTypeComponents() {
static InferReturnTypeComponentsRegistration shapeRegister(
getOneHotName(),
[](MLIRContext *context, std::optional<Location> loc,
ValueShapeRange operands, DictionaryAttr attr,
OpaqueProperties properties, RegionRange,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnTypes) {
ShapedType dataType = dyn_cast<ShapedType>(operands[0].getType());
if (!dataType) {
LLVM_DEBUG(llvm::dbgs() << loc << ": get dataType failed\n");
return failure();
}
int64_t axis = attr.getAs<DictionaryAttr>(getCustomCallAttrName())
.getAs<IntegerAttr>("axis")
.getInt();
int64_t depth = attr.getAs<DictionaryAttr>(getCustomCallAttrName())
.getAs<IntegerAttr>("depth")
.getInt();
Attribute onValue = attr.getAs<DictionaryAttr>(getCustomCallAttrName())
.getAs<Attribute>("on_value");
Type onValueType;
if (dyn_cast<IntegerAttr>(onValue)) {
onValueType = dyn_cast<IntegerAttr>(onValue).getType();
} else if (dyn_cast<FloatAttr>(onValue)) {
onValueType = dyn_cast<FloatAttr>(onValue).getType();
} else {
LLVM_DEBUG(llvm::dbgs()
<< loc << ": get output element type failed\n");
return failure();
}

auto dataShape = dataType.getShape();
llvm::SmallVector<int64_t> outShape;
for (int64_t i = 0; i < dataShape.size(); ++i) {
if (axis == i) {
outShape.push_back(depth);
}
outShape.push_back(dataShape[i]);
}
if (-1 == axis || axis >= dataShape.size()) {
outShape.push_back(depth);
}
inferredReturnTypes.emplace_back(outShape, onValueType);
return success();
});
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

#include "byteir/Dialect/Shape/IR/ShapeExtOps.h"
#include "byteir/Dialect/mhlo/DynamicShapeOpRegister/Register.h"
#include "byteir/Dialect/mhlo/Util/CustomCallUtil.h"
#include "mhlo/IR/hlo_ops.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
Expand Down Expand Up @@ -98,7 +99,7 @@ LogicalResult InsertReshapeShapeConstraints(Operation *op, OpBuilder &builder) {
builder.create<shape_ext::MeetOp>(op->getLoc(), oprSize, resSize);

return success();
};
}

void mlir::registerReshapeShapeConstraints() {
static InsertShapeConstraintRegistration shapeRegister(
Expand Down
2 changes: 1 addition & 1 deletion frontends/tf-frontend/tf_mlir_ext/tests/fuse_tf_ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ func.func @replace_where_3D(%arg0: tensor<256x1xi64>, %arg1: tensor<256x24x8xf16
}
// CHECK-LABEL: func.func @replace_where_3D(%arg0: tensor<256x1xi64>, %arg1: tensor<256x24x8xf16>) -> tensor<?x8xf16> {
// CHECK-DAG: %[[CST:.*]] = "tf.Const"() <{value = dense<1> : tensor<1xi64>}> : () -> tensor<1xi64>
// CHECK-DAG: %[[CST_0:.*]] = "tf.Const"() <{value = dense<[-9223372036854775808, 24, 1]> : tensor<3xi64>}> : () -> tensor<3xi64>
// CHECK-DAG: %[[CST_0:.*]] = "tf.Const"() <{value = dense<[-1, 24, 1]> : tensor<3xi64>}> : () -> tensor<3xi64>
// CHECK-DAG: %[[CST_1:.*]] = "tf.Const"() <{value = dense<0> : tensor<1xi64>}> : () -> tensor<1xi64>
// CHECK-DAG: %[[CST_2:.*]] = "tf.Const"() <{value = dense<28800> : tensor<i64>}> : () -> tensor<i64>
// CHECK-DAG: %[[CST_3:.*]] = "tf.Const"() <{value = dense<86400> : tensor<i64>}> : () -> tensor<i64>
Expand Down
17 changes: 17 additions & 0 deletions frontends/tf-frontend/tf_mlir_ext/tests/mhlo_legalize_tf_ext.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -62,3 +62,20 @@ func.func @tile_right_dynamic(%arg0: tensor<1x64xf16>, %arg1: tensor<2xi32>) ->
// CHECK-LABEL: %3 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %from_elements) <{broadcast_dimensions = dense<[1, 3]> : tensor<2xi64>}> : (tensor<1x64xf16>, tensor<4xindex>) -> tensor<?x1x?x64xf16>
// CHECK-LABEL: %4 = mhlo.dynamic_reshape %3, %from_elements_1 : (tensor<?x1x?x64xf16>, tensor<2xindex>) -> tensor<1x?xf16>
// CHECK-LABEL: return %4 : tensor<1x?xf16>

func.func @reshape_case0(%arg0: tensor<?x24xf16>) -> tensor<?x24x1xf16> {
%cst = "tf.Const"() <{value = dense<[-1, 24, 1]> : tensor<3xi64>}> : () -> tensor<3xi64>
%0 = "tf.Reshape"(%arg0, %cst) : (tensor<?x24xf16>, tensor<3xi64>) -> tensor<?x24x1xf16>
return %0 : tensor<?x24x1xf16>
}
// CHECK-LABEL: func.func @reshape_case0
// CHECK-DGA: %c1 = shape.const_size 1
// CHECK-DGA: %c24 = shape.const_size 24
// CHECK-LABEL: %0 = shape.shape_of %arg0 : tensor<?x24xf16> -> tensor<2xindex>
// CHECK-LABEL: %1 = shape.num_elements %0 : tensor<2xindex> -> index
// CHECK-LABEL: %2 = shape.index_to_size %1
// CHECK-LABEL: %3 = shape.div %2, %c24 : !shape.size, !shape.size -> !shape.size
// CHECK-LABEL: %4 = shape.from_extents %3, %c24, %c1 : !shape.size, !shape.size, !shape.size
// CHECK-LABEL: %5 = shape.to_extent_tensor %4 : !shape.shape -> tensor<3xindex>
// CHECK-LABEL: %6 = mhlo.dynamic_reshape %arg0, %5 : (tensor<?x24xf16>, tensor<3xindex>) -> tensor<?x24x1xf16>
// CHECK-LABEL: return %6 : tensor<?x24x1xf16>
6 changes: 5 additions & 1 deletion frontends/tf-frontend/tf_mlir_ext/transforms/fuse_tf_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,11 @@ Value replaceWhereStatic(PatternRewriter &rewriter, Location loc, Value input,
}
auto shapeType =
RankedTensorType::get({inputShape.size()}, rewriter.getIntegerType(64));
auto shapeAttr = DenseIntElementsAttr::get(shapeType, oneHotShape);
SmallVector<int64_t> shapeVec;
for (auto s : oneHotShape) {
shapeVec.push_back((s < 0) ? -1 : s);
}
auto shapeAttr = DenseIntElementsAttr::get(shapeType, shapeVec);
Value shape = rewriter.create<TF::ConstOp>(loc, shapeAttr);
oneHotOutputType = oneHotOutputType.clone(oneHotShape);
oneHotOutput = rewriter.create<TF::ReshapeOp>(loc, oneHotOutputType,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/DeviceMappingInterface.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
Expand Down Expand Up @@ -626,6 +627,85 @@ class ConvertTileOp : public OpRewritePattern<TF::TileOp> {
}
};

class ConvertReshapeOp : public OpRewritePattern<TF::ReshapeOp> {
public:
using OpRewritePattern<TF::ReshapeOp>::OpRewritePattern;

LogicalResult matchAndRewrite(TF::ReshapeOp reshapeOp,
PatternRewriter &rewriter) const override {
auto loc = reshapeOp->getLoc();
auto input = reshapeOp.getTensor();
auto shape = reshapeOp.getShape();
auto output = reshapeOp.getOutput();
auto inputType = input.getType().dyn_cast<RankedTensorType>();
auto outputType = output.getType().dyn_cast<RankedTensorType>();
if (!inputType || !outputType) {
return failure();
}
if (inputType.hasStaticShape() || outputType.hasStaticShape()) {
return failure();
}
DenseIntElementsAttr shapeAttr;

if (!matchPattern(shape, m_Constant(&shapeAttr))) {
return failure();
}
SmallVector<int64_t> shapeVec;
shapeVec.reserve(shapeAttr.getNumElements());
for (auto intAttr : shapeAttr.getValues<IntegerAttr>()) {
shapeVec.push_back(intAttr.getInt());
}

int64_t negativeNum = 0;
if (llvm::all_of(shapeVec, [&negativeNum](int64_t s) {
if (s < 0) {
negativeNum++;
}
return s >= 0;
})) {
return failure();
}
if (negativeNum != 1) {
return rewriter.notifyMatchFailure(
reshapeOp, "const shape operand has multiple dynamic dims");
}
int64_t staticNum = 1;
for (auto s : shapeVec) {
if (s > 0) {
staticNum *= s;
}
}

Value shapeOf = rewriter.create<shape::ShapeOfOp>(loc, input);
reshapeOp.dump();
Value numberElements = rewriter.create<shape::NumElementsOp>(loc, shapeOf);
numberElements = rewriter.create<shape::IndexToSizeOp>(loc, numberElements);
Value staticElementsNum =
rewriter.create<shape::ConstSizeOp>(loc, staticNum);
Value dynamicSize =
rewriter.create<shape::DivOp>(loc, numberElements, staticElementsNum);
SmallVector<Value> newShapeVec;
newShapeVec.reserve(shapeAttr.getNumElements());
for (auto s : shapeVec) {
Value dimSize;
if (s > 0) {
dimSize = rewriter.create<shape::ConstSizeOp>(loc, s);
} else {
dimSize = dynamicSize;
}
newShapeVec.push_back(dimSize);
}
Value newShape = rewriter.create<shape::FromExtentsOp>(loc, newShapeVec);
auto newShapeType = RankedTensorType::get(
{static_cast<int64_t>(newShapeVec.size())}, rewriter.getIndexType());
newShape =
rewriter.create<shape::ToExtentTensorOp>(loc, newShapeType, newShape);
rewriter.replaceOpWithNewOp<mhlo::DynamicReshapeOp>(reshapeOp, outputType,
input, newShape);
return success();
}
};

class ConvertScfIfOp : public OpRewritePattern<scf::IfOp> {
public:
using OpRewritePattern<scf::IfOp>::OpRewritePattern;
Expand Down Expand Up @@ -675,6 +755,7 @@ void PopulateMhloLegalizeTfExtPatterns(MLIRContext *context,
patterns->add(std::make_unique<ConvertBatchMatMulV2Op>(context));
patterns->add(std::make_unique<ConvertRoundOp>(context));
patterns->add(std::make_unique<ConvertTileOp>(context));
patterns->add(std::make_unique<ConvertReshapeOp>(context));
// patterns->add(std::make_unique<ConvertScfIfOp>(context));
}

Expand Down
1 change: 1 addition & 0 deletions frontends/tf-frontend/tf_mlir_ext/transforms/passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def MhloLegalizeTfExt: Pass<"mhlo-legalize-tf-ext", "func::FuncOp"> {
"mlir::TF::TensorFlowDialect",
"mlir::mhlo::MhloDialect",
"mlir::tensor::TensorDialect",
"mlir::shape::ShapeDialect",
"mlir::scf::SCFDialect",
];
}
Expand Down
5 changes: 5 additions & 0 deletions frontends/tf-frontend/tf_mlir_ext/transforms/passes_detail.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,11 @@ namespace tensor {
class TensorDialect;
} // namespace tensor

namespace shape {
class ShapeDialect;
;
} // namespace shape

namespace scf {
class SCFDialect;
} // namespace scf
Expand Down
Loading