From 7a9d18c9e46b9ea02196899defa491e38d97954a Mon Sep 17 00:00:00 2001 From: zhengxuegui Date: Sun, 7 Apr 2024 10:57:44 +0800 Subject: [PATCH] [compiler] support shape reification for callOp --- compiler/include/byteir/Dialect/mhlo/Passes.h | 2 +- .../include/byteir/Dialect/mhlo/Passes.td | 19 -- compiler/include/byteir/Transforms/Passes.h | 1 + compiler/include/byteir/Transforms/Passes.td | 19 ++ .../mhlo => }/Transforms/ShapeReification.h | 2 +- compiler/lib/Analysis/SymbolicShape.cpp | 2 +- compiler/lib/Dialect/mhlo/CMakeLists.txt | 1 - .../lib/Dialect/mhlo/Util/ShapeInferUtil.cpp | 178 ++++++++++++++++++ compiler/lib/Pipelines/ByreTensorOpt.cpp | 1 + compiler/lib/Transforms/CMakeLists.txt | 1 + compiler/lib/Transforms/PassDetail.h | 8 + .../mhlo => }/Transforms/ShapeReification.cpp | 4 +- .../FuncToByre/func_to_byre_tensor.mlir | 17 ++ .../test/Transforms/shapeReification.mlir | 57 +++++- 14 files changed, 286 insertions(+), 26 deletions(-) rename compiler/include/byteir/{Dialect/mhlo => }/Transforms/ShapeReification.h (94%) rename compiler/lib/{Dialect/mhlo => }/Transforms/ShapeReification.cpp (97%) diff --git a/compiler/include/byteir/Dialect/mhlo/Passes.h b/compiler/include/byteir/Dialect/mhlo/Passes.h index 351b071df..e1505f367 100644 --- a/compiler/include/byteir/Dialect/mhlo/Passes.h +++ b/compiler/include/byteir/Dialect/mhlo/Passes.h @@ -36,7 +36,7 @@ #include "byteir/Dialect/mhlo/Transforms/LayoutTransformation.h" #include "byteir/Dialect/mhlo/Transforms/MatmulLayoutTransform.h" #include "byteir/Dialect/mhlo/Transforms/RewriteWithConstraint.h" -#include "byteir/Dialect/mhlo/Transforms/ShapeReification.h" +#include "byteir/Transforms/ShapeReification.h" #include "byteir/Dialect/mhlo/Transforms/StaticShapeInference.h" #include "byteir/Dialect/mhlo/Transforms/UnfuseBatchNorm.h" diff --git a/compiler/include/byteir/Dialect/mhlo/Passes.td b/compiler/include/byteir/Dialect/mhlo/Passes.td index 7fe03f8f8..58f35033d 100644 --- a/compiler/include/byteir/Dialect/mhlo/Passes.td +++ b/compiler/include/byteir/Dialect/mhlo/Passes.td @@ -305,25 +305,6 @@ def RewriteWithConstraint : Pass<"rewrite-with-constraint", "mlir::func::FuncOp let constructor = "mlir::createRewriteWithConstraintPass()"; } -//===----------------------------------------------------------------------===// -// ShapeReification -//===----------------------------------------------------------------------===// - -def ShapeReification : Pass<"byteir-shape-reification", "func::FuncOp"> { - let summary = "Iteratively reify all shape computations."; - let description = [{ - If an operation has a shape reification implementation, that is to say, we - know how to express the outputs' shape by it's inputs' shape symbolicly, - then a tensor.dim or shape.shape_of on this type of operation could be - reified. And shape reification procedure could be handled recursively. - }]; - let constructor = "mlir::createByteIRShapeReificationPass()"; - let dependentDialects = [ - "mlir::shape::ShapeDialect", - "mlir::tensor::TensorDialect" - ]; -} - //===----------------------------------------------------------------------===// // Static Shape Inference //===----------------------------------------------------------------------===// diff --git a/compiler/include/byteir/Transforms/Passes.h b/compiler/include/byteir/Transforms/Passes.h index 7a179a34e..a5ec9c79d 100644 --- a/compiler/include/byteir/Transforms/Passes.h +++ b/compiler/include/byteir/Transforms/Passes.h @@ -36,6 +36,7 @@ #include "byteir/Transforms/SetArgShape.h" #include "byteir/Transforms/SetSpace.h" #include "byteir/Transforms/TryCatchModulePipeline.h" +#include "byteir/Transforms/ShapeReification.h" namespace mlir { diff --git a/compiler/include/byteir/Transforms/Passes.td b/compiler/include/byteir/Transforms/Passes.td index 97d69c022..17e88c1e6 100644 --- a/compiler/include/byteir/Transforms/Passes.td +++ b/compiler/include/byteir/Transforms/Passes.td @@ -425,4 +425,23 @@ def SetOpSpace: Pass<"set-op-space", "func::FuncOp"> { ]; } +//===----------------------------------------------------------------------===// +// ShapeReification +//===----------------------------------------------------------------------===// + +def ShapeReification : Pass<"byteir-shape-reification", "func::FuncOp"> { + let summary = "Iteratively reify all shape computations."; + let description = [{ + If an operation has a shape reification implementation, that is to say, we + know how to express the outputs' shape by it's inputs' shape symbolicly, + then a tensor.dim or shape.shape_of on this type of operation could be + reified. And shape reification procedure could be handled recursively. + }]; + let constructor = "mlir::createByteIRShapeReificationPass()"; + let dependentDialects = [ + "mlir::shape::ShapeDialect", + "mlir::tensor::TensorDialect" + ]; +} + #endif // BYTEIR_TRANSFORMS_PASSES diff --git a/compiler/include/byteir/Dialect/mhlo/Transforms/ShapeReification.h b/compiler/include/byteir/Transforms/ShapeReification.h similarity index 94% rename from compiler/include/byteir/Dialect/mhlo/Transforms/ShapeReification.h rename to compiler/include/byteir/Transforms/ShapeReification.h index 19f338f22..7c4cb5043 100644 --- a/compiler/include/byteir/Dialect/mhlo/Transforms/ShapeReification.h +++ b/compiler/include/byteir/Transforms/ShapeReification.h @@ -1,6 +1,6 @@ //===- ShapeReification.h -------------------------------------*--- C++ -*-===// // -// Copyright 2022 ByteDance Ltd. and/or its affiliates. All rights reserved. +// Copyright 2024 ByteDance Ltd. and/or its affiliates. All rights reserved. // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at diff --git a/compiler/lib/Analysis/SymbolicShape.cpp b/compiler/lib/Analysis/SymbolicShape.cpp index 1f5d9f499..703dec1c4 100644 --- a/compiler/lib/Analysis/SymbolicShape.cpp +++ b/compiler/lib/Analysis/SymbolicShape.cpp @@ -16,7 +16,7 @@ //===----------------------------------------------------------------------===// #include "byteir/Analysis/SymbolicShape.h" -#include "byteir/Dialect/mhlo/Transforms/ShapeReification.h" +#include "byteir/Transforms/ShapeReification.h" #include "mlir/Dialect/Shape/IR/Shape.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/IRMapping.h" diff --git a/compiler/lib/Dialect/mhlo/CMakeLists.txt b/compiler/lib/Dialect/mhlo/CMakeLists.txt index 81667fb71..a6501cf0f 100644 --- a/compiler/lib/Dialect/mhlo/CMakeLists.txt +++ b/compiler/lib/Dialect/mhlo/CMakeLists.txt @@ -105,7 +105,6 @@ add_mlir_dialect_library(ByteIRMhloPasses Transforms/ReduceFusion.cpp Transforms/ReshapeGather.cpp Transforms/RewriteWithConstraint.cpp - Transforms/ShapeReification.cpp Transforms/StaticShapeInference.cpp Transforms/TrivialFusion.cpp Transforms/UnfuseBatchNorm.cpp diff --git a/compiler/lib/Dialect/mhlo/Util/ShapeInferUtil.cpp b/compiler/lib/Dialect/mhlo/Util/ShapeInferUtil.cpp index 0bf8250b5..965adc765 100644 --- a/compiler/lib/Dialect/mhlo/Util/ShapeInferUtil.cpp +++ b/compiler/lib/Dialect/mhlo/Util/ShapeInferUtil.cpp @@ -17,13 +17,19 @@ #include "byteir/Dialect/mhlo/Util/ShapeInferUtil.h" #include "byteir/Dialect/mhlo/DynamicShapeOpRegister/Register.h" +#include "byteir/Transforms/ShapeReification.h" #include "mhlo/IR/hlo_ops.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/BuiltinTypes.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/Passes.h" +#include "mlir/Transforms/TopologicalSortUtils.h" #include "llvm/ADT/StringMap.h" #include "llvm/Support/Debug.h" +#include +#include using namespace mlir; @@ -177,6 +183,168 @@ mlir::inferReturnTypeComponents(llvm::StringRef name) { return nullptr; } +namespace { + +SmallVector collectAllOpsForReturn(Operation *retOp) { + llvm::DenseSet visitedOp; + std::queue opQueue; + + opQueue.push(retOp); + while (!opQueue.empty()) { + auto frontOp = opQueue.front(); + opQueue.pop(); + if (visitedOp.find(frontOp) != visitedOp.end()) { + continue; + } + visitedOp.insert(frontOp); + for (Value operand : frontOp->getOperands()) { + if (!operand.getDefiningOp()) { + continue; + } + if (Operation *defOp = operand.getDefiningOp()) { + opQueue.push(defOp); + } + } + } + visitedOp.erase(retOp); + return SmallVector(visitedOp.begin(), visitedOp.end()); +} + +bool deduceFromFuncArgShape(Value value) { + if (value.isa()) { + return false; + } + + auto defOp = value.getDefiningOp(); + if (!defOp) { + return false; + } + + if (isa(defOp)) { + return true; + } + + if (isa(defOp)) { + auto operand = defOp->getOperand(0); + if (operand.isa()) { + return true; + } + return false; + } + + for (Value &&operand : defOp->getOperands()) { + if (!deduceFromFuncArgShape(operand)) { + return false; + } + } + return true; +} + +LogicalResult reifyCallOp(OpBuilder &builder, Operation *op, + SmallVectorImpl &reifications) { + OpBuilder::InsertionGuard guard(builder); + auto callOp = dyn_cast(op); + if (!callOp) { + return failure(); + } + + ModuleOp moduleOp = op->getParentRegion()->getParentOfType(); + // auxiliary builder used for create operations in shape func + // original builder maybe a rewriter, used for create operations in specific + // pattern. + OpBuilder auxiliaryBuilder(moduleOp); + StringRef funcName = callOp.getCallee(); + auto funcOp = moduleOp.lookupSymbol(funcName); + + // clone funcOp, newFuncOp used for deduce function shape + std::string newFuncName = funcName.str() + "_Shape"; + auxiliaryBuilder.setInsertionPointToStart(moduleOp.getBody()); + auto newFuncOp = auxiliaryBuilder.create( + funcOp->getLoc(), newFuncName, funcOp.getFunctionType()); + newFuncOp.setPrivate(); + IRMapping emptyBvm; + funcOp.cloneInto(newFuncOp, emptyBvm); + + // replace the operands of returnOp with corresponding shape + func::ReturnOp retOp = *newFuncOp.getOps().begin(); + if (!retOp) { + newFuncOp->erase(); + return failure(); + } + + SmallVector allResultTypes; + SmallVector allResults; + + auxiliaryBuilder.setInsertionPoint(retOp); + for (Value &&retTensor : retOp.getOperands()) { + auto retShape = + auxiliaryBuilder.create(retOp.getLoc(), retTensor); + allResultTypes.emplace_back(retShape.getType()); + allResults.emplace_back(retShape); + } + + // return the shape of original tensor returned by function + auto newRetOp = + auxiliaryBuilder.create(retOp.getLoc(), allResults); + auto newFuncType = auxiliaryBuilder.getFunctionType( + newFuncOp.getArgumentTypes(), allResultTypes); + newFuncOp.setFunctionType(newFuncType); + retOp->erase(); + + // reify newFunc to get the shape computation for current callOp + { + PassManager pm(moduleOp->getContext(), func::FuncOp::getOperationName()); + pm.addPass(createCanonicalizerPass()); + pm.addPass(createCSEPass()); + pm.addPass(createByteIRShapeReificationPass()); + pm.addPass(createCanonicalizerPass()); + pm.addPass(createCSEPass()); + + if (mlir::failed(pm.run(newFuncOp))) { + newFuncOp->erase(); + return failure(); + } + } + + // collect all shape computation ops + SmallVector reificationOps = collectAllOpsForReturn(newRetOp); + + // value only depends on the shape of FuncArgs. + for (Value &&ret : newRetOp.getOperands()) { + if (!deduceFromFuncArgShape(ret)) { + newFuncOp->erase(); + return failure(); + } + } + + // mapping the shape computation ops and collect reifications + { + mlir::computeTopologicalSorting(reificationOps); + + IRMapping bvm; + size_t numArg = newFuncOp.getNumArguments(); + for (size_t i = 0; i < numArg; ++i) { + bvm.map(newFuncOp.getArgument(i), callOp.getOperand(i)); + } + + builder.setInsertionPoint(callOp); + + for (Operation *oldOp : reificationOps) { + auto newOp = builder.clone(*oldOp, bvm); + } + + for (Value &&ret : newRetOp.getOperands()) { + reifications.push_back(bvm.lookup(ret)); + } + } + + // remove newFuncOp + newFuncOp->erase(); + return success(); +} + +} // namespace + LogicalResult mlir::reifyShapes(OpBuilder &builder, Operation *op, SmallVectorImpl &reifications) { if (!op) @@ -207,6 +375,16 @@ LogicalResult mlir::reifyShapes(OpBuilder &builder, Operation *op, } if (failed(inferFunc(op, builder, op->getOperands(), reifications))) return failure(); + } else if (auto callOp = dyn_cast(op)) { + if (failed(reifyCallOp(builder, op, reifications))) { + return failure(); + } + } else if (auto dpsOp = dyn_cast(op)) { + for (OpResult &&result : op->getOpResults()) { + auto tiedOperand = dpsOp.getTiedOpOperand(result); + reifications.push_back( + builder.create(op->getLoc(), tiedOperand->get())); + } } else { // Return failure if op doesn't have InferShapedTypeOpInterface and not // registered. diff --git a/compiler/lib/Pipelines/ByreTensorOpt.cpp b/compiler/lib/Pipelines/ByreTensorOpt.cpp index 5b1f710ad..4d5c2b5c6 100644 --- a/compiler/lib/Pipelines/ByreTensorOpt.cpp +++ b/compiler/lib/Pipelines/ByreTensorOpt.cpp @@ -47,6 +47,7 @@ void createByreTensorOptPipelineImpl(OpPassManager &pm, std::string entryFunc, createConvertHloToByreCustomPass(getCudaByreCustomConfig())); pm.addNestedPass( createConvertHloToByreTensorPass(appendArgTypes)); + pm.addNestedPass(createByteIRShapeReificationPass()); pm.addPass(createCanonicalizerPass()); } } // namespace diff --git a/compiler/lib/Transforms/CMakeLists.txt b/compiler/lib/Transforms/CMakeLists.txt index 9ac510696..3ab4a25ab 100644 --- a/compiler/lib/Transforms/CMakeLists.txt +++ b/compiler/lib/Transforms/CMakeLists.txt @@ -17,6 +17,7 @@ add_mlir_library(ByteIRTransforms RewriteOpToStdCall.cpp SetArgShape.cpp SetSpace.cpp + ShapeReification.cpp Utils.cpp ADDITIONAL_HEADER_DIRS diff --git a/compiler/lib/Transforms/PassDetail.h b/compiler/lib/Transforms/PassDetail.h index f0cf6f3fa..05ade22f2 100644 --- a/compiler/lib/Transforms/PassDetail.h +++ b/compiler/lib/Transforms/PassDetail.h @@ -51,6 +51,14 @@ namespace scf { class SCFDialect; } // namespace scf +namespace shape { +class ShapeDialect; +} // namespace shape + +namespace tensor { +class TensorDialect; +} // namespace tensor + #define GEN_PASS_CLASSES #include "byteir/Transforms/Passes.h.inc" diff --git a/compiler/lib/Dialect/mhlo/Transforms/ShapeReification.cpp b/compiler/lib/Transforms/ShapeReification.cpp similarity index 97% rename from compiler/lib/Dialect/mhlo/Transforms/ShapeReification.cpp rename to compiler/lib/Transforms/ShapeReification.cpp index 7b6c1b548..382e2ed10 100644 --- a/compiler/lib/Dialect/mhlo/Transforms/ShapeReification.cpp +++ b/compiler/lib/Transforms/ShapeReification.cpp @@ -1,6 +1,6 @@ //===- ShapeReification.cpp -----------------------------------*--- C++ -*-===// // -// Copyright 2022 ByteDance Ltd. and/or its affiliates. All rights reserved. +// Copyright 2024 ByteDance Ltd. and/or its affiliates. All rights reserved. // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at @@ -15,7 +15,7 @@ // //===----------------------------------------------------------------------===// -#include "byteir/Dialect/mhlo/Transforms/ShapeReification.h" +#include "byteir/Transforms/ShapeReification.h" #include "byteir/Dialect/mhlo/DynamicShapeOpRegister/Register.h" #include "byteir/Dialect/mhlo/Util/ShapeInferUtil.h" diff --git a/compiler/test/Conversion/FuncToByre/func_to_byre_tensor.mlir b/compiler/test/Conversion/FuncToByre/func_to_byre_tensor.mlir index 31a468e7d..4821a5915 100644 --- a/compiler/test/Conversion/FuncToByre/func_to_byre_tensor.mlir +++ b/compiler/test/Conversion/FuncToByre/func_to_byre_tensor.mlir @@ -20,3 +20,20 @@ func.func @test_normal_function_call(%arg0 : tensor<4xf32>) -> tensor<4xf32> att } // CHECK-LABEL: test_normal_function_call // CHECK: call @some_func + + +// ----- + +func.func private @Unknown0(%arg0: tensor, %arg1: tensor) -> tensor attributes {__byteir_elementwise_fusion__, byre_compute_name = "Unknown0"} { + %0 = mhlo.add %arg0, %arg1 : tensor + return %0 : tensor +} + +func.func @forward(%arg0: tensor, %arg1: tensor) -> tensor attributes {__placeholder__byre.entry_point} { + %1 = call @Unknown0(%arg1, %arg0) : (tensor, tensor) -> tensor + return %1 : tensor +} + +// CHECK-LABEL: func.func @forward +// CHECK: tensor.empty +// CHECK-NEXT: byre.compute_on_tensor @Unknown0 diff --git a/compiler/test/Transforms/shapeReification.mlir b/compiler/test/Transforms/shapeReification.mlir index d1b3cd530..157d5e176 100644 --- a/compiler/test/Transforms/shapeReification.mlir +++ b/compiler/test/Transforms/shapeReification.mlir @@ -1,4 +1,4 @@ -// RUN: byteir-opt %s -byteir-shape-reification -canonicalize -cse | FileCheck %s +// RUN: byteir-opt %s --split-input-file -byteir-shape-reification -canonicalize -cse | FileCheck %s func.func @several_ops(%arg0: tensor, %arg1: tensor<2x4xf32>, %arg2: tensor<4xf32>) -> (!shape.shape, !shape.shape, !shape.shape, !shape.shape) { %0 = "mhlo.dot"(%arg0, %arg1) : (tensor, tensor<2x4xf32>) -> tensor @@ -26,6 +26,8 @@ func.func @several_ops(%arg0: tensor, %arg1: tensor<2x4xf32>, %arg2: te // CHECK-DAG: %[[V3:.+]] = shape.value_as_shape %[[C2]] : tensor<1xindex> -> !shape.shape // CHECK-DAG: return %[[V2]], %[[V3]], %[[V2]], %[[V2]] : !shape.shape, !shape.shape, !shape.shape, !shape.shape +// ----- + // CHECK-LABEL: @infer_shape_using_dim_op func.func @infer_shape_using_dim_op(%arg0: tensor, %arg1: tensor, %arg2: tensor<4x4xf32>) -> !shape.shape { %0 = mhlo.add %arg0, %arg1 : tensor @@ -40,6 +42,8 @@ func.func @infer_shape_using_dim_op(%arg0: tensor, %arg1: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor) -> tensor { %0 = "mhlo.custom_call"(%arg0, %arg1, %arg2, %arg3) {call_target_name = "tf.DynamicStitch", has_side_effect = false} : (tensor, tensor, tensor, tensor) -> tensor %c0 = arith.constant 0 : index @@ -52,6 +56,8 @@ func.func @dynamic_stitch(%arg0: tensor, %arg1: tensor, %arg2: ten return %0 : tensor } +// ----- + func.func @gelu(%arg0: tensor) -> tensor { %0 = mhlo.custom_call @byteir.gelu(%arg0) {backend_config = "", byteir_attrs = {approximate = "erf"}} : (tensor) -> tensor %c0 = arith.constant 0 : index @@ -62,6 +68,8 @@ func.func @gelu(%arg0: tensor) -> tensor { return %0 : tensor } +// ----- + // CHECK-LABEL: func.func @dot_general func.func @dot_general(%arg0: tensor, %arg1: tensor) -> tensor<3xindex> { %c1 = arith.constant 1 : index @@ -80,11 +88,14 @@ func.func @dot_general(%arg0: tensor, %arg1: tensor) -> return %3 : tensor<3xindex> } +// ----- + // TODO: Check this after nested function call is supported func.func private @inner_func(%arg0 : tensor, %arg1 : tensor) -> tensor { %0 = mhlo.add %arg0, %arg1 : tensor return %0 : tensor } + func.func @outer_func(%arg0: tensor, %arg1: tensor) -> (!shape.shape, !shape.shape) { %0 = mhlo.add %arg0, %arg1 : tensor %1 = shape.shape_of %0 : tensor -> tensor<2xindex> @@ -94,3 +105,47 @@ func.func @outer_func(%arg0: tensor, %arg1: tensor) -> (!shape %5 = shape.value_as_shape %4 : tensor<2xindex> -> !shape.shape return %2, %5 : !shape.shape, !shape.shape } +// CHECK-LABEL: func.func @outer_func +// CHECK: %[[V0:.*]] = shape.shape_of %arg0 : tensor -> tensor<2xindex> +// CHECK: %[[V1:.*]] = shape.value_as_shape %1 : tensor<2xindex> -> !shape.shape +// CHECK: return %[[V1]], %[[V1]] : !shape.shape, !shape.shape + +// ----- + +func.func private @Unknown1(%arg0: tensor, %arg1: tensor) -> tensor attributes {__byteir_matmul_epilogue_fusion__} { + %0 = mhlo.constant dense_resource<__elided__> : tensor<10x20xf32> + %1 = "mhlo.dot"(%arg0, %0) : (tensor, tensor<10x20xf32>) -> tensor + %2 = mhlo.add %1, %arg1 : tensor + return %2 : tensor +} + +func.func private @Unknown0(%arg0: tensor, %arg1: tensor<20xf32>, %arg2: tensor, %arg3: tensor) -> (tensor, tensor) { + %0 = mhlo.constant dense<0.000000e+00> : tensor + %c20 = arith.constant 20 : index + %c0 = arith.constant 0 : index + %dim = tensor.dim %arg0, %c0 : tensor + %from_elements = tensor.from_elements %dim, %c20 : tensor<2xindex> + %1 = "mhlo.dynamic_broadcast_in_dim"(%arg1, %from_elements) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<20xf32>, tensor<2xindex>) -> tensor + %2 = mhlo.add %arg2, %1 : tensor + %3 = call @Unknown1(%arg0, %2) : (tensor, tensor) -> tensor + %4 = mhlo.maximum %2, %3 : tensor + return %4, %3 : tensor, tensor +} + +func.func @forward(%arg0: tensor, %arg1: tensor, %arg2: tensor<20x?xf32>) -> tensor<2xindex> attributes {__placeholder__byre.entry_point} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %0 = mhlo.constant dense_resource<__elided__> : tensor<10x20xf32> + %1 = mhlo.constant dense_resource<__elided__> : tensor<20xf32> + %2 = "mhlo.dot"(%arg0, %0) : (tensor, tensor<10x20xf32>) -> tensor + %3:2 = call @Unknown0(%arg0, %1, %2, %arg1) : (tensor, tensor<20xf32>, tensor, tensor) -> (tensor, tensor) + %4 = "mhlo.dot"(%3#0, %arg2) : (tensor, tensor<20x?xf32>) -> tensor + %5 = shape.shape_of %4 : tensor -> tensor<2xindex> + return %5 : tensor<2xindex> +} + +// CHECK-LABEL: func.func @forward +// CHECK: %[[DIM:.*]] = tensor.dim %arg0, %c0 : tensor +// CHECK: %[[DIM0:.*]] = tensor.dim %arg2, %c1 : tensor<20x?xf32> +// CHECK: %[[SHAPE:.*]] = tensor.from_elements %[[DIM:.*]], %[[DIM0:.*]] : tensor<2xindex> +// CHECK: return %[[SHAPE:.*]] : tensor<2xindex> \ No newline at end of file