spirv: Add lower fp16 to fp32 pass
This commit is contained in:
		@@ -4,8 +4,8 @@
 | 
			
		||||
 | 
			
		||||
#pragma once
 | 
			
		||||
 | 
			
		||||
#include <string>
 | 
			
		||||
#include <compare>
 | 
			
		||||
#include <string>
 | 
			
		||||
 | 
			
		||||
#include <fmt/format.h>
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -547,11 +547,11 @@ F32 IREmitter::FPSqrt(const F32& value) {
 | 
			
		||||
 | 
			
		||||
F16F32F64 IREmitter::FPSaturate(const F16F32F64& value) {
 | 
			
		||||
    switch (value.Type()) {
 | 
			
		||||
    case Type::U16:
 | 
			
		||||
    case Type::F16:
 | 
			
		||||
        return Inst<F16>(Opcode::FPSaturate16, value);
 | 
			
		||||
    case Type::U32:
 | 
			
		||||
    case Type::F32:
 | 
			
		||||
        return Inst<F32>(Opcode::FPSaturate32, value);
 | 
			
		||||
    case Type::U64:
 | 
			
		||||
    case Type::F64:
 | 
			
		||||
        return Inst<F64>(Opcode::FPSaturate64, value);
 | 
			
		||||
    default:
 | 
			
		||||
        ThrowInvalidType(value.Type());
 | 
			
		||||
@@ -560,11 +560,11 @@ F16F32F64 IREmitter::FPSaturate(const F16F32F64& value) {
 | 
			
		||||
 | 
			
		||||
F16F32F64 IREmitter::FPRoundEven(const F16F32F64& value) {
 | 
			
		||||
    switch (value.Type()) {
 | 
			
		||||
    case Type::U16:
 | 
			
		||||
    case Type::F16:
 | 
			
		||||
        return Inst<F16>(Opcode::FPRoundEven16, value);
 | 
			
		||||
    case Type::U32:
 | 
			
		||||
    case Type::F32:
 | 
			
		||||
        return Inst<F32>(Opcode::FPRoundEven32, value);
 | 
			
		||||
    case Type::U64:
 | 
			
		||||
    case Type::F64:
 | 
			
		||||
        return Inst<F64>(Opcode::FPRoundEven64, value);
 | 
			
		||||
    default:
 | 
			
		||||
        ThrowInvalidType(value.Type());
 | 
			
		||||
@@ -573,11 +573,11 @@ F16F32F64 IREmitter::FPRoundEven(const F16F32F64& value) {
 | 
			
		||||
 | 
			
		||||
F16F32F64 IREmitter::FPFloor(const F16F32F64& value) {
 | 
			
		||||
    switch (value.Type()) {
 | 
			
		||||
    case Type::U16:
 | 
			
		||||
    case Type::F16:
 | 
			
		||||
        return Inst<F16>(Opcode::FPFloor16, value);
 | 
			
		||||
    case Type::U32:
 | 
			
		||||
    case Type::F32:
 | 
			
		||||
        return Inst<F32>(Opcode::FPFloor32, value);
 | 
			
		||||
    case Type::U64:
 | 
			
		||||
    case Type::F64:
 | 
			
		||||
        return Inst<F64>(Opcode::FPFloor64, value);
 | 
			
		||||
    default:
 | 
			
		||||
        ThrowInvalidType(value.Type());
 | 
			
		||||
@@ -586,11 +586,11 @@ F16F32F64 IREmitter::FPFloor(const F16F32F64& value) {
 | 
			
		||||
 | 
			
		||||
F16F32F64 IREmitter::FPCeil(const F16F32F64& value) {
 | 
			
		||||
    switch (value.Type()) {
 | 
			
		||||
    case Type::U16:
 | 
			
		||||
    case Type::F16:
 | 
			
		||||
        return Inst<F16>(Opcode::FPCeil16, value);
 | 
			
		||||
    case Type::U32:
 | 
			
		||||
    case Type::F32:
 | 
			
		||||
        return Inst<F32>(Opcode::FPCeil32, value);
 | 
			
		||||
    case Type::U64:
 | 
			
		||||
    case Type::F64:
 | 
			
		||||
        return Inst<F64>(Opcode::FPCeil64, value);
 | 
			
		||||
    default:
 | 
			
		||||
        ThrowInvalidType(value.Type());
 | 
			
		||||
@@ -599,11 +599,11 @@ F16F32F64 IREmitter::FPCeil(const F16F32F64& value) {
 | 
			
		||||
 | 
			
		||||
F16F32F64 IREmitter::FPTrunc(const F16F32F64& value) {
 | 
			
		||||
    switch (value.Type()) {
 | 
			
		||||
    case Type::U16:
 | 
			
		||||
    case Type::F16:
 | 
			
		||||
        return Inst<F16>(Opcode::FPTrunc16, value);
 | 
			
		||||
    case Type::U32:
 | 
			
		||||
    case Type::F32:
 | 
			
		||||
        return Inst<F32>(Opcode::FPTrunc32, value);
 | 
			
		||||
    case Type::U64:
 | 
			
		||||
    case Type::F64:
 | 
			
		||||
        return Inst<F64>(Opcode::FPTrunc64, value);
 | 
			
		||||
    default:
 | 
			
		||||
        ThrowInvalidType(value.Type());
 | 
			
		||||
@@ -729,33 +729,33 @@ U32U64 IREmitter::ConvertFToS(size_t bitsize, const F16F32F64& value) {
 | 
			
		||||
    switch (bitsize) {
 | 
			
		||||
    case 16:
 | 
			
		||||
        switch (value.Type()) {
 | 
			
		||||
        case Type::U16:
 | 
			
		||||
        case Type::F16:
 | 
			
		||||
            return Inst<U32>(Opcode::ConvertS16F16, value);
 | 
			
		||||
        case Type::U32:
 | 
			
		||||
        case Type::F32:
 | 
			
		||||
            return Inst<U32>(Opcode::ConvertS16F32, value);
 | 
			
		||||
        case Type::U64:
 | 
			
		||||
        case Type::F64:
 | 
			
		||||
            return Inst<U32>(Opcode::ConvertS16F64, value);
 | 
			
		||||
        default:
 | 
			
		||||
            ThrowInvalidType(value.Type());
 | 
			
		||||
        }
 | 
			
		||||
    case 32:
 | 
			
		||||
        switch (value.Type()) {
 | 
			
		||||
        case Type::U16:
 | 
			
		||||
        case Type::F16:
 | 
			
		||||
            return Inst<U32>(Opcode::ConvertS32F16, value);
 | 
			
		||||
        case Type::U32:
 | 
			
		||||
        case Type::F32:
 | 
			
		||||
            return Inst<U32>(Opcode::ConvertS32F32, value);
 | 
			
		||||
        case Type::U64:
 | 
			
		||||
        case Type::F64:
 | 
			
		||||
            return Inst<U32>(Opcode::ConvertS32F64, value);
 | 
			
		||||
        default:
 | 
			
		||||
            ThrowInvalidType(value.Type());
 | 
			
		||||
        }
 | 
			
		||||
    case 64:
 | 
			
		||||
        switch (value.Type()) {
 | 
			
		||||
        case Type::U16:
 | 
			
		||||
        case Type::F16:
 | 
			
		||||
            return Inst<U64>(Opcode::ConvertS64F16, value);
 | 
			
		||||
        case Type::U32:
 | 
			
		||||
        case Type::F32:
 | 
			
		||||
            return Inst<U64>(Opcode::ConvertS64F32, value);
 | 
			
		||||
        case Type::U64:
 | 
			
		||||
        case Type::F64:
 | 
			
		||||
            return Inst<U64>(Opcode::ConvertS64F64, value);
 | 
			
		||||
        default:
 | 
			
		||||
            ThrowInvalidType(value.Type());
 | 
			
		||||
@@ -769,33 +769,33 @@ U32U64 IREmitter::ConvertFToU(size_t bitsize, const F16F32F64& value) {
 | 
			
		||||
    switch (bitsize) {
 | 
			
		||||
    case 16:
 | 
			
		||||
        switch (value.Type()) {
 | 
			
		||||
        case Type::U16:
 | 
			
		||||
        case Type::F16:
 | 
			
		||||
            return Inst<U32>(Opcode::ConvertU16F16, value);
 | 
			
		||||
        case Type::U32:
 | 
			
		||||
        case Type::F32:
 | 
			
		||||
            return Inst<U32>(Opcode::ConvertU16F32, value);
 | 
			
		||||
        case Type::U64:
 | 
			
		||||
        case Type::F64:
 | 
			
		||||
            return Inst<U32>(Opcode::ConvertU16F64, value);
 | 
			
		||||
        default:
 | 
			
		||||
            ThrowInvalidType(value.Type());
 | 
			
		||||
        }
 | 
			
		||||
    case 32:
 | 
			
		||||
        switch (value.Type()) {
 | 
			
		||||
        case Type::U16:
 | 
			
		||||
        case Type::F16:
 | 
			
		||||
            return Inst<U32>(Opcode::ConvertU32F16, value);
 | 
			
		||||
        case Type::U32:
 | 
			
		||||
        case Type::F32:
 | 
			
		||||
            return Inst<U32>(Opcode::ConvertU32F32, value);
 | 
			
		||||
        case Type::U64:
 | 
			
		||||
        case Type::F64:
 | 
			
		||||
            return Inst<U32>(Opcode::ConvertU32F64, value);
 | 
			
		||||
        default:
 | 
			
		||||
            ThrowInvalidType(value.Type());
 | 
			
		||||
        }
 | 
			
		||||
    case 64:
 | 
			
		||||
        switch (value.Type()) {
 | 
			
		||||
        case Type::U16:
 | 
			
		||||
        case Type::F16:
 | 
			
		||||
            return Inst<U64>(Opcode::ConvertU64F16, value);
 | 
			
		||||
        case Type::U32:
 | 
			
		||||
        case Type::F32:
 | 
			
		||||
            return Inst<U64>(Opcode::ConvertU64F32, value);
 | 
			
		||||
        case Type::U64:
 | 
			
		||||
        case Type::F64:
 | 
			
		||||
            return Inst<U64>(Opcode::ConvertU64F64, value);
 | 
			
		||||
        default:
 | 
			
		||||
            ThrowInvalidType(value.Type());
 | 
			
		||||
@@ -829,10 +829,10 @@ U32U64 IREmitter::ConvertU(size_t result_bitsize, const U32U64& value) {
 | 
			
		||||
    case 64:
 | 
			
		||||
        switch (value.Type()) {
 | 
			
		||||
        case Type::U32:
 | 
			
		||||
            return Inst<U64>(Opcode::ConvertU64U32, value);
 | 
			
		||||
        case Type::U64:
 | 
			
		||||
            // Nothing to do
 | 
			
		||||
            return value;
 | 
			
		||||
        case Type::U64:
 | 
			
		||||
            return Inst<U64>(Opcode::ConvertU64U32, value);
 | 
			
		||||
        default:
 | 
			
		||||
            break;
 | 
			
		||||
        }
 | 
			
		||||
 
 | 
			
		||||
@@ -216,6 +216,10 @@ void Inst::ReplaceUsesWith(Value replacement) {
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void Inst::ReplaceOpcode(IR::Opcode opcode) {
 | 
			
		||||
    op = opcode;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void Inst::Use(const Value& value) {
 | 
			
		||||
    Inst* const inst{value.Inst()};
 | 
			
		||||
    ++inst->use_count;
 | 
			
		||||
 
 | 
			
		||||
@@ -86,6 +86,8 @@ public:
 | 
			
		||||
 | 
			
		||||
    void ReplaceUsesWith(Value replacement);
 | 
			
		||||
 | 
			
		||||
    void ReplaceOpcode(IR::Opcode opcode);
 | 
			
		||||
 | 
			
		||||
    template <typename FlagsType>
 | 
			
		||||
    requires(sizeof(FlagsType) <= sizeof(u32) && std::is_trivially_copyable_v<FlagsType>)
 | 
			
		||||
        [[nodiscard]] FlagsType Flags() const noexcept {
 | 
			
		||||
 
 | 
			
		||||
@@ -119,8 +119,10 @@ OPCODE(PackUint2x32,                                        U64,            U32x
 | 
			
		||||
OPCODE(UnpackUint2x32,                                      U32x2,          U64,                                                            )
 | 
			
		||||
OPCODE(PackFloat2x16,                                       U32,            F16x2,                                                          )
 | 
			
		||||
OPCODE(UnpackFloat2x16,                                     F16x2,          U32,                                                            )
 | 
			
		||||
OPCODE(PackDouble2x32,                                      U64,            U32x2,                                                          )
 | 
			
		||||
OPCODE(UnpackDouble2x32,                                    U32x2,          U64,                                                            )
 | 
			
		||||
OPCODE(PackHalf2x16,                                        U32,            F32x2,                                                          )
 | 
			
		||||
OPCODE(UnpackHalf2x16,                                      F32x2,          U32,                                                            )
 | 
			
		||||
OPCODE(PackDouble2x32,                                      F64,            U32x2,                                                          )
 | 
			
		||||
OPCODE(UnpackDouble2x32,                                    U32x2,          F64,                                                            )
 | 
			
		||||
 | 
			
		||||
// Pseudo-operation, handled specially at final emit
 | 
			
		||||
OPCODE(GetZeroFromOp,                                       U1,             Opaque,                                                         )
 | 
			
		||||
 
 | 
			
		||||
@@ -35,4 +35,4 @@ std::string DumpProgram(const Program& program) {
 | 
			
		||||
    return ret;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
} // namespace Shader::IR
 | 
			
		||||
} // namespace Shader::IR
 | 
			
		||||
 
 | 
			
		||||
@@ -56,6 +56,7 @@ IR::Program TranslateProgram(ObjectPool<IR::Inst>& inst_pool, ObjectPool<IR::Blo
 | 
			
		||||
            .post_order_blocks{},
 | 
			
		||||
        });
 | 
			
		||||
    }
 | 
			
		||||
    Optimization::LowerFp16ToFp32(program);
 | 
			
		||||
    for (IR::Function& function : functions) {
 | 
			
		||||
        function.post_order_blocks = PostOrder(function.blocks);
 | 
			
		||||
        Optimization::SsaRewritePass(function.post_order_blocks);
 | 
			
		||||
@@ -69,6 +70,7 @@ IR::Program TranslateProgram(ObjectPool<IR::Inst>& inst_pool, ObjectPool<IR::Blo
 | 
			
		||||
        Optimization::VerificationPass(function);
 | 
			
		||||
    }
 | 
			
		||||
    Optimization::CollectShaderInfoPass(program);
 | 
			
		||||
 | 
			
		||||
    fmt::print(stdout, "{}\n", IR::DumpProgram(program));
 | 
			
		||||
    return program;
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -34,7 +34,7 @@ union F2I {
 | 
			
		||||
    BitField<8, 2, DestFormat> dest_format;
 | 
			
		||||
    BitField<10, 2, SrcFormat> src_format;
 | 
			
		||||
    BitField<12, 1, u64> is_signed;
 | 
			
		||||
    BitField<39, 1, Rounding> rounding;
 | 
			
		||||
    BitField<39, 2, Rounding> rounding;
 | 
			
		||||
    BitField<49, 1, u64> half;
 | 
			
		||||
    BitField<44, 1, u64> ftz;
 | 
			
		||||
    BitField<45, 1, u64> abs;
 | 
			
		||||
@@ -55,6 +55,28 @@ size_t BitSize(DestFormat dest_format) {
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
IR::F64 UnpackCbuf(TranslatorVisitor& v, u64 insn) {
 | 
			
		||||
    union {
 | 
			
		||||
        u64 raw;
 | 
			
		||||
        BitField<20, 14, s64> offset;
 | 
			
		||||
        BitField<34, 5, u64> binding;
 | 
			
		||||
    } const cbuf{insn};
 | 
			
		||||
    if (cbuf.binding >= 18) {
 | 
			
		||||
        throw NotImplementedException("Out of bounds constant buffer binding {}", cbuf.binding);
 | 
			
		||||
    }
 | 
			
		||||
    if (cbuf.offset >= 0x4'000 || cbuf.offset < 0) {
 | 
			
		||||
        throw NotImplementedException("Out of bounds constant buffer offset {}", cbuf.offset * 4);
 | 
			
		||||
    }
 | 
			
		||||
    if (cbuf.offset % 2 != 0) {
 | 
			
		||||
        throw NotImplementedException("Unaligned F64 constant buffer offset {}", cbuf.offset * 4);
 | 
			
		||||
    }
 | 
			
		||||
    const IR::U32 binding{v.ir.Imm32(static_cast<u32>(cbuf.binding))};
 | 
			
		||||
    const IR::U32 byte_offset{v.ir.Imm32(static_cast<u32>(cbuf.offset) * 4 + 4)};
 | 
			
		||||
    const IR::U32 cbuf_data{v.ir.GetCbuf(binding, byte_offset)};
 | 
			
		||||
    const IR::Value vector{v.ir.CompositeConstruct(v.ir.Imm32(0U), cbuf_data)};
 | 
			
		||||
    return v.ir.PackDouble2x32(vector);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void TranslateF2I(TranslatorVisitor& v, u64 insn, const IR::F16F32F64& src_a) {
 | 
			
		||||
    // F2I is used to convert from a floating point value to an integer
 | 
			
		||||
    const F2I f2i{insn};
 | 
			
		||||
@@ -82,19 +104,16 @@ void TranslateF2I(TranslatorVisitor& v, u64 insn, const IR::F16F32F64& src_a) {
 | 
			
		||||
    const size_t bitsize{BitSize(f2i.dest_format)};
 | 
			
		||||
    const IR::U16U32U64 result{v.ir.ConvertFToI(bitsize, is_signed, rounded_value)};
 | 
			
		||||
 | 
			
		||||
    v.X(f2i.dest_reg, result);
 | 
			
		||||
    if (bitsize == 64) {
 | 
			
		||||
        const IR::Value vector{v.ir.UnpackUint2x32(result)};
 | 
			
		||||
        v.X(f2i.dest_reg + 0, IR::U32{v.ir.CompositeExtract(vector, 0)});
 | 
			
		||||
        v.X(f2i.dest_reg + 1, IR::U32{v.ir.CompositeExtract(vector, 1)});
 | 
			
		||||
    } else {
 | 
			
		||||
        v.X(f2i.dest_reg, result);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    if (f2i.cc != 0) {
 | 
			
		||||
        v.SetZFlag(v.ir.GetZeroFromOp(result));
 | 
			
		||||
        if (is_signed) {
 | 
			
		||||
            v.SetSFlag(v.ir.GetSignFromOp(result));
 | 
			
		||||
        } else {
 | 
			
		||||
            v.ResetSFlag();
 | 
			
		||||
        }
 | 
			
		||||
        v.ResetCFlag();
 | 
			
		||||
 | 
			
		||||
        // TODO: Investigate if out of bound conversions sets the overflow flag
 | 
			
		||||
        v.ResetOFlag();
 | 
			
		||||
        throw NotImplementedException("F2I CC");
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
} // Anonymous namespace
 | 
			
		||||
@@ -118,12 +137,25 @@ void TranslatorVisitor::F2I_reg(u64 insn) {
 | 
			
		||||
                                          f2i.base.src_format.Value());
 | 
			
		||||
        }
 | 
			
		||||
    }()};
 | 
			
		||||
 | 
			
		||||
    TranslateF2I(*this, insn, op_a);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void TranslatorVisitor::F2I_cbuf(u64) {
 | 
			
		||||
    throw NotImplementedException("{}", Opcode::F2I_cbuf);
 | 
			
		||||
void TranslatorVisitor::F2I_cbuf(u64 insn) {
 | 
			
		||||
    const F2I f2i{insn};
 | 
			
		||||
    const IR::F16F32F64 op_a{[&]() -> IR::F16F32F64 {
 | 
			
		||||
        switch (f2i.src_format) {
 | 
			
		||||
        case SrcFormat::F16:
 | 
			
		||||
            return IR::F16{ir.CompositeExtract(ir.UnpackFloat2x16(GetCbuf(insn)), f2i.half)};
 | 
			
		||||
        case SrcFormat::F32:
 | 
			
		||||
            return GetCbufF(insn);
 | 
			
		||||
        case SrcFormat::F64: {
 | 
			
		||||
            return UnpackCbuf(*this, insn);
 | 
			
		||||
        }
 | 
			
		||||
        default:
 | 
			
		||||
            throw NotImplementedException("Invalid F2I source format {}", f2i.src_format.Value());
 | 
			
		||||
        }
 | 
			
		||||
    }()};
 | 
			
		||||
    TranslateF2I(*this, insn, op_a);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void TranslatorVisitor::F2I_imm(u64) {
 | 
			
		||||
 
 | 
			
		||||
@@ -11,7 +11,7 @@ namespace Shader::Maxwell {
 | 
			
		||||
 | 
			
		||||
class TranslatorVisitor {
 | 
			
		||||
public:
 | 
			
		||||
    explicit TranslatorVisitor(Environment& env_, IR::Block& block) : env{env_} ,ir(block) {}
 | 
			
		||||
    explicit TranslatorVisitor(Environment& env_, IR::Block& block) : env{env_}, ir(block) {}
 | 
			
		||||
 | 
			
		||||
    Environment& env;
 | 
			
		||||
    IR::IREmitter ir;
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user