From ecb30c907266921818d5b6b03e341028fa2ea082 Mon Sep 17 00:00:00 2001
From: FernandoS27 <fsahmkow27@gmail.com>
Date: Thu, 1 Apr 2021 22:20:57 +0200
Subject: [PATCH] shader: Improve VOTE.VTG stub

---
 .../backend/spirv/emit_spirv.h                |  8 +++
 .../spirv/emit_spirv_context_get_set.cpp      | 32 ++++++++++++
 .../frontend/ir/ir_emitter.cpp                | 37 +++++++++++++-
 .../frontend/ir/ir_emitter.h                  | 10 ++++
 src/shader_recompiler/frontend/ir/opcodes.inc |  8 +++
 .../frontend/maxwell/translate/impl/vote.cpp  |  5 +-
 .../ir_opt/ssa_rewrite_pass.cpp               | 51 ++++++++++++++++++-
 7 files changed, 147 insertions(+), 4 deletions(-)

diff --git a/src/shader_recompiler/backend/spirv/emit_spirv.h b/src/shader_recompiler/backend/spirv/emit_spirv.h
index 9c9e0c5dd8..d2eda1f8ea 100644
--- a/src/shader_recompiler/backend/spirv/emit_spirv.h
+++ b/src/shader_recompiler/backend/spirv/emit_spirv.h
@@ -59,6 +59,14 @@ void EmitSetZFlag(EmitContext& ctx);
 void EmitSetSFlag(EmitContext& ctx);
 void EmitSetCFlag(EmitContext& ctx);
 void EmitSetOFlag(EmitContext& ctx);
+void EmitGetFCSMFlag(EmitContext& ctx);
+void EmitGetTAFlag(EmitContext& ctx);
+void EmitGetTRFlag(EmitContext& ctx);
+void EmitGetMXFlag(EmitContext& ctx);
+void EmitSetFCSMFlag(EmitContext& ctx);
+void EmitSetTAFlag(EmitContext& ctx);
+void EmitSetTRFlag(EmitContext& ctx);
+void EmitSetMXFlag(EmitContext& ctx);
 Id EmitWorkgroupId(EmitContext& ctx);
 Id EmitLocalInvocationId(EmitContext& ctx);
 Id EmitLoadLocal(EmitContext& ctx, Id word_offset);
diff --git a/src/shader_recompiler/backend/spirv/emit_spirv_context_get_set.cpp b/src/shader_recompiler/backend/spirv/emit_spirv_context_get_set.cpp
index e42407f1fb..a96ee6f0de 100644
--- a/src/shader_recompiler/backend/spirv/emit_spirv_context_get_set.cpp
+++ b/src/shader_recompiler/backend/spirv/emit_spirv_context_get_set.cpp
@@ -263,6 +263,38 @@ void EmitSetOFlag(EmitContext&) {
     throw NotImplementedException("SPIR-V Instruction");
 }
 
+void EmitGetFCSMFlag(EmitContext&) {
+    throw NotImplementedException("SPIR-V Instruction");
+}
+
+void EmitGetTAFlag(EmitContext&) {
+    throw NotImplementedException("SPIR-V Instruction");
+}
+
+void EmitGetTRFlag(EmitContext&) {
+    throw NotImplementedException("SPIR-V Instruction");
+}
+
+void EmitGetMXFlag(EmitContext&) {
+    throw NotImplementedException("SPIR-V Instruction");
+}
+
+void EmitSetFCSMFlag(EmitContext&) {
+    throw NotImplementedException("SPIR-V Instruction");
+}
+
+void EmitSetTAFlag(EmitContext&) {
+    throw NotImplementedException("SPIR-V Instruction");
+}
+
+void EmitSetTRFlag(EmitContext&) {
+    throw NotImplementedException("SPIR-V Instruction");
+}
+
+void EmitSetMXFlag(EmitContext&) {
+    throw NotImplementedException("SPIR-V Instruction");
+}
+
 Id EmitWorkgroupId(EmitContext& ctx) {
     return ctx.OpLoad(ctx.U32[3], ctx.workgroup_id);
 }
diff --git a/src/shader_recompiler/frontend/ir/ir_emitter.cpp b/src/shader_recompiler/frontend/ir/ir_emitter.cpp
index 5258ede094..ddaa873f26 100644
--- a/src/shader_recompiler/frontend/ir/ir_emitter.cpp
+++ b/src/shader_recompiler/frontend/ir/ir_emitter.cpp
@@ -198,6 +198,38 @@ void IREmitter::SetOFlag(const U1& value) {
     Inst(Opcode::SetOFlag, value);
 }
 
