mirror of
				https://git.suyu.dev/suyu/suyu
				synced 2025-11-04 00:49:02 -06:00 
			
		
		
		
	shader: Constant propagation and global memory to storage buffer
This commit is contained in:
		@@ -59,7 +59,9 @@ add_executable(shader_recompiler
 | 
			
		||||
    frontend/maxwell/translate/impl/move_special_register.cpp
 | 
			
		||||
    frontend/maxwell/translate/translate.cpp
 | 
			
		||||
    frontend/maxwell/translate/translate.h
 | 
			
		||||
    ir_opt/constant_propagation_pass.cpp
 | 
			
		||||
    ir_opt/dead_code_elimination_pass.cpp
 | 
			
		||||
    ir_opt/global_memory_to_storage_buffer_pass.cpp
 | 
			
		||||
    ir_opt/identity_removal_pass.cpp
 | 
			
		||||
    ir_opt/passes.h
 | 
			
		||||
    ir_opt/ssa_rewrite_pass.cpp
 | 
			
		||||
 
 | 
			
		||||
@@ -504,6 +504,20 @@ U32U64 IREmitter::IAdd(const U32U64& a, const U32U64& b) {
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
U32U64 IREmitter::ISub(const U32U64& a, const U32U64& b) {
 | 
			
		||||
    if (a.Type() != b.Type()) {
 | 
			
		||||
        throw InvalidArgument("Mismatching types {} and {}", a.Type(), b.Type());
 | 
			
		||||
    }
 | 
			
		||||
    switch (a.Type()) {
 | 
			
		||||
    case Type::U32:
 | 
			
		||||
        return Inst<U32>(Opcode::ISub32, a, b);
 | 
			
		||||
    case Type::U64:
 | 
			
		||||
        return Inst<U64>(Opcode::ISub64, a, b);
 | 
			
		||||
    default:
 | 
			
		||||
        ThrowInvalidType(a.Type());
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
U32 IREmitter::IMul(const U32& a, const U32& b) {
 | 
			
		||||
    return Inst<U32>(Opcode::IMul32, a, b);
 | 
			
		||||
}
 | 
			
		||||
@@ -679,8 +693,8 @@ U32U64 IREmitter::ConvertFToI(size_t bitsize, bool is_signed, const U16U32U64& v
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
U32U64 IREmitter::ConvertU(size_t bitsize, const U32U64& value) {
 | 
			
		||||
    switch (bitsize) {
 | 
			
		||||
U32U64 IREmitter::ConvertU(size_t result_bitsize, const U32U64& value) {
 | 
			
		||||
    switch (result_bitsize) {
 | 
			
		||||
    case 32:
 | 
			
		||||
        switch (value.Type()) {
 | 
			
		||||
        case Type::U32:
 | 
			
		||||
@@ -703,7 +717,7 @@ U32U64 IREmitter::ConvertU(size_t bitsize, const U32U64& value) {
 | 
			
		||||
            break;
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
    throw NotImplementedException("Conversion from {} to {} bits", value.Type(), bitsize);
 | 
			
		||||
    throw NotImplementedException("Conversion from {} to {} bits", value.Type(), result_bitsize);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
} // namespace Shader::IR
 | 
			
		||||
 
 | 
			
		||||
@@ -17,6 +17,8 @@ namespace Shader::IR {
 | 
			
		||||
class IREmitter {
 | 
			
		||||
public:
 | 
			
		||||
    explicit IREmitter(Block& block_) : block{block_}, insertion_point{block.end()} {}
 | 
			
		||||
    explicit IREmitter(Block& block_, Block::iterator insertion_point_)
 | 
			
		||||
        : block{block_}, insertion_point{insertion_point_} {}
 | 
			
		||||
 | 
			
		||||
    Block& block;
 | 
			
		||||
 | 
			
		||||
@@ -125,6 +127,7 @@ public:
 | 
			
		||||
    [[nodiscard]] U16U32U64 FPTrunc(const U16U32U64& value);
 | 
			
		||||
 | 
			
		||||
    [[nodiscard]] U32U64 IAdd(const U32U64& a, const U32U64& b);
 | 
			
		||||
    [[nodiscard]] U32U64 ISub(const U32U64& a, const U32U64& b);
 | 
			
		||||
    [[nodiscard]] U32 IMul(const U32& a, const U32& b);
 | 
			
		||||
    [[nodiscard]] U32 INeg(const U32& value);
 | 
			
		||||
    [[nodiscard]] U32 IAbs(const U32& value);
 | 
			
		||||
@@ -155,7 +158,7 @@ public:
 | 
			
		||||
    [[nodiscard]] U32U64 ConvertFToU(size_t bitsize, const U16U32U64& value);
 | 
			
		||||
    [[nodiscard]] U32U64 ConvertFToI(size_t bitsize, bool is_signed, const U16U32U64& value);
 | 
			
		||||
 | 
			
		||||
    [[nodiscard]] U32U64 ConvertU(size_t bitsize, const U32U64& value);
 | 
			
		||||
    [[nodiscard]] U32U64 ConvertU(size_t result_bitsize, const U32U64& value);
 | 
			
		||||
 | 
			
		||||
private:
 | 
			
		||||
    IR::Block::iterator insertion_point;
 | 
			
		||||
 
 | 
			
		||||
@@ -2,6 +2,8 @@
 | 
			
		||||
// Licensed under GPLv2 or any later version
 | 
			
		||||
// Refer to the license.txt file included.
 | 
			
		||||
 | 
			
		||||
#include <algorithm>
 | 
			
		||||
 | 
			
		||||
#include "shader_recompiler/exception.h"
 | 
			
		||||
#include "shader_recompiler/frontend/ir/microinstruction.h"
 | 
			
		||||
#include "shader_recompiler/frontend/ir/type.h"
 | 
			
		||||
@@ -44,6 +46,13 @@ bool Inst::MayHaveSideEffects() const noexcept {
 | 
			
		||||
    case Opcode::WriteGlobal32:
 | 
			
		||||
    case Opcode::WriteGlobal64:
 | 
			
		||||
    case Opcode::WriteGlobal128:
 | 
			
		||||
    case Opcode::WriteStorageU8:
 | 
			
		||||
    case Opcode::WriteStorageS8:
 | 
			
		||||
    case Opcode::WriteStorageU16:
 | 
			
		||||
    case Opcode::WriteStorageS16:
 | 
			
		||||
    case Opcode::WriteStorage32:
 | 
			
		||||
    case Opcode::WriteStorage64:
 | 
			
		||||
    case Opcode::WriteStorage128:
 | 
			
		||||
        return true;
 | 
			
		||||
    default:
 | 
			
		||||
        return false;
 | 
			
		||||
@@ -56,15 +65,19 @@ bool Inst::IsPseudoInstruction() const noexcept {
 | 
			
		||||
    case Opcode::GetSignFromOp:
 | 
			
		||||
    case Opcode::GetCarryFromOp:
 | 
			
		||||
    case Opcode::GetOverflowFromOp:
 | 
			
		||||
    case Opcode::GetZSCOFromOp:
 | 
			
		||||
        return true;
 | 
			
		||||
    default:
 | 
			
		||||
        return false;
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
bool Inst::AreAllArgsImmediates() const noexcept {
 | 
			
		||||
    return std::all_of(args.begin(), args.begin() + NumArgs(),
 | 
			
		||||
                       [](const IR::Value& value) { return value.IsImmediate(); });
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
bool Inst::HasAssociatedPseudoOperation() const noexcept {
 | 
			
		||||
    return zero_inst || sign_inst || carry_inst || overflow_inst || zsco_inst;
 | 
			
		||||
    return zero_inst || sign_inst || carry_inst || overflow_inst;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
Inst* Inst::GetAssociatedPseudoOperation(IR::Opcode opcode) {
 | 
			
		||||
@@ -82,9 +95,6 @@ Inst* Inst::GetAssociatedPseudoOperation(IR::Opcode opcode) {
 | 
			
		||||
    case Opcode::GetOverflowFromOp:
 | 
			
		||||
        CheckPseudoInstruction(overflow_inst, Opcode::GetOverflowFromOp);
 | 
			
		||||
        return overflow_inst;
 | 
			
		||||
    case Opcode::GetZSCOFromOp:
 | 
			
		||||
        CheckPseudoInstruction(zsco_inst, Opcode::GetZSCOFromOp);
 | 
			
		||||
        return zsco_inst;
 | 
			
		||||
    default:
 | 
			
		||||
        throw InvalidArgument("{} is not a pseudo-instruction", opcode);
 | 
			
		||||
    }
 | 
			
		||||
@@ -176,9 +186,6 @@ void Inst::Use(const Value& value) {
 | 
			
		||||
    case Opcode::GetOverflowFromOp:
 | 
			
		||||
        SetPseudoInstruction(value.Inst()->overflow_inst, this);
 | 
			
		||||
        break;
 | 
			
		||||
    case Opcode::GetZSCOFromOp:
 | 
			
		||||
        SetPseudoInstruction(value.Inst()->zsco_inst, this);
 | 
			
		||||
        break;
 | 
			
		||||
    default:
 | 
			
		||||
        break;
 | 
			
		||||
    }
 | 
			
		||||
@@ -200,9 +207,6 @@ void Inst::UndoUse(const Value& value) {
 | 
			
		||||
    case Opcode::GetOverflowFromOp:
 | 
			
		||||
        RemovePseudoInstruction(value.Inst()->overflow_inst, Opcode::GetOverflowFromOp);
 | 
			
		||||
        break;
 | 
			
		||||
    case Opcode::GetZSCOFromOp:
 | 
			
		||||
        RemovePseudoInstruction(value.Inst()->zsco_inst, Opcode::GetZSCOFromOp);
 | 
			
		||||
        break;
 | 
			
		||||
    default:
 | 
			
		||||
        break;
 | 
			
		||||
    }
 | 
			
		||||
 
 | 
			
		||||
@@ -49,6 +49,9 @@ public:
 | 
			
		||||
    /// Pseudo-instructions depend on their parent instructions for their semantics.
 | 
			
		||||
    [[nodiscard]] bool IsPseudoInstruction() const noexcept;
 | 
			
		||||
 | 
			
		||||
    /// Determines if all arguments of this instruction are immediates.
 | 
			
		||||
    [[nodiscard]] bool AreAllArgsImmediates() const noexcept;
 | 
			
		||||
 | 
			
		||||
    /// Determines if there is a pseudo-operation associated with this instruction.
 | 
			
		||||
    [[nodiscard]] bool HasAssociatedPseudoOperation() const noexcept;
 | 
			
		||||
    /// Gets a pseudo-operation associated with this instruction
 | 
			
		||||
@@ -94,7 +97,6 @@ private:
 | 
			
		||||
    Inst* sign_inst{};
 | 
			
		||||
    Inst* carry_inst{};
 | 
			
		||||
    Inst* overflow_inst{};
 | 
			
		||||
    Inst* zsco_inst{};
 | 
			
		||||
    std::vector<std::pair<Block*, Value>> phi_operands;
 | 
			
		||||
    u64 flags{};
 | 
			
		||||
};
 | 
			
		||||
 
 | 
			
		||||
@@ -24,9 +24,6 @@ OPCODE(GetAttribute,                                        U32,            Attr
 | 
			
		||||
OPCODE(SetAttribute,                                        U32,            Attribute,                                                      )
 | 
			
		||||
OPCODE(GetAttributeIndexed,                                 U32,            U32,                                                            )
 | 
			
		||||
OPCODE(SetAttributeIndexed,                                 U32,            U32,                                                            )
 | 
			
		||||
OPCODE(GetZSCORaw,                                          U32,                                                                            )
 | 
			
		||||
OPCODE(SetZSCORaw,                                          Void,           U32,                                                            )
 | 
			
		||||
OPCODE(SetZSCO,                                             Void,           ZSCO,                                                           )
 | 
			
		||||
OPCODE(GetZFlag,                                            U1,             Void,                                                           )
 | 
			
		||||
OPCODE(GetSFlag,                                            U1,             Void,                                                           )
 | 
			
		||||
OPCODE(GetCFlag,                                            U1,             Void,                                                           )
 | 
			
		||||
@@ -65,6 +62,22 @@ OPCODE(WriteGlobal32,                                       Void,           U64,
 | 
			
		||||
OPCODE(WriteGlobal64,                                       Void,           U64,            Opaque,                                         )
 | 
			
		||||
OPCODE(WriteGlobal128,                                      Void,           U64,            Opaque,                                         )
 | 
			
		||||
 | 
			
		||||
// Storage buffer operations
 | 
			
		||||
OPCODE(LoadStorageU8,                                       U32,            U32,            U32,                                            )
 | 
			
		||||
OPCODE(LoadStorageS8,                                       U32,            U32,            U32,                                            )
 | 
			
		||||
OPCODE(LoadStorageU16,                                      U32,            U32,            U32,                                            )
 | 
			
		||||
OPCODE(LoadStorageS16,                                      U32,            U32,            U32,                                            )
 | 
			
		||||
OPCODE(LoadStorage32,                                       U32,            U32,            U32,                                            )
 | 
			
		||||
OPCODE(LoadStorage64,                                       Opaque,         U32,            U32,                                            )
 | 
			
		||||
OPCODE(LoadStorage128,                                      Opaque,         U32,            U32,                                            )
 | 
			
		||||
OPCODE(WriteStorageU8,                                      Void,           U32,            U32,            U32,                                            )
 | 
			
		||||
OPCODE(WriteStorageS8,                                      Void,           U32,            U32,            U32,                                            )
 | 
			
		||||
OPCODE(WriteStorageU16,                                     Void,           U32,            U32,            U32,                                            )
 | 
			
		||||
OPCODE(WriteStorageS16,                                     Void,           U32,            U32,            U32,                                            )
 | 
			
		||||
OPCODE(WriteStorage32,                                      Void,           U32,            U32,            U32,                                            )
 | 
			
		||||
OPCODE(WriteStorage64,                                      Void,           U32,            U32,            Opaque,                                         )
 | 
			
		||||
OPCODE(WriteStorage128,                                     Void,           U32,            U32,            Opaque,                                         )
 | 
			
		||||
 | 
			
		||||
// Vector utility
 | 
			
		||||
OPCODE(CompositeConstruct2,                                 Opaque,         Opaque,         Opaque,                                         )
 | 
			
		||||
OPCODE(CompositeConstruct3,                                 Opaque,         Opaque,         Opaque,         Opaque,                         )
 | 
			
		||||
@@ -90,7 +103,6 @@ OPCODE(GetZeroFromOp,                                       U1,             Opaq
 | 
			
		||||
OPCODE(GetSignFromOp,                                       U1,             Opaque,                                                         )
 | 
			
		||||
OPCODE(GetCarryFromOp,                                      U1,             Opaque,                                                         )
 | 
			
		||||
OPCODE(GetOverflowFromOp,                                   U1,             Opaque,                                                         )
 | 
			
		||||
OPCODE(GetZSCOFromOp,                                       ZSCO,           Opaque,                                                         )
 | 
			
		||||
 | 
			
		||||
// Floating-point operations
 | 
			
		||||
OPCODE(FPAbs16,                                             U16,            U16,                                                            )
 | 
			
		||||
@@ -143,6 +155,8 @@ OPCODE(FPTrunc64,                                           U64,            U64,
 | 
			
		||||
// Integer operations
 | 
			
		||||
OPCODE(IAdd32,                                              U32,            U32,            U32,                                            )
 | 
			
		||||
OPCODE(IAdd64,                                              U64,            U64,            U64,                                            )
 | 
			
		||||
OPCODE(ISub32,                                              U32,            U32,            U32,                                            )
 | 
			
		||||
OPCODE(ISub64,                                              U64,            U64,            U64,                                            )
 | 
			
		||||
OPCODE(IMul32,                                              U32,            U32,            U32,                                            )
 | 
			
		||||
OPCODE(INeg32,                                              U32,            U32,                                                            )
 | 
			
		||||
OPCODE(IAbs32,                                              U32,            U32,                                                            )
 | 
			
		||||
 
 | 
			
		||||
@@ -11,7 +11,7 @@ namespace Shader::IR {
 | 
			
		||||
 | 
			
		||||
std::string NameOf(Type type) {
 | 
			
		||||
    static constexpr std::array names{
 | 
			
		||||
        "Opaque", "Label", "Reg", "Pred", "Attribute", "U1", "U8", "U16", "U32", "U64", "ZSCO",
 | 
			
		||||
        "Opaque", "Label", "Reg", "Pred", "Attribute", "U1", "U8", "U16", "U32", "U64",
 | 
			
		||||
    };
 | 
			
		||||
    const size_t bits{static_cast<size_t>(type)};
 | 
			
		||||
    if (bits == 0) {
 | 
			
		||||
 
 | 
			
		||||
@@ -25,7 +25,6 @@ enum class Type {
 | 
			
		||||
    U16 = 1 << 7,
 | 
			
		||||
    U32 = 1 << 8,
 | 
			
		||||
    U64 = 1 << 9,
 | 
			
		||||
    ZSCO = 1 << 10,
 | 
			
		||||
};
 | 
			
		||||
DECLARE_ENUM_FLAG_OPERATORS(Type)
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -91,26 +91,41 @@ IR::Attribute Value::Attribute() const {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
bool Value::U1() const {
 | 
			
		||||
    if (IsIdentity()) {
 | 
			
		||||
        return inst->Arg(0).U1();
 | 
			
		||||
    }
 | 
			
		||||
    ValidateAccess(Type::U1);
 | 
			
		||||
    return imm_u1;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
u8 Value::U8() const {
 | 
			
		||||
    if (IsIdentity()) {
 | 
			
		||||
        return inst->Arg(0).U8();
 | 
			
		||||
    }
 | 
			
		||||
    ValidateAccess(Type::U8);
 | 
			
		||||
    return imm_u8;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
u16 Value::U16() const {
 | 
			
		||||
    if (IsIdentity()) {
 | 
			
		||||
        return inst->Arg(0).U16();
 | 
			
		||||
    }
 | 
			
		||||
    ValidateAccess(Type::U16);
 | 
			
		||||
    return imm_u16;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
u32 Value::U32() const {
 | 
			
		||||
    if (IsIdentity()) {
 | 
			
		||||
        return inst->Arg(0).U32();
 | 
			
		||||
    }
 | 
			
		||||
    ValidateAccess(Type::U32);
 | 
			
		||||
    return imm_u32;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
u64 Value::U64() const {
 | 
			
		||||
    if (IsIdentity()) {
 | 
			
		||||
        return inst->Arg(0).U64();
 | 
			
		||||
    }
 | 
			
		||||
    ValidateAccess(Type::U64);
 | 
			
		||||
    return imm_u64;
 | 
			
		||||
}
 | 
			
		||||
@@ -142,8 +157,6 @@ bool Value::operator==(const Value& other) const {
 | 
			
		||||
        return imm_u32 == other.imm_u32;
 | 
			
		||||
    case Type::U64:
 | 
			
		||||
        return imm_u64 == other.imm_u64;
 | 
			
		||||
    case Type::ZSCO:
 | 
			
		||||
        throw NotImplementedException("ZSCO comparison");
 | 
			
		||||
    }
 | 
			
		||||
    throw LogicError("Invalid type {}", type);
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -96,6 +96,5 @@ using U64 = TypedValue<Type::U64>;
 | 
			
		||||
using U32U64 = TypedValue<Type::U32 | Type::U64>;
 | 
			
		||||
using U16U32U64 = TypedValue<Type::U16 | Type::U32 | Type::U64>;
 | 
			
		||||
using UAny = TypedValue<Type::U8 | Type::U16 | Type::U32 | Type::U64>;
 | 
			
		||||
using ZSCO = TypedValue<Type::ZSCO>;
 | 
			
		||||
 | 
			
		||||
} // namespace Shader::IR
 | 
			
		||||
 
 | 
			
		||||
@@ -52,9 +52,11 @@ Program::Program(Environment& env, const Flow::CFG& cfg) {
 | 
			
		||||
    }
 | 
			
		||||
    std::ranges::for_each(functions, Optimization::SsaRewritePass);
 | 
			
		||||
    for (IR::Function& function : functions) {
 | 
			
		||||
        Optimization::Invoke(Optimization::GlobalMemoryToStorageBufferPass, function);
 | 
			
		||||
        Optimization::Invoke(Optimization::ConstantPropagationPass, function);
 | 
			
		||||
        Optimization::Invoke(Optimization::DeadCodeEliminationPass, function);
 | 
			
		||||
        Optimization::Invoke(Optimization::IdentityRemovalPass, function);
 | 
			
		||||
        // Optimization::Invoke(Optimization::VerificationPass, function);
 | 
			
		||||
        Optimization::IdentityRemovalPass(function);
 | 
			
		||||
        Optimization::VerificationPass(function);
 | 
			
		||||
    }
 | 
			
		||||
    //*/
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										146
									
								
								src/shader_recompiler/ir_opt/constant_propagation_pass.cpp
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										146
									
								
								src/shader_recompiler/ir_opt/constant_propagation_pass.cpp
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,146 @@
 | 
			
		||||
// Copyright 2021 yuzu Emulator Project
 | 
			
		||||
// Licensed under GPLv2 or any later version
 | 
			
		||||
// Refer to the license.txt file included.
 | 
			
		||||
 | 
			
		||||
#include <algorithm>
 | 
			
		||||
#include <type_traits>
 | 
			
		||||
 | 
			
		||||
#include "common/bit_util.h"
 | 
			
		||||
#include "shader_recompiler/exception.h"
 | 
			
		||||
#include "shader_recompiler/frontend/ir/microinstruction.h"
 | 
			
		||||
#include "shader_recompiler/ir_opt/passes.h"
 | 
			
		||||
 | 
			
		||||
namespace Shader::Optimization {
 | 
			
		||||
namespace {
 | 
			
		||||
[[nodiscard]] u32 BitFieldUExtract(u32 base, u32 shift, u32 count) {
 | 
			
		||||
    if (static_cast<size_t>(shift) + static_cast<size_t>(count) > Common::BitSize<u32>()) {
 | 
			
		||||
        throw LogicError("Undefined result in BitFieldUExtract({}, {}, {})", base, shift, count);
 | 
			
		||||
    }
 | 
			
		||||
    return (base >> shift) & ((1U << count) - 1);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <typename T>
 | 
			
		||||
[[nodiscard]] T Arg(const IR::Value& value) {
 | 
			
		||||
    if constexpr (std::is_same_v<T, bool>) {
 | 
			
		||||
        return value.U1();
 | 
			
		||||
    } else if constexpr (std::is_same_v<T, u32>) {
 | 
			
		||||
        return value.U32();
 | 
			
		||||
    } else if constexpr (std::is_same_v<T, u64>) {
 | 
			
		||||
        return value.U64();
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <typename ImmFn>
 | 
			
		||||
bool FoldCommutative(IR::Inst& inst, ImmFn&& imm_fn) {
 | 
			
		||||
    const auto arg = [](const IR::Value& value) {
 | 
			
		||||
        if constexpr (std::is_invocable_r_v<bool, ImmFn, bool, bool>) {
 | 
			
		||||
            return value.U1();
 | 
			
		||||
        } else if constexpr (std::is_invocable_r_v<u32, ImmFn, u32, u32>) {
 | 
			
		||||
            return value.U32();
 | 
			
		||||
        } else if constexpr (std::is_invocable_r_v<u64, ImmFn, u64, u64>) {
 | 
			
		||||
            return value.U64();
 | 
			
		||||
        }
 | 
			
		||||
    };
 | 
			
		||||
    const IR::Value lhs{inst.Arg(0)};
 | 
			
		||||
    const IR::Value rhs{inst.Arg(1)};
 | 
			
		||||
 | 
			
		||||
    const bool is_lhs_immediate{lhs.IsImmediate()};
 | 
			
		||||
    const bool is_rhs_immediate{rhs.IsImmediate()};
 | 
			
		||||
 | 
			
		||||
    if (is_lhs_immediate && is_rhs_immediate) {
 | 
			
		||||
        const auto result{imm_fn(arg(lhs), arg(rhs))};
 | 
			
		||||
        inst.ReplaceUsesWith(IR::Value{result});
 | 
			
		||||
        return false;
 | 
			
		||||
    }
 | 
			
		||||
    if (is_lhs_immediate && !is_rhs_immediate) {
 | 
			
		||||
        IR::Inst* const rhs_inst{rhs.InstRecursive()};
 | 
			
		||||
        if (rhs_inst->Opcode() == inst.Opcode() && rhs_inst->Arg(1).IsImmediate()) {
 | 
			
		||||
            const auto combined{imm_fn(arg(lhs), arg(rhs_inst->Arg(1)))};
 | 
			
		||||
            inst.SetArg(0, rhs_inst->Arg(0));
 | 
			
		||||
            inst.SetArg(1, IR::Value{combined});
 | 
			
		||||
        } else {
 | 
			
		||||
            // Normalize
 | 
			
		||||
            inst.SetArg(0, rhs);
 | 
			
		||||
            inst.SetArg(1, lhs);
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
    if (!is_lhs_immediate && is_rhs_immediate) {
 | 
			
		||||
        const IR::Inst* const lhs_inst{lhs.InstRecursive()};
 | 
			
		||||
        if (lhs_inst->Opcode() == inst.Opcode() && lhs_inst->Arg(1).IsImmediate()) {
 | 
			
		||||
            const auto combined{imm_fn(arg(rhs), arg(lhs_inst->Arg(1)))};
 | 
			
		||||
            inst.SetArg(0, lhs_inst->Arg(0));
 | 
			
		||||
            inst.SetArg(1, IR::Value{combined});
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
    return true;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void FoldGetRegister(IR::Inst& inst) {
 | 
			
		||||
    if (inst.Arg(0).Reg() == IR::Reg::RZ) {
 | 
			
		||||
        inst.ReplaceUsesWith(IR::Value{u32{0}});
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void FoldGetPred(IR::Inst& inst) {
 | 
			
		||||
    if (inst.Arg(0).Pred() == IR::Pred::PT) {
 | 
			
		||||
        inst.ReplaceUsesWith(IR::Value{true});
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <typename T>
 | 
			
		||||
void FoldAdd(IR::Inst& inst) {
 | 
			
		||||
    if (inst.HasAssociatedPseudoOperation()) {
 | 
			
		||||
        return;
 | 
			
		||||
    }
 | 
			
		||||
    if (!FoldCommutative(inst, [](T a, T b) { return a + b; })) {
 | 
			
		||||
        return;
 | 
			
		||||
    }
 | 
			
		||||
    const IR::Value rhs{inst.Arg(1)};
 | 
			
		||||
    if (rhs.IsImmediate() && Arg<T>(rhs) == 0) {
 | 
			
		||||
        inst.ReplaceUsesWith(inst.Arg(0));
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void FoldLogicalAnd(IR::Inst& inst) {
 | 
			
		||||
    if (!FoldCommutative(inst, [](bool a, bool b) { return a && b; })) {
 | 
			
		||||
        return;
 | 
			
		||||
    }
 | 
			
		||||
    const IR::Value rhs{inst.Arg(1)};
 | 
			
		||||
    if (rhs.IsImmediate()) {
 | 
			
		||||
        if (rhs.U1()) {
 | 
			
		||||
            inst.ReplaceUsesWith(inst.Arg(0));
 | 
			
		||||
        } else {
 | 
			
		||||
            inst.ReplaceUsesWith(IR::Value{false});
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void ConstantPropagation(IR::Inst& inst) {
 | 
			
		||||
    switch (inst.Opcode()) {
 | 
			
		||||
    case IR::Opcode::GetRegister:
 | 
			
		||||
        return FoldGetRegister(inst);
 | 
			
		||||
    case IR::Opcode::GetPred:
 | 
			
		||||
        return FoldGetPred(inst);
 | 
			
		||||
    case IR::Opcode::IAdd32:
 | 
			
		||||
        return FoldAdd<u32>(inst);
 | 
			
		||||
    case IR::Opcode::IAdd64:
 | 
			
		||||
        return FoldAdd<u64>(inst);
 | 
			
		||||
    case IR::Opcode::BitFieldUExtract:
 | 
			
		||||
        if (inst.AreAllArgsImmediates() && !inst.HasAssociatedPseudoOperation()) {
 | 
			
		||||
            inst.ReplaceUsesWith(IR::Value{
 | 
			
		||||
                BitFieldUExtract(inst.Arg(0).U32(), inst.Arg(1).U32(), inst.Arg(2).U32())});
 | 
			
		||||
        }
 | 
			
		||||
        break;
 | 
			
		||||
    case IR::Opcode::LogicalAnd:
 | 
			
		||||
        return FoldLogicalAnd(inst);
 | 
			
		||||
    default:
 | 
			
		||||
        break;
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
} // Anonymous namespace
 | 
			
		||||
 | 
			
		||||
void ConstantPropagationPass(IR::Block& block) {
 | 
			
		||||
    std::ranges::for_each(block, ConstantPropagation);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
} // namespace Shader::Optimization
 | 
			
		||||
@@ -0,0 +1,331 @@
 | 
			
		||||
// Copyright 2021 yuzu Emulator Project
 | 
			
		||||
// Licensed under GPLv2 or any later version
 | 
			
		||||
// Refer to the license.txt file included.
 | 
			
		||||
 | 
			
		||||
#include <algorithm>
 | 
			
		||||
#include <compare>
 | 
			
		||||
#include <optional>
 | 
			
		||||
#include <ranges>
 | 
			
		||||
 | 
			
		||||
#include <boost/container/flat_set.hpp>
 | 
			
		||||
#include <boost/container/small_vector.hpp>
 | 
			
		||||
 | 
			
		||||
#include "shader_recompiler/frontend/ir/basic_block.h"
 | 
			
		||||
#include "shader_recompiler/frontend/ir/ir_emitter.h"
 | 
			
		||||
#include "shader_recompiler/frontend/ir/microinstruction.h"
 | 
			
		||||
#include "shader_recompiler/ir_opt/passes.h"
 | 
			
		||||
 | 
			
		||||
namespace Shader::Optimization {
 | 
			
		||||
namespace {
 | 
			
		||||
/// Address in constant buffers to the storage buffer descriptor
 | 
			
		||||
struct StorageBufferAddr {
 | 
			
		||||
    auto operator<=>(const StorageBufferAddr&) const noexcept = default;
 | 
			
		||||
 | 
			
		||||
    u32 index;
 | 
			
		||||
    u32 offset;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
/// Block iterator to a global memory instruction and the storage buffer it uses
 | 
			
		||||
struct StorageInst {
 | 
			
		||||
    StorageBufferAddr storage_buffer;
 | 
			
		||||
    IR::Block::iterator inst;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
/// Bias towards a certain range of constant buffers when looking for storage buffers
 | 
			
		||||
struct Bias {
 | 
			
		||||
    u32 index;
 | 
			
		||||
    u32 offset_begin;
 | 
			
		||||
    u32 offset_end;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
using StorageBufferSet =
 | 
			
		||||
    boost::container::flat_set<StorageBufferAddr, std::less<StorageBufferAddr>,
 | 
			
		||||
                               boost::container::small_vector<StorageBufferAddr, 16>>;
 | 
			
		||||
using StorageInstVector = boost::container::small_vector<StorageInst, 32>;
 | 
			
		||||
 | 
			
		||||
/// Returns true when the instruction is a global memory instruction
 | 
			
		||||
bool IsGlobalMemory(const IR::Inst& inst) {
 | 
			
		||||
    switch (inst.Opcode()) {
 | 
			
		||||
    case IR::Opcode::LoadGlobalS8:
 | 
			
		||||
    case IR::Opcode::LoadGlobalU8:
 | 
			
		||||
    case IR::Opcode::LoadGlobalS16:
 | 
			
		||||
    case IR::Opcode::LoadGlobalU16:
 | 
			
		||||
    case IR::Opcode::LoadGlobal32:
 | 
			
		||||
    case IR::Opcode::LoadGlobal64:
 | 
			
		||||
    case IR::Opcode::LoadGlobal128:
 | 
			
		||||
    case IR::Opcode::WriteGlobalS8:
 | 
			
		||||
    case IR::Opcode::WriteGlobalU8:
 | 
			
		||||
    case IR::Opcode::WriteGlobalS16:
 | 
			
		||||
    case IR::Opcode::WriteGlobalU16:
 | 
			
		||||
    case IR::Opcode::WriteGlobal32:
 | 
			
		||||
    case IR::Opcode::WriteGlobal64:
 | 
			
		||||
    case IR::Opcode::WriteGlobal128:
 | 
			
		||||
        return true;
 | 
			
		||||
    default:
 | 
			
		||||
        return false;
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/// Converts a global memory opcode to its storage buffer equivalent
 | 
			
		||||
IR::Opcode GlobalToStorage(IR::Opcode opcode) {
 | 
			
		||||
    switch (opcode) {
 | 
			
		||||
    case IR::Opcode::LoadGlobalS8:
 | 
			
		||||
        return IR::Opcode::LoadStorageS8;
 | 
			
		||||
    case IR::Opcode::LoadGlobalU8:
 | 
			
		||||
        return IR::Opcode::LoadStorageU8;
 | 
			
		||||
    case IR::Opcode::LoadGlobalS16:
 | 
			
		||||
        return IR::Opcode::LoadStorageS16;
 | 
			
		||||
    case IR::Opcode::LoadGlobalU16:
 | 
			
		||||
        return IR::Opcode::LoadStorageU16;
 | 
			
		||||
    case IR::Opcode::LoadGlobal32:
 | 
			
		||||
        return IR::Opcode::LoadStorage32;
 | 
			
		||||
    case IR::Opcode::LoadGlobal64:
 | 
			
		||||
        return IR::Opcode::LoadStorage64;
 | 
			
		||||
    case IR::Opcode::LoadGlobal128:
 | 
			
		||||
        return IR::Opcode::LoadStorage128;
 | 
			
		||||
    case IR::Opcode::WriteGlobalS8:
 | 
			
		||||
        return IR::Opcode::WriteStorageS8;
 | 
			
		||||
    case IR::Opcode::WriteGlobalU8:
 | 
			
		||||
        return IR::Opcode::WriteStorageU8;
 | 
			
		||||
    case IR::Opcode::WriteGlobalS16:
 | 
			
		||||
        return IR::Opcode::WriteStorageS16;
 | 
			
		||||
    case IR::Opcode::WriteGlobalU16:
 | 
			
		||||
        return IR::Opcode::WriteStorageU16;
 | 
			
		||||
    case IR::Opcode::WriteGlobal32:
 | 
			
		||||
        return IR::Opcode::WriteStorage32;
 | 
			
		||||
    case IR::Opcode::WriteGlobal64:
 | 
			
		||||
        return IR::Opcode::WriteStorage64;
 | 
			
		||||
    case IR::Opcode::WriteGlobal128:
 | 
			
		||||
        return IR::Opcode::WriteStorage128;
 | 
			
		||||
    default:
 | 
			
		||||
        throw InvalidArgument("Invalid global memory opcode {}", opcode);
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/// Returns true when a storage buffer address satisfies a bias
 | 
			
		||||
bool MeetsBias(const StorageBufferAddr& storage_buffer, const Bias& bias) noexcept {
 | 
			
		||||
    return storage_buffer.index == bias.index && storage_buffer.offset >= bias.offset_begin &&
 | 
			
		||||
           storage_buffer.offset < bias.offset_end;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/// Ignores a global memory operation, reads return zero and writes are ignored
 | 
			
		||||
void IgnoreGlobalMemory(IR::Block& block, IR::Block::iterator inst) {
 | 
			
		||||
    const IR::Value zero{u32{0}};
 | 
			
		||||
    switch (inst->Opcode()) {
 | 
			
		||||
    case IR::Opcode::LoadGlobalS8:
 | 
			
		||||
    case IR::Opcode::LoadGlobalU8:
 | 
			
		||||
    case IR::Opcode::LoadGlobalS16:
 | 
			
		||||
    case IR::Opcode::LoadGlobalU16:
 | 
			
		||||
    case IR::Opcode::LoadGlobal32:
 | 
			
		||||
        inst->ReplaceUsesWith(zero);
 | 
			
		||||
        break;
 | 
			
		||||
    case IR::Opcode::LoadGlobal64:
 | 
			
		||||
        inst->ReplaceUsesWith(
 | 
			
		||||
            IR::Value{&*block.PrependNewInst(inst, IR::Opcode::CompositeConstruct2, {zero, zero})});
 | 
			
		||||
        break;
 | 
			
		||||
    case IR::Opcode::LoadGlobal128:
 | 
			
		||||
        inst->ReplaceUsesWith(IR::Value{&*block.PrependNewInst(
 | 
			
		||||
            inst, IR::Opcode::CompositeConstruct4, {zero, zero, zero, zero})});
 | 
			
		||||
        break;
 | 
			
		||||
    case IR::Opcode::WriteGlobalS8:
 | 
			
		||||
    case IR::Opcode::WriteGlobalU8:
 | 
			
		||||
    case IR::Opcode::WriteGlobalS16:
 | 
			
		||||
    case IR::Opcode::WriteGlobalU16:
 | 
			
		||||
    case IR::Opcode::WriteGlobal32:
 | 
			
		||||
    case IR::Opcode::WriteGlobal64:
 | 
			
		||||
    case IR::Opcode::WriteGlobal128:
 | 
			
		||||
        inst->Invalidate();
 | 
			
		||||
        break;
 | 
			
		||||
    default:
 | 
			
		||||
        throw LogicError("Invalid opcode to ignore its global memory operation {}", inst->Opcode());
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/// Recursively tries to track the storage buffer address used by a global memory instruction
 | 
			
		||||
std::optional<StorageBufferAddr> Track(const IR::Value& value, const Bias* bias) {
 | 
			
		||||
    if (value.IsImmediate()) {
 | 
			
		||||
        // Immediates can't be a storage buffer
 | 
			
		||||
        return std::nullopt;
 | 
			
		||||
    }
 | 
			
		||||
    const IR::Inst* const inst{value.InstRecursive()};
 | 
			
		||||
    if (inst->Opcode() == IR::Opcode::GetCbuf) {
 | 
			
		||||
        const IR::Value index{inst->Arg(0)};
 | 
			
		||||
        const IR::Value offset{inst->Arg(1)};
 | 
			
		||||
        if (!index.IsImmediate()) {
 | 
			
		||||
            // Definitely not a storage buffer if it's read from a non-immediate index
 | 
			
		||||
            return std::nullopt;
 | 
			
		||||
        }
 | 
			
		||||
        if (!offset.IsImmediate()) {
 | 
			
		||||
            // TODO: Support SSBO arrays
 | 
			
		||||
            return std::nullopt;
 | 
			
		||||
        }
 | 
			
		||||
        const StorageBufferAddr storage_buffer{
 | 
			
		||||
            .index = index.U32(),
 | 
			
		||||
            .offset = offset.U32(),
 | 
			
		||||
        };
 | 
			
		||||
        if (bias && !MeetsBias(storage_buffer, *bias)) {
 | 
			
		||||
            // We have to blacklist some addresses in case we wrongly point to them
 | 
			
		||||
            return std::nullopt;
 | 
			
		||||
        }
 | 
			
		||||
        return storage_buffer;
 | 
			
		||||
    }
 | 
			
		||||
    // Reversed loops are more likely to find the right result
 | 
			
		||||
    for (size_t arg = inst->NumArgs(); arg--;) {
 | 
			
		||||
        if (const std::optional storage_buffer{Track(inst->Arg(arg), bias)}) {
 | 
			
		||||
            return *storage_buffer;
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
    return std::nullopt;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/// Collects the storage buffer used by a global memory instruction and the instruction itself
 | 
			
		||||
void CollectStorageBuffers(IR::Block& block, IR::Block::iterator inst,
 | 
			
		||||
                           StorageBufferSet& storage_buffer_set, StorageInstVector& to_replace) {
 | 
			
		||||
    // NVN puts storage buffers in a specific range, we have to bias towards these addresses to
 | 
			
		||||
    // avoid getting false positives
 | 
			
		||||
    static constexpr Bias nvn_bias{
 | 
			
		||||
        .index{0},
 | 
			
		||||
        .offset_begin{0x110},
 | 
			
		||||
        .offset_end{0x610},
 | 
			
		||||
    };
 | 
			
		||||
    // First try to find storage buffers in the NVN address
 | 
			
		||||
    const IR::U64 addr{inst->Arg(0)};
 | 
			
		||||
    std::optional<StorageBufferAddr> storage_buffer{Track(addr, &nvn_bias)};
 | 
			
		||||
    if (!storage_buffer) {
 | 
			
		||||
        // If it fails, track without a bias
 | 
			
		||||
        storage_buffer = Track(addr, nullptr);
 | 
			
		||||
        if (!storage_buffer) {
 | 
			
		||||
            // If that also failed, drop the global memory usage
 | 
			
		||||
            IgnoreGlobalMemory(block, inst);
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
    // Collect storage buffer and the instruction
 | 
			
		||||
    storage_buffer_set.insert(*storage_buffer);
 | 
			
		||||
    to_replace.push_back(StorageInst{
 | 
			
		||||
        .storage_buffer{*storage_buffer},
 | 
			
		||||
        .inst{inst},
 | 
			
		||||
    });
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/// Tries to track the first 32-bits of a global memory instruction
 | 
			
		||||
std::optional<IR::U32> TrackLowAddress(IR::IREmitter& ir, IR::Inst* inst) {
 | 
			
		||||
    // The first argument is the low level GPU pointer to the global memory instruction
 | 
			
		||||
    const IR::U64 addr{inst->Arg(0)};
 | 
			
		||||
    if (addr.IsImmediate()) {
 | 
			
		||||
        // Not much we can do if it's an immediate
 | 
			
		||||
        return std::nullopt;
 | 
			
		||||
    }
 | 
			
		||||
    // This address is expected to either be a PackUint2x32 or a IAdd64
 | 
			
		||||
    IR::Inst* addr_inst{addr.InstRecursive()};
 | 
			
		||||
    s32 imm_offset{0};
 | 
			
		||||
    if (addr_inst->Opcode() == IR::Opcode::IAdd64) {
 | 
			
		||||
        // If it's an IAdd64, get the immediate offset it is applying and grab the address
 | 
			
		||||
        // instruction. This expects for the instruction to be canonicalized having the address on
 | 
			
		||||
        // the first argument and the immediate offset on the second one.
 | 
			
		||||
        const IR::U64 imm_offset_value{addr_inst->Arg(1)};
 | 
			
		||||
        if (!imm_offset_value.IsImmediate()) {
 | 
			
		||||
            return std::nullopt;
 | 
			
		||||
        }
 | 
			
		||||
        imm_offset = static_cast<s32>(static_cast<s64>(imm_offset_value.U64()));
 | 
			
		||||
        const IR::U64 iadd_addr{addr_inst->Arg(0)};
 | 
			
		||||
        if (iadd_addr.IsImmediate()) {
 | 
			
		||||
            return std::nullopt;
 | 
			
		||||
        }
 | 
			
		||||
        addr_inst = iadd_addr.Inst();
 | 
			
		||||
    }
 | 
			
		||||
    // With IAdd64 handled, now PackUint2x32 is expected without exceptions
 | 
			
		||||
    if (addr_inst->Opcode() != IR::Opcode::PackUint2x32) {
 | 
			
		||||
        return std::nullopt;
 | 
			
		||||
    }
 | 
			
		||||
    // PackUint2x32 is expected to be generated from a vector
 | 
			
		||||
    const IR::Value vector{addr_inst->Arg(0)};
 | 
			
		||||
    if (vector.IsImmediate()) {
 | 
			
		||||
        return std::nullopt;
 | 
			
		||||
    }
 | 
			
		||||
    // This vector is expected to be a CompositeConstruct2
 | 
			
		||||
    IR::Inst* const vector_inst{vector.InstRecursive()};
 | 
			
		||||
    if (vector_inst->Opcode() != IR::Opcode::CompositeConstruct2) {
 | 
			
		||||
        return std::nullopt;
 | 
			
		||||
    }
 | 
			
		||||
    // Grab the first argument from the CompositeConstruct2, this is the low address.
 | 
			
		||||
    // Re-apply the offset in case we found one.
 | 
			
		||||
    const IR::U32 low_addr{vector_inst->Arg(0)};
 | 
			
		||||
    return imm_offset != 0 ? IR::U32{ir.IAdd(low_addr, ir.Imm32(imm_offset))} : low_addr;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/// Returns the offset in indices (not bytes) for an equivalent storage instruction
 | 
			
		||||
IR::U32 StorageOffset(IR::Block& block, IR::Block::iterator inst, StorageBufferAddr buffer) {
 | 
			
		||||
    IR::IREmitter ir{block, inst};
 | 
			
		||||
    IR::U32 offset;
 | 
			
		||||
    if (const std::optional<IR::U32> low_addr{TrackLowAddress(ir, &*inst)}) {
 | 
			
		||||
        offset = *low_addr;
 | 
			
		||||
    } else {
 | 
			
		||||
        offset = ir.ConvertU(32, IR::U64{inst->Arg(0)});
 | 
			
		||||
    }
 | 
			
		||||
    // Subtract the least significant 32 bits from the guest offset. The result is the storage
 | 
			
		||||
    // buffer offset in bytes.
 | 
			
		||||
    const IR::U32 low_cbuf{ir.GetCbuf(ir.Imm32(buffer.index), ir.Imm32(buffer.offset))};
 | 
			
		||||
    return ir.ISub(offset, low_cbuf);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/// Replace a global memory load instruction with its storage buffer equivalent
 | 
			
		||||
void ReplaceLoad(IR::Block& block, IR::Block::iterator inst, const IR::U32& storage_index,
 | 
			
		||||
                 const IR::U32& offset) {
 | 
			
		||||
    const IR::Opcode new_opcode{GlobalToStorage(inst->Opcode())};
 | 
			
		||||
    const IR::Value value{&*block.PrependNewInst(inst, new_opcode, {storage_index, offset})};
 | 
			
		||||
    inst->ReplaceUsesWith(value);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/// Replace a global memory write instruction with its storage buffer equivalent
 | 
			
		||||
void ReplaceWrite(IR::Block& block, IR::Block::iterator inst, const IR::U32& storage_index,
 | 
			
		||||
                  const IR::U32& offset) {
 | 
			
		||||
    const IR::Opcode new_opcode{GlobalToStorage(inst->Opcode())};
 | 
			
		||||
    block.PrependNewInst(inst, new_opcode, {storage_index, offset, inst->Arg(1)});
 | 
			
		||||
    inst->Invalidate();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/// Replace a global memory instruction with its storage buffer equivalent
 | 
			
		||||
void Replace(IR::Block& block, IR::Block::iterator inst, const IR::U32& storage_index,
 | 
			
		||||
             const IR::U32& offset) {
 | 
			
		||||
    switch (inst->Opcode()) {
 | 
			
		||||
    case IR::Opcode::LoadGlobalS8:
 | 
			
		||||
    case IR::Opcode::LoadGlobalU8:
 | 
			
		||||
    case IR::Opcode::LoadGlobalS16:
 | 
			
		||||
    case IR::Opcode::LoadGlobalU16:
 | 
			
		||||
    case IR::Opcode::LoadGlobal32:
 | 
			
		||||
    case IR::Opcode::LoadGlobal64:
 | 
			
		||||
    case IR::Opcode::LoadGlobal128:
 | 
			
		||||
        return ReplaceLoad(block, inst, storage_index, offset);
 | 
			
		||||
    case IR::Opcode::WriteGlobalS8:
 | 
			
		||||
    case IR::Opcode::WriteGlobalU8:
 | 
			
		||||
    case IR::Opcode::WriteGlobalS16:
 | 
			
		||||
    case IR::Opcode::WriteGlobalU16:
 | 
			
		||||
    case IR::Opcode::WriteGlobal32:
 | 
			
		||||
    case IR::Opcode::WriteGlobal64:
 | 
			
		||||
    case IR::Opcode::WriteGlobal128:
 | 
			
		||||
        return ReplaceWrite(block, inst, storage_index, offset);
 | 
			
		||||
    default:
 | 
			
		||||
        throw InvalidArgument("Invalid global memory opcode {}", inst->Opcode());
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
} // Anonymous namespace
 | 
			
		||||
 | 
			
		||||
void GlobalMemoryToStorageBufferPass(IR::Block& block) {
 | 
			
		||||
    StorageBufferSet storage_buffers;
 | 
			
		||||
    StorageInstVector to_replace;
 | 
			
		||||
 | 
			
		||||
    for (IR::Block::iterator inst{block.begin()}; inst != block.end(); ++inst) {
 | 
			
		||||
        if (!IsGlobalMemory(*inst)) {
 | 
			
		||||
            continue;
 | 
			
		||||
        }
 | 
			
		||||
        CollectStorageBuffers(block, inst, storage_buffers, to_replace);
 | 
			
		||||
    }
 | 
			
		||||
    for (const auto [storage_buffer, inst] : to_replace) {
 | 
			
		||||
        const auto it{storage_buffers.find(storage_buffer)};
 | 
			
		||||
        const IR::U32 storage_index{IR::Value{static_cast<u32>(storage_buffers.index_of(it))}};
 | 
			
		||||
        const IR::U32 offset{StorageOffset(block, inst, storage_buffer)};
 | 
			
		||||
        Replace(block, inst, storage_index, offset);
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
} // namespace Shader::Optimization
 | 
			
		||||
@@ -10,22 +10,24 @@
 | 
			
		||||
 | 
			
		||||
namespace Shader::Optimization {
 | 
			
		||||
 | 
			
		||||
void IdentityRemovalPass(IR::Block& block) {
 | 
			
		||||
void IdentityRemovalPass(IR::Function& function) {
 | 
			
		||||
    std::vector<IR::Inst*> to_invalidate;
 | 
			
		||||
 | 
			
		||||
    for (auto inst = block.begin(); inst != block.end();) {
 | 
			
		||||
        const size_t num_args{inst->NumArgs()};
 | 
			
		||||
        for (size_t i = 0; i < num_args; ++i) {
 | 
			
		||||
            IR::Value arg;
 | 
			
		||||
            while ((arg = inst->Arg(i)).IsIdentity()) {
 | 
			
		||||
                inst->SetArg(i, arg.Inst()->Arg(0));
 | 
			
		||||
    for (auto& block : function.blocks) {
 | 
			
		||||
        for (auto inst = block->begin(); inst != block->end();) {
 | 
			
		||||
            const size_t num_args{inst->NumArgs()};
 | 
			
		||||
            for (size_t i = 0; i < num_args; ++i) {
 | 
			
		||||
                IR::Value arg;
 | 
			
		||||
                while ((arg = inst->Arg(i)).IsIdentity()) {
 | 
			
		||||
                    inst->SetArg(i, arg.Inst()->Arg(0));
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
            if (inst->Opcode() == IR::Opcode::Identity || inst->Opcode() == IR::Opcode::Void) {
 | 
			
		||||
                to_invalidate.push_back(&*inst);
 | 
			
		||||
                inst = block->Instructions().erase(inst);
 | 
			
		||||
            } else {
 | 
			
		||||
                ++inst;
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
        if (inst->Opcode() == IR::Opcode::Identity || inst->Opcode() == IR::Opcode::Void) {
 | 
			
		||||
            to_invalidate.push_back(&*inst);
 | 
			
		||||
            inst = block.Instructions().erase(inst);
 | 
			
		||||
        } else {
 | 
			
		||||
            ++inst;
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
    for (IR::Inst* const inst : to_invalidate) {
 | 
			
		||||
 
 | 
			
		||||
@@ -16,9 +16,11 @@ void Invoke(Func&& func, IR::Function& function) {
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void ConstantPropagationPass(IR::Block& block);
 | 
			
		||||
void DeadCodeEliminationPass(IR::Block& block);
 | 
			
		||||
void IdentityRemovalPass(IR::Block& block);
 | 
			
		||||
void GlobalMemoryToStorageBufferPass(IR::Block& block);
 | 
			
		||||
void IdentityRemovalPass(IR::Function& function);
 | 
			
		||||
void SsaRewritePass(IR::Function& function);
 | 
			
		||||
void VerificationPass(const IR::Block& block);
 | 
			
		||||
void VerificationPass(const IR::Function& function);
 | 
			
		||||
 | 
			
		||||
} // namespace Shader::Optimization
 | 
			
		||||
 
 | 
			
		||||
@@ -14,8 +14,6 @@
 | 
			
		||||
//      https://link.springer.com/chapter/10.1007/978-3-642-37051-9_6
 | 
			
		||||
//
 | 
			
		||||
 | 
			
		||||
#include <map>
 | 
			
		||||
 | 
			
		||||
#include <boost/container/flat_map.hpp>
 | 
			
		||||
 | 
			
		||||
#include "shader_recompiler/frontend/ir/basic_block.h"
 | 
			
		||||
@@ -30,6 +28,12 @@ namespace Shader::Optimization {
 | 
			
		||||
namespace {
 | 
			
		||||
using ValueMap = boost::container::flat_map<IR::Block*, IR::Value, std::less<IR::Block*>>;
 | 
			
		||||
 | 
			
		||||
struct FlagTag {};
 | 
			
		||||
struct ZeroFlagTag : FlagTag {};
 | 
			
		||||
struct SignFlagTag : FlagTag {};
 | 
			
		||||
struct CarryFlagTag : FlagTag {};
 | 
			
		||||
struct OverflowFlagTag : FlagTag {};
 | 
			
		||||
 | 
			
		||||
struct DefTable {
 | 
			
		||||
    [[nodiscard]] ValueMap& operator[](IR::Reg variable) noexcept {
 | 
			
		||||
        return regs[IR::RegIndex(variable)];
 | 
			
		||||
@@ -39,8 +43,28 @@ struct DefTable {
 | 
			
		||||
        return preds[IR::PredIndex(variable)];
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    [[nodiscard]] ValueMap& operator[](ZeroFlagTag) noexcept {
 | 
			
		||||
        return zero_flag;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    [[nodiscard]] ValueMap& operator[](SignFlagTag) noexcept {
 | 
			
		||||
        return sign_flag;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    [[nodiscard]] ValueMap& operator[](CarryFlagTag) noexcept {
 | 
			
		||||
        return carry_flag;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    [[nodiscard]] ValueMap& operator[](OverflowFlagTag) noexcept {
 | 
			
		||||
        return overflow_flag;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    std::array<ValueMap, IR::NUM_USER_REGS> regs;
 | 
			
		||||
    std::array<ValueMap, IR::NUM_USER_PREDS> preds;
 | 
			
		||||
    ValueMap zero_flag;
 | 
			
		||||
    ValueMap sign_flag;
 | 
			
		||||
    ValueMap carry_flag;
 | 
			
		||||
    ValueMap overflow_flag;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
IR::Opcode UndefOpcode(IR::Reg) noexcept {
 | 
			
		||||
@@ -51,6 +75,10 @@ IR::Opcode UndefOpcode(IR::Pred) noexcept {
 | 
			
		||||
    return IR::Opcode::Undef1;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
IR::Opcode UndefOpcode(const FlagTag&) noexcept {
 | 
			
		||||
    return IR::Opcode::Undef1;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
[[nodiscard]] bool IsPhi(const IR::Inst& inst) noexcept {
 | 
			
		||||
    return inst.Opcode() == IR::Opcode::Phi;
 | 
			
		||||
}
 | 
			
		||||
@@ -135,6 +163,18 @@ void SsaRewritePass(IR::Function& function) {
 | 
			
		||||
                    pass.WriteVariable(pred, block.get(), inst.Arg(1));
 | 
			
		||||
                }
 | 
			
		||||
                break;
 | 
			
		||||
            case IR::Opcode::SetZFlag:
 | 
			
		||||
                pass.WriteVariable(ZeroFlagTag{}, block.get(), inst.Arg(0));
 | 
			
		||||
                break;
 | 
			
		||||
            case IR::Opcode::SetSFlag:
 | 
			
		||||
                pass.WriteVariable(SignFlagTag{}, block.get(), inst.Arg(0));
 | 
			
		||||
                break;
 | 
			
		||||
            case IR::Opcode::SetCFlag:
 | 
			
		||||
                pass.WriteVariable(CarryFlagTag{}, block.get(), inst.Arg(0));
 | 
			
		||||
                break;
 | 
			
		||||
            case IR::Opcode::SetOFlag:
 | 
			
		||||
                pass.WriteVariable(OverflowFlagTag{}, block.get(), inst.Arg(0));
 | 
			
		||||
                break;
 | 
			
		||||
            case IR::Opcode::GetRegister:
 | 
			
		||||
                if (const IR::Reg reg{inst.Arg(0).Reg()}; reg != IR::Reg::RZ) {
 | 
			
		||||
                    inst.ReplaceUsesWith(pass.ReadVariable(reg, block.get()));
 | 
			
		||||
@@ -145,6 +185,18 @@ void SsaRewritePass(IR::Function& function) {
 | 
			
		||||
                    inst.ReplaceUsesWith(pass.ReadVariable(pred, block.get()));
 | 
			
		||||
                }
 | 
			
		||||
                break;
 | 
			
		||||
            case IR::Opcode::GetZFlag:
 | 
			
		||||
                inst.ReplaceUsesWith(pass.ReadVariable(ZeroFlagTag{}, block.get()));
 | 
			
		||||
                break;
 | 
			
		||||
            case IR::Opcode::GetSFlag:
 | 
			
		||||
                inst.ReplaceUsesWith(pass.ReadVariable(SignFlagTag{}, block.get()));
 | 
			
		||||
                break;
 | 
			
		||||
            case IR::Opcode::GetCFlag:
 | 
			
		||||
                inst.ReplaceUsesWith(pass.ReadVariable(CarryFlagTag{}, block.get()));
 | 
			
		||||
                break;
 | 
			
		||||
            case IR::Opcode::GetOFlag:
 | 
			
		||||
                inst.ReplaceUsesWith(pass.ReadVariable(OverflowFlagTag{}, block.get()));
 | 
			
		||||
                break;
 | 
			
		||||
            default:
 | 
			
		||||
                break;
 | 
			
		||||
            }
 | 
			
		||||
 
 | 
			
		||||
@@ -11,40 +11,44 @@
 | 
			
		||||
 | 
			
		||||
namespace Shader::Optimization {
 | 
			
		||||
 | 
			
		||||
static void ValidateTypes(const IR::Block& block) {
 | 
			
		||||
    for (const IR::Inst& inst : block) {
 | 
			
		||||
        const size_t num_args{inst.NumArgs()};
 | 
			
		||||
        for (size_t i = 0; i < num_args; ++i) {
 | 
			
		||||
            const IR::Type t1{inst.Arg(i).Type()};
 | 
			
		||||
            const IR::Type t2{IR::ArgTypeOf(inst.Opcode(), i)};
 | 
			
		||||
            if (!IR::AreTypesCompatible(t1, t2)) {
 | 
			
		||||
                throw LogicError("Invalid types in block:\n{}", IR::DumpBlock(block));
 | 
			
		||||
static void ValidateTypes(const IR::Function& function) {
 | 
			
		||||
    for (const auto& block : function.blocks) {
 | 
			
		||||
        for (const IR::Inst& inst : *block) {
 | 
			
		||||
            const size_t num_args{inst.NumArgs()};
 | 
			
		||||
            for (size_t i = 0; i < num_args; ++i) {
 | 
			
		||||
                const IR::Type t1{inst.Arg(i).Type()};
 | 
			
		||||
                const IR::Type t2{IR::ArgTypeOf(inst.Opcode(), i)};
 | 
			
		||||
                if (!IR::AreTypesCompatible(t1, t2)) {
 | 
			
		||||
                    throw LogicError("Invalid types in block:\n{}", IR::DumpBlock(*block));
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static void ValidateUses(const IR::Block& block) {
 | 
			
		||||
static void ValidateUses(const IR::Function& function) {
 | 
			
		||||
    std::map<IR::Inst*, int> actual_uses;
 | 
			
		||||
    for (const IR::Inst& inst : block) {
 | 
			
		||||
        const size_t num_args{inst.NumArgs()};
 | 
			
		||||
        for (size_t i = 0; i < num_args; ++i) {
 | 
			
		||||
            const IR::Value arg{inst.Arg(i)};
 | 
			
		||||
            if (!arg.IsImmediate()) {
 | 
			
		||||
                ++actual_uses[arg.Inst()];
 | 
			
		||||
    for (const auto& block : function.blocks) {
 | 
			
		||||
        for (const IR::Inst& inst : *block) {
 | 
			
		||||
            const size_t num_args{inst.NumArgs()};
 | 
			
		||||
            for (size_t i = 0; i < num_args; ++i) {
 | 
			
		||||
                const IR::Value arg{inst.Arg(i)};
 | 
			
		||||
                if (!arg.IsImmediate()) {
 | 
			
		||||
                    ++actual_uses[arg.Inst()];
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
    for (const auto [inst, uses] : actual_uses) {
 | 
			
		||||
        if (inst->UseCount() != uses) {
 | 
			
		||||
            throw LogicError("Invalid uses in block:\n{}", IR::DumpBlock(block));
 | 
			
		||||
            throw LogicError("Invalid uses in block:" /*, IR::DumpFunction(function)*/);
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void VerificationPass(const IR::Block& block) {
 | 
			
		||||
    ValidateTypes(block);
 | 
			
		||||
    ValidateUses(block);
 | 
			
		||||
void VerificationPass(const IR::Function& function) {
 | 
			
		||||
    ValidateTypes(function);
 | 
			
		||||
    ValidateUses(function);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
} // namespace Shader::Optimization
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user