shader: Add support for fp16 comparisons and misc fixes
This commit is contained in:
		| @@ -234,7 +234,9 @@ Id EmitFPOrdGreaterThanEqual64(EmitContext& ctx, Id lhs, Id rhs); | |||||||
| Id EmitFPUnordGreaterThanEqual16(EmitContext& ctx, Id lhs, Id rhs); | Id EmitFPUnordGreaterThanEqual16(EmitContext& ctx, Id lhs, Id rhs); | ||||||
| Id EmitFPUnordGreaterThanEqual32(EmitContext& ctx, Id lhs, Id rhs); | Id EmitFPUnordGreaterThanEqual32(EmitContext& ctx, Id lhs, Id rhs); | ||||||
| Id EmitFPUnordGreaterThanEqual64(EmitContext& ctx, Id lhs, Id rhs); | Id EmitFPUnordGreaterThanEqual64(EmitContext& ctx, Id lhs, Id rhs); | ||||||
|  | Id EmitFPIsNan16(EmitContext& ctx, Id value); | ||||||
| Id EmitFPIsNan32(EmitContext& ctx, Id value); | Id EmitFPIsNan32(EmitContext& ctx, Id value); | ||||||
|  | Id EmitFPIsNan64(EmitContext& ctx, Id value); | ||||||
| Id EmitIAdd32(EmitContext& ctx, IR::Inst* inst, Id a, Id b); | Id EmitIAdd32(EmitContext& ctx, IR::Inst* inst, Id a, Id b); | ||||||
| void EmitIAdd64(EmitContext& ctx); | void EmitIAdd64(EmitContext& ctx); | ||||||
| Id EmitISub32(EmitContext& ctx, Id a, Id b); | Id EmitISub32(EmitContext& ctx, Id a, Id b); | ||||||
|   | |||||||
| @@ -346,8 +346,16 @@ Id EmitFPUnordGreaterThanEqual64(EmitContext& ctx, Id lhs, Id rhs) { | |||||||
|     return ctx.OpFUnordGreaterThanEqual(ctx.U1, lhs, rhs); |     return ctx.OpFUnordGreaterThanEqual(ctx.U1, lhs, rhs); | ||||||
| } | } | ||||||
|  |  | ||||||
|  | Id EmitFPIsNan16(EmitContext& ctx, Id value) { | ||||||
|  |     return ctx.OpIsNan(ctx.U1, value); | ||||||
|  | } | ||||||
|  |  | ||||||
| Id EmitFPIsNan32(EmitContext& ctx, Id value) { | Id EmitFPIsNan32(EmitContext& ctx, Id value) { | ||||||
|     return ctx.OpIsNan(ctx.U1, value); |     return ctx.OpIsNan(ctx.U1, value); | ||||||
| } | } | ||||||
|  |  | ||||||
|  | Id EmitFPIsNan64(EmitContext& ctx, Id value) { | ||||||
|  |     return ctx.OpIsNan(ctx.U1, value); | ||||||
|  | } | ||||||
|  |  | ||||||
| } // namespace Shader::Backend::SPIRV | } // namespace Shader::Backend::SPIRV | ||||||
|   | |||||||
| @@ -895,15 +895,30 @@ U1 IREmitter::FPGreaterThanEqual(const F16F32F64& lhs, const F16F32F64& rhs, FpC | |||||||
|     } |     } | ||||||
| } | } | ||||||
|  |  | ||||||
| U1 IREmitter::FPIsNan(const F32& value) { | U1 IREmitter::FPIsNan(const F16F32F64& value) { | ||||||
|  |     switch (value.Type()) { | ||||||
|  |     case Type::F16: | ||||||
|  |         return Inst<U1>(Opcode::FPIsNan16, value); | ||||||
|  |     case Type::F32: | ||||||
|         return Inst<U1>(Opcode::FPIsNan32, value); |         return Inst<U1>(Opcode::FPIsNan32, value); | ||||||
|  |     case Type::F64: | ||||||
|  |         return Inst<U1>(Opcode::FPIsNan64, value); | ||||||
|  |     default: | ||||||
|  |         ThrowInvalidType(value.Type()); | ||||||
|  |     } | ||||||
| } | } | ||||||
|  |  | ||||||
| U1 IREmitter::FPOrdered(const F32& lhs, const F32& rhs) { | U1 IREmitter::FPOrdered(const F16F32F64& lhs, const F16F32F64& rhs) { | ||||||
|  |     if (lhs.Type() != rhs.Type()) { | ||||||
|  |         throw InvalidArgument("Mismatching types {} and {}", lhs.Type(), rhs.Type()); | ||||||
|  |     } | ||||||
|     return LogicalAnd(LogicalNot(FPIsNan(lhs)), LogicalNot(FPIsNan(rhs))); |     return LogicalAnd(LogicalNot(FPIsNan(lhs)), LogicalNot(FPIsNan(rhs))); | ||||||
| } | } | ||||||
|  |  | ||||||
| U1 IREmitter::FPUnordered(const F32& lhs, const F32& rhs) { | U1 IREmitter::FPUnordered(const F16F32F64& lhs, const F16F32F64& rhs) { | ||||||
|  |     if (lhs.Type() != rhs.Type()) { | ||||||
|  |         throw InvalidArgument("Mismatching types {} and {}", lhs.Type(), rhs.Type()); | ||||||
|  |     } | ||||||
|     return LogicalOr(FPIsNan(lhs), FPIsNan(rhs)); |     return LogicalOr(FPIsNan(lhs), FPIsNan(rhs)); | ||||||
| } | } | ||||||
|  |  | ||||||
|   | |||||||
| @@ -161,9 +161,9 @@ public: | |||||||
|                                      FpControl control = {}, bool ordered = true); |                                      FpControl control = {}, bool ordered = true); | ||||||
|     [[nodiscard]] U1 FPGreaterThanEqual(const F16F32F64& lhs, const F16F32F64& rhs, |     [[nodiscard]] U1 FPGreaterThanEqual(const F16F32F64& lhs, const F16F32F64& rhs, | ||||||
|                                         FpControl control = {}, bool ordered = true); |                                         FpControl control = {}, bool ordered = true); | ||||||
|     [[nodiscard]] U1 FPIsNan(const F32& value); |     [[nodiscard]] U1 FPIsNan(const F16F32F64& value); | ||||||
|     [[nodiscard]] U1 FPOrdered(const F32& lhs, const F32& rhs); |     [[nodiscard]] U1 FPOrdered(const F16F32F64& lhs, const F16F32F64& rhs); | ||||||
|     [[nodiscard]] U1 FPUnordered(const F32& lhs, const F32& rhs); |     [[nodiscard]] U1 FPUnordered(const F16F32F64& lhs, const F16F32F64& rhs); | ||||||
|     [[nodiscard]] F32F64 FPMax(const F32F64& lhs, const F32F64& rhs, FpControl control = {}); |     [[nodiscard]] F32F64 FPMax(const F32F64& lhs, const F32F64& rhs, FpControl control = {}); | ||||||
|     [[nodiscard]] F32F64 FPMin(const F32F64& lhs, const F32F64& rhs, FpControl control = {}); |     [[nodiscard]] F32F64 FPMin(const F32F64& lhs, const F32F64& rhs, FpControl control = {}); | ||||||
|  |  | ||||||
|   | |||||||
| @@ -236,7 +236,9 @@ OPCODE(FPOrdGreaterThanEqual64,                             U1,             F64, | |||||||
| OPCODE(FPUnordGreaterThanEqual16,                           U1,             F16,            F16,                                                            ) | OPCODE(FPUnordGreaterThanEqual16,                           U1,             F16,            F16,                                                            ) | ||||||
| OPCODE(FPUnordGreaterThanEqual32,                           U1,             F32,            F32,                                                            ) | OPCODE(FPUnordGreaterThanEqual32,                           U1,             F32,            F32,                                                            ) | ||||||
| OPCODE(FPUnordGreaterThanEqual64,                           U1,             F64,            F64,                                                            ) | OPCODE(FPUnordGreaterThanEqual64,                           U1,             F64,            F64,                                                            ) | ||||||
|  | OPCODE(FPIsNan16,                                           U1,             F16,                                                                            ) | ||||||
| OPCODE(FPIsNan32,                                           U1,             F32,                                                                            ) | OPCODE(FPIsNan32,                                           U1,             F32,                                                                            ) | ||||||
|  | OPCODE(FPIsNan64,                                           U1,             F64,                                                                            ) | ||||||
|  |  | ||||||
| // Integer operations | // Integer operations | ||||||
| OPCODE(IAdd32,                                              U32,            U32,            U32,                                                            ) | OPCODE(IAdd32,                                              U32,            U32,            U32,                                                            ) | ||||||
|   | |||||||
| @@ -6,7 +6,6 @@ | |||||||
|  |  | ||||||
| namespace Shader::Maxwell { | namespace Shader::Maxwell { | ||||||
| namespace { | namespace { | ||||||
|  |  | ||||||
| void HADD2(TranslatorVisitor& v, u64 insn, Merge merge, bool ftz, bool sat, bool abs_a, bool neg_a, | void HADD2(TranslatorVisitor& v, u64 insn, Merge merge, bool ftz, bool sat, bool abs_a, bool neg_a, | ||||||
|            Swizzle swizzle_a, bool abs_b, bool neg_b, Swizzle swizzle_b, const IR::U32& src_b) { |            Swizzle swizzle_a, bool abs_b, bool neg_b, Swizzle swizzle_b, const IR::U32& src_b) { | ||||||
|     union { |     union { | ||||||
| @@ -66,7 +65,7 @@ void HADD2(TranslatorVisitor& v, u64 insn, bool sat, bool abs_b, bool neg_b, Swi | |||||||
|     HADD2(v, insn, hadd2.merge, hadd2.ftz != 0, sat, hadd2.abs_a != 0, hadd2.neg_a != 0, |     HADD2(v, insn, hadd2.merge, hadd2.ftz != 0, sat, hadd2.abs_a != 0, hadd2.neg_a != 0, | ||||||
|           hadd2.swizzle_a, abs_b, neg_b, swizzle_b, src_b); |           hadd2.swizzle_a, abs_b, neg_b, swizzle_b, src_b); | ||||||
| } | } | ||||||
| } // namespace | } // Anonymous namespace | ||||||
|  |  | ||||||
| void TranslatorVisitor::HADD2_reg(u64 insn) { | void TranslatorVisitor::HADD2_reg(u64 insn) { | ||||||
|     union { |     union { | ||||||
|   | |||||||
| @@ -6,7 +6,6 @@ | |||||||
|  |  | ||||||
| namespace Shader::Maxwell { | namespace Shader::Maxwell { | ||||||
| namespace { | namespace { | ||||||
|  |  | ||||||
| void HFMA2(TranslatorVisitor& v, u64 insn, Merge merge, Swizzle swizzle_a, bool neg_b, bool neg_c, | void HFMA2(TranslatorVisitor& v, u64 insn, Merge merge, Swizzle swizzle_a, bool neg_b, bool neg_c, | ||||||
|            Swizzle swizzle_b, Swizzle swizzle_c, const IR::U32& src_b, const IR::U32& src_c, |            Swizzle swizzle_b, Swizzle swizzle_c, const IR::U32& src_b, const IR::U32& src_c, | ||||||
|            bool sat, HalfPrecision precision) { |            bool sat, HalfPrecision precision) { | ||||||
| @@ -85,8 +84,7 @@ void HFMA2(TranslatorVisitor& v, u64 insn, bool neg_b, bool neg_c, Swizzle swizz | |||||||
|     HFMA2(v, insn, hfma2.merge, hfma2.swizzle_a, neg_b, neg_c, swizzle_b, swizzle_c, src_b, src_c, |     HFMA2(v, insn, hfma2.merge, hfma2.swizzle_a, neg_b, neg_c, swizzle_b, swizzle_c, src_b, src_c, | ||||||
|           sat, precision); |           sat, precision); | ||||||
| } | } | ||||||
|  | } // Anonymous namespace | ||||||
| } // namespace |  | ||||||
|  |  | ||||||
| void TranslatorVisitor::HFMA2_reg(u64 insn) { | void TranslatorVisitor::HFMA2_reg(u64 insn) { | ||||||
|     union { |     union { | ||||||
|   | |||||||
| @@ -6,7 +6,6 @@ | |||||||
|  |  | ||||||
| namespace Shader::Maxwell { | namespace Shader::Maxwell { | ||||||
| namespace { | namespace { | ||||||
|  |  | ||||||
| void HMUL2(TranslatorVisitor& v, u64 insn, Merge merge, bool sat, bool abs_a, bool neg_a, | void HMUL2(TranslatorVisitor& v, u64 insn, Merge merge, bool sat, bool abs_a, bool neg_a, | ||||||
|            Swizzle swizzle_a, bool abs_b, bool neg_b, Swizzle swizzle_b, const IR::U32& src_b, |            Swizzle swizzle_a, bool abs_b, bool neg_b, Swizzle swizzle_b, const IR::U32& src_b, | ||||||
|            HalfPrecision precision) { |            HalfPrecision precision) { | ||||||
| @@ -79,7 +78,7 @@ void HMUL2(TranslatorVisitor& v, u64 insn, bool sat, bool abs_a, bool neg_a, boo | |||||||
|     HMUL2(v, insn, hmul2.merge, sat, abs_a, neg_a, hmul2.swizzle_a, abs_b, neg_b, swizzle_b, src_b, |     HMUL2(v, insn, hmul2.merge, sat, abs_a, neg_a, hmul2.swizzle_a, abs_b, neg_b, swizzle_b, src_b, | ||||||
|           hmul2.precision); |           hmul2.precision); | ||||||
| } | } | ||||||
| } // namespace | } // Anonymous namespace | ||||||
|  |  | ||||||
| void TranslatorVisitor::HMUL2_reg(u64 insn) { | void TranslatorVisitor::HMUL2_reg(u64 insn) { | ||||||
|     union { |     union { | ||||||
|   | |||||||
| @@ -76,6 +76,7 @@ void TranslatorVisitor::HSET2_reg(u64 insn) { | |||||||
|         BitField<35, 4, FPCompareOp> compare_op; |         BitField<35, 4, FPCompareOp> compare_op; | ||||||
|         BitField<28, 2, Swizzle> swizzle_b; |         BitField<28, 2, Swizzle> swizzle_b; | ||||||
|     } const hset2{insn}; |     } const hset2{insn}; | ||||||
|  |  | ||||||
|     HSET2(*this, insn, GetReg20(insn), hset2.bf != 0, hset2.ftz != 0, hset2.neg_b != 0, |     HSET2(*this, insn, GetReg20(insn), hset2.bf != 0, hset2.ftz != 0, hset2.neg_b != 0, | ||||||
|           hset2.abs_b != 0, hset2.compare_op, hset2.swizzle_b); |           hset2.abs_b != 0, hset2.compare_op, hset2.swizzle_b); | ||||||
| } | } | ||||||
|   | |||||||
| @@ -74,6 +74,9 @@ void VisitUsages(Info& info, IR::Inst& inst) { | |||||||
|     case IR::Opcode::CompositeExtractF16x2: |     case IR::Opcode::CompositeExtractF16x2: | ||||||
|     case IR::Opcode::CompositeExtractF16x3: |     case IR::Opcode::CompositeExtractF16x3: | ||||||
|     case IR::Opcode::CompositeExtractF16x4: |     case IR::Opcode::CompositeExtractF16x4: | ||||||
|  |     case IR::Opcode::CompositeInsertF16x2: | ||||||
|  |     case IR::Opcode::CompositeInsertF16x3: | ||||||
|  |     case IR::Opcode::CompositeInsertF16x4: | ||||||
|     case IR::Opcode::SelectF16: |     case IR::Opcode::SelectF16: | ||||||
|     case IR::Opcode::BitCastU16F16: |     case IR::Opcode::BitCastU16F16: | ||||||
|     case IR::Opcode::BitCastF16U16: |     case IR::Opcode::BitCastF16U16: | ||||||
| @@ -103,6 +106,19 @@ void VisitUsages(Info& info, IR::Inst& inst) { | |||||||
|     case IR::Opcode::FPRoundEven16: |     case IR::Opcode::FPRoundEven16: | ||||||
|     case IR::Opcode::FPSaturate16: |     case IR::Opcode::FPSaturate16: | ||||||
|     case IR::Opcode::FPTrunc16: |     case IR::Opcode::FPTrunc16: | ||||||
|  |     case IR::Opcode::FPOrdEqual16: | ||||||
|  |     case IR::Opcode::FPUnordEqual16: | ||||||
|  |     case IR::Opcode::FPOrdNotEqual16: | ||||||
|  |     case IR::Opcode::FPUnordNotEqual16: | ||||||
|  |     case IR::Opcode::FPOrdLessThan16: | ||||||
|  |     case IR::Opcode::FPUnordLessThan16: | ||||||
|  |     case IR::Opcode::FPOrdGreaterThan16: | ||||||
|  |     case IR::Opcode::FPUnordGreaterThan16: | ||||||
|  |     case IR::Opcode::FPOrdLessThanEqual16: | ||||||
|  |     case IR::Opcode::FPUnordLessThanEqual16: | ||||||
|  |     case IR::Opcode::FPOrdGreaterThanEqual16: | ||||||
|  |     case IR::Opcode::FPUnordGreaterThanEqual16: | ||||||
|  |     case IR::Opcode::FPIsNan16: | ||||||
|         info.uses_fp16 = true; |         info.uses_fp16 = true; | ||||||
|         break; |         break; | ||||||
|     case IR::Opcode::FPAbs64: |     case IR::Opcode::FPAbs64: | ||||||
|   | |||||||
| @@ -74,6 +74,8 @@ IR::Opcode Replace(IR::Opcode op) { | |||||||
|         return IR::Opcode::FPOrdGreaterThanEqual32; |         return IR::Opcode::FPOrdGreaterThanEqual32; | ||||||
|     case IR::Opcode::FPUnordGreaterThanEqual16: |     case IR::Opcode::FPUnordGreaterThanEqual16: | ||||||
|         return IR::Opcode::FPUnordGreaterThanEqual32; |         return IR::Opcode::FPUnordGreaterThanEqual32; | ||||||
|  |     case IR::Opcode::FPIsNan16: | ||||||
|  |         return IR::Opcode::FPIsNan32; | ||||||
|     case IR::Opcode::ConvertS16F16: |     case IR::Opcode::ConvertS16F16: | ||||||
|         return IR::Opcode::ConvertS16F32; |         return IR::Opcode::ConvertS16F32; | ||||||
|     case IR::Opcode::ConvertS32F16: |     case IR::Opcode::ConvertS32F16: | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user