shader: Fold integer FMA from Nvidia's pattern
Fold shaders doing "a * b + c" on integers from the pattern generated by Nvidia's GL compiler. On a somewhat complex compute shader it reduces the code size by 16 instructions from 2 matches on Turing GPUs. On Intel as extracted from KHR_pipeline_executable_properties: Before the optimization: ``` Instruction Count: 2057 Basic Block Count: 45 Scratch Memory Size: 14752 Spill Count: 232 Fill Count: 261 SEND Count: 610 Cycle Count: 11325 ``` After the optimization: ``` Instruction Count: 2046 Basic Block Count: 44 Scratch Memory Size: 13728 Spill Count: 219 Fill Count: 268 SEND Count: 604 Cycle Count: 11367 ```
This commit is contained in:
		| @@ -3,6 +3,7 @@ | |||||||
| // Refer to the license.txt file included. | // Refer to the license.txt file included. | ||||||
|  |  | ||||||
| #include <algorithm> | #include <algorithm> | ||||||
|  | #include <functional> | ||||||
| #include <tuple> | #include <tuple> | ||||||
| #include <type_traits> | #include <type_traits> | ||||||
|  |  | ||||||
| @@ -88,6 +89,26 @@ bool FoldWhenAllImmediates(IR::Inst& inst, Func&& func) { | |||||||
|     return true; |     return true; | ||||||
| } | } | ||||||
|  |  | ||||||
|  | /// Return true when all values in a range are equal | ||||||
|  | template <typename Range> | ||||||
|  | bool AreEqual(const Range& range) { | ||||||
|  |     auto resolver{[](const auto& value) { return value.Resolve(); }}; | ||||||
|  |     auto equal{[](const IR::Value& lhs, const IR::Value& rhs) { | ||||||
|  |         if (lhs == rhs) { | ||||||
|  |             return true; | ||||||
|  |         } | ||||||
|  |         // Not equal, but try to match if they read the same constant buffer | ||||||
|  |         if (!lhs.IsImmediate() && !rhs.IsImmediate() && | ||||||
|  |             lhs.Inst()->GetOpcode() == IR::Opcode::GetCbufU32 && | ||||||
|  |             rhs.Inst()->GetOpcode() == IR::Opcode::GetCbufU32 && | ||||||
|  |             lhs.Inst()->Arg(0) == rhs.Inst()->Arg(0) && lhs.Inst()->Arg(1) == rhs.Inst()->Arg(1)) { | ||||||
|  |             return true; | ||||||
|  |         } | ||||||
|  |         return false; | ||||||
|  |     }}; | ||||||
|  |     return std::ranges::adjacent_find(range, std::not_fn(equal), resolver) == std::end(range); | ||||||
|  | } | ||||||
|  |  | ||||||
| void FoldGetRegister(IR::Inst& inst) { | void FoldGetRegister(IR::Inst& inst) { | ||||||
|     if (inst.Arg(0).Reg() == IR::Reg::RZ) { |     if (inst.Arg(0).Reg() == IR::Reg::RZ) { | ||||||
|         inst.ReplaceUsesWith(IR::Value{u32{0}}); |         inst.ReplaceUsesWith(IR::Value{u32{0}}); | ||||||
| @@ -100,6 +121,157 @@ void FoldGetPred(IR::Inst& inst) { | |||||||
|     } |     } | ||||||
| } | } | ||||||
|  |  | ||||||
|  | /// Replaces the XMAD pattern generated by an integer FMA | ||||||
|  | bool FoldXmadMultiplyAdd(IR::Block& block, IR::Inst& inst) { | ||||||
|  |     /* | ||||||
|  |      * We are looking for this specific pattern: | ||||||
|  |      *   %6      = BitFieldUExtract %op_b, #0, #16 | ||||||
|  |      *   %7      = BitFieldUExtract %op_a', #16, #16 | ||||||
|  |      *   %8      = IMul32 %6, %7 | ||||||
|  |      *   %10     = BitFieldUExtract %op_a', #0, #16 | ||||||
|  |      *   %11     = BitFieldInsert %8, %10, #16, #16 | ||||||
|  |      *   %15     = BitFieldUExtract %op_b, #0, #16 | ||||||
|  |      *   %16     = BitFieldUExtract %op_a, #0, #16 | ||||||
|  |      *   %17     = IMul32 %15, %16 | ||||||
|  |      *   %18     = IAdd32 %17, %op_c | ||||||
|  |      *   %22     = BitFieldUExtract %op_b, #16, #16 | ||||||
|  |      *   %23     = BitFieldUExtract %11, #16, #16 | ||||||
|  |      *   %24     = IMul32 %22, %23 | ||||||
|  |      *   %25     = ShiftLeftLogical32 %24, #16 | ||||||
|  |      *   %26     = ShiftLeftLogical32 %11, #16 | ||||||
|  |      *   %27     = IAdd32 %26, %18 | ||||||
|  |      *   %result = IAdd32 %25, %27 | ||||||
|  |      * | ||||||
|  |      * And replace it with: | ||||||
|  |      *  %temp   = IMul32 %op_a, %op_b | ||||||
|  |      *  %result = IAdd32 %temp, %op_c | ||||||
|  |      * | ||||||
|  |      * This optimization has been proven safe by Nvidia's compiler logic being reversed. | ||||||
|  |      * (If Nvidia generates this code from 'fma(a, b, c)', we can do the same in the reverse order.) | ||||||
|  |      */ | ||||||
|  |     const IR::Value zero{0u}; | ||||||
|  |     const IR::Value sixteen{16u}; | ||||||
|  |     IR::Inst* const _25{inst.Arg(0).TryInstRecursive()}; | ||||||
|  |     IR::Inst* const _27{inst.Arg(1).TryInstRecursive()}; | ||||||
|  |     if (!_25 || !_27) { | ||||||
|  |         return false; | ||||||
|  |     } | ||||||
|  |     if (_27->GetOpcode() != IR::Opcode::IAdd32) { | ||||||
|  |         return false; | ||||||
|  |     } | ||||||
|  |     if (_25->GetOpcode() != IR::Opcode::ShiftLeftLogical32 || _25->Arg(1) != sixteen) { | ||||||
|  |         return false; | ||||||
|  |     } | ||||||
|  |     IR::Inst* const _24{_25->Arg(0).TryInstRecursive()}; | ||||||
|  |     if (!_24 || _24->GetOpcode() != IR::Opcode::IMul32) { | ||||||
|  |         return false; | ||||||
|  |     } | ||||||
|  |     IR::Inst* const _22{_24->Arg(0).TryInstRecursive()}; | ||||||
|  |     IR::Inst* const _23{_24->Arg(1).TryInstRecursive()}; | ||||||
|  |     if (!_22 || !_23) { | ||||||
|  |         return false; | ||||||
|  |     } | ||||||
|  |     if (_22->GetOpcode() != IR::Opcode::BitFieldUExtract) { | ||||||
|  |         return false; | ||||||
|  |     } | ||||||
|  |     if (_23->GetOpcode() != IR::Opcode::BitFieldUExtract) { | ||||||
|  |         return false; | ||||||
|  |     } | ||||||
|  |     if (_22->Arg(1) != sixteen || _22->Arg(2) != sixteen) { | ||||||
|  |         return false; | ||||||
|  |     } | ||||||
|  |     if (_23->Arg(1) != sixteen || _23->Arg(2) != sixteen) { | ||||||
|  |         return false; | ||||||
|  |     } | ||||||
|  |     IR::Inst* const _11{_23->Arg(0).TryInstRecursive()}; | ||||||
|  |     if (!_11 || _11->GetOpcode() != IR::Opcode::BitFieldInsert) { | ||||||
|  |         return false; | ||||||
|  |     } | ||||||
|  |     if (_11->Arg(2) != sixteen || _11->Arg(3) != sixteen) { | ||||||
|  |         return false; | ||||||
|  |     } | ||||||
|  |     IR::Inst* const _8{_11->Arg(0).TryInstRecursive()}; | ||||||
|  |     IR::Inst* const _10{_11->Arg(1).TryInstRecursive()}; | ||||||
|  |     if (!_8 || !_10) { | ||||||
|  |         return false; | ||||||
|  |     } | ||||||
|  |     if (_8->GetOpcode() != IR::Opcode::IMul32) { | ||||||
|  |         return false; | ||||||
|  |     } | ||||||
|  |     if (_10->GetOpcode() != IR::Opcode::BitFieldUExtract) { | ||||||
|  |         return false; | ||||||
|  |     } | ||||||
|  |     IR::Inst* const _6{_8->Arg(0).TryInstRecursive()}; | ||||||
|  |     IR::Inst* const _7{_8->Arg(1).TryInstRecursive()}; | ||||||
|  |     if (!_6 || !_7) { | ||||||
|  |         return false; | ||||||
|  |     } | ||||||
|  |     if (_6->GetOpcode() != IR::Opcode::BitFieldUExtract) { | ||||||
|  |         return false; | ||||||
|  |     } | ||||||
|  |     if (_7->GetOpcode() != IR::Opcode::BitFieldUExtract) { | ||||||
|  |         return false; | ||||||
|  |     } | ||||||
|  |     if (_6->Arg(1) != zero || _6->Arg(2) != sixteen) { | ||||||
|  |         return false; | ||||||
|  |     } | ||||||
|  |     if (_7->Arg(1) != sixteen || _7->Arg(2) != sixteen) { | ||||||
|  |         return false; | ||||||
|  |     } | ||||||
|  |     IR::Inst* const _26{_27->Arg(0).TryInstRecursive()}; | ||||||
|  |     IR::Inst* const _18{_27->Arg(1).TryInstRecursive()}; | ||||||
|  |     if (!_26 || !_18) { | ||||||
|  |         return false; | ||||||
|  |     } | ||||||
|  |     if (_26->GetOpcode() != IR::Opcode::ShiftLeftLogical32 || _26->Arg(1) != sixteen) { | ||||||
|  |         return false; | ||||||
|  |     } | ||||||
|  |     if (_26->Arg(0).InstRecursive() != _11) { | ||||||
|  |         return false; | ||||||
|  |     } | ||||||
|  |     if (_18->GetOpcode() != IR::Opcode::IAdd32) { | ||||||
|  |         return false; | ||||||
|  |     } | ||||||
|  |     IR::Inst* const _17{_18->Arg(0).TryInstRecursive()}; | ||||||
|  |     if (!_17 || _17->GetOpcode() != IR::Opcode::IMul32) { | ||||||
|  |         return false; | ||||||
|  |     } | ||||||
|  |     IR::Inst* const _15{_17->Arg(0).TryInstRecursive()}; | ||||||
|  |     IR::Inst* const _16{_17->Arg(1).TryInstRecursive()}; | ||||||
|  |     if (!_15 || !_16) { | ||||||
|  |         return false; | ||||||
|  |     } | ||||||
|  |     if (_15->GetOpcode() != IR::Opcode::BitFieldUExtract) { | ||||||
|  |         return false; | ||||||
|  |     } | ||||||
|  |     if (_16->GetOpcode() != IR::Opcode::BitFieldUExtract) { | ||||||
|  |         return false; | ||||||
|  |     } | ||||||
|  |     if (_15->Arg(1) != zero || _16->Arg(1) != zero || _10->Arg(1) != zero) { | ||||||
|  |         return false; | ||||||
|  |     } | ||||||
|  |     if (_15->Arg(2) != sixteen || _16->Arg(2) != sixteen || _10->Arg(2) != sixteen) { | ||||||
|  |         return false; | ||||||
|  |     } | ||||||
|  |     const std::array<IR::Value, 3> op_as{ | ||||||
|  |         _7->Arg(0).Resolve(), | ||||||
|  |         _16->Arg(0).Resolve(), | ||||||
|  |         _10->Arg(0).Resolve(), | ||||||
|  |     }; | ||||||
|  |     const std::array<IR::Value, 3> op_bs{ | ||||||
|  |         _22->Arg(0).Resolve(), | ||||||
|  |         _6->Arg(0).Resolve(), | ||||||
|  |         _15->Arg(0).Resolve(), | ||||||
|  |     }; | ||||||
|  |     const IR::U32 op_c{_18->Arg(1)}; | ||||||
|  |     if (!AreEqual(op_as) || !AreEqual(op_bs)) { | ||||||
|  |         return false; | ||||||
|  |     } | ||||||
|  |     IR::IREmitter ir{block, IR::Block::InstructionList::s_iterator_to(inst)}; | ||||||
|  |     inst.ReplaceUsesWith(ir.IAdd(ir.IMul(IR::U32{op_as[0]}, IR::U32{op_bs[1]}), op_c)); | ||||||
|  |     return true; | ||||||
|  | } | ||||||
|  |  | ||||||
| /// Replaces the pattern generated by two XMAD multiplications | /// Replaces the pattern generated by two XMAD multiplications | ||||||
| bool FoldXmadMultiply(IR::Block& block, IR::Inst& inst) { | bool FoldXmadMultiply(IR::Block& block, IR::Inst& inst) { | ||||||
|     /* |     /* | ||||||
| @@ -179,6 +351,9 @@ void FoldAdd(IR::Block& block, IR::Inst& inst) { | |||||||
|         if (FoldXmadMultiply(block, inst)) { |         if (FoldXmadMultiply(block, inst)) { | ||||||
|             return; |             return; | ||||||
|         } |         } | ||||||
|  |         if (FoldXmadMultiplyAdd(block, inst)) { | ||||||
|  |             return; | ||||||
|  |         } | ||||||
|     } |     } | ||||||
| } | } | ||||||
|  |  | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user