Merge pull request #3032 from ReinUsesLisp/simplify-control-flow-brx
shader/control_flow: Abstract repeated code chunks in BRX tracking
This commit is contained in:
		@@ -16,7 +16,9 @@
 | 
			
		||||
#include "video_core/shader/shader_ir.h"
 | 
			
		||||
 | 
			
		||||
namespace VideoCommon::Shader {
 | 
			
		||||
 | 
			
		||||
namespace {
 | 
			
		||||
 | 
			
		||||
using Tegra::Shader::Instruction;
 | 
			
		||||
using Tegra::Shader::OpCode;
 | 
			
		||||
 | 
			
		||||
@@ -68,15 +70,15 @@ struct CFGRebuildState {
 | 
			
		||||
    const ProgramCode& program_code;
 | 
			
		||||
    ConstBufferLocker& locker;
 | 
			
		||||
    u32 start{};
 | 
			
		||||
    std::vector<BlockInfo> block_info{};
 | 
			
		||||
    std::list<u32> inspect_queries{};
 | 
			
		||||
    std::list<Query> queries{};
 | 
			
		||||
    std::unordered_map<u32, u32> registered{};
 | 
			
		||||
    std::set<u32> labels{};
 | 
			
		||||
    std::map<u32, u32> ssy_labels{};
 | 
			
		||||
    std::map<u32, u32> pbk_labels{};
 | 
			
		||||
    std::unordered_map<u32, BlockStack> stacks{};
 | 
			
		||||
    ASTManager* manager;
 | 
			
		||||
    std::vector<BlockInfo> block_info;
 | 
			
		||||
    std::list<u32> inspect_queries;
 | 
			
		||||
    std::list<Query> queries;
 | 
			
		||||
    std::unordered_map<u32, u32> registered;
 | 
			
		||||
    std::set<u32> labels;
 | 
			
		||||
    std::map<u32, u32> ssy_labels;
 | 
			
		||||
    std::map<u32, u32> pbk_labels;
 | 
			
		||||
    std::unordered_map<u32, BlockStack> stacks;
 | 
			
		||||
    ASTManager* manager{};
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
enum class BlockCollision : u32 { None, Found, Inside };
 | 
			
		||||
@@ -109,7 +111,7 @@ BlockInfo& CreateBlockInfo(CFGRebuildState& state, u32 start, u32 end) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
Pred GetPredicate(u32 index, bool negated) {
 | 
			
		||||
    return static_cast<Pred>(index + (negated ? 8 : 0));
 | 
			
		||||
    return static_cast<Pred>(static_cast<u64>(index) + (negated ? 8ULL : 0ULL));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
@@ -136,15 +138,13 @@ struct BranchIndirectInfo {
 | 
			
		||||
    s32 relative_position{};
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
std::optional<BranchIndirectInfo> TrackBranchIndirectInfo(const CFGRebuildState& state,
 | 
			
		||||
                                                          u32 start_address, u32 current_position) {
 | 
			
		||||
    const u32 shader_start = state.start;
 | 
			
		||||
    u32 pos = current_position;
 | 
			
		||||
    BranchIndirectInfo result{};
 | 
			
		||||
    u64 track_register = 0;
 | 
			
		||||
struct BufferInfo {
 | 
			
		||||
    u32 index;
 | 
			
		||||
    u32 offset;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
    // Step 0 Get BRX Info
 | 
			
		||||
    const Instruction instr = {state.program_code[pos]};
 | 
			
		||||
std::optional<std::pair<s32, u64>> GetBRXInfo(const CFGRebuildState& state, u32& pos) {
 | 
			
		||||
    const Instruction instr = state.program_code[pos];
 | 
			
		||||
    const auto opcode = OpCode::Decode(instr);
 | 
			
		||||
    if (opcode->get().GetId() != OpCode::Id::BRX) {
 | 
			
		||||
        return std::nullopt;
 | 
			
		||||
@@ -152,86 +152,94 @@ std::optional<BranchIndirectInfo> TrackBranchIndirectInfo(const CFGRebuildState&
 | 
			
		||||
    if (instr.brx.constant_buffer != 0) {
 | 
			
		||||
        return std::nullopt;
 | 
			
		||||
    }
 | 
			
		||||
    track_register = instr.gpr8.Value();
 | 
			
		||||
    result.relative_position = instr.brx.GetBranchExtend();
 | 
			
		||||
    pos--;
 | 
			
		||||
    bool found_track = false;
 | 
			
		||||
    --pos;
 | 
			
		||||
    return std::make_pair(instr.brx.GetBranchExtend(), instr.gpr8.Value());
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
    // Step 1 Track LDC
 | 
			
		||||
    while (pos >= shader_start) {
 | 
			
		||||
        if (IsSchedInstruction(pos, shader_start)) {
 | 
			
		||||
            pos--;
 | 
			
		||||
            continue;
 | 
			
		||||
        }
 | 
			
		||||
        const Instruction instr = {state.program_code[pos]};
 | 
			
		||||
        const auto opcode = OpCode::Decode(instr);
 | 
			
		||||
        if (opcode->get().GetId() == OpCode::Id::LD_C) {
 | 
			
		||||
            if (instr.gpr0.Value() == track_register &&
 | 
			
		||||
                instr.ld_c.type.Value() == Tegra::Shader::UniformType::Single) {
 | 
			
		||||
                result.buffer = instr.cbuf36.index.Value();
 | 
			
		||||
                result.offset = static_cast<u32>(instr.cbuf36.GetOffset());
 | 
			
		||||
                track_register = instr.gpr8.Value();
 | 
			
		||||
                pos--;
 | 
			
		||||
                found_track = true;
 | 
			
		||||
                break;
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
        pos--;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    if (!found_track) {
 | 
			
		||||
        return std::nullopt;
 | 
			
		||||
    }
 | 
			
		||||
    found_track = false;
 | 
			
		||||
 | 
			
		||||
    // Step 2 Track SHL
 | 
			
		||||
    while (pos >= shader_start) {
 | 
			
		||||
        if (IsSchedInstruction(pos, shader_start)) {
 | 
			
		||||
            pos--;
 | 
			
		||||
template <typename Result, typename TestCallable, typename PackCallable>
 | 
			
		||||
// requires std::predicate<TestCallable, Instruction, const OpCode::Matcher&>
 | 
			
		||||
// requires std::invocable<PackCallable, Instruction, const OpCode::Matcher&>
 | 
			
		||||
std::optional<Result> TrackInstruction(const CFGRebuildState& state, u32& pos, TestCallable test,
 | 
			
		||||
                                       PackCallable pack) {
 | 
			
		||||
    for (; pos >= state.start; --pos) {
 | 
			
		||||
        if (IsSchedInstruction(pos, state.start)) {
 | 
			
		||||
            continue;
 | 
			
		||||
        }
 | 
			
		||||
        const Instruction instr = state.program_code[pos];
 | 
			
		||||
        const auto opcode = OpCode::Decode(instr);
 | 
			
		||||
        if (opcode->get().GetId() == OpCode::Id::SHL_IMM) {
 | 
			
		||||
            if (instr.gpr0.Value() == track_register) {
 | 
			
		||||
                track_register = instr.gpr8.Value();
 | 
			
		||||
                pos--;
 | 
			
		||||
                found_track = true;
 | 
			
		||||
                break;
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
        pos--;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    if (!found_track) {
 | 
			
		||||
        return std::nullopt;
 | 
			
		||||
    }
 | 
			
		||||
    found_track = false;
 | 
			
		||||
 | 
			
		||||
    // Step 3 Track IMNMX
 | 
			
		||||
    while (pos >= shader_start) {
 | 
			
		||||
        if (IsSchedInstruction(pos, shader_start)) {
 | 
			
		||||
            pos--;
 | 
			
		||||
        if (!opcode) {
 | 
			
		||||
            continue;
 | 
			
		||||
        }
 | 
			
		||||
        const Instruction instr = state.program_code[pos];
 | 
			
		||||
        const auto opcode = OpCode::Decode(instr);
 | 
			
		||||
        if (opcode->get().GetId() == OpCode::Id::IMNMX_IMM) {
 | 
			
		||||
            if (instr.gpr0.Value() == track_register) {
 | 
			
		||||
                track_register = instr.gpr8.Value();
 | 
			
		||||
                result.entries = instr.alu.GetSignedImm20_20() + 1;
 | 
			
		||||
                pos--;
 | 
			
		||||
                found_track = true;
 | 
			
		||||
                break;
 | 
			
		||||
            }
 | 
			
		||||
        if (test(instr, opcode->get())) {
 | 
			
		||||
            --pos;
 | 
			
		||||
            return std::make_optional(pack(instr, opcode->get()));
 | 
			
		||||
        }
 | 
			
		||||
        pos--;
 | 
			
		||||
    }
 | 
			
		||||
    return std::nullopt;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
    if (!found_track) {
 | 
			
		||||
std::optional<std::pair<BufferInfo, u64>> TrackLDC(const CFGRebuildState& state, u32& pos,
 | 
			
		||||
                                                   u64 brx_tracked_register) {
 | 
			
		||||
    return TrackInstruction<std::pair<BufferInfo, u64>>(
 | 
			
		||||
        state, pos,
 | 
			
		||||
        [brx_tracked_register](auto instr, const auto& opcode) {
 | 
			
		||||
            return opcode.GetId() == OpCode::Id::LD_C &&
 | 
			
		||||
                   instr.gpr0.Value() == brx_tracked_register &&
 | 
			
		||||
                   instr.ld_c.type.Value() == Tegra::Shader::UniformType::Single;
 | 
			
		||||
        },
 | 
			
		||||
        [](auto instr, const auto& opcode) {
 | 
			
		||||
            const BufferInfo info = {static_cast<u32>(instr.cbuf36.index.Value()),
 | 
			
		||||
                                     static_cast<u32>(instr.cbuf36.GetOffset())};
 | 
			
		||||
            return std::make_pair(info, instr.gpr8.Value());
 | 
			
		||||
        });
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
std::optional<u64> TrackSHLRegister(const CFGRebuildState& state, u32& pos,
 | 
			
		||||
                                    u64 ldc_tracked_register) {
 | 
			
		||||
    return TrackInstruction<u64>(state, pos,
 | 
			
		||||
                                 [ldc_tracked_register](auto instr, const auto& opcode) {
 | 
			
		||||
                                     return opcode.GetId() == OpCode::Id::SHL_IMM &&
 | 
			
		||||
                                            instr.gpr0.Value() == ldc_tracked_register;
 | 
			
		||||
                                 },
 | 
			
		||||
                                 [](auto instr, const auto&) { return instr.gpr8.Value(); });
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
std::optional<u32> TrackIMNMXValue(const CFGRebuildState& state, u32& pos,
 | 
			
		||||
                                   u64 shl_tracked_register) {
 | 
			
		||||
    return TrackInstruction<u32>(state, pos,
 | 
			
		||||
                                 [shl_tracked_register](auto instr, const auto& opcode) {
 | 
			
		||||
                                     return opcode.GetId() == OpCode::Id::IMNMX_IMM &&
 | 
			
		||||
                                            instr.gpr0.Value() == shl_tracked_register;
 | 
			
		||||
                                 },
 | 
			
		||||
                                 [](auto instr, const auto&) {
 | 
			
		||||
                                     return static_cast<u32>(instr.alu.GetSignedImm20_20() + 1);
 | 
			
		||||
                                 });
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
std::optional<BranchIndirectInfo> TrackBranchIndirectInfo(const CFGRebuildState& state, u32 pos) {
 | 
			
		||||
    const auto brx_info = GetBRXInfo(state, pos);
 | 
			
		||||
    if (!brx_info) {
 | 
			
		||||
        return std::nullopt;
 | 
			
		||||
    }
 | 
			
		||||
    return result;
 | 
			
		||||
    const auto [relative_position, brx_tracked_register] = *brx_info;
 | 
			
		||||
 | 
			
		||||
    const auto ldc_info = TrackLDC(state, pos, brx_tracked_register);
 | 
			
		||||
    if (!ldc_info) {
 | 
			
		||||
        return std::nullopt;
 | 
			
		||||
    }
 | 
			
		||||
    const auto [buffer_info, ldc_tracked_register] = *ldc_info;
 | 
			
		||||
 | 
			
		||||
    const auto shl_tracked_register = TrackSHLRegister(state, pos, ldc_tracked_register);
 | 
			
		||||
    if (!shl_tracked_register) {
 | 
			
		||||
        return std::nullopt;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    const auto entries = TrackIMNMXValue(state, pos, *shl_tracked_register);
 | 
			
		||||
    if (!entries) {
 | 
			
		||||
        return std::nullopt;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    return BranchIndirectInfo{buffer_info.index, buffer_info.offset, *entries, relative_position};
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
std::pair<ParseResult, ParseInfo> ParseCode(CFGRebuildState& state, u32 address) {
 | 
			
		||||
@@ -420,30 +428,30 @@ std::pair<ParseResult, ParseInfo> ParseCode(CFGRebuildState& state, u32 address)
 | 
			
		||||
            break;
 | 
			
		||||
        }
 | 
			
		||||
        case OpCode::Id::BRX: {
 | 
			
		||||
            auto tmp = TrackBranchIndirectInfo(state, address, offset);
 | 
			
		||||
            if (tmp) {
 | 
			
		||||
                auto result = *tmp;
 | 
			
		||||
                std::vector<CaseBranch> branches{};
 | 
			
		||||
                s32 pc_target = offset + result.relative_position;
 | 
			
		||||
                for (u32 i = 0; i < result.entries; i++) {
 | 
			
		||||
                    auto k = state.locker.ObtainKey(result.buffer, result.offset + i * 4);
 | 
			
		||||
                    if (!k) {
 | 
			
		||||
                        return {ParseResult::AbnormalFlow, parse_info};
 | 
			
		||||
                    }
 | 
			
		||||
                    u32 value = *k;
 | 
			
		||||
                    u32 target = static_cast<u32>((value >> 3) + pc_target);
 | 
			
		||||
                    insert_label(state, target);
 | 
			
		||||
                    branches.emplace_back(value, target);
 | 
			
		||||
                }
 | 
			
		||||
                parse_info.end_address = offset;
 | 
			
		||||
                parse_info.branch_info = MakeBranchInfo<MultiBranch>(
 | 
			
		||||
                    static_cast<u32>(instr.gpr8.Value()), std::move(branches));
 | 
			
		||||
 | 
			
		||||
                return {ParseResult::ControlCaught, parse_info};
 | 
			
		||||
            } else {
 | 
			
		||||
            const auto tmp = TrackBranchIndirectInfo(state, offset);
 | 
			
		||||
            if (!tmp) {
 | 
			
		||||
                LOG_WARNING(HW_GPU, "BRX Track Unsuccesful");
 | 
			
		||||
                return {ParseResult::AbnormalFlow, parse_info};
 | 
			
		||||
            }
 | 
			
		||||
            return {ParseResult::AbnormalFlow, parse_info};
 | 
			
		||||
 | 
			
		||||
            const auto result = *tmp;
 | 
			
		||||
            const s32 pc_target = offset + result.relative_position;
 | 
			
		||||
            std::vector<CaseBranch> branches;
 | 
			
		||||
            for (u32 i = 0; i < result.entries; i++) {
 | 
			
		||||
                auto key = state.locker.ObtainKey(result.buffer, result.offset + i * 4);
 | 
			
		||||
                if (!key) {
 | 
			
		||||
                    return {ParseResult::AbnormalFlow, parse_info};
 | 
			
		||||
                }
 | 
			
		||||
                u32 value = *key;
 | 
			
		||||
                u32 target = static_cast<u32>((value >> 3) + pc_target);
 | 
			
		||||
                insert_label(state, target);
 | 
			
		||||
                branches.emplace_back(value, target);
 | 
			
		||||
            }
 | 
			
		||||
            parse_info.end_address = offset;
 | 
			
		||||
            parse_info.branch_info = MakeBranchInfo<MultiBranch>(
 | 
			
		||||
                static_cast<u32>(instr.gpr8.Value()), std::move(branches));
 | 
			
		||||
 | 
			
		||||
            return {ParseResult::ControlCaught, parse_info};
 | 
			
		||||
        }
 | 
			
		||||
        default:
 | 
			
		||||
            break;
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user