-
Notifications
You must be signed in to change notification settings - Fork 520
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Linalg] Add conversion between bf16 and f16 #3963
Conversation
If others are okay with it we can merge this, but I think for our case we probably want to find what causes this cast to get generated in the first place, because it is guaranteed to cost us in numerics. (The datatypes have big variance in which values they can express, especially going from bf16 to fp16 is bad) Exponent: FP16 has a 5-bit exponent, while BF16 has an 8-bit exponent |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree with Dan. arith.bitcast between f16 and bf16 is probably not going to work.
Please add an e2e test that would cover this as a start.
I'm not sure if this is desirable from a performance standpoint, but you are certainly able to correctly get the conversion to work by doing Here are some examples you can compile and run to see the outputs: #map = affine_map<(d0) -> (d0)>
module {
func.func @convert(%arg0: tensor<1xbf16>) -> tensor<1xf16> {
%0 = tensor.empty() : tensor<1xf16>
%1 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%arg0 : tensor<1xbf16>) outs(%0 : tensor<1xf16>) {
^bb0(%in: bf16, %out: f16):
%2 = arith.extf %in : bf16 to f32
%3 = arith.truncf %2 : f32 to f16
linalg.yield %3 : f16
} -> tensor<1xf16>
return %1 : tensor<1xf16>
}
} running this on
But #map = affine_map<(d0) -> (d0)>
module {
func.func @convert(%arg0: tensor<1xbf16>) -> tensor<1xf16> {
%0 = tensor.empty() : tensor<1xf16>
%1 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%arg0 : tensor<1xbf16>) outs(%0 : tensor<1xf16>) {
^bb0(%in: bf16, %out: f16):
%2 = arith.bitcast %in : bf16 to f16
linalg.yield %2 : f16
} -> tensor<1xf16>
return %1 : tensor<1xf16>
}
} Yields
|
@zjgarvey I tried to add a e2e test in torch-mlir but failed when lower linalg to refbackend. argument must be a memref of f32, f64, i32, i64, i8, i1, c32, c64, but got 'memref<1x1xbf16>'. Which means bf16 cannot be set in e2e test.
The created linalg
I also tried to add onnx.cast in Shark-Testsuite with bf16, but it will also failed since |
a1e685d
to
d3fe084
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey, Chi. I think a small change in the lit test would be preferable, and renaming the PR to fix the spelling error for conversion
. Otherwise this looks good to me.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, Chi! LGTM
To fix issue #3962 : 'arith.extf' op operand type 'bf16' and result type 'f16' are cast incompatible