mirror of
				https://git.suyu.dev/suyu/suyu
				synced 2025-11-04 00:49:02 -06:00 
			
		
		
		
	shader/half_set_predicate: Fix HSETP2 implementation
This commit is contained in:
		@@ -254,10 +254,6 @@ public:
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
private:
 | 
			
		||||
    using OperationDecompilerFn = std::string (GLSLDecompiler::*)(Operation);
 | 
			
		||||
    using OperationDecompilersArray =
 | 
			
		||||
        std::array<OperationDecompilerFn, static_cast<std::size_t>(OperationCode::Amount)>;
 | 
			
		||||
 | 
			
		||||
    void DeclareVertex() {
 | 
			
		||||
        if (stage != ShaderStage::Vertex)
 | 
			
		||||
            return;
 | 
			
		||||
@@ -1400,14 +1396,10 @@ private:
 | 
			
		||||
        return fmt::format("{}[{}]", pair, VisitOperand(operation, 1, Type::Uint));
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    std::string LogicalAll2(Operation operation) {
 | 
			
		||||
    std::string LogicalAnd2(Operation operation) {
 | 
			
		||||
        return GenerateUnary(operation, "all", Type::Bool, Type::Bool2);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    std::string LogicalAny2(Operation operation) {
 | 
			
		||||
        return GenerateUnary(operation, "any", Type::Bool, Type::Bool2);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    template <bool with_nan>
 | 
			
		||||
    std::string GenerateHalfComparison(Operation operation, const std::string& compare_op) {
 | 
			
		||||
        const std::string comparison{GenerateBinaryCall(operation, compare_op, Type::Bool2,
 | 
			
		||||
@@ -1714,7 +1706,7 @@ private:
 | 
			
		||||
        return "utof(gl_WorkGroupID"s + GetSwizzle(element) + ')';
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    static constexpr OperationDecompilersArray operation_decompilers = {
 | 
			
		||||
    static constexpr std::array operation_decompilers = {
 | 
			
		||||
        &GLSLDecompiler::Assign,
 | 
			
		||||
 | 
			
		||||
        &GLSLDecompiler::Select,
 | 
			
		||||
@@ -1798,8 +1790,7 @@ private:
 | 
			
		||||
        &GLSLDecompiler::LogicalXor,
 | 
			
		||||
        &GLSLDecompiler::LogicalNegate,
 | 
			
		||||
        &GLSLDecompiler::LogicalPick2,
 | 
			
		||||
        &GLSLDecompiler::LogicalAll2,
 | 
			
		||||
        &GLSLDecompiler::LogicalAny2,
 | 
			
		||||
        &GLSLDecompiler::LogicalAnd2,
 | 
			
		||||
 | 
			
		||||
        &GLSLDecompiler::LogicalLessThan<Type::Float>,
 | 
			
		||||
        &GLSLDecompiler::LogicalEqual<Type::Float>,
 | 
			
		||||
@@ -1863,6 +1854,7 @@ private:
 | 
			
		||||
        &GLSLDecompiler::WorkGroupId<1>,
 | 
			
		||||
        &GLSLDecompiler::WorkGroupId<2>,
 | 
			
		||||
    };
 | 
			
		||||
    static_assert(operation_decompilers.size() == static_cast<std::size_t>(OperationCode::Amount));
 | 
			
		||||
 | 
			
		||||
    std::string GetRegister(u32 index) const {
 | 
			
		||||
        return GetDeclarationWithSuffix(index, "gpr");
 | 
			
		||||
 
 | 
			
		||||
@@ -205,10 +205,6 @@ public:
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
private:
 | 
			
		||||
    using OperationDecompilerFn = Id (SPIRVDecompiler::*)(Operation);
 | 
			
		||||
    using OperationDecompilersArray =
 | 
			
		||||
        std::array<OperationDecompilerFn, static_cast<std::size_t>(OperationCode::Amount)>;
 | 
			
		||||
 | 
			
		||||
    static constexpr auto INTERNAL_FLAGS_COUNT = static_cast<std::size_t>(InternalFlag::Amount);
 | 
			
		||||
 | 
			
		||||
    void AllocateBindings() {
 | 
			
		||||
@@ -804,12 +800,7 @@ private:
 | 
			
		||||
        return {};
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    Id LogicalAll2(Operation operation) {
 | 
			
		||||
        UNIMPLEMENTED();
 | 
			
		||||
        return {};
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    Id LogicalAny2(Operation operation) {
 | 
			
		||||
    Id LogicalAnd2(Operation operation) {
 | 
			
		||||
        UNIMPLEMENTED();
 | 
			
		||||
        return {};
 | 
			
		||||
    }
 | 
			
		||||
@@ -1206,7 +1197,7 @@ private:
 | 
			
		||||
        return {};
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    static constexpr OperationDecompilersArray operation_decompilers = {
 | 
			
		||||
    static constexpr std::array operation_decompilers = {
 | 
			
		||||
        &SPIRVDecompiler::Assign,
 | 
			
		||||
 | 
			
		||||
        &SPIRVDecompiler::Ternary<&Module::OpSelect, Type::Float, Type::Bool, Type::Float,
 | 
			
		||||
@@ -1291,8 +1282,7 @@ private:
 | 
			
		||||
        &SPIRVDecompiler::Binary<&Module::OpLogicalNotEqual, Type::Bool>,
 | 
			
		||||
        &SPIRVDecompiler::Unary<&Module::OpLogicalNot, Type::Bool>,
 | 
			
		||||
        &SPIRVDecompiler::LogicalPick2,
 | 
			
		||||
        &SPIRVDecompiler::LogicalAll2,
 | 
			
		||||
        &SPIRVDecompiler::LogicalAny2,
 | 
			
		||||
        &SPIRVDecompiler::LogicalAnd2,
 | 
			
		||||
 | 
			
		||||
        &SPIRVDecompiler::Binary<&Module::OpFOrdLessThan, Type::Bool, Type::Float>,
 | 
			
		||||
        &SPIRVDecompiler::Binary<&Module::OpFOrdEqual, Type::Bool, Type::Float>,
 | 
			
		||||
@@ -1357,6 +1347,7 @@ private:
 | 
			
		||||
        &SPIRVDecompiler::WorkGroupId<1>,
 | 
			
		||||
        &SPIRVDecompiler::WorkGroupId<2>,
 | 
			
		||||
    };
 | 
			
		||||
    static_assert(operation_decompilers.size() == static_cast<std::size_t>(OperationCode::Amount));
 | 
			
		||||
 | 
			
		||||
    const VKDevice& device;
 | 
			
		||||
    const ShaderIR& ir;
 | 
			
		||||
 
 | 
			
		||||
@@ -51,26 +51,23 @@ u32 ShaderIR::DecodeHalfSetPredicate(NodeBlock& bb, u32 pc) {
 | 
			
		||||
        op_b = Immediate(0);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    // We can't use the constant predicate as destination.
 | 
			
		||||
    ASSERT(instr.hsetp2.pred3 != static_cast<u64>(Pred::UnusedIndex));
 | 
			
		||||
 | 
			
		||||
    const Node second_pred = GetPredicate(instr.hsetp2.pred39, instr.hsetp2.neg_pred != 0);
 | 
			
		||||
 | 
			
		||||
    const OperationCode combiner = GetPredicateCombiner(instr.hsetp2.op);
 | 
			
		||||
    const OperationCode pair_combiner =
 | 
			
		||||
        h_and ? OperationCode::LogicalAll2 : OperationCode::LogicalAny2;
 | 
			
		||||
    const Node pred39 = GetPredicate(instr.hsetp2.pred39, instr.hsetp2.neg_pred);
 | 
			
		||||
 | 
			
		||||
    const auto Write = [&](u64 dest, Node src) {
 | 
			
		||||
        SetPredicate(bb, dest, Operation(combiner, std::move(src), pred39));
 | 
			
		||||
    };
 | 
			
		||||
 | 
			
		||||
    const Node comparison = GetPredicateComparisonHalf(cond, op_a, op_b);
 | 
			
		||||
    const Node first_pred = Operation(pair_combiner, comparison);
 | 
			
		||||
 | 
			
		||||
    // Set the primary predicate to the result of Predicate OP SecondPredicate
 | 
			
		||||
    const Node value = Operation(combiner, first_pred, second_pred);
 | 
			
		||||
    SetPredicate(bb, instr.hsetp2.pred3, value);
 | 
			
		||||
 | 
			
		||||
    if (instr.hsetp2.pred0 != static_cast<u64>(Pred::UnusedIndex)) {
 | 
			
		||||
        // Set the secondary predicate to the result of !Predicate OP SecondPredicate, if enabled
 | 
			
		||||
        const Node negated_pred = Operation(OperationCode::LogicalNegate, first_pred);
 | 
			
		||||
        SetPredicate(bb, instr.hsetp2.pred0, Operation(combiner, negated_pred, second_pred));
 | 
			
		||||
    const u64 first = instr.hsetp2.pred0;
 | 
			
		||||
    const u64 second = instr.hsetp2.pred3;
 | 
			
		||||
    if (h_and) {
 | 
			
		||||
        const Node joined = Operation(OperationCode::LogicalAnd2, comparison);
 | 
			
		||||
        Write(first, joined);
 | 
			
		||||
        Write(second, Operation(OperationCode::LogicalNegate, joined));
 | 
			
		||||
    } else {
 | 
			
		||||
        Write(first, Operation(OperationCode::LogicalPick2, comparison, Immediate(0u)));
 | 
			
		||||
        Write(second, Operation(OperationCode::LogicalPick2, comparison, Immediate(1u)));
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    return pc;
 | 
			
		||||
 
 | 
			
		||||
@@ -101,8 +101,7 @@ enum class OperationCode {
 | 
			
		||||
    LogicalXor,    /// (bool a, bool b) -> bool
 | 
			
		||||
    LogicalNegate, /// (bool a) -> bool
 | 
			
		||||
    LogicalPick2,  /// (bool2 pair, uint index) -> bool
 | 
			
		||||
    LogicalAll2,   /// (bool2 a) -> bool
 | 
			
		||||
    LogicalAny2,   /// (bool2 a) -> bool
 | 
			
		||||
    LogicalAnd2,   /// (bool2 a) -> bool
 | 
			
		||||
 | 
			
		||||
    LogicalFLessThan,     /// (float a, float b) -> bool
 | 
			
		||||
    LogicalFEqual,        /// (float a, float b) -> bool
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user