Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add multi-function support to FlatIR and stablehlo #171

Merged
merged 4 commits into from
Oct 2, 2024

Conversation

jhalakpatel
Copy link
Collaborator

@jhalakpatel jhalakpatel commented Sep 3, 2024

Addresses issue: #155

Major changes:

  1. Add a decorator wraps_to_flat_ir_to_func which wraps flat ir operations defined in to_flat_ir method of a trace operation.
  2. Add FlatIRFunction class to represent a Flat IR function.
  3. Update integrate_subgraph to iteratively add nested ops in a function. Rebind function inputs/outputs.
  4. Update to_mlir_impl method to add function ops. Add logic to fix up function input and result type due to type resolution in op.to_mlir.
  5. Implement function deduplicaton.

Before:

==== Trace IR ====
t0 = storage(data=[2.0000, 3.0000], shape=(2,), dtype=float32, device=cpu:0)
t1 = storage(data=[1.0000, 1.0000], shape=(2,), dtype=float32, device=cpu:0)
t2 = t0 + t1
outputs:
    t2: [rank=(1), dtype=(float32), loc=(gpu:0)]

==== Flat IR ====
t0: [rank=(1), shape=((2,)), dtype=(float32), loc=(gpu:0)] = ConstantOp(data=<mlir_tensorrt.runtime._mlir_libs._api.MemRefValue object at 0x77c4a7f63cf0>)
t1: [rank=(1), shape=((2,)), dtype=(float32), loc=(gpu:0)] = ConstantOp(data=<mlir_tensorrt.runtime._mlir_libs._api.MemRefValue object at 0x77c49053fc30>)
t_inter8: [rank=(0), shape=(()), dtype=(int32), loc=(gpu:0)] = GetDimensionSizeOp(t0, dim=0)
t_inter9: [rank=(1), shape=([1]), dtype=(int32), loc=(<class 'tripy.common.device.device'>)] = ConstantOp(data=[1])
t_inter7: [rank=(1), shape=((1,)), dtype=(int32), loc=(gpu:0)] = DynamicReshapeOp(t_inter8, t_inter9)
t_inter6: [rank=(1), dtype=(int32), loc=(gpu:0)] = ConcatenateOp(t_inter7, dim=0)
t_inter5: [rank=(1), shape=([1]), dtype=(bool), loc=(gpu:0)] = CompareOp.EQ(t_inter6, t_inter9, compare_direction='EQ')
t_inter13: [rank=(0), shape=(()), dtype=(int32), loc=(gpu:0)] = GetDimensionSizeOp(t1, dim=0)
t_inter12: [rank=(1), shape=((1,)), dtype=(int32), loc=(gpu:0)] = DynamicReshapeOp(t_inter13, t_inter9)
t_inter11: [rank=(1), dtype=(int32), loc=(gpu:0)] = ConcatenateOp(t_inter12, dim=0)
t_inter4: [rank=(1), shape=([1]), dtype=(int32), loc=(gpu:0)] = SelectOp(t_inter5, t_inter11, t_inter6)
t_inter3: [rank=(1), dtype=(float32), loc=(gpu:0)] = DynamicBroadcastOp(t0, t_inter4, broadcast_dim=[0])
t_inter15: [rank=(1), dtype=(float32), loc=(gpu:0)] = DynamicBroadcastOp(t1, t_inter4, broadcast_dim=[0])
t2: [rank=(1), dtype=(float32), loc=(gpu:0)] = AddOp(t_inter3, t_inter15)
outputs:
    t2: [rank=(1), dtype=(float32), loc=(gpu:0)]

