Skip to content

Commit

Permalink
[sync] migrate tf-frontend/onnx-frontend to stablehlo, update llvm to…
Browse files Browse the repository at this point in the history
… b2cdf3c, add bounded shape inference and other updates (#101)

Sync internal ByteIR from commit id cdd3a5b to 840e671. Detailed changes
are as follow:
1. migrate tf-frontend/onnx-frontend to stablehlo rather than mhlo
2. update llvm to b2cdf3c
3. add bounded shape inference for tf.StridedSlice.
4. register InferReturnTypeComponents for byteir ops, e.g.,
byteir.layer_norm and byteir.addn
5. update version number of compiler and runtime.

---------

Co-authored-by: yan.xu0210 <[email protected]>
  • Loading branch information
Vremold and Connor-XY authored Jan 23, 2024
1 parent 117a920 commit c452f75
Show file tree
Hide file tree
Showing 177 changed files with 1,663 additions and 4,786 deletions.
3 changes: 1 addition & 2 deletions compiler/cmake/mhlo.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@ add_subdirectory(${BYTEIR_SRC_DIR}/../external/mlir-hlo ${CMAKE_CURRENT_BINARY_D

# FIXME: remove this when upstream fix
target_link_libraries(MhloDialect PUBLIC StablehloTypeInference StablehloAssemblyFormat)
target_link_libraries(GmlStPasses PUBLIC MLIRGmlStUtils)
target_link_libraries(MLIRBufferTransforms PUBLIC DeallocationDialect DeallocationPasses)
target_link_libraries(MLIRBufferTransforms PUBLIC DeallocationPasses)

include_directories(${BYTEIR_SRC_DIR}/../external/mlir-hlo)
include_directories(${CMAKE_CURRENT_BINARY_DIR}/mlir-hlo)
Expand Down
5 changes: 3 additions & 2 deletions compiler/include/byteir/Analysis/ShapeAnalysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,10 @@ struct ValueTypeModificatoinRAII {

using ShapeLattice = dataflow::Lattice<shape_analysis::ValueKnowledge>;

class ShapeAnalysis : public dataflow::SparseDataFlowAnalysis<ShapeLattice> {
class ShapeAnalysis
: public dataflow::SparseForwardDataFlowAnalysis<ShapeLattice> {
public:
using SparseDataFlowAnalysis::SparseDataFlowAnalysis;
using SparseForwardDataFlowAnalysis::SparseForwardDataFlowAnalysis;

void visitOperation(Operation *op, ArrayRef<const ShapeLattice *> operands,
ArrayRef<ShapeLattice *> results) override;
Expand Down
4 changes: 3 additions & 1 deletion compiler/include/byteir/Dialect/Byre/ByreDialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,11 @@
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/FunctionInterfaces.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/Interfaces/CopyOpInterface.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Interfaces/ViewLikeInterface.h"

Expand All @@ -47,6 +47,8 @@ class AsyncTokenType
public:
// Used for generic hooks in TypeBase.
using Base::Base;

static constexpr StringLiteral name = "byre.async_token";
};

// Adds a `byre.async.token` to the front of the argument list.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ struct Version {
std::to_string(patch);
}

std::string getBytecodeProducerString() const;

uint32_t getBytecodeVersion() const;

static ArrayRef<Version> getSupportedVersions();
Expand Down
3 changes: 2 additions & 1 deletion compiler/include/byteir/Dialect/Cat/IR/CatDialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,14 @@
#ifndef BYTEIR_DIALECT_CAT_IR_CATDIALECT_H
#define BYTEIR_DIALECT_CAT_IR_CATDIALECT_H

#include "mlir/Bytecode/BytecodeOpInterface.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/FunctionInterfaces.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"

#include "byteir/Dialect/Cat/IR/CatOpInterfaces.h.inc"
Expand Down
1 change: 1 addition & 0 deletions compiler/include/byteir/Dialect/Ccl/IR/CclOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#ifndef BYTEIR_DIALECT_CCL_CCLOPS_H
#define BYTEIR_DIALECT_CCL_CCLOPS_H

#include "mlir/Bytecode/BytecodeOpInterface.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/OpImplementation.h"
Expand Down
1 change: 1 addition & 0 deletions compiler/include/byteir/Dialect/Lace/LaceDialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#ifndef BYTEIR_DIALECT_LACE_LACEDIALECT_H
#define BYTEIR_DIALECT_LACE_LACEDIALECT_H

#include "mlir/Bytecode/BytecodeOpInterface.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Interfaces/ViewLikeInterface.h"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
#include "mlir/Support/LLVM.h"

namespace mlir {

namespace linalg_ext {
class LinalgExtOp;

Expand Down
32 changes: 16 additions & 16 deletions compiler/include/byteir/Dialect/Linalg/IR/LinalgExtInterfaces.td
Original file line number Diff line number Diff line change
Expand Up @@ -99,13 +99,13 @@ def LinalgExtInterface : OpInterface<"LinalgExtOp"> {
/*desc=*/[{
Return the input operands.
}],
/*retTy=*/"OpOperandVector",
/*retTy=*/"SmallVector<OpOperand *>",
/*methodName=*/"getInputOperands",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
int64_t numInputs = getNumInputs();
OpOperandVector result;
SmallVector<OpOperand *> result;
result.reserve(numInputs);
llvm::transform(
this->getOperation()->getOpOperands().take_front(numInputs),
Expand All @@ -131,12 +131,12 @@ def LinalgExtInterface : OpInterface<"LinalgExtOp"> {
/*desc=*/[{
Return the subset of input operands that are of buffer type.
}],
/*retTy=*/"OpOperandVector",
/*retTy=*/"SmallVector<OpOperand *>",
/*methodName=*/"getInputBufferOperands",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
OpOperandVector result;
SmallVector<OpOperand *> result;
result.reserve(getNumInputs());
llvm::copy_if(getInputOperands(),
std::back_inserter(result),
Expand All @@ -150,12 +150,12 @@ def LinalgExtInterface : OpInterface<"LinalgExtOp"> {
/*desc=*/[{
Return the subset of input operands that are of tensor type.
}],
/*retTy=*/"OpOperandVector",
/*retTy=*/"SmallVector<OpOperand *>",
/*methodName=*/"getInputTensorOperands",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
OpOperandVector result;
SmallVector<OpOperand *> result;
result.reserve(getNumInputs());
llvm::copy_if(getInputOperands(),
std::back_inserter(result),
Expand All @@ -172,13 +172,13 @@ def LinalgExtInterface : OpInterface<"LinalgExtOp"> {
/*desc=*/[{
Return the output operands.
}],
/*retTy=*/"OpOperandVector",
/*retTy=*/"SmallVector<OpOperand *>",
/*methodName=*/"getOutputOperands",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
int64_t numOutputs = getNumOutputs();
OpOperandVector result;
SmallVector<OpOperand *> result;
result.reserve(numOutputs);
llvm::transform(
this->getOperation()->getOpOperands()
Expand Down Expand Up @@ -219,12 +219,12 @@ def LinalgExtInterface : OpInterface<"LinalgExtOp"> {
/*desc=*/[{
Return the subset of output operands that are of buffer type.
}],
/*retTy=*/"OpOperandVector",
/*retTy=*/"SmallVector<OpOperand *>",
/*methodName=*/"getOutputBufferOperands",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
OpOperandVector result;
SmallVector<OpOperand *> result;
result.reserve(getNumOutputs());
llvm::copy_if(getOutputOperands(),
std::back_inserter(result),
Expand All @@ -238,12 +238,12 @@ def LinalgExtInterface : OpInterface<"LinalgExtOp"> {
/*desc=*/[{
Return the subset of output operands that are of tensor type.
}],
/*retTy=*/"OpOperandVector",
/*retTy=*/"SmallVector<OpOperand *>",
/*methodName=*/"getOutputTensorOperands",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
OpOperandVector result;
SmallVector<OpOperand *> result;
result.reserve(getNumOutputs());
llvm::copy_if(getOutputOperands(),
std::back_inserter(result),
Expand Down Expand Up @@ -298,13 +298,13 @@ def LinalgExtInterface : OpInterface<"LinalgExtOp"> {
/*desc=*/[{
Return the range over input and output operands.
}],
/*retTy=*/"OpOperandVector",
/*retTy=*/"SmallVector<OpOperand *>",
/*methodName=*/"getInputAndOutputOperands",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
int64_t numInputsAndOutputs = getNumInputsAndOutputs();
OpOperandVector result;
SmallVector<OpOperand *> result;
result.reserve(numInputsAndOutputs);
llvm::transform(
this->getOperation()->getOpOperands()
Expand Down Expand Up @@ -397,7 +397,7 @@ def LinalgExtInterface : OpInterface<"LinalgExtOp"> {
/*desc=*/[{
Return operands that are neither inputs nor outputs.
}],
/*retTy=*/"OpOperandVector",
/*retTy=*/"SmallVector<OpOperand *>",
/*methodName=*/"getNonInputOrOutputOperands",
/*args=*/(ins),
/*methodBody=*/"",
Expand All @@ -407,7 +407,7 @@ def LinalgExtInterface : OpInterface<"LinalgExtOp"> {
assert(numInputsAndOutputs <= numOperands);
if (numInputsAndOutputs == numOperands)
return {};
OpOperandVector result;
SmallVector<OpOperand *> result;
result.reserve(numOperands - numInputsAndOutputs);
llvm::transform(
this->getOperation()->getOpOperands()
Expand Down
1 change: 1 addition & 0 deletions compiler/include/byteir/Dialect/Linalg/IR/LinalgExtOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#define BYTEIR_DIALECT_LINALG_IR_LINALGEXTOPS_H

#include "byteir/Dialect/Linalg/IR/LinalgExtInterfaces.h"
#include "mlir/Bytecode/BytecodeOpInterface.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinTypes.h"
Expand Down
3 changes: 3 additions & 0 deletions compiler/include/byteir/Dialect/Linalg/IR/LinalgExtOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ class LinalgExt_Op<string mnemonic, list<Trait> traits = []> :
outputsIndexAndLength.first + outputsIndexAndLength.second);
}

MutableOperandRange getDpsInitsMutable() { return getOutputsMutable(); }
}];
}

Expand Down Expand Up @@ -738,6 +739,8 @@ def LinalgExt_BatchMatmulOp : LinalgExtStructuredBase_Op<"batch_matmul",
return {getNumOperands - 1, getNumOperands};
}

MutableOperandRange getDpsInitsMutable() { return getInitMutable(); }

// Additional functions
int64_t getFullRank() {
return getInit().getType().cast<ShapedType>().getRank() + 1;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,12 @@ struct LinalgExtBufferizableOpInterfaceImpl {
const bufferization::AnalysisState &state) const;

bufferization::AliasingOpOperandList
getAliasingOpOperands(Operation *op, OpResult opResult,
getAliasingOpOperands(Operation *op, Value value,
const bufferization::AnalysisState &) const;

bufferization::AliasingOpResultList
getAliasingOpResults(Operation *op, OpOperand &opOperand,
const bufferization::AnalysisState &) const;
bufferization::AliasingValueList
getAliasingValues(Operation *op, OpOperand &opOperand,
const bufferization::AnalysisState &) const;

bufferization::BufferRelation
bufferRelation(Operation *op, OpResult opResult,
Expand All @@ -60,7 +60,7 @@ struct LinalgExtBufferizableOpInterface
using LinalgExtBufferizableOpInterfaceImpl::bufferizesToMemoryWrite;
using LinalgExtBufferizableOpInterfaceImpl::bufferRelation;
using LinalgExtBufferizableOpInterfaceImpl::getAliasingOpOperands;
using LinalgExtBufferizableOpInterfaceImpl::getAliasingOpResults;
using LinalgExtBufferizableOpInterfaceImpl::getAliasingValues;
};

void registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry);
Expand Down
26 changes: 24 additions & 2 deletions compiler/include/byteir/Dialect/Linalg/Util/Util.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,7 @@ ParseResult parseCommonStructuredOpParts(OpAsmParser &parser,
void getGenericEffectsImpl(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
&effects,
ValueRange results, const OpOperandVector &inputOperands,
const OpOperandVector &outputOperands);
ValueRange results, ValueRange inputs, ValueRange outputs);

void calculateTileOffsetsAndSizes(
RewriterBase &b, Location loc, ValueRange inductionVars,
Expand All @@ -82,6 +81,29 @@ void calculateTileOffsetsAndSizes(
SmallVector<OpFoldResult> &tiledOffsets,
SmallVector<OpFoldResult> &tiledSizes);

/// Convert a list of ops of type `SrcOpTy` to list of `Operation *`.
template <typename SrcOpTy>
static SmallVector<Operation *> getAsOperations(ArrayRef<SrcOpTy> ops) {
return llvm::to_vector(
llvm::map_range(ops, [](auto op) -> Operation * { return op; }));
}
template <typename SrcOpTy>
static SmallVector<Operation *>
getAsOperations(const SmallVector<SrcOpTy> &ops) {
return getAsOperations(ArrayRef<SrcOpTy>(ops));
}

/// Convert a list of `Operation *` to a list of `DstOpTy.
template <typename DstOpTy>
static SmallVector<DstOpTy> castToTypedOperations(ArrayRef<Operation *> ops) {
return llvm::to_vector(
llvm::map_range(ops, [](Operation *op) { return cast<DstOpTy>(op); }));
}
template <typename DstOpTy>
static SmallVector<DstOpTy>
castToTypedOperations(const SmallVector<Operation *> &ops) {
return castToTypedOperations<DstOpTy>(ArrayRef<Operation *>(ops));
}
} // namespace mlir

#endif // BYTEIR_DIALECT_LINALG_UTIL_UTIL_H
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#ifndef BYTEIR_DIALECT_TRANSFORM_IR_TRANSFORMEXTOPS_H
#define BYTEIR_DIALECT_TRANSFORM_IR_TRANSFORMEXTOPS_H

#include "mlir/Bytecode/BytecodeOpInterface.h"
#include "mlir/Dialect/PDL/IR/PDLTypes.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,10 @@ void registerDynamicReshapeInferReturnTypeComponents();
void registerRealDynamicSliceInferReturnTypeComponents();
void registerReduceInferReturnTypeComponents();
void registerSoftmaxInferReturnTypeComponents();
void registerAddNInferReturnTypeComponents();
void registerTorchIndexSelectInferReturnTypeComponents();
void registerGeLUInferReturnTypeComponents();
void registerLayerNormInferReturnTypeComponents();

inline void registerAllMhloInferReturnTypeComponents() {
registerConvolutionInferReturnTypeComponents();
Expand All @@ -45,8 +47,10 @@ inline void registerAllMhloInferReturnTypeComponents() {
registerRealDynamicSliceInferReturnTypeComponents();
registerReduceInferReturnTypeComponents();
registerSoftmaxInferReturnTypeComponents();
registerAddNInferReturnTypeComponents();
registerTorchIndexSelectInferReturnTypeComponents();
registerGeLUInferReturnTypeComponents();
registerLayerNormInferReturnTypeComponents();
}

//===----------------------------------------------------------------------===//
Expand Down
6 changes: 3 additions & 3 deletions compiler/include/byteir/Utils/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@

#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/FunctionInterfaces.h"
#include "mlir/IR/Visitors.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
Expand Down Expand Up @@ -225,7 +225,7 @@ SmallVector<OpFoldResult> canonicalizeOpFoldResult(ArrayRef<OpFoldResult> ofrs,
bool enableFold = false);

// Return true if block contains single op
template <typename Op> bool isBlockSingleOp(Block *block) {
template <typename... Ops> bool isBlockSingleOp(Block *block) {
if (block == nullptr)
return false;

Expand All @@ -234,7 +234,7 @@ template <typename Op> bool isBlockSingleOp(Block *block) {
return false;

auto computeOp = retOp->getOperand(0).getDefiningOp();
if (isa_and_nonnull<Op>(computeOp)) {
if (computeOp && (isa<Ops>(computeOp) || ...)) {
return (computeOp->getOperand(0) == block->getArgument(0) &&
computeOp->getOperand(1) == block->getArgument(1)) ||
(computeOp->getOperand(0) == block->getArgument(1) &&
Expand Down
17 changes: 15 additions & 2 deletions compiler/lib/Analysis/ShapeAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -363,8 +363,21 @@ void ShapeValueAnalysis::visitOperation(
LLVM_DEBUG(llvm::dbgs() << "shape value analysis on " << *op << "\n");
TypeSwitch<Operation *>(op)
.Case<shape::ShapeOfOp>([&](Operation *op) {
auto shapeLattice = getOrCreate<ShapeLattice>(op->getOperand(0));
visitOperation(op, operands, {shapeLattice}, results);
auto shapeLattice = getOrCreateFor<ShapeLattice>(op, op->getOperand(0));
auto shapeKnowledge = shapeLattice->getValue();
if (!shapeKnowledge.isUninitialized() && shapeKnowledge) {
if (auto shapedType =
llvm::dyn_cast<ShapedType>(shapeKnowledge.getType())) {
if (shapedType.hasStaticShape()) {
ShapeValueLattice *result = results[0];
Builder builder(op->getContext());
auto staticShape =
builder.getIndexTensorAttr(shapedType.getShape());
propagateIfChanged(result, result->join(dataflow::ConstantValue(
staticShape, op->getDialect())));
}
}
}
})
.Case<tensor::DimOp>([&](Operation *op) {
SmallVector<const ShapeLattice *> shapeLattices(op->getNumOperands(),
Expand Down
Loading

0 comments on commit c452f75

Please sign in to comment.