From f10d40a0a25dc6709b8cbd0a6793175434db6472 Mon Sep 17 00:00:00 2001
From: ameerj <52414509+ameerj@users.noreply.github.com>
Date: Tue, 22 Mar 2022 01:22:21 -0400
Subject: [PATCH] shader_recompiler/dead_code_elimination: Add
 DeadBranchElimination pass

This adds a pass to eliminate if(false) branches within the shader code
---
 .../ir_opt/dead_code_elimination_pass.cpp     | 71 ++++++++++++++++---
 1 file changed, 62 insertions(+), 9 deletions(-)

diff --git a/src/shader_recompiler/ir_opt/dead_code_elimination_pass.cpp b/src/shader_recompiler/ir_opt/dead_code_elimination_pass.cpp
index 400836301..6c7c7b32d 100644
--- a/src/shader_recompiler/ir_opt/dead_code_elimination_pass.cpp
+++ b/src/shader_recompiler/ir_opt/dead_code_elimination_pass.cpp
@@ -7,19 +7,72 @@
 #include "shader_recompiler/ir_opt/passes.h"
 
 namespace Shader::Optimization {
-
-void DeadCodeEliminationPass(IR::Program& program) {
+namespace {
+template <bool TEST_USES>
+void DeadInstElimination(IR::Block* const block) {
     // We iterate over the instructions in reverse order.
     // This is because removing an instruction reduces the number of uses for earlier instructions.
-    for (IR::Block* const block : program.post_order_blocks) {
-        auto it{block->end()};
-        while (it != block->begin()) {
-            --it;
-            if (!it->HasUses() && !it->MayHaveSideEffects()) {
-                it->Invalidate();
-                it = block->Instructions().erase(it);
+    auto it{block->end()};
+    while (it != block->begin()) {
+        --it;
+        if constexpr (TEST_USES) {
+            if (it->HasUses() || it->MayHaveSideEffects()) {
+                continue;
             }
         }
+        it->Invalidate();
+        it = block->Instructions().erase(it);
+    }
+}
+
+void DeadBranchElimination(IR::Program& program) {
+    const auto begin_it{program.syntax_list.begin()};
+    for (auto node_it = begin_it; node_it != program.syntax_list.end(); ++node_it) {
+        if (node_it->type != IR::AbstractSyntaxNode::Type::If) {
+            continue;
+        }
+        IR::Inst* const cond_ref{node_it->data.if_node.cond.Inst()};
+        const IR::U1 cond{cond_ref->Arg(0)};
+        if (!cond.IsImmediate()) {
+            continue;
+        }
+        if (cond.U1()) {
+            continue;
+        }
+        // False immediate condition. Remove condition ref, erase the entire branch.
+        cond_ref->Invalidate();
+        // Account for nested if-statements within the if(false) branch
+        u32 nested_ifs{1u};
+        while (node_it->type != IR::AbstractSyntaxNode::Type::EndIf || nested_ifs > 0) {
+            node_it = program.syntax_list.erase(node_it);
+            switch (node_it->type) {
+            case IR::AbstractSyntaxNode::Type::If:
+                ++nested_ifs;
+                break;
+            case IR::AbstractSyntaxNode::Type::EndIf:
+                --nested_ifs;
+                break;
+            case IR::AbstractSyntaxNode::Type::Block: {
+                IR::Block* const block{node_it->data.block};
+                DeadInstElimination<false>(block);
+                break;
+            }
+            default:
+                break;
+            }
+        }
+        // Erase EndIf node of the if(false) branch
+        node_it = program.syntax_list.erase(node_it);
+        // Account for loop increment
+        --node_it;
+    }
+}
+} // namespace
+
+void DeadCodeEliminationPass(IR::Program& program) {
+    DeadBranchElimination(program);
+    for (IR::Block* const block : program.post_order_blocks) {
+        DeadInstElimination<true>(block);
     }
 }