==== MLIR ====
module @outs_t2_1 {
  func.func @main() -> tensor<?xf32> {
    %cst = stablehlo.constant dense<[2.000000e+00, 3.000000e+00]> : tensor<2xf32>
    %cst_0 = stablehlo.constant dense<1.000000e+00> : tensor<2xf32>
    %c = stablehlo.constant dense<2> : tensor<i32>
    %c_1 = stablehlo.constant dense<1> : tensor<1xi32>
    %c_2 = stablehlo.constant dense<2> : tensor<1xi32>
    %0 = stablehlo.compare  EQ, %c_2, %c_1 : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1>
    %c_3 = stablehlo.constant dense<2> : tensor<i32>
    %c_4 = stablehlo.constant dense<2> : tensor<1xi32>
    %1 = stablehlo.select %0, %c_4, %c_2 : tensor<1xi1>, tensor<1xi32>
    %2 = stablehlo.dynamic_broadcast_in_dim %cst, %1, dims = [0] : (tensor<2xf32>, tensor<1xi32>) -> tensor<?xf32>
    %3 = stablehlo.dynamic_broadcast_in_dim %cst_0, %1, dims = [0] : (tensor<2xf32>, tensor<1xi32>) -> tensor<?xf32>
    %4 = stablehlo.add %2, %3 : tensor<?xf32>
    return %4 : tensor<?xf32>
  }
}

After:

 ==== Trace IR ====
t0 = storage(data=[2.0000, 3.0000], shape=(2,), dtype=float32, device=cpu:0)
t1 = storage(data=[1.0000, 1.0000], shape=(2,), dtype=float32, device=cpu:0)
t2 = t0 + t1
outputs:
    t2: [rank=(1), dtype=(float32), loc=(gpu:0)]

==== Flat IR ====
function BinaryElementwise(
    t_inter3: [rank=(1), shape=((2,)), dtype=(float32), loc=(gpu:0)]
    t_inter4: [rank=(1), shape=((2,)), dtype=(float32), loc=(gpu:0)]
) -> (
    t_inter5: [rank=(1), dtype=(float32), loc=(gpu:0)]
) {
    t_inter11: [rank=(0), shape=(()), dtype=(int32), loc=(gpu:0)] = GetDimensionSizeOp(t_inter3, dim=0)
    t_inter12: [rank=(1), shape=([1]), dtype=(int32), loc=(<class 'tripy.common.device.device'>)] = ConstantOp(data=[1])
    t_inter10: [rank=(1), shape=((1,)), dtype=(int32), loc=(gpu:0)] = DynamicReshapeOp(t_inter11, t_inter12)
    t_inter9: [rank=(1), dtype=(int32), loc=(gpu:0)] = ConcatenateOp(t_inter10, dim=0)
    t_inter13: [rank=(1), shape=([1]), dtype=(int32), loc=(<class 'tripy.common.device.device'>)] = ConstantOp(data=[1])
    t_inter8: [rank=(1), shape=([1]), dtype=(bool), loc=(gpu:0)] = CompareOp.EQ(t_inter9, t_inter13, compare_direction='EQ')
    t_inter16: [rank=(0), shape=(()), dtype=(int32), loc=(gpu:0)] = GetDimensionSizeOp(t_inter4, dim=0)
    t_inter17: [rank=(1), shape=([1]), dtype=(int32), loc=(<class 'tripy.common.device.device'>)] = ConstantOp(data=[1])
    t_inter15: [rank=(1), shape=((1,)), dtype=(int32), loc=(gpu:0)] = DynamicReshapeOp(t_inter16, t_inter17)
    t_inter14: [rank=(1), dtype=(int32), loc=(gpu:0)] = ConcatenateOp(t_inter15, dim=0)
    t_inter7: [rank=(1), shape=([1]), dtype=(int32), loc=(gpu:0)] = SelectOp(t_inter8, t_inter14, t_inter9)
    t_inter6: [rank=(1), dtype=(float32), loc=(gpu:0)] = DynamicBroadcastOp(t_inter3, t_inter7, broadcast_dim=[0])
    t_inter18: [rank=(1), dtype=(float32), loc=(gpu:0)] = DynamicBroadcastOp(t_inter4, t_inter7, broadcast_dim=[0])
    t_inter5: [rank=(1), dtype=(float32), loc=(gpu:0)] = AddOp(t_inter6, t_inter18)
    return t_inter5
}

