Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

torch.prims.convert_element_type to linalg bf16 to f16 fail #3962

Open
AmosLewis opened this issue Jan 16, 2025 · 3 comments
Open

torch.prims.convert_element_type to linalg bf16 to f16 fail #3962

AmosLewis opened this issue Jan 16, 2025 · 3 comments
Assignees

Comments

@AmosLewis
Copy link
Collaborator

AmosLewis commented Jan 16, 2025

'arith.extf' op operand type 'bf16' and result type 'f16' are cast incompatible

This error from llama3_8b_fp8 model
small reproducer input ir convert.torch.mlir :

func.func @convert(%652: !torch.vtensor<[1,?,32,128],bf16>) -> !torch.vtensor<[1,?,32,128],f16> {
  %int5 = torch.constant.int 5
  %0 = torch.prims.convert_element_type %652, %int5 : !torch.vtensor<[1,?,32,128],bf16>, !torch.int -> !torch.vtensor<[1,?,32,128],f16>
  return %0 : !torch.vtensor<[1,?,32,128],f16>
}

torch-mlir-opt --torch-decompose-complex-ops --cse --canonicalize convert.torch.mlir > todtype.torch.mlir

module {
  func.func @convert(%arg0: !torch.vtensor<[1,?,32,128],bf16>) -> !torch.vtensor<[1,?,32,128],f16> {
    %int5 = torch.constant.int 5
    %false = torch.constant.bool false
    %none = torch.constant.none
    %0 = torch.aten.to.dtype %arg0, %int5, %false, %false, %none : !torch.vtensor<[1,?,32,128],bf16>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,?,32,128],f16>
    return %0 : !torch.vtensor<[1,?,32,128],f16>
  }
}

torch-mlir-opt ---convert-torch-to-linalg todtype.torch.mlir

'arith.extf' op operand type 'bf16' and result type 'f16' are cast incompatible
mlir-asm-printer: 'func.func' failed to verify and will be printed in generic form
"func.func"() <{function_type = (!torch.vtensor<[1,?,32,128],bf16>) -> !torch.vtensor<[1,?,32,128],f16>, sym_name = "convert"}> ({
^bb0(%arg0: !torch.vtensor<[1,?,32,128],bf16>):
  %0 = "builtin.unrealized_conversion_cast"(%arg0) : (!torch.vtensor<[1,?,32,128],bf16>) -> tensor<1x?x32x128xbf16>
  %1 = "torch.constant.int"() <{value = 5 : i64}> : () -> !torch.int
  %2 = "builtin.unrealized_conversion_cast"(%1) : (!torch.int) -> i64
  %3 = "torch.constant.bool"() <{value = false}> : () -> !torch.bool
  %4 = "builtin.unrealized_conversion_cast"(%3) : (!torch.bool) -> i1
  %5 = "torch.constant.none"() : () -> !torch.none
  %6 = "arith.constant"() <{value = 1 : index}> : () -> index
  %7 = "arith.constant"() <{value = 1 : index}> : () -> index
  %8 = "tensor.dim"(%0, %7) : (tensor<1x?x32x128xbf16>, index) -> index
  %9 = "arith.constant"() <{value = 2 : index}> : () -> index
  %10 = "arith.constant"() <{value = 32 : index}> : () -> index
  %11 = "arith.constant"() <{value = 3 : index}> : () -> index
  %12 = "arith.constant"() <{value = 128 : index}> : () -> index
  %13 = "tensor.empty"(%8) : (index) -> tensor<1x?x32x128xf16>
  %14 = "linalg.generic"(%0, %13) <{indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>], operandSegmentSizes = array<i32: 1, 1>}> ({
  ^bb0(%arg1: bf16, %arg2: f16):
    %17 = "arith.extf"(%arg1) : (bf16) -> f16
    "linalg.yield"(%17) : (f16) -> ()
  }) : (tensor<1x?x32x128xbf16>, tensor<1x?x32x128xf16>) -> tensor<1x?x32x128xf16>
  %15 = "tensor.cast"(%14) : (tensor<1x?x32x128xf16>) -> tensor<1x?x32x128xf16>
  %16 = "torch.aten.to.dtype"(%arg0, %1, %3, %3, %5) : (!torch.vtensor<[1,?,32,128],bf16>, !torch.int, !torch.bool, !torch.bool, !torch.none) -> !torch.vtensor<[1,?,32,128],f16>
  "func.return"(%16) : (!torch.vtensor<[1,?,32,128],f16>) -> ()
}) : () -> ()
@AmosLewis AmosLewis self-assigned this Jan 16, 2025
@AmosLewis
Copy link
Collaborator Author

AmosLewis commented Jan 16, 2025

The bug code is from here https://github.com/llvm/torch-mlir/blob/09af3b6030d8d0c0ee8a80840734224d5c4b82a3/lib/Conversion/Utils/Utils.cpp#L337C1-L342C6

if (auto scalarFloat = dyn_cast<mlir::FloatType>(scalarType)) {
  if (scalarFloat.getWidth() > dtypeFloat.getWidth())
    return b.create<arith::TruncFOp>(loc, dtype, scalar);
  // Only scalarFloat width < dtypeFloat width can reach here.
  return b.create<arith::ExtFOp>(loc, dtype, scalar);
}

When input scalar float is bf16, target dtype is f16, their width is same, the arith::extf op does not supported this kind of cast. And I also tried arith::truncf, it also failed. I didn't find other arith op for float type cast. Does this mean arith dialect doesnot support convert bf16 to f16?

@AmosLewis
Copy link
Collaborator Author

arith.bitcast might work when width equal.
Bitcast between values of equal bit width

@stellaraccident
Copy link
Collaborator

I believe that the canonical way to convert between f16<->bf16 is to first upcast to f32. It's not really a general thing, specific to those two types.

zjgarvey pushed a commit that referenced this issue Jan 17, 2025
To fix issue #3962 :
'arith.extf' op operand type 'bf16' and result type 'f16' are cast
incompatible
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants