diff --git a/tests/unit/compiler/venom/test_branch_optimizer.py b/tests/unit/compiler/venom/test_branch_optimizer.py index 82dff4777d..4c46127e1d 100644 --- a/tests/unit/compiler/venom/test_branch_optimizer.py +++ b/tests/unit/compiler/venom/test_branch_optimizer.py @@ -22,9 +22,9 @@ def test_simple_jump_case(): jnz_input = bb.append_instruction("iszero", op3) bb.append_instruction("jnz", jnz_input, br1.label, br2.label) - br1.append_instruction("add", op3, 10) + br1.append_instruction("add", op3, p1) br1.append_instruction("stop") - br2.append_instruction("add", op3, p1) + br2.append_instruction("add", op3, 10) br2.append_instruction("stop") term_inst = bb.instructions[-1] @@ -47,6 +47,6 @@ def test_simple_jump_case(): # Test that the dfg is updated correctly dfg = ac.request_analysis(DFGAnalysis) - assert dfg is old_dfg, "DFG should not be invalidated by BranchOptimizationPass" + assert dfg is not old_dfg, "DFG should be invalidated by BranchOptimizationPass" assert term_inst in dfg.get_uses(op3), "jnz not using the new condition" assert term_inst not in dfg.get_uses(jnz_input), "jnz still using the old condition" diff --git a/vyper/venom/passes/branch_optimization.py b/vyper/venom/passes/branch_optimization.py index d5b0ed9809..920dc5e431 100644 --- a/vyper/venom/passes/branch_optimization.py +++ b/vyper/venom/passes/branch_optimization.py @@ -1,4 +1,5 @@ -from vyper.venom.analysis import DFGAnalysis +from vyper.venom.analysis import CFGAnalysis, DFGAnalysis, LivenessAnalysis +from vyper.venom.basicblock import IRInstruction from vyper.venom.passes.base_pass import IRPass @@ -14,17 +15,30 @@ def _optimize_branches(self) -> None: if term_inst.opcode != "jnz": continue - prev_inst = self.dfg.get_producing_instruction(term_inst.operands[0]) - if prev_inst.opcode == "iszero": + fst, snd = bb.cfg_out + + fst_liveness = fst.instructions[0].liveness + snd_liveness = snd.instructions[0].liveness + + cost_a, cost_b = len(fst_liveness), len(snd_liveness) + + cond = term_inst.operands[0] + prev_inst = self.dfg.get_producing_instruction(cond) + if cost_a >= cost_b and prev_inst.opcode == "iszero": new_cond = prev_inst.operands[0] term_inst.operands = [new_cond, term_inst.operands[2], term_inst.operands[1]] - - # Since the DFG update is simple we do in place to avoid invalidating the DFG - # and having to recompute it (which is expensive(er)) - self.dfg.remove_use(prev_inst.output, term_inst) - self.dfg.add_use(new_cond, term_inst) + elif cost_a > cost_b: + new_cond = fn.get_next_variable() + inst = IRInstruction("iszero", [term_inst.operands[0]], output=new_cond) + bb.insert_instruction(inst, index=-1) + term_inst.operands = [new_cond, term_inst.operands[2], term_inst.operands[1]] def run_pass(self): + self.liveness = self.analyses_cache.request_analysis(LivenessAnalysis) + self.cfg = self.analyses_cache.request_analysis(CFGAnalysis) self.dfg = self.analyses_cache.request_analysis(DFGAnalysis) self._optimize_branches() + + self.analyses_cache.invalidate_analysis(LivenessAnalysis) + self.analyses_cache.invalidate_analysis(CFGAnalysis)