mirror of
				https://git.suyu.dev/suyu/suyu
				synced 2025-11-04 00:49:02 -06:00 
			
		
		
		
	shader: Rework global memory tracking to use breadth-first search
This commit is contained in:
		@@ -4,9 +4,9 @@
 | 
			
		||||
 | 
			
		||||
#include <algorithm>
 | 
			
		||||
#include <compare>
 | 
			
		||||
#include <map>
 | 
			
		||||
#include <optional>
 | 
			
		||||
#include <ranges>
 | 
			
		||||
#include <queue>
 | 
			
		||||
 | 
			
		||||
#include <boost/container/flat_set.hpp>
 | 
			
		||||
#include <boost/container/small_vector.hpp>
 | 
			
		||||
@@ -40,15 +40,19 @@ struct Bias {
 | 
			
		||||
    u32 offset_end;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
using boost::container::flat_set;
 | 
			
		||||
using boost::container::small_vector;
 | 
			
		||||
using StorageBufferSet =
 | 
			
		||||
    boost::container::flat_set<StorageBufferAddr, std::less<StorageBufferAddr>,
 | 
			
		||||
                               boost::container::small_vector<StorageBufferAddr, 16>>;
 | 
			
		||||
using StorageInstVector = boost::container::small_vector<StorageInst, 24>;
 | 
			
		||||
using VisitedBlocks = boost::container::flat_set<IR::Block*, std::less<IR::Block*>,
 | 
			
		||||
                                                 boost::container::small_vector<IR::Block*, 4>>;
 | 
			
		||||
    flat_set<StorageBufferAddr, std::less<StorageBufferAddr>, small_vector<StorageBufferAddr, 16>>;
 | 
			
		||||
using StorageInstVector = small_vector<StorageInst, 24>;
 | 
			
		||||
using StorageWritesSet =
 | 
			
		||||
    boost::container::flat_set<StorageBufferAddr, std::less<StorageBufferAddr>,
 | 
			
		||||
                               boost::container::small_vector<StorageBufferAddr, 16>>;
 | 
			
		||||
    flat_set<StorageBufferAddr, std::less<StorageBufferAddr>, small_vector<StorageBufferAddr, 16>>;
 | 
			
		||||
 | 
			
		||||
struct StorageInfo {
 | 
			
		||||
    StorageBufferSet set;
 | 
			
		||||
    StorageInstVector to_replace;
 | 
			
		||||
    StorageWritesSet writes;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
/// Returns true when the instruction is a global memory instruction
 | 
			
		||||
bool IsGlobalMemory(const IR::Inst& inst) {
 | 
			
		||||
@@ -215,60 +219,72 @@ std::optional<LowAddrInfo> TrackLowAddress(IR::Inst* inst) {
 | 
			
		||||
    };
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/// Recursively tries to track the storage buffer address used by a global memory instruction
 | 
			
		||||
std::optional<StorageBufferAddr> Track(IR::Block* block, const IR::Value& value, const Bias* bias,
 | 
			
		||||
                                       VisitedBlocks& visited) {
 | 
			
		||||
    if (value.IsImmediate()) {
 | 
			
		||||
        // Immediates can't be a storage buffer
 | 
			
		||||
/// Tries to get the storage buffer out of a constant buffer read instruction
 | 
			
		||||
std::optional<StorageBufferAddr> TryGetStorageBuffer(const IR::Inst* inst, const Bias* bias) {
 | 
			
		||||
    if (inst->Opcode() != IR::Opcode::GetCbufU32) {
 | 
			
		||||
        return std::nullopt;
 | 
			
		||||
    }
 | 
			
		||||
    const IR::Inst* const inst{value.InstRecursive()};
 | 
			
		||||
    if (inst->Opcode() == IR::Opcode::GetCbufU32) {
 | 
			
		||||
        const IR::Value index{inst->Arg(0)};
 | 
			
		||||
        const IR::Value offset{inst->Arg(1)};
 | 
			
		||||
        if (!index.IsImmediate()) {
 | 
			
		||||
            // Definitely not a storage buffer if it's read from a non-immediate index
 | 
			
		||||
            return std::nullopt;
 | 
			
		||||
        }
 | 
			
		||||
        if (!offset.IsImmediate()) {
 | 
			
		||||
            // TODO: Support SSBO arrays
 | 
			
		||||
            return std::nullopt;
 | 
			
		||||
        }
 | 
			
		||||
        const StorageBufferAddr storage_buffer{
 | 
			
		||||
            .index{index.U32()},
 | 
			
		||||
            .offset{offset.U32()},
 | 
			
		||||
        };
 | 
			
		||||
        if (bias && !MeetsBias(storage_buffer, *bias)) {
 | 
			
		||||
            // We have to blacklist some addresses in case we wrongly point to them
 | 
			
		||||
            return std::nullopt;
 | 
			
		||||
        }
 | 
			
		||||
        return storage_buffer;
 | 
			
		||||
    const IR::Value index{inst->Arg(0)};
 | 
			
		||||
    const IR::Value offset{inst->Arg(1)};
 | 
			
		||||
    if (!index.IsImmediate()) {
 | 
			
		||||
        // Definitely not a storage buffer if it's read from a non-immediate index
 | 
			
		||||
        return std::nullopt;
 | 
			
		||||
    }
 | 
			
		||||
    // Reversed loops are more likely to find the right result
 | 
			
		||||
    for (size_t arg = inst->NumArgs(); arg--;) {
 | 
			
		||||
        IR::Block* inst_block{block};
 | 
			
		||||
        if (inst->Opcode() == IR::Opcode::Phi) {
 | 
			
		||||
            // If we are going through a phi node, mark the current block as visited
 | 
			
		||||
            visited.insert(block);
 | 
			
		||||
            // and skip already visited blocks to avoid looping forever
 | 
			
		||||
            IR::Block* const phi_block{inst->PhiBlock(arg)};
 | 
			
		||||
            if (visited.contains(phi_block)) {
 | 
			
		||||
                // Already visited, skip
 | 
			
		||||
    if (!offset.IsImmediate()) {
 | 
			
		||||
        // TODO: Support SSBO arrays
 | 
			
		||||
        return std::nullopt;
 | 
			
		||||
    }
 | 
			
		||||
    const StorageBufferAddr storage_buffer{
 | 
			
		||||
        .index{index.U32()},
 | 
			
		||||
        .offset{offset.U32()},
 | 
			
		||||
    };
 | 
			
		||||
    if (bias && !MeetsBias(storage_buffer, *bias)) {
 | 
			
		||||
        // We have to blacklist some addresses in case we wrongly point to them
 | 
			
		||||
        return std::nullopt;
 | 
			
		||||
    }
 | 
			
		||||
    return storage_buffer;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/// Tries to track the storage buffer address used by a global memory instruction
 | 
			
		||||
std::optional<StorageBufferAddr> Track(const IR::Value& value, const Bias* bias) {
 | 
			
		||||
    if (value.IsImmediate()) {
 | 
			
		||||
        // Nothing to do with immediates
 | 
			
		||||
        return std::nullopt;
 | 
			
		||||
    }
 | 
			
		||||
    // Breadth-first search visiting the right most arguments first
 | 
			
		||||
    // Small vector has been determined from shaders in Super Smash Bros. Ultimate
 | 
			
		||||
    small_vector<const IR::Inst*, 2> visited;
 | 
			
		||||
    std::queue<const IR::Inst*> queue;
 | 
			
		||||
    queue.push(value.InstRecursive());
 | 
			
		||||
 | 
			
		||||
    while (!queue.empty()) {
 | 
			
		||||
        // Pop one instruction from the queue
 | 
			
		||||
        const IR::Inst* const inst{queue.front()};
 | 
			
		||||
        queue.pop();
 | 
			
		||||
        if (const std::optional<StorageBufferAddr> result = TryGetStorageBuffer(inst, bias)) {
 | 
			
		||||
            // This is the instruction we were looking for
 | 
			
		||||
            return result;
 | 
			
		||||
        }
 | 
			
		||||
        // Visit the right most arguments first
 | 
			
		||||
        for (size_t arg = inst->NumArgs(); arg--;) {
 | 
			
		||||
            const IR::Value arg_value{inst->Arg(arg)};
 | 
			
		||||
            if (arg_value.IsImmediate()) {
 | 
			
		||||
                continue;
 | 
			
		||||
            }
 | 
			
		||||
            inst_block = phi_block;
 | 
			
		||||
        }
 | 
			
		||||
        const std::optional storage_buffer{Track(inst_block, inst->Arg(arg), bias, visited)};
 | 
			
		||||
        if (storage_buffer) {
 | 
			
		||||
            return *storage_buffer;
 | 
			
		||||
            // Queue instruction if it hasn't been visited
 | 
			
		||||
            const IR::Inst* const arg_inst{arg_value.InstRecursive()};
 | 
			
		||||
            if (std::ranges::find(visited, arg_inst) == visited.end()) {
 | 
			
		||||
                visited.push_back(arg_inst);
 | 
			
		||||
                queue.push(arg_inst);
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
    // SSA tree has been traversed and the origin hasn't been found
 | 
			
		||||
    return std::nullopt;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/// Collects the storage buffer used by a global memory instruction and the instruction itself
 | 
			
		||||
void CollectStorageBuffers(IR::Block& block, IR::Inst& inst, StorageBufferSet& storage_buffer_set,
 | 
			
		||||
                           StorageInstVector& to_replace, StorageWritesSet& writes_set) {
 | 
			
		||||
void CollectStorageBuffers(IR::Block& block, IR::Inst& inst, StorageInfo& info) {
 | 
			
		||||
    // NVN puts storage buffers in a specific range, we have to bias towards these addresses to
 | 
			
		||||
    // avoid getting false positives
 | 
			
		||||
    static constexpr Bias nvn_bias{
 | 
			
		||||
@@ -284,24 +300,23 @@ void CollectStorageBuffers(IR::Block& block, IR::Inst& inst, StorageBufferSet& s
 | 
			
		||||
    }
 | 
			
		||||
    // First try to find storage buffers in the NVN address
 | 
			
		||||
    const IR::U32 low_addr{low_addr_info->value};
 | 
			
		||||
    VisitedBlocks visited_blocks;
 | 
			
		||||
    std::optional storage_buffer{Track(&block, low_addr, &nvn_bias, visited_blocks)};
 | 
			
		||||
    std::optional storage_buffer{Track(low_addr, &nvn_bias)};
 | 
			
		||||
    if (!storage_buffer) {
 | 
			
		||||
        // If it fails, track without a bias
 | 
			
		||||
        visited_blocks.clear();
 | 
			
		||||
        storage_buffer = Track(&block, low_addr, nullptr, visited_blocks);
 | 
			
		||||
        storage_buffer = Track(low_addr, nullptr);
 | 
			
		||||
        if (!storage_buffer) {
 | 
			
		||||
            // If that also failed, drop the global memory usage
 | 
			
		||||
            // LOG_ERROR
 | 
			
		||||
            DiscardGlobalMemory(block, inst);
 | 
			
		||||
            return;
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
    // Collect storage buffer and the instruction
 | 
			
		||||
    if (IsGlobalMemoryWrite(inst)) {
 | 
			
		||||
        writes_set.insert(*storage_buffer);
 | 
			
		||||
        info.writes.insert(*storage_buffer);
 | 
			
		||||
    }
 | 
			
		||||
    storage_buffer_set.insert(*storage_buffer);
 | 
			
		||||
    to_replace.push_back(StorageInst{
 | 
			
		||||
    info.set.insert(*storage_buffer);
 | 
			
		||||
    info.to_replace.push_back(StorageInst{
 | 
			
		||||
        .storage_buffer{*storage_buffer},
 | 
			
		||||
        .inst{&inst},
 | 
			
		||||
        .block{&block},
 | 
			
		||||
@@ -371,33 +386,29 @@ void Replace(IR::Block& block, IR::Inst& inst, const IR::U32& storage_index,
 | 
			
		||||
} // Anonymous namespace
 | 
			
		||||
 | 
			
		||||
void GlobalMemoryToStorageBufferPass(IR::Program& program) {
 | 
			
		||||
    StorageBufferSet storage_buffers;
 | 
			
		||||
    StorageInstVector to_replace;
 | 
			
		||||
    StorageWritesSet writes_set;
 | 
			
		||||
 | 
			
		||||
    StorageInfo info;
 | 
			
		||||
    for (IR::Block* const block : program.post_order_blocks) {
 | 
			
		||||
        for (IR::Inst& inst : block->Instructions()) {
 | 
			
		||||
            if (!IsGlobalMemory(inst)) {
 | 
			
		||||
                continue;
 | 
			
		||||
            }
 | 
			
		||||
            CollectStorageBuffers(*block, inst, storage_buffers, to_replace, writes_set);
 | 
			
		||||
            CollectStorageBuffers(*block, inst, info);
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
    Info& info{program.info};
 | 
			
		||||
    u32 storage_index{};
 | 
			
		||||
    for (const StorageBufferAddr& storage_buffer : storage_buffers) {
 | 
			
		||||
        info.storage_buffers_descriptors.push_back({
 | 
			
		||||
    for (const StorageBufferAddr& storage_buffer : info.set) {
 | 
			
		||||
        program.info.storage_buffers_descriptors.push_back({
 | 
			
		||||
            .cbuf_index{storage_buffer.index},
 | 
			
		||||
            .cbuf_offset{storage_buffer.offset},
 | 
			
		||||
            .count{1},
 | 
			
		||||
            .is_written{writes_set.contains(storage_buffer)},
 | 
			
		||||
            .is_written{info.writes.contains(storage_buffer)},
 | 
			
		||||
        });
 | 
			
		||||
        ++storage_index;
 | 
			
		||||
    }
 | 
			
		||||
    for (const StorageInst& storage_inst : to_replace) {
 | 
			
		||||
    for (const StorageInst& storage_inst : info.to_replace) {
 | 
			
		||||
        const StorageBufferAddr storage_buffer{storage_inst.storage_buffer};
 | 
			
		||||
        const auto it{storage_buffers.find(storage_inst.storage_buffer)};
 | 
			
		||||
        const IR::U32 index{IR::Value{static_cast<u32>(storage_buffers.index_of(it))}};
 | 
			
		||||
        const auto it{info.set.find(storage_inst.storage_buffer)};
 | 
			
		||||
        const IR::U32 index{IR::Value{static_cast<u32>(info.set.index_of(it))}};
 | 
			
		||||
        IR::Block* const block{storage_inst.block};
 | 
			
		||||
        IR::Inst* const inst{storage_inst.inst};
 | 
			
		||||
        const IR::U32 offset{StorageOffset(*block, *inst, storage_buffer)};
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user