Main Function:
inputs:
t0: [rank=(1), shape=((2,)), dtype=(float32), loc=(gpu:0)] = ConstantOp(data=<mlir_tensorrt.runtime._mlir_libs._api.MemRefValue object at 0x75254079f730>)
t1: [rank=(1), shape=((2,)), dtype=(float32), loc=(gpu:0)] = ConstantOp(data=<mlir_tensorrt.runtime._mlir_libs._api.MemRefValue object at 0x75254073cff0>)
t2 = function BinaryElementwise(t0, t1)
outputs:
    t2: [rank=(1), dtype=(float32), loc=(gpu:0)]

==== MLIR ====
module @outs_t2_1 {
  func.func @main() -> tensor<?xf32> {
    %cst = stablehlo.constant dense<[2.000000e+00, 3.000000e+00]> : tensor<2xf32>
    %cst_0 = stablehlo.constant dense<1.000000e+00> : tensor<2xf32>
    %0 = call @BinaryElementwise(%cst, %cst_0) : (tensor<2xf32>, tensor<2xf32>) -> tensor<?xf32>
    return %0 : tensor<?xf32>
  }
  func.func private @BinaryElementwise(%arg0: tensor<2xf32>, %arg1: tensor<2xf32>) -> tensor<?xf32> {
    %c = stablehlo.constant dense<2> : tensor<i32>
    %c_0 = stablehlo.constant dense<1> : tensor<1xi32>
    %c_1 = stablehlo.constant dense<2> : tensor<1xi32>
    %c_2 = stablehlo.constant dense<1> : tensor<1xi32>
    %0 = stablehlo.compare  EQ, %c_1, %c_2 : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1>
    %c_3 = stablehlo.constant dense<2> : tensor<i32>
    %c_4 = stablehlo.constant dense<1> : tensor<1xi32>
    %c_5 = stablehlo.constant dense<2> : tensor<1xi32>
    %1 = stablehlo.select %0, %c_5, %c_1 : tensor<1xi1>, tensor<1xi32>
    %2 = stablehlo.dynamic_broadcast_in_dim %arg0, %1, dims = [0] : (tensor<2xf32>, tensor<1xi32>) -> tensor<?xf32>
    %3 = stablehlo.dynamic_broadcast_in_dim %arg1, %1, dims = [0] : (tensor<2xf32>, tensor<1xi32>) -> tensor<?xf32>
    %4 = stablehlo.add %2, %3 : tensor<?xf32>
    return %4 : tensor<?xf32>
  }
}

Function deduplication: Slice is called 10 times in the below main function.

FlatIR:

