-
Notifications
You must be signed in to change notification settings - Fork 12.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[InstCombine] Fold umax(nuw_mul(x, C0), x + 1) into (x == 0 ? 1 : nuw_mul(x, C0)) #123468
base: main
Are you sure you want to change the base?
Conversation
@llvm/pr-subscribers-llvm-transforms Author: None (Ruhung) ChangesThis PR introduces the following transformations:
Fixes : #122388 Full diff: https://github.com/llvm/llvm-project/pull/123468.diff 2 Files Affected:
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
index c55c40c88bc845..0db7d8818fbd0b 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
@@ -1847,6 +1847,37 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
return CastInst::Create(Instruction::ZExt, NarrowMaxMin, II->getType());
}
}
+ // If C0 is not 0:
+ // umax(nuw_shl(x, C0), x + 1) -> x == 0 ? 1 : nuw_shl(x, C0)
+ // If C0 is not 0 or 1:
+ // umax(nuw_mul(x, C0), x + 1) -> x == 0 ? 1 : nuw_mul(x, C0)
+ ConstantInt *C0;
+ bool isShl = false;
+ BinaryOperator *Op = nullptr;
+ auto matchShiftOrMul = [&](Value *I) {
+ if (match(I, m_OneUse(m_NUWShl(m_Value(X), m_ConstantInt(C0))))) {
+ isShl = true;
+ return true;
+ } else if (match(I, m_OneUse(m_NUWMul(m_Value(X), m_ConstantInt(C0)))) &&
+ C0 && !C0->isOne()) {
+ isShl = false;
+ return true;
+ }
+ return false;
+ };
+ if (((matchShiftOrMul(I0) &&
+ match(I1, m_OneUse(m_Add(m_Specific(X), m_One())))) ||
+ (matchShiftOrMul(I1) &&
+ match(I0, m_OneUse(m_Add(m_Specific(X), m_One()))))) &&
+ C0 && !C0->isZero()) {
+ Op = isShl ? BinaryOperator::CreateNUWShl(X, C0)
+ : BinaryOperator::CreateNUWMul(X, C0);
+ Builder.Insert(Op);
+ Value *Cmp = Builder.CreateICmpEQ(X, ConstantInt::get(X->getType(), 0));
+ Value *NewSelect =
+ Builder.CreateSelect(Cmp, ConstantInt::get(X->getType(), 1), Op);
+ return replaceInstUsesWith(*II, NewSelect);
+ }
// If both operands of unsigned min/max are sign-extended, it is still ok
// to narrow the operation.
[[fallthrough]];
diff --git a/llvm/test/Transforms/InstCombine/add-shl-mul-umax.ll b/llvm/test/Transforms/InstCombine/add-shl-mul-umax.ll
new file mode 100644
index 00000000000000..86ce5dc06ce031
--- /dev/null
+++ b/llvm/test/Transforms/InstCombine/add-shl-mul-umax.ll
@@ -0,0 +1,300 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
+; RUN: opt -S -passes=instcombine < %s | FileCheck %s
+
+; When C0 is neither 0 nor 1:
+; umax(nuw_mul(x, C0), x + 1) is optimized to:
+; x == 0 ? 1 : nuw_mul(x, C0)
+; When C0 is not 0:
+; umax(nuw_shl(x, C0), x + 1) is optimized to:
+; x == 0 ? 1 : nuw_shl(x, C0)
+
+; Positive Test Cases for `shl`
+
+define i64 @test_shl_by_2(i64 %x) {
+; CHECK-LABEL: define i64 @test_shl_by_2(
+; CHECK-SAME: i64 [[X:%.*]]) {
+; CHECK-NEXT: [[TMP2:%.*]] = shl nuw i64 [[X]], 2
+; CHECK-NEXT: [[TMP3:%.*]] = icmp eq i64 [[X]], 0
+; CHECK-NEXT: [[MAX:%.*]] = select i1 [[TMP3]], i64 1, i64 [[TMP2]]
+; CHECK-NEXT: ret i64 [[MAX]]
+;
+ %x1 = add i64 %x, 1
+ %shl = shl nuw i64 %x, 2
+ %max = call i64 @llvm.umax.i64(i64 %shl, i64 %x1)
+ ret i64 %max
+}
+
+define i64 @test_shl_by_5(i64 %x) {
+; CHECK-LABEL: define i64 @test_shl_by_5(
+; CHECK-SAME: i64 [[X:%.*]]) {
+; CHECK-NEXT: [[TMP2:%.*]] = shl nuw i64 [[X]], 5
+; CHECK-NEXT: [[TMP3:%.*]] = icmp eq i64 [[X]], 0
+; CHECK-NEXT: [[MAX:%.*]] = select i1 [[TMP3]], i64 1, i64 [[TMP2]]
+; CHECK-NEXT: ret i64 [[MAX]]
+;
+ %x1 = add i64 %x, 1
+ %shl = shl nuw i64 %x, 5
+ %max = call i64 @llvm.umax.i64(i64 %shl, i64 %x1)
+ ret i64 %max
+}
+
+; Commuted Test Cases for `shl`
+
+define i64 @test_shl_umax_commuted(i64 %x) {
+; CHECK-LABEL: define i64 @test_shl_umax_commuted(
+; CHECK-SAME: i64 [[X:%.*]]) {
+; CHECK-NEXT: [[SHL:%.*]] = shl nuw i64 [[X]], 2
+; CHECK-NEXT: [[TMP2:%.*]] = icmp eq i64 [[X]], 0
+; CHECK-NEXT: [[MAX:%.*]] = select i1 [[TMP2]], i64 1, i64 [[SHL]]
+; CHECK-NEXT: ret i64 [[MAX]]
+;
+ %x1 = add i64 %x, 1
+ %shl = shl nuw i64 %x, 2
+ %max = call i64 @llvm.umax.i64(i64 %x1, i64 %shl)
+ ret i64 %max
+}
+
+; Negative Test Cases for `shl`
+
+define i64 @test_shl_by_zero(i64 %x) {
+; CHECK-LABEL: define i64 @test_shl_by_zero(
+; CHECK-SAME: i64 [[X:%.*]]) {
+; CHECK-NEXT: [[X1:%.*]] = add i64 [[X]], 1
+; CHECK-NEXT: [[MAX:%.*]] = call i64 @llvm.umax.i64(i64 [[X]], i64 [[X1]])
+; CHECK-NEXT: ret i64 [[MAX]]
+;
+ %x1 = add i64 %x, 1
+ %shl = shl nuw i64 %x, 0
+ %max = call i64 @llvm.umax.i64(i64 %shl, i64 %x1)
+ ret i64 %max
+}
+
+define i64 @test_shl_add_by_2(i64 %x) {
+; CHECK-LABEL: define i64 @test_shl_add_by_2(
+; CHECK-SAME: i64 [[X:%.*]]) {
+; CHECK-NEXT: [[X1:%.*]] = add i64 [[X]], 2
+; CHECK-NEXT: [[SHL:%.*]] = shl nuw i64 [[X]], 2
+; CHECK-NEXT: [[MAX:%.*]] = call i64 @llvm.umax.i64(i64 [[SHL]], i64 [[X1]])
+; CHECK-NEXT: ret i64 [[MAX]]
+;
+ %x1 = add i64 %x, 2
+ %shl = shl nuw i64 %x, 2
+ %max = call i64 @llvm.umax.i64(i64 %shl, i64 %x1)
+ ret i64 %max
+}
+
+define i64 @test_shl_without_nuw(i64 %x) {
+; CHECK-LABEL: define i64 @test_shl_without_nuw(
+; CHECK-SAME: i64 [[X:%.*]]) {
+; CHECK-NEXT: [[X1:%.*]] = add i64 [[X]], 1
+; CHECK-NEXT: [[SHL:%.*]] = shl i64 [[X]], 2
+; CHECK-NEXT: [[MAX:%.*]] = call i64 @llvm.umax.i64(i64 [[SHL]], i64 [[X1]])
+; CHECK-NEXT: ret i64 [[MAX]]
+;
+ %x1 = add i64 %x, 1
+ %shl = shl i64 %x, 2
+ %max = call i64 @llvm.umax.i64(i64 %shl, i64 %x1)
+ ret i64 %max
+}
+
+; Multi-use Test Cases for `shl`
+declare void @use(i64)
+
+define i64 @test_shl_multi_use_add(i64 %x) {
+; CHECK-LABEL: define i64 @test_shl_multi_use_add(
+; CHECK-SAME: i64 [[X:%.*]]) {
+; CHECK-NEXT: [[X1:%.*]] = add i64 [[X]], 1
+; CHECK-NEXT: call void @use(i64 [[X1]])
+; CHECK-NEXT: [[TMP2:%.*]] = shl nuw i64 [[X]], 3
+; CHECK-NEXT: [[MAX:%.*]] = call i64 @llvm.umax.i64(i64 [[TMP2]], i64 [[X1]])
+; CHECK-NEXT: ret i64 [[MAX]]
+;
+ %x1 = add i64 %x, 1
+ call void @use(i64 %x1)
+ %shl = shl nuw i64 %x, 3
+ %max = call i64 @llvm.umax.i64(i64 %shl, i64 %x1)
+ ret i64 %max
+}
+
+define i64 @test_shl_multi_use_shl(i64 %x) {
+; CHECK-LABEL: define i64 @test_shl_multi_use_shl(
+; CHECK-SAME: i64 [[X:%.*]]) {
+; CHECK-NEXT: [[X1:%.*]] = add i64 [[X]], 1
+; CHECK-NEXT: [[SHL:%.*]] = shl nuw i64 [[X]], 2
+; CHECK-NEXT: call void @use(i64 [[SHL]])
+; CHECK-NEXT: [[MAX:%.*]] = call i64 @llvm.umax.i64(i64 [[SHL]], i64 [[X1]])
+; CHECK-NEXT: ret i64 [[MAX]]
+;
+ %x1 = add i64 %x, 1
+ %shl = shl nuw i64 %x, 2
+ call void @use(i64 %shl)
+ %max = call i64 @llvm.umax.i64(i64 %shl, i64 %x1)
+ ret i64 %max
+}
+
+define i64 @test_shl_multi_use_max(i64 %x) {
+; CHECK-LABEL: define i64 @test_shl_multi_use_max(
+; CHECK-SAME: i64 [[X:%.*]]) {
+; CHECK-NEXT: [[TMP2:%.*]] = shl nuw i64 [[X]], 3
+; CHECK-NEXT: [[TMP3:%.*]] = icmp eq i64 [[X]], 0
+; CHECK-NEXT: [[MAX:%.*]] = select i1 [[TMP3]], i64 1, i64 [[TMP2]]
+; CHECK-NEXT: call void @use(i64 [[MAX]])
+; CHECK-NEXT: ret i64 [[MAX]]
+;
+ %x1 = add i64 %x, 1
+ %shl = shl nuw i64 %x, 3
+ %max = call i64 @llvm.umax.i64(i64 %shl, i64 %x1)
+ call void @use(i64 %max)
+ ret i64 %max
+}
+
+; Positive Test Cases for `mul`
+
+define i64 @test_mul_by_2(i64 %x) {
+; CHECK-LABEL: define i64 @test_mul_by_2(
+; CHECK-SAME: i64 [[X:%.*]]) {
+; CHECK-NEXT: [[TMP2:%.*]] = shl nuw i64 [[X]], 1
+; CHECK-NEXT: [[TMP3:%.*]] = icmp eq i64 [[X]], 0
+; CHECK-NEXT: [[MAX:%.*]] = select i1 [[TMP3]], i64 1, i64 [[TMP2]]
+; CHECK-NEXT: ret i64 [[MAX]]
+;
+ %x1 = add i64 %x, 1
+ %mul = mul nuw i64 %x, 2
+ %max = call i64 @llvm.umax.i64(i64 %mul, i64 %x1)
+ ret i64 %max
+}
+
+define i64 @test_mul_by_5(i64 %x) {
+; CHECK-LABEL: define i64 @test_mul_by_5(
+; CHECK-SAME: i64 [[X:%.*]]) {
+; CHECK-NEXT: [[MUL:%.*]] = mul nuw i64 [[X]], 5
+; CHECK-NEXT: [[TMP2:%.*]] = icmp eq i64 [[X]], 0
+; CHECK-NEXT: [[MAX:%.*]] = select i1 [[TMP2]], i64 1, i64 [[MUL]]
+; CHECK-NEXT: ret i64 [[MAX]]
+;
+ %x1 = add i64 %x, 1
+ %mul = mul nuw i64 %x, 5
+ %max = call i64 @llvm.umax.i64(i64 %mul, i64 %x1)
+ ret i64 %max
+}
+
+; Commuted Test Cases for `mul`
+
+define i64 @test_mul_max_commuted(i64 %x) {
+; CHECK-LABEL: define i64 @test_mul_max_commuted(
+; CHECK-SAME: i64 [[X:%.*]]) {
+; CHECK-NEXT: [[MUL:%.*]] = shl nuw i64 [[X]], 1
+; CHECK-NEXT: [[TMP2:%.*]] = icmp eq i64 [[X]], 0
+; CHECK-NEXT: [[MAX:%.*]] = select i1 [[TMP2]], i64 1, i64 [[MUL]]
+; CHECK-NEXT: ret i64 [[MAX]]
+;
+ %x1 = add i64 %x, 1
+ %mul = mul nuw i64 %x, 2
+ %max = call i64 @llvm.umax.i64(i64 %x1, i64 %mul)
+ ret i64 %max
+}
+
+; Negative Test Cases for `mul`
+
+define i64 @test_mul_by_zero(i64 %x) {
+; CHECK-LABEL: define i64 @test_mul_by_zero(
+; CHECK-SAME: i64 [[X:%.*]]) {
+; CHECK-NEXT: [[X1:%.*]] = add i64 [[X]], 1
+; CHECK-NEXT: ret i64 [[X1]]
+;
+ %x1 = add i64 %x, 1
+ %mul = mul nuw i64 %x, 0
+ %max = call i64 @llvm.umax.i64(i64 %mul, i64 %x1)
+ ret i64 %max
+}
+
+define i64 @test_mul_by_1(i64 %x) {
+; CHECK-LABEL: define i64 @test_mul_by_1(
+; CHECK-SAME: i64 [[X:%.*]]) {
+; CHECK-NEXT: [[X1:%.*]] = add i64 [[X]], 1
+; CHECK-NEXT: [[MAX:%.*]] = call i64 @llvm.umax.i64(i64 [[X]], i64 [[X1]])
+; CHECK-NEXT: ret i64 [[MAX]]
+;
+ %x1 = add i64 %x, 1
+ %mul = mul nuw i64 %x, 1
+ %max = call i64 @llvm.umax.i64(i64 %mul, i64 %x1)
+ ret i64 %max
+}
+
+define i64 @test_mul_add_by_2(i64 %x) {
+; CHECK-LABEL: define i64 @test_mul_add_by_2(
+; CHECK-SAME: i64 [[X:%.*]]) {
+; CHECK-NEXT: [[X1:%.*]] = add i64 [[X]], 2
+; CHECK-NEXT: [[MUL:%.*]] = shl nuw i64 [[X]], 1
+; CHECK-NEXT: [[MAX:%.*]] = call i64 @llvm.umax.i64(i64 [[MUL]], i64 [[X1]])
+; CHECK-NEXT: ret i64 [[MAX]]
+;
+ %x1 = add i64 %x, 2
+ %mul = mul nuw i64 %x, 2
+ %max = call i64 @llvm.umax.i64(i64 %mul, i64 %x1)
+ ret i64 %max
+}
+
+define i64 @test_mul_without_nuw(i64 %x) {
+; CHECK-LABEL: define i64 @test_mul_without_nuw(
+; CHECK-SAME: i64 [[X:%.*]]) {
+; CHECK-NEXT: [[X1:%.*]] = add i64 [[X]], 1
+; CHECK-NEXT: [[MUL:%.*]] = shl i64 [[X]], 1
+; CHECK-NEXT: [[MAX:%.*]] = call i64 @llvm.umax.i64(i64 [[MUL]], i64 [[X1]])
+; CHECK-NEXT: ret i64 [[MAX]]
+;
+ %x1 = add i64 %x, 1
+ %mul = mul i64 %x, 2
+ %max = call i64 @llvm.umax.i64(i64 %mul, i64 %x1)
+ ret i64 %max
+}
+
+; Multi-use Test Cases for `mul`
+
+define i64 @test_mul_multi_use_add(i64 %x) {
+; CHECK-LABEL: define i64 @test_mul_multi_use_add(
+; CHECK-SAME: i64 [[X:%.*]]) {
+; CHECK-NEXT: [[X1:%.*]] = add i64 [[X]], 1
+; CHECK-NEXT: call void @use(i64 [[X1]])
+; CHECK-NEXT: [[TMP2:%.*]] = shl nuw i64 [[X]], 1
+; CHECK-NEXT: [[MAX:%.*]] = call i64 @llvm.umax.i64(i64 [[TMP2]], i64 [[X1]])
+; CHECK-NEXT: ret i64 [[MAX]]
+;
+ %x1 = add i64 %x, 1
+ call void @use(i64 %x1)
+ %mul = mul nuw i64 %x, 2
+ %max = call i64 @llvm.umax.i64(i64 %mul, i64 %x1)
+ ret i64 %max
+}
+
+define i64 @test_mul_multi_use_mul(i64 %x) {
+; CHECK-LABEL: define i64 @test_mul_multi_use_mul(
+; CHECK-SAME: i64 [[X:%.*]]) {
+; CHECK-NEXT: [[X1:%.*]] = add i64 [[X]], 1
+; CHECK-NEXT: [[MUL:%.*]] = shl nuw i64 [[X]], 1
+; CHECK-NEXT: call void @use(i64 [[MUL]])
+; CHECK-NEXT: [[MAX:%.*]] = call i64 @llvm.umax.i64(i64 [[MUL]], i64 [[X1]])
+; CHECK-NEXT: ret i64 [[MAX]]
+;
+ %x1 = add i64 %x, 1
+ %mul = mul nuw i64 %x, 2
+ call void @use(i64 %mul)
+ %max = call i64 @llvm.umax.i64(i64 %mul, i64 %x1)
+ ret i64 %max
+}
+
+define i64 @test_mul_multi_use_max(i64 %x) {
+; CHECK-LABEL: define i64 @test_mul_multi_use_max(
+; CHECK-SAME: i64 [[X:%.*]]) {
+; CHECK-NEXT: [[TMP2:%.*]] = shl nuw i64 [[X]], 1
+; CHECK-NEXT: [[TMP3:%.*]] = icmp eq i64 [[X]], 0
+; CHECK-NEXT: [[MAX:%.*]] = select i1 [[TMP3]], i64 1, i64 [[TMP2]]
+; CHECK-NEXT: call void @use(i64 [[MAX]])
+; CHECK-NEXT: ret i64 [[MAX]]
+;
+ %x1 = add i64 %x, 1
+ %mul = mul nuw i64 %x, 2
+ %max = call i64 @llvm.umax.i64(i64 %mul, i64 %x1)
+ call void @use(i64 %max)
+ ret i64 %max
+}
|
match(I0, m_OneUse(m_Add(m_Specific(X), m_One()))))) && | ||
C0 && !C0->isZero()) { | ||
Op = isShl ? BinaryOperator::CreateNUWShl(X, C0) | ||
: BinaryOperator::CreateNUWMul(X, C0); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can't you reuse the original mul/shl? Then you also don't need the one-use restriction. This will also preserve the nsw flag, which is valid per https://alive2.llvm.org/ce/z/BfYZmT. (There should be a test for the nuw nsw combination.)
match(I1, m_OneUse(m_Add(m_Specific(X), m_One())))) || | ||
(matchShiftOrMul(I1) && | ||
match(I0, m_OneUse(m_Add(m_Specific(X), m_One()))))) && | ||
C0 && !C0->isZero()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Checking C0
is isn't nullptr
is redundant here and above.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed the check.
%mul = mul nuw i64 %x, 2 | ||
%max = call i64 @llvm.umax.i64(i64 %mul, i64 %x1) | ||
call void @use(i64 %max) | ||
ret i64 %max |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The mul
tests need to have non-power of two multipliers, otherwise they are actually going through the shl
pass.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done. Thanks!
const APInt *C0; | ||
auto matchShiftOrMul = [&](Value *I) { | ||
return ((match(I, m_OneUse(m_NUWShl(m_Value(X), m_APInt(C0))))) || | ||
(match(I, m_OneUse(m_NUWMul(m_Value(X), m_APInt(C0)))) && |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The one-use restriction on mul/shl are unnecessary.
(((matchI0 = matchShiftOrMul(I0)) && | ||
match(I1, m_OneUse(m_Add(m_Specific(X), m_One())))) || | ||
(matchShiftOrMul(I1) && | ||
match(I0, m_OneUse(m_Add(m_Specific(X), m_One())))))) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it would be cleaner if you structured the fold like moveNotAfterMinMax below (i.e. implement it in a lambda, call with commuted args).
This PR introduces the following transformations:
umax(nuw_shl(x, C0), x + 1) -> x == 0 ? 1 : nuw_shl(x, C0)
umax(nuw_mul(x, C0), x + 1) -> x == 0 ? 1 : nuw_mul(x, C0)
Fixes : #122388
Alive2 proof : https://alive2.llvm.org/ce/z/rkp_8U