+U1 IREmitter::GetFCSMFlag() {
+    return Inst<U1>(Opcode::GetFCSMFlag);
+}
+
+U1 IREmitter::GetTAFlag() {
+    return Inst<U1>(Opcode::GetTAFlag);
+}
+
+U1 IREmitter::GetTRFlag() {
+    return Inst<U1>(Opcode::GetTRFlag);
+}
+
+U1 IREmitter::GetMXFlag() {
+    return Inst<U1>(Opcode::GetMXFlag);
+}
+
+void IREmitter::SetFCSMFlag(const U1& value) {
+    Inst(Opcode::SetFCSMFlag, value);
+}
+
+void IREmitter::SetTAFlag(const U1& value) {
+    Inst(Opcode::SetTAFlag, value);
+}
+
+void IREmitter::SetTRFlag(const U1& value) {
+    Inst(Opcode::SetTRFlag, value);
+}
+
+void IREmitter::SetMXFlag(const U1& value) {
+    Inst(Opcode::SetMXFlag, value);
+}
+
 static U1 GetFlowTest(IREmitter& ir, FlowTest flow_test) {
     switch (flow_test) {
     case FlowTest::F:
@@ -256,13 +288,14 @@ static U1 GetFlowTest(IREmitter& ir, FlowTest flow_test) {
         return ir.LogicalOr(ir.GetSFlag(), ir.GetZFlag());
     case FlowTest::RGT:
         return ir.LogicalAnd(ir.LogicalNot(ir.GetSFlag()), ir.LogicalNot(ir.GetZFlag()));
+
+    case FlowTest::FCSM_TR:
+        return ir.LogicalAnd(ir.GetFCSMFlag(), ir.GetTRFlag());
     case FlowTest::CSM_TA:
     case FlowTest::CSM_TR:
     case FlowTest::CSM_MX:
     case FlowTest::FCSM_TA:
-    case FlowTest::FCSM_TR:
     case FlowTest::FCSM_MX:
-        return ir.Imm1(false);
     default:
         throw NotImplementedException("Flow test {}", flow_test);
     }
diff --git a/src/shader_recompiler/frontend/ir/ir_emitter.h b/src/shader_recompiler/frontend/ir/ir_emitter.h
index a4616e2474..6e04eec7f3 100644
--- a/src/shader_recompiler/frontend/ir/ir_emitter.h
+++ b/src/shader_recompiler/frontend/ir/ir_emitter.h
@@ -70,6 +70,16 @@ public:
     void SetCFlag(const U1& value);
     void SetOFlag(const U1& value);
 
+    [[nodiscard]] U1 GetFCSMFlag();
+    [[nodiscard]] U1 GetTAFlag();
+    [[nodiscard]] U1 GetTRFlag();
+    [[nodiscard]] U1 GetMXFlag();
+
+    void SetFCSMFlag(const U1& value);
+    void SetTAFlag(const U1& value);
+    void SetTRFlag(const U1& value);
+    void SetMXFlag(const U1& value);
+
     [[nodiscard]] U1 Condition(IR::Condition cond);
     [[nodiscard]] U1 GetFlowTestResult(FlowTest test);
 
diff --git a/src/shader_recompiler/frontend/ir/opcodes.inc b/src/shader_recompiler/frontend/ir/opcodes.inc
index ffd0cc690b..7023727752 100644
--- a/src/shader_recompiler/frontend/ir/opcodes.inc
+++ b/src/shader_recompiler/frontend/ir/opcodes.inc
@@ -46,10 +46,18 @@ OPCODE(GetZFlag,                                            U1,             Void
 OPCODE(GetSFlag,                                            U1,             Void,                                                                           )
 OPCODE(GetCFlag,                                            U1,             Void,                                                                           )
 OPCODE(GetOFlag,                                            U1,             Void,                                                                           )
+OPCODE(GetFCSMFlag,                                         U1,             Void,                                                                           )
+OPCODE(GetTAFlag,                                           U1,             Void,                                                                           )
+OPCODE(GetTRFlag,                                           U1,             Void,                                                                           )
+OPCODE(GetMXFlag,                                           U1,             Void,                                                                           )
 OPCODE(SetZFlag,                                            Void,           U1,                                                                             )
 OPCODE(SetSFlag,                                            Void,           U1,                                                                             )
 OPCODE(SetCFlag,                                            Void,           U1,                                                                             )
 OPCODE(SetOFlag,                                            Void,           U1,                                                                             )
+OPCODE(SetFCSMFlag,                                         Void,           U1,                                                                             )
+OPCODE(SetTAFlag,                                           Void,           U1,                                                                             )
+OPCODE(SetTRFlag,                                           Void,           U1,                                                                             )
+OPCODE(SetMXFlag,                                           Void,           U1,                                                                             )
 OPCODE(WorkgroupId,                                         U32x3,                                                                                          )
 OPCODE(LocalInvocationId,                                   U32x3,                                                                                          )
 
diff --git a/src/shader_recompiler/frontend/maxwell/translate/impl/vote.cpp b/src/shader_recompiler/frontend/maxwell/translate/impl/vote.cpp
index 391520a186..2acabb6629 100644
--- a/src/shader_recompiler/frontend/maxwell/translate/impl/vote.cpp
+++ b/src/shader_recompiler/frontend/maxwell/translate/impl/vote.cpp
@@ -50,7 +50,10 @@ void TranslatorVisitor::VOTE(u64 insn) {
 }
 
 void TranslatorVisitor::VOTE_vtg(u64) {
-    // Stub
+    // LOG_WARNING("VOTE.VTG: Stubbed!");
+    auto imm = ir.Imm1(false);
+    ir.SetFCSMFlag(imm);
+    ir.SetTRFlag(imm);
 }
 
 } // namespace Shader::Maxwell
diff --git a/src/shader_recompiler/ir_opt/ssa_rewrite_pass.cpp b/src/shader_recompiler/ir_opt/ssa_rewrite_pass.cpp
index 2592337461..7dab330345 100644
--- a/src/shader_recompiler/ir_opt/ssa_rewrite_pass.cpp
+++ b/src/shader_recompiler/ir_opt/ssa_rewrite_pass.cpp
@@ -38,6 +38,10 @@ struct ZeroFlagTag : FlagTag {};
 struct SignFlagTag : FlagTag {};
 struct CarryFlagTag : FlagTag {};
 struct OverflowFlagTag : FlagTag {};
+struct FCSMFlagTag : FlagTag {};
+struct TAFlagTag : FlagTag {};
+struct TRFlagTag : FlagTag {};
+struct MXFlagTag : FlagTag {};
 
 struct GotoVariable : FlagTag {
     GotoVariable() = default;
@@ -53,7 +57,8 @@ struct IndirectBranchVariable {
 };
 
 using Variant = std::variant<IR::Reg, IR::Pred, ZeroFlagTag, SignFlagTag, CarryFlagTag,
-                             OverflowFlagTag, GotoVariable, IndirectBranchVariable>;
+                             OverflowFlagTag, FCSMFlagTag, TAFlagTag, TRFlagTag, MXFlagTag,
+                             GotoVariable, IndirectBranchVariable>;
 using ValueMap = boost::container::flat_map<IR::Block*, IR::Value, std::less<IR::Block*>>;
 
 struct DefTable {
@@ -89,6 +94,22 @@ struct DefTable {
         return overflow_flag;
     }
 
+    [[nodiscard]] ValueMap& operator[](FCSMFlagTag) noexcept {
+        return fcsm_flag;
+    }
+
+    [[nodiscard]] ValueMap& operator[](TAFlagTag) noexcept {
+        return ta_flag;
+    }
+
+    [[nodiscard]] ValueMap& operator[](TRFlagTag) noexcept {
+        return tr_flag;
+    }
+
+    [[nodiscard]] ValueMap& operator[](MXFlagTag) noexcept {
+        return mr_flag;
+    }
+
     std::array<ValueMap, IR::NUM_USER_REGS> regs;
     std::array<ValueMap, IR::NUM_USER_PREDS> preds;
     boost::container::flat_map<u32, ValueMap> goto_vars;
@@ -97,6 +118,10 @@ struct DefTable {
     ValueMap sign_flag;
     ValueMap carry_flag;
     ValueMap overflow_flag;
+    ValueMap fcsm_flag;
+    ValueMap ta_flag;
+    ValueMap tr_flag;
+    ValueMap mr_flag;
 };
 
 IR::Opcode UndefOpcode(IR::Reg) noexcept {
@@ -247,6 +272,18 @@ void VisitInst(Pass& pass, IR::Block* block, IR::Inst& inst) {
     case IR::Opcode::SetOFlag:
         pass.WriteVariable(OverflowFlagTag{}, block, inst.Arg(0));
         break;
+    case IR::Opcode::SetFCSMFlag:
+        pass.WriteVariable(FCSMFlagTag{}, block, inst.Arg(0));
+        break;
+    case IR::Opcode::SetTAFlag:
+        pass.WriteVariable(TAFlagTag{}, block, inst.Arg(0));
+        break;
+    case IR::Opcode::SetTRFlag:
+        pass.WriteVariable(TRFlagTag{}, block, inst.Arg(0));
+        break;
+    case IR::Opcode::SetMXFlag:
+        pass.WriteVariable(MXFlagTag{}, block, inst.Arg(0));
+        break;
     case IR::Opcode::GetRegister:
         if (const IR::Reg reg{inst.Arg(0).Reg()}; reg != IR::Reg::RZ) {
             inst.ReplaceUsesWith(pass.ReadVariable(reg, block));
@@ -275,6 +312,18 @@ void VisitInst(Pass& pass, IR::Block* block, IR::Inst& inst) {
     case IR::Opcode::GetOFlag:
         inst.ReplaceUsesWith(pass.ReadVariable(OverflowFlagTag{}, block));
         break;
+    case IR::Opcode::GetFCSMFlag:
+        inst.ReplaceUsesWith(pass.ReadVariable(FCSMFlagTag{}, block));
+        break;
+    case IR::Opcode::GetTAFlag:
+        inst.ReplaceUsesWith(pass.ReadVariable(TAFlagTag{}, block));
+        break;
+    case IR::Opcode::GetTRFlag:
+        inst.ReplaceUsesWith(pass.ReadVariable(TRFlagTag{}, block));
+        break;
+    case IR::Opcode::GetMXFlag:
+        inst.ReplaceUsesWith(pass.ReadVariable(MXFlagTag{}, block));
+        break;
     default:
         break;
     }