shader: Misc fixes
This commit is contained in:
		| @@ -25,6 +25,9 @@ EmitContext::EmitContext(IR::Program& program) { | ||||
|     f16.Define(*this, TypeFloat(16), "f16"); | ||||
|     f64.Define(*this, TypeFloat(64), "f64"); | ||||
|  | ||||
|     true_value = ConstantTrue(u1); | ||||
|     false_value = ConstantFalse(u1); | ||||
|  | ||||
|     for (const IR::Function& function : program.functions) { | ||||
|         for (IR::Block* const block : function.blocks) { | ||||
|             block_label_map.emplace_back(block, OpLabel()); | ||||
| @@ -58,6 +61,7 @@ EmitSPIRV::EmitSPIRV(IR::Program& program) { | ||||
|     std::fclose(file); | ||||
|     std::system("spirv-dis shader.spv"); | ||||
|     std::system("spirv-val shader.spv"); | ||||
|     std::system("spirv-cross shader.spv"); | ||||
| } | ||||
|  | ||||
| template <auto method> | ||||
| @@ -109,6 +113,8 @@ static Id TypeId(const EmitContext& ctx, IR::Type type) { | ||||
|     switch (type) { | ||||
|     case IR::Type::U1: | ||||
|         return ctx.u1; | ||||
|     case IR::Type::U32: | ||||
|         return ctx.u32[1]; | ||||
|     default: | ||||
|         throw NotImplementedException("Phi node type {}", type); | ||||
|     } | ||||
|   | ||||
| @@ -79,6 +79,8 @@ public: | ||||
|             return def_map.Consume(value.Inst()); | ||||
|         } | ||||
|         switch (value.Type()) { | ||||
|         case IR::Type::U1: | ||||
|             return value.U1() ? true_value : false_value; | ||||
|         case IR::Type::U32: | ||||
|             return Constant(u32[1], value.U32()); | ||||
|         case IR::Type::F32: | ||||
| @@ -108,6 +110,9 @@ public: | ||||
|     VectorTypes f16; | ||||
|     VectorTypes f64; | ||||
|  | ||||
|     Id true_value{}; | ||||
|     Id false_value{}; | ||||
|  | ||||
|     Id workgroup_id{}; | ||||
|     Id local_invocation_id{}; | ||||
|  | ||||
|   | ||||
| @@ -113,7 +113,7 @@ static std::string ArgToIndex(const std::map<const Block*, size_t>& block_to_ind | ||||
|     if (arg.IsLabel()) { | ||||
|         return BlockToIndex(block_to_index, arg.Label()); | ||||
|     } | ||||
|     if (!arg.IsImmediate()) { | ||||
|     if (!arg.IsImmediate() || arg.IsIdentity()) { | ||||
|         return fmt::format("%{}", InstIndex(inst_to_index, inst_index, arg.Inst())); | ||||
|     } | ||||
|     switch (arg.Type()) { | ||||
| @@ -166,7 +166,7 @@ std::string DumpBlock(const Block& block, const std::map<const Block*, size_t>& | ||||
|             const std::string arg_str{ArgToIndex(block_to_index, inst_to_index, inst_index, arg)}; | ||||
|             ret += arg_index != 0 ? ", " : " "; | ||||
|             if (op == Opcode::Phi) { | ||||
|                 ret += fmt::format("[ {}, {} ]", arg_index, | ||||
|                 ret += fmt::format("[ {}, {} ]", arg_str, | ||||
|                                    BlockToIndex(block_to_index, inst.PhiBlock(arg_index))); | ||||
|             } else { | ||||
|                 ret += arg_str; | ||||
|   | ||||
| @@ -46,10 +46,12 @@ F64 IREmitter::Imm64(f64 value) const { | ||||
|  | ||||
| void IREmitter::Branch(Block* label) { | ||||
|     label->AddImmediatePredecessor(block); | ||||
|     block->SetBranch(label); | ||||
|     Inst(Opcode::Branch, label); | ||||
| } | ||||
|  | ||||
| void IREmitter::BranchConditional(const U1& condition, Block* true_label, Block* false_label) { | ||||
|     block->SetBranches(IR::Condition{true}, true_label, false_label); | ||||
|     true_label->AddImmediatePredecessor(block); | ||||
|     false_label->AddImmediatePredecessor(block); | ||||
|     Inst(Opcode::BranchConditional, condition, true_label, false_label); | ||||
|   | ||||
| @@ -143,19 +143,21 @@ Value Inst::Arg(size_t index) const { | ||||
| } | ||||
|  | ||||
| void Inst::SetArg(size_t index, Value value) { | ||||
|     if (op == Opcode::Phi) { | ||||
|         throw LogicError("Setting argument on a phi instruction"); | ||||
|     } | ||||
|     if (index >= NumArgsOf(op)) { | ||||
|     if (index >= NumArgs()) { | ||||
|         throw InvalidArgument("Out of bounds argument index {} in opcode {}", index, op); | ||||
|     } | ||||
|     if (!args[index].IsImmediate()) { | ||||
|         UndoUse(args[index]); | ||||
|     const IR::Value arg{Arg(index)}; | ||||
|     if (!arg.IsImmediate()) { | ||||
|         UndoUse(arg); | ||||
|     } | ||||
|     if (!value.IsImmediate()) { | ||||
|         Use(value); | ||||
|     } | ||||
|     args[index] = value; | ||||
|     if (op == Opcode::Phi) { | ||||
|         phi_args[index].second = value; | ||||
|     } else { | ||||
|         args[index] = value; | ||||
|     } | ||||
| } | ||||
|  | ||||
| Block* Inst::PhiBlock(size_t index) const { | ||||
|   | ||||
| @@ -76,8 +76,8 @@ void IADD(TranslatorVisitor& v, u64 insn, IR::U32 op_b) { | ||||
| } | ||||
| } // Anonymous namespace | ||||
|  | ||||
| void TranslatorVisitor::IADD_reg(u64) { | ||||
|     throw NotImplementedException("IADD (reg)"); | ||||
| void TranslatorVisitor::IADD_reg(u64 insn) { | ||||
|     IADD(*this, insn, GetReg20(insn)); | ||||
| } | ||||
|  | ||||
| void TranslatorVisitor::IADD_cbuf(u64 insn) { | ||||
|   | ||||
| @@ -92,8 +92,8 @@ void TranslatorVisitor::ISETP_cbuf(u64 insn) { | ||||
|     ISETP(*this, insn, GetCbuf(insn)); | ||||
| } | ||||
|  | ||||
| void TranslatorVisitor::ISETP_imm(u64) { | ||||
|     throw NotImplementedException("ISETP_imm"); | ||||
| void TranslatorVisitor::ISETP_imm(u64 insn) { | ||||
|     ISETP(*this, insn, GetImm20(insn)); | ||||
| } | ||||
|  | ||||
| } // namespace Shader::Maxwell | ||||
|   | ||||
| @@ -32,6 +32,8 @@ template <typename T> | ||||
|         return value.U1(); | ||||
|     } else if constexpr (std::is_same_v<T, u32>) { | ||||
|         return value.U32(); | ||||
|     } else if constexpr (std::is_same_v<T, s32>) { | ||||
|         return static_cast<s32>(value.U32()); | ||||
|     } else if constexpr (std::is_same_v<T, f32>) { | ||||
|         return value.F32(); | ||||
|     } else if constexpr (std::is_same_v<T, u64>) { | ||||
| @@ -39,17 +41,8 @@ template <typename T> | ||||
|     } | ||||
| } | ||||
|  | ||||
| template <typename ImmFn> | ||||
| template <typename T, typename ImmFn> | ||||
| bool FoldCommutative(IR::Inst& inst, ImmFn&& imm_fn) { | ||||
|     const auto arg = [](const IR::Value& value) { | ||||
|         if constexpr (std::is_invocable_r_v<bool, ImmFn, bool, bool>) { | ||||
|             return value.U1(); | ||||
|         } else if constexpr (std::is_invocable_r_v<u32, ImmFn, u32, u32>) { | ||||
|             return value.U32(); | ||||
|         } else if constexpr (std::is_invocable_r_v<u64, ImmFn, u64, u64>) { | ||||
|             return value.U64(); | ||||
|         } | ||||
|     }; | ||||
|     const IR::Value lhs{inst.Arg(0)}; | ||||
|     const IR::Value rhs{inst.Arg(1)}; | ||||
|  | ||||
| @@ -57,14 +50,14 @@ bool FoldCommutative(IR::Inst& inst, ImmFn&& imm_fn) { | ||||
|     const bool is_rhs_immediate{rhs.IsImmediate()}; | ||||
|  | ||||
|     if (is_lhs_immediate && is_rhs_immediate) { | ||||
|         const auto result{imm_fn(arg(lhs), arg(rhs))}; | ||||
|         const auto result{imm_fn(Arg<T>(lhs), Arg<T>(rhs))}; | ||||
|         inst.ReplaceUsesWith(IR::Value{result}); | ||||
|         return false; | ||||
|     } | ||||
|     if (is_lhs_immediate && !is_rhs_immediate) { | ||||
|         IR::Inst* const rhs_inst{rhs.InstRecursive()}; | ||||
|         if (rhs_inst->Opcode() == inst.Opcode() && rhs_inst->Arg(1).IsImmediate()) { | ||||
|             const auto combined{imm_fn(arg(lhs), arg(rhs_inst->Arg(1)))}; | ||||
|             const auto combined{imm_fn(Arg<T>(lhs), Arg<T>(rhs_inst->Arg(1)))}; | ||||
|             inst.SetArg(0, rhs_inst->Arg(0)); | ||||
|             inst.SetArg(1, IR::Value{combined}); | ||||
|         } else { | ||||
| @@ -76,7 +69,7 @@ bool FoldCommutative(IR::Inst& inst, ImmFn&& imm_fn) { | ||||
|     if (!is_lhs_immediate && is_rhs_immediate) { | ||||
|         const IR::Inst* const lhs_inst{lhs.InstRecursive()}; | ||||
|         if (lhs_inst->Opcode() == inst.Opcode() && lhs_inst->Arg(1).IsImmediate()) { | ||||
|             const auto combined{imm_fn(arg(rhs), arg(lhs_inst->Arg(1)))}; | ||||
|             const auto combined{imm_fn(Arg<T>(rhs), Arg<T>(lhs_inst->Arg(1)))}; | ||||
|             inst.SetArg(0, lhs_inst->Arg(0)); | ||||
|             inst.SetArg(1, IR::Value{combined}); | ||||
|         } | ||||
| @@ -101,7 +94,7 @@ void FoldAdd(IR::Inst& inst) { | ||||
|     if (inst.HasAssociatedPseudoOperation()) { | ||||
|         return; | ||||
|     } | ||||
|     if (!FoldCommutative(inst, [](T a, T b) { return a + b; })) { | ||||
|     if (!FoldCommutative<T>(inst, [](T a, T b) { return a + b; })) { | ||||
|         return; | ||||
|     } | ||||
|     const IR::Value rhs{inst.Arg(1)}; | ||||
| @@ -119,7 +112,7 @@ void FoldSelect(IR::Inst& inst) { | ||||
| } | ||||
|  | ||||
| void FoldLogicalAnd(IR::Inst& inst) { | ||||
|     if (!FoldCommutative(inst, [](bool a, bool b) { return a && b; })) { | ||||
|     if (!FoldCommutative<bool>(inst, [](bool a, bool b) { return a && b; })) { | ||||
|         return; | ||||
|     } | ||||
|     const IR::Value rhs{inst.Arg(1)}; | ||||
| @@ -133,7 +126,7 @@ void FoldLogicalAnd(IR::Inst& inst) { | ||||
| } | ||||
|  | ||||
| void FoldLogicalOr(IR::Inst& inst) { | ||||
|     if (!FoldCommutative(inst, [](bool a, bool b) { return a || b; })) { | ||||
|     if (!FoldCommutative<bool>(inst, [](bool a, bool b) { return a || b; })) { | ||||
|         return; | ||||
|     } | ||||
|     const IR::Value rhs{inst.Arg(1)}; | ||||
| @@ -226,6 +219,8 @@ void ConstantPropagation(IR::Inst& inst) { | ||||
|         return FoldLogicalOr(inst); | ||||
|     case IR::Opcode::LogicalNot: | ||||
|         return FoldLogicalNot(inst); | ||||
|     case IR::Opcode::SLessThan: | ||||
|         return FoldWhenAllImmediates(inst, [](s32 a, s32 b) { return a < b; }); | ||||
|     case IR::Opcode::ULessThan: | ||||
|         return FoldWhenAllImmediates(inst, [](u32 a, u32 b) { return a < b; }); | ||||
|     case IR::Opcode::BitFieldUExtract: | ||||
|   | ||||
| @@ -113,6 +113,7 @@ private: | ||||
|     IR::Value ReadVariableRecursive(auto variable, IR::Block* block) { | ||||
|         IR::Value val; | ||||
|         if (const std::span preds{block->ImmediatePredecessors()}; preds.size() == 1) { | ||||
|             // Optimize the common case of one predecessor: no phi needed | ||||
|             val = ReadVariable(variable, preds.front()); | ||||
|         } else { | ||||
|             // Break potential cycles with operandless phi | ||||
| @@ -160,66 +161,70 @@ private: | ||||
|  | ||||
|     DefTable current_def; | ||||
| }; | ||||
|  | ||||
| void VisitInst(Pass& pass, IR::Block* block, IR::Inst& inst) { | ||||
|     switch (inst.Opcode()) { | ||||
|     case IR::Opcode::SetRegister: | ||||
|         if (const IR::Reg reg{inst.Arg(0).Reg()}; reg != IR::Reg::RZ) { | ||||
|             pass.WriteVariable(reg, block, inst.Arg(1)); | ||||
|         } | ||||
|         break; | ||||
|     case IR::Opcode::SetPred: | ||||
|         if (const IR::Pred pred{inst.Arg(0).Pred()}; pred != IR::Pred::PT) { | ||||
|             pass.WriteVariable(pred, block, inst.Arg(1)); | ||||
|         } | ||||
|         break; | ||||
|     case IR::Opcode::SetGotoVariable: | ||||
|         pass.WriteVariable(GotoVariable{inst.Arg(0).U32()}, block, inst.Arg(1)); | ||||
|         break; | ||||
|     case IR::Opcode::SetZFlag: | ||||
|         pass.WriteVariable(ZeroFlagTag{}, block, inst.Arg(0)); | ||||
|         break; | ||||
|     case IR::Opcode::SetSFlag: | ||||
|         pass.WriteVariable(SignFlagTag{}, block, inst.Arg(0)); | ||||
|         break; | ||||
|     case IR::Opcode::SetCFlag: | ||||
|         pass.WriteVariable(CarryFlagTag{}, block, inst.Arg(0)); | ||||
|         break; | ||||
|     case IR::Opcode::SetOFlag: | ||||
|         pass.WriteVariable(OverflowFlagTag{}, block, inst.Arg(0)); | ||||
|         break; | ||||
|     case IR::Opcode::GetRegister: | ||||
|         if (const IR::Reg reg{inst.Arg(0).Reg()}; reg != IR::Reg::RZ) { | ||||
|             inst.ReplaceUsesWith(pass.ReadVariable(reg, block)); | ||||
|         } | ||||
|         break; | ||||
|     case IR::Opcode::GetPred: | ||||
|         if (const IR::Pred pred{inst.Arg(0).Pred()}; pred != IR::Pred::PT) { | ||||
|             inst.ReplaceUsesWith(pass.ReadVariable(pred, block)); | ||||
|         } | ||||
|         break; | ||||
|     case IR::Opcode::GetGotoVariable: | ||||
|         inst.ReplaceUsesWith(pass.ReadVariable(GotoVariable{inst.Arg(0).U32()}, block)); | ||||
|         break; | ||||
|     case IR::Opcode::GetZFlag: | ||||
|         inst.ReplaceUsesWith(pass.ReadVariable(ZeroFlagTag{}, block)); | ||||
|         break; | ||||
|     case IR::Opcode::GetSFlag: | ||||
|         inst.ReplaceUsesWith(pass.ReadVariable(SignFlagTag{}, block)); | ||||
|         break; | ||||
|     case IR::Opcode::GetCFlag: | ||||
|         inst.ReplaceUsesWith(pass.ReadVariable(CarryFlagTag{}, block)); | ||||
|         break; | ||||
|     case IR::Opcode::GetOFlag: | ||||
|         inst.ReplaceUsesWith(pass.ReadVariable(OverflowFlagTag{}, block)); | ||||
|         break; | ||||
|     default: | ||||
|         break; | ||||
|     } | ||||
| } | ||||
| } // Anonymous namespace | ||||
|  | ||||
| void SsaRewritePass(IR::Function& function) { | ||||
|     Pass pass; | ||||
|     for (IR::Block* const block : function.blocks) { | ||||
|         for (IR::Inst& inst : block->Instructions()) { | ||||
|             switch (inst.Opcode()) { | ||||
|             case IR::Opcode::SetRegister: | ||||
|                 if (const IR::Reg reg{inst.Arg(0).Reg()}; reg != IR::Reg::RZ) { | ||||
|                     pass.WriteVariable(reg, block, inst.Arg(1)); | ||||
|                 } | ||||
|                 break; | ||||
|             case IR::Opcode::SetPred: | ||||
|                 if (const IR::Pred pred{inst.Arg(0).Pred()}; pred != IR::Pred::PT) { | ||||
|                     pass.WriteVariable(pred, block, inst.Arg(1)); | ||||
|                 } | ||||
|                 break; | ||||
|             case IR::Opcode::SetGotoVariable: | ||||
|                 pass.WriteVariable(GotoVariable{inst.Arg(0).U32()}, block, inst.Arg(1)); | ||||
|                 break; | ||||
|             case IR::Opcode::SetZFlag: | ||||
|                 pass.WriteVariable(ZeroFlagTag{}, block, inst.Arg(0)); | ||||
|                 break; | ||||
|             case IR::Opcode::SetSFlag: | ||||
|                 pass.WriteVariable(SignFlagTag{}, block, inst.Arg(0)); | ||||
|                 break; | ||||
|             case IR::Opcode::SetCFlag: | ||||
|                 pass.WriteVariable(CarryFlagTag{}, block, inst.Arg(0)); | ||||
|                 break; | ||||
|             case IR::Opcode::SetOFlag: | ||||
|                 pass.WriteVariable(OverflowFlagTag{}, block, inst.Arg(0)); | ||||
|                 break; | ||||
|             case IR::Opcode::GetRegister: | ||||
|                 if (const IR::Reg reg{inst.Arg(0).Reg()}; reg != IR::Reg::RZ) { | ||||
|                     inst.ReplaceUsesWith(pass.ReadVariable(reg, block)); | ||||
|                 } | ||||
|                 break; | ||||
|             case IR::Opcode::GetPred: | ||||
|                 if (const IR::Pred pred{inst.Arg(0).Pred()}; pred != IR::Pred::PT) { | ||||
|                     inst.ReplaceUsesWith(pass.ReadVariable(pred, block)); | ||||
|                 } | ||||
|                 break; | ||||
|             case IR::Opcode::GetGotoVariable: | ||||
|                 inst.ReplaceUsesWith(pass.ReadVariable(GotoVariable{inst.Arg(0).U32()}, block)); | ||||
|                 break; | ||||
|             case IR::Opcode::GetZFlag: | ||||
|                 inst.ReplaceUsesWith(pass.ReadVariable(ZeroFlagTag{}, block)); | ||||
|                 break; | ||||
|             case IR::Opcode::GetSFlag: | ||||
|                 inst.ReplaceUsesWith(pass.ReadVariable(SignFlagTag{}, block)); | ||||
|                 break; | ||||
|             case IR::Opcode::GetCFlag: | ||||
|                 inst.ReplaceUsesWith(pass.ReadVariable(CarryFlagTag{}, block)); | ||||
|                 break; | ||||
|             case IR::Opcode::GetOFlag: | ||||
|                 inst.ReplaceUsesWith(pass.ReadVariable(OverflowFlagTag{}, block)); | ||||
|                 break; | ||||
|             default: | ||||
|                 break; | ||||
|             } | ||||
|             VisitInst(pass, block, inst); | ||||
|         } | ||||
|     } | ||||
| } | ||||
|   | ||||
| @@ -38,7 +38,8 @@ void RunDatabase() { | ||||
|         map.emplace_back(std::make_unique<FileEnvironment>(path.string().c_str())); | ||||
|     }); | ||||
|     auto block_pool{std::make_unique<ObjectPool<Flow::Block>>()}; | ||||
|     auto t0 = std::chrono::high_resolution_clock::now(); | ||||
|     using namespace std::chrono; | ||||
|     auto t0 = high_resolution_clock::now(); | ||||
|     int N = 1; | ||||
|     int n = 0; | ||||
|     for (int i = 0; i < N; ++i) { | ||||
| @@ -55,9 +56,8 @@ void RunDatabase() { | ||||
|             // const std::string code{EmitGLASM(program)}; | ||||
|         } | ||||
|     } | ||||
|     auto t = std::chrono::high_resolution_clock::now(); | ||||
|     fmt::print(stdout, "{} ms", | ||||
|                std::chrono::duration_cast<std::chrono::milliseconds>(t - t0).count() / double(N)); | ||||
|     auto t = high_resolution_clock::now(); | ||||
|     fmt::print(stdout, "{} ms", duration_cast<milliseconds>(t - t0).count() / double(N)); | ||||
| } | ||||
|  | ||||
| int main() { | ||||
| @@ -67,8 +67,8 @@ int main() { | ||||
|     auto inst_pool{std::make_unique<ObjectPool<IR::Inst>>()}; | ||||
|     auto block_pool{std::make_unique<ObjectPool<IR::Block>>()}; | ||||
|  | ||||
|     FileEnvironment env{"D:\\Shaders\\Database\\Oninaki\\CS8F146B41DB6BD826.bin"}; | ||||
|     // FileEnvironment env{"D:\\Shaders\\shader.bin"}; | ||||
|     // FileEnvironment env{"D:\\Shaders\\Database\\Oninaki\\CS8F146B41DB6BD826.bin"}; | ||||
|     FileEnvironment env{"D:\\Shaders\\shader.bin"}; | ||||
|     for (int i = 0; i < 1; ++i) { | ||||
|         block_pool->ReleaseContents(); | ||||
|         inst_pool->ReleaseContents(); | ||||
|   | ||||
		Reference in New Issue
	
	Block a user