Main Function:
inputs:
t0: [rank=(2), shape=((2, 3)), dtype=(float32), loc=(gpu:0)] = ConstantOp()
t_inter4: [rank=(0), shape=(()), dtype=(int32), loc=(gpu:0)] = GetDimensionSizeOp(t0, dim=0)
t_inter5: [rank=(1), shape=([1]), dtype=(int32), loc=(<class 'tripy.common.device.device'>)] = ConstantOp(data=[1])
t_inter3: [rank=(1), shape=((1,)), dtype=(int32), loc=(gpu:0)] = DynamicReshapeOp(t_inter4, t_inter5)
t_inter7: [rank=(0), shape=(()), dtype=(int32), loc=(gpu:0)] = GetDimensionSizeOp(t0, dim=1)
t_inter6: [rank=(1), shape=((1,)), dtype=(int32), loc=(gpu:0)] = DynamicReshapeOp(t_inter7, t_inter5)
t1: [rank=(1), shape=([2]), dtype=(int32), loc=(gpu:0)] = ConcatenateOp(t_inter3, t_inter6, dim=0)
t3: [rank=(2), dtype=(int32), loc=(gpu:0)] = DynamicIotaOp(t1, dim=1)
t4: [rank=(1), dtype=(int32), loc=(gpu:0)] = function ArgMinMax(t0, t3)
t_inter18: [rank=(0), shape=(()), dtype=(int32), loc=(gpu:0)] = GetDimensionSizeOp(t4, dim=0)
t_inter17: [rank=(1), shape=((1,)), dtype=(int32), loc=(gpu:0)] = DynamicReshapeOp(t_inter18, t_inter5)
t5: [rank=(1), shape=([1]), dtype=(int32), loc=(gpu:0)] = ConcatenateOp(t_inter17, dim=0)
t_inter22: [rank=(0), shape=(()), dtype=(int32), loc=(gpu:0)] = GetDimensionSizeOp(t5, dim=0)
t_inter21: [rank=(1), shape=((1,)), dtype=(int32), loc=(gpu:0)] = DynamicReshapeOp(t_inter22, t_inter5)
t7: [rank=(1), shape=([1]), dtype=(int32), loc=(gpu:0)] = ConcatenateOp(t_inter21, dim=0)
t11: [rank=(0), shape=(()), dtype=(int32), loc=(gpu:0)] = ConstantOp(data=0)
t12: [rank=(0), shape=(()), dtype=(int32), loc=(gpu:0)] = ConstantOp(data=1)
t14: [rank=(1), dtype=(int32), loc=(gpu:0)] = function Slice(t7, t11, t12, t12)
t16: [rank=(0), dtype=(int32), loc=(gpu:0)] = function Squeeze(t14)
t19: [rank=(0), dtype=(bool), loc=(gpu:0)] = function Comparison(t16, t11)
t25: [rank=(1), dtype=(int32), loc=(gpu:0)] = function Slice(t7, t11, t12, t12)
t27: [rank=(0), dtype=(int32), loc=(gpu:0)] = function Squeeze(t25)
t29: [rank=(1), shape=((1,)), dtype=(int32), loc=(gpu:0)] = ConstantOp(data=[0])
t30: [rank=(1), dtype=(int32), loc=(gpu:0)] = function Where(t19, t27, t29)
t45: [rank=(1), dtype=(int32), loc=(gpu:0)] = function Slice(t7, t11, t12, t12)
t47: [rank=(0), dtype=(int32), loc=(gpu:0)] = function Squeeze(t45)
t50: [rank=(0), dtype=(bool), loc=(gpu:0)] = function Comparison(t47, t12)
t56: [rank=(1), dtype=(int32), loc=(gpu:0)] = function Slice(t7, t11, t12, t12)
t58: [rank=(0), dtype=(int32), loc=(gpu:0)] = function Squeeze(t56)
t61: [rank=(1), dtype=(int32), loc=(gpu:0)] = function Where(t50, t58, t_inter5)
t63: [rank=(1), shape=([1]), dtype=(int32), loc=(gpu:0)] = function Slice_1(t5, t30, t61, t12)
t_inter142: [rank=(0), shape=(()), dtype=(int32), loc=(gpu:0)] = GetDimensionSizeOp(t5, dim=0)
t_inter141: [rank=(1), shape=((1,)), dtype=(int32), loc=(gpu:0)] = DynamicReshapeOp(t_inter142, t_inter5)
t66: [rank=(1), shape=([1]), dtype=(int32), loc=(gpu:0)] = ConcatenateOp(t_inter141, dim=0)
t73: [rank=(1), dtype=(int32), loc=(gpu:0)] = function Slice(t66, t11, t12, t12)
t75: [rank=(0), dtype=(int32), loc=(gpu:0)] = function Squeeze(t73)
t78: [rank=(0), dtype=(bool), loc=(gpu:0)] = function Comparison(t75, t12)
t84: [rank=(1), dtype=(int32), loc=(gpu:0)] = function Slice(t66, t11, t12, t12)
t86: [rank=(0), dtype=(int32), loc=(gpu:0)] = function Squeeze(t84)
t89: [rank=(1), dtype=(int32), loc=(gpu:0)] = function Where(t78, t86, t_inter5)
t95: [rank=(1), dtype=(int32), loc=(gpu:0)] = function Slice(t66, t11, t12, t12)
t97: [rank=(0), dtype=(int32), loc=(gpu:0)] = function Squeeze(t95)
t100: [rank=(0), dtype=(bool), loc=(gpu:0)] = function Comparison_1(t97, t11)
t106: [rank=(1), dtype=(int32), loc=(gpu:0)] = function Slice(t66, t11, t12, t12)
t108: [rank=(0), dtype=(int32), loc=(gpu:0)] = function Squeeze(t106)
t110: [rank=(0), dtype=(int32), loc=(gpu:0)] = function BinaryElementwise(t108, t97)
t112: [rank=(0), dtype=(int32), loc=(gpu:0)] = function Where_1(t100, t97, t110)
t114: [rank=(0), dtype=(bool), loc=(gpu:0)] = function Comparison(t112, t11)
t121: [rank=(1), dtype=(int32), loc=(gpu:0)] = function Slice(t66, t11, t12, t12)
t123: [rank=(0), dtype=(int32), loc=(gpu:0)] = function Squeeze(t121)
t125: [rank=(0), dtype=(bool), loc=(gpu:0)] = function Comparison_2(t123, t112)
t131: [rank=(1), dtype=(int32), loc=(gpu:0)] = function Slice(t66, t11, t12, t12)
t133: [rank=(0), dtype=(int32), loc=(gpu:0)] = function Squeeze(t131)
t135: [rank=(0), dtype=(int32), loc=(gpu:0)] = function Where_1(t125, t133, t112)
t136: [rank=(1), dtype=(int32), loc=(gpu:0)] = function Where_2(t114, t29, t135)
t138: [rank=(1), shape=([0]), dtype=(int32), loc=(gpu:0)] = function Slice_2(t5, t89, t136, t12)
t140: [rank=(1), shape=([2]), dtype=(int32), loc=(gpu:0)] = ConcatenateOp(t63, t_inter5, t138, dim=0)
t142: [rank=(2), dtype=(int32), loc=(gpu:0)] = DynamicBroadcastOp(t4, t140, broadcast_dim=[0])
outputs:
    t142: [rank=(2), dtype=(int32), loc=(gpu:0)]

