diff --git a/source/opt/dead_branch_elim_pass.cpp b/source/opt/dead_branch_elim_pass.cpp index 8425a398cc..16d9fd5636 100644 --- a/source/opt/dead_branch_elim_pass.cpp +++ b/source/opt/dead_branch_elim_pass.cpp @@ -167,7 +167,6 @@ bool DeadBranchElimPass::MarkLiveBlocks( } if (simplify) { - modified = true; conditions_to_simplify.push_back({block, live_lab_id}); stack.push_back(GetParentBlock(live_lab_id)); } else { @@ -179,24 +178,29 @@ bool DeadBranchElimPass::MarkLiveBlocks( } } - // Traverse |conditions_to_simplify in reverse order. This is done so that we - // simplify nested constructs before simplifying the constructs that contain - // them. + // Traverse |conditions_to_simplify| in reverse order. This is done so that + // we simplify nested constructs before simplifying the constructs that + // contain them. for (auto b = conditions_to_simplify.rbegin(); b != conditions_to_simplify.rend(); ++b) { - SimplifyBranch(b->first, b->second); + modified |= SimplifyBranch(b->first, b->second); } return modified; } -void DeadBranchElimPass::SimplifyBranch(BasicBlock* block, +bool DeadBranchElimPass::SimplifyBranch(BasicBlock* block, uint32_t live_lab_id) { Instruction* merge_inst = block->GetMergeInst(); Instruction* terminator = block->terminator(); if (merge_inst && merge_inst->opcode() == SpvOpSelectionMerge) { if (merge_inst->NextNode()->opcode() == SpvOpSwitch && SwitchHasNestedBreak(block->id())) { + if (terminator->NumInOperands() == 2) { + // We cannot remove the branch, and it already has a single case, so no + // work to do. + return false; + } // We have to keep the switch because it has a nest break, so we // remove all cases except for the live one. Instruction::OperandList new_operands; @@ -231,6 +235,7 @@ void DeadBranchElimPass::SimplifyBranch(BasicBlock* block, AddBranch(live_lab_id, block); context()->KillInst(terminator); } + return true; } void DeadBranchElimPass::MarkUnreachableStructuredTargets( @@ -643,7 +648,8 @@ bool DeadBranchElimPass::SwitchHasNestedBreak(uint32_t switch_header_id) { if (bb->id() == switch_header_id) { return true; } - return (cfg_analysis->ContainingConstruct(inst) == switch_header_id); + return (cfg_analysis->ContainingConstruct(inst) == switch_header_id && + bb->GetMergeInst() == nullptr); }); } diff --git a/source/opt/dead_branch_elim_pass.h b/source/opt/dead_branch_elim_pass.h index a50933fdb3..7841bc4705 100644 --- a/source/opt/dead_branch_elim_pass.h +++ b/source/opt/dead_branch_elim_pass.h @@ -159,14 +159,15 @@ class DeadBranchElimPass : public MemPass { std::unordered_set* blocks_with_back_edges); // Returns true if there is a brach to the merge node of the selection - // construct |switch_header_id| that is inside a nested selection construct. + // construct |switch_header_id| that is inside a nested selection construct or + // in the header of the nested selection construct. bool SwitchHasNestedBreak(uint32_t switch_header_id); - // Replaces the terminator of |block| with a branch to |live_lab_id|. The - // merge instruction is deleted or moved as needed to maintain structured - // control flow. Assumes that the StructuredCFGAnalysis is valid for the - // constructs containing |block|. - void SimplifyBranch(BasicBlock* block, uint32_t live_lab_id); + // Return true of the terminator of |block| is successfully replaced with a + // branch to |live_lab_id|. The merge instruction is deleted or moved as + // needed to maintain structured control flow. Assumes that the + // StructuredCFGAnalysis is valid for the constructs containing |block|. + bool SimplifyBranch(BasicBlock* block, uint32_t live_lab_id); }; } // namespace opt diff --git a/test/opt/dead_branch_elim_test.cpp b/test/opt/dead_branch_elim_test.cpp index e612867620..3dcc0f7701 100644 --- a/test/opt/dead_branch_elim_test.cpp +++ b/test/opt/dead_branch_elim_test.cpp @@ -1378,6 +1378,7 @@ OpFunctionEnd SinglePassRunAndMatch(text, true); } + TEST_F(DeadBranchElimTest, LeaveContinueBackedgeExtraBlock) { const std::string text = R"( ; CHECK: OpBranch [[header:%\w+]] @@ -3161,6 +3162,73 @@ OpFunctionEnd SinglePassRunAndCheck(before, after, true, true); } +TEST_F(DeadBranchElimTest, BreakInNestedHeaderWithSingleCase) { + const std::string text = R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 450 +OpName %main "main" +%void = OpTypeVoid +%4 = OpTypeFunction %void +%bool = OpTypeBool +%uint = OpTypeInt 32 0 +%uint_0 = OpConstant %uint 0 +%8 = OpUndef %bool +%main = OpFunction %void None %4 +%9 = OpLabel +OpSelectionMerge %10 None +OpSwitch %uint_0 %11 +%11 = OpLabel +OpSelectionMerge %12 None +OpBranchConditional %8 %10 %12 +%12 = OpLabel +OpBranch %10 +%10 = OpLabel +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndCheck(text, text, true, true); +} + +TEST_F(DeadBranchElimTest, BreakInNestedHeaderWithTwoCases) { + const std::string text = R"( +; CHECK: OpSelectionMerge [[merge:%\w+]] None +; CHECK-NEXT: OpSwitch %uint_0 [[bb:%\w+\n]] +OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 450 +OpName %main "main" +%void = OpTypeVoid +%4 = OpTypeFunction %void +%bool = OpTypeBool +%uint = OpTypeInt 32 0 +%uint_0 = OpConstant %uint 0 +%8 = OpUndef %bool +%main = OpFunction %void None %4 +%9 = OpLabel +OpSelectionMerge %10 None +OpSwitch %uint_0 %11 1 %12 +%11 = OpLabel +OpSelectionMerge %13 None +OpBranchConditional %8 %10 %13 +%13 = OpLabel +OpBranch %10 +%12 = OpLabel +OpBranch %10 +%10 = OpLabel +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + // TODO(greg-lunarg): Add tests to verify handling of these cases: // // More complex control flow