StableHLO:

module @outs_t142_1 {
  func.func @main() -> tensor<?x?xi32> {
    %cst = stablehlo.constant dense<[[0.000000e+00, 1.000000e+00, 2.000000e+00], [3.000000e+00, 4.000000e+00, 5.000000e+00]]> : tensor<2x3xf32>
    %c = stablehlo.constant dense<2> : tensor<i32>
    %c_0 = stablehlo.constant dense<1> : tensor<1xi32>
    %c_1 = stablehlo.constant dense<2> : tensor<1xi32>
    %c_2 = stablehlo.constant dense<3> : tensor<i32>
    %c_3 = stablehlo.constant dense<3> : tensor<1xi32>
    %0 = stablehlo.concatenate %c_1, %c_3, dim = 0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32>
    %1 = stablehlo.dynamic_iota %0, dim = 1 : (tensor<2xi32>) -> tensor<?x?xi32>
    %2 = stablehlo.convert %cst : (tensor<2x3xf32>) -> tensor<?x?xf32>
    %3 = call @ArgMinMax(%2, %1) : (tensor<?x?xf32>, tensor<?x?xi32>) -> tensor<?xi32>
    %4 = stablehlo.get_dimension_size %3, dim = 0 : (tensor<?xi32>) -> tensor<i32>
    %5 = stablehlo.reshape %4 : (tensor<i32>) -> tensor<1xi32>
    %c_4 = stablehlo.constant dense<1> : tensor<i32>
    %c_5 = stablehlo.constant dense<1> : tensor<1xi32>
    %c_6 = stablehlo.constant dense<0> : tensor<i32>
    %c_7 = stablehlo.constant dense<1> : tensor<i32>
    %6 = stablehlo.convert %c_5 : (tensor<1xi32>) -> tensor<?xi32>
    %7 = call @Slice(%6, %c_6, %c_7, %c_7) : (tensor<?xi32>, tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<?xi32>
    %8 = call @Squeeze(%7) : (tensor<?xi32>) -> tensor<i32>
    %9 = call @Comparison(%8, %c_6) : (tensor<i32>, tensor<i32>) -> tensor<i1>
    %10 = stablehlo.convert %c_5 : (tensor<1xi32>) -> tensor<?xi32>
    %11 = call @Slice(%10, %c_6, %c_7, %c_7) : (tensor<?xi32>, tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<?xi32>
    %12 = call @Squeeze(%11) : (tensor<?xi32>) -> tensor<i32>
    %c_8 = stablehlo.constant dense<0> : tensor<1xi32>
    %13 = stablehlo.convert %c_8 : (tensor<1xi32>) -> tensor<?xi32>
    %14 = call @Where(%9, %12, %13) : (tensor<i1>, tensor<i32>, tensor<?xi32>) -> tensor<?xi32>
    %15 = stablehlo.convert %c_5 : (tensor<1xi32>) -> tensor<?xi32>
    %16 = call @Slice(%15, %c_6, %c_7, %c_7) : (tensor<?xi32>, tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<?xi32>
    %17 = call @Squeeze(%16) : (tensor<?xi32>) -> tensor<i32>
    %18 = call @Comparison(%17, %c_7) : (tensor<i32>, tensor<i32>) -> tensor<i1>
    %19 = stablehlo.convert %c_5 : (tensor<1xi32>) -> tensor<?xi32>
    %20 = call @Slice(%19, %c_6, %c_7, %c_7) : (tensor<?xi32>, tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<?xi32>
    %21 = call @Squeeze(%20) : (tensor<?xi32>) -> tensor<i32>
    %22 = stablehlo.convert %c_0 : (tensor<1xi32>) -> tensor<?xi32>
    %23 = call @Where(%18, %21, %22) : (tensor<i1>, tensor<i32>, tensor<?xi32>) -> tensor<?xi32>
    %24 = stablehlo.convert %5 : (tensor<1xi32>) -> tensor<?xi32>
    %25 = call @Slice_1(%24, %14, %23, %c_7) : (tensor<?xi32>, tensor<?xi32>, tensor<?xi32>, tensor<i32>) -> tensor<1xi32>
    %c_9 = stablehlo.constant dense<1> : tensor<i32>
    %c_10 = stablehlo.constant dense<1> : tensor<1xi32>
    %26 = stablehlo.convert %c_10 : (tensor<1xi32>) -> tensor<?xi32>
    %27 = call @Slice(%26, %c_6, %c_7, %c_7) : (tensor<?xi32>, tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<?xi32>
    %28 = call @Squeeze(%27) : (tensor<?xi32>) -> tensor<i32>
    %29 = call @Comparison(%28, %c_7) : (tensor<i32>, tensor<i32>) -> tensor<i1>
    %30 = stablehlo.convert %c_10 : (tensor<1xi32>) -> tensor<?xi32>
    %31 = call @Slice(%30, %c_6, %c_7, %c_7) : (tensor<?xi32>, tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<?xi32>
    %32 = call @Squeeze(%31) : (tensor<?xi32>) -> tensor<i32>
    %33 = stablehlo.convert %c_0 : (tensor<1xi32>) -> tensor<?xi32>
    %34 = call @Where(%29, %32, %33) : (tensor<i1>, tensor<i32>, tensor<?xi32>) -> tensor<?xi32>
    %35 = stablehlo.convert %c_10 : (tensor<1xi32>) -> tensor<?xi32>
    %36 = call @Slice(%35, %c_6, %c_7, %c_7) : (tensor<?xi32>, tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<?xi32>
    %37 = call @Squeeze(%36) : (tensor<?xi32>) -> tensor<i32>
    %38 = call @Comparison_1(%37, %c_6) : (tensor<i32>, tensor<i32>) -> tensor<i1>
    %39 = stablehlo.convert %c_10 : (tensor<1xi32>) -> tensor<?xi32>
    %40 = call @Slice(%39, %c_6, %c_7, %c_7) : (tensor<?xi32>, tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<?xi32>
    %41 = call @Squeeze(%40) : (tensor<?xi32>) -> tensor<i32>
    %42 = call @BinaryElementwise(%41, %37) : (tensor<i32>, tensor<i32>) -> tensor<i32>
    %43 = call @Where_1(%38, %37, %42) : (tensor<i1>, tensor<i32>, tensor<i32>) -> tensor<i32>
    %44 = call @Comparison(%43, %c_6) : (tensor<i32>, tensor<i32>) -> tensor<i1>
    %45 = stablehlo.convert %c_10 : (tensor<1xi32>) -> tensor<?xi32>
    %46 = call @Slice(%45, %c_6, %c_7, %c_7) : (tensor<?xi32>, tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<?xi32>
    %47 = call @Squeeze(%46) : (tensor<?xi32>) -> tensor<i32>
    %48 = call @Comparison_2(%47, %43) : (tensor<i32>, tensor<i32>) -> tensor<i1>
    %49 = stablehlo.convert %c_10 : (tensor<1xi32>) -> tensor<?xi32>
    %50 = call @Slice(%49, %c_6, %c_7, %c_7) : (tensor<?xi32>, tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<?xi32>
    %51 = call @Squeeze(%50) : (tensor<?xi32>) -> tensor<i32>
    %52 = call @Where_1(%48, %51, %43) : (tensor<i1>, tensor<i32>, tensor<i32>) -> tensor<i32>
    %53 = stablehlo.convert %c_8 : (tensor<1xi32>) -> tensor<?xi32>
    %54 = call @Where_2(%44, %53, %52) : (tensor<i1>, tensor<?xi32>, tensor<i32>) -> tensor<?xi32>
    %55 = stablehlo.convert %5 : (tensor<1xi32>) -> tensor<?xi32>
    %56 = call @Slice_2(%55, %34, %54, %c_7) : (tensor<?xi32>, tensor<?xi32>, tensor<?xi32>, tensor<i32>) -> tensor<0xi32>
    %57 = stablehlo.concatenate %25, %c_0, %56, dim = 0 : (tensor<1xi32>, tensor<1xi32>, tensor<0xi32>) -> tensor<2xi32>
    %58 = stablehlo.dynamic_broadcast_in_dim %3, %57, dims = [0] : (tensor<?xi32>, tensor<2xi32>) -> tensor<?x?xi32>
    return %58 : tensor<?x?xi32>
  }

@jhalakpatel jhalakpatel force-pushed the jhalakp-multi-func branch 4 times, most recently from 383a933 to 8bb43ab Compare September 3, 2024 03:40
@jhalakpatel jhalakpatel force-pushed the jhalakp-multi-func branch 11 times, most recently from f1181c4 to 07c164e Compare September 9, 2024 20:01
@jhalakpatel jhalakpatel force-pushed the jhalakp-multi-func branch 3 times, most recently from 811c458 to b944abf Compare September 10, 2024 01:21
@jhalakpatel jhalakpatel changed the title Add flat ir function Add mulit function support to FlatIR and stablehlo Sep 10, 2024
@jhalakpatel jhalakpatel marked this pull request as ready for review September 10, 2024 01:44
tripy/tests/helper.py Outdated Show resolved Hide resolved
@pranavm-nvidia pranavm-nvidia changed the title Add mulit function support to FlatIR and stablehlo Add multi-function support to FlatIR and stablehlo Sep 11, 2024
@jhalakpatel jhalakpatel force-pushed the jhalakp-multi-func branch 11 times, most recently from 2d612b8 to 9d36ff8 Compare September 27, 2024 19:35
tripy/tripy/backend/mlir/executor.py Show resolved Hide resolved
tripy/tripy/backend/mlir/memref.py Show resolved Hide resolved
tripy/tripy/flat_ir/flat_ir.py Outdated Show resolved Hide resolved
@jhalakpatel jhalakpatel force-pushed the jhalakp-multi-func branch 5 times, most recently from bbb44fa to 353fd30 Compare October 2, 2024 21:25
@jhalakpatel jhalakpatel merged commit 59b9536 into NVIDIA:main Oct 2, 2024
1 check passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants