shader: Add subgroup masks
This commit is contained in:
		| @@ -390,8 +390,16 @@ void EmitContext::DefineInputs(const Info& info) { | ||||
|     if (info.uses_local_invocation_id) { | ||||
|         local_invocation_id = DefineInput(*this, U32[3], spv::BuiltIn::LocalInvocationId); | ||||
|     } | ||||
|     if (info.uses_subgroup_mask) { | ||||
|         subgroup_mask_eq = DefineInput(*this, U32[4], spv::BuiltIn::SubgroupEqMaskKHR); | ||||
|         subgroup_mask_lt = DefineInput(*this, U32[4], spv::BuiltIn::SubgroupLtMaskKHR); | ||||
|         subgroup_mask_le = DefineInput(*this, U32[4], spv::BuiltIn::SubgroupLeMaskKHR); | ||||
|         subgroup_mask_gt = DefineInput(*this, U32[4], spv::BuiltIn::SubgroupGtMaskKHR); | ||||
|         subgroup_mask_ge = DefineInput(*this, U32[4], spv::BuiltIn::SubgroupGeMaskKHR); | ||||
|     } | ||||
|     if (info.uses_subgroup_invocation_id || | ||||
|         (profile.warp_size_potentially_larger_than_guest && info.uses_subgroup_vote)) { | ||||
|         (profile.warp_size_potentially_larger_than_guest && | ||||
|          (info.uses_subgroup_vote || info.uses_subgroup_mask))) { | ||||
|         subgroup_local_invocation_id = | ||||
|             DefineInput(*this, U32[1], spv::BuiltIn::SubgroupLocalInvocationId); | ||||
|     } | ||||
|   | ||||
| @@ -97,6 +97,11 @@ public: | ||||
|     Id workgroup_id{}; | ||||
|     Id local_invocation_id{}; | ||||
|     Id subgroup_local_invocation_id{}; | ||||
|     Id subgroup_mask_eq{}; | ||||
|     Id subgroup_mask_lt{}; | ||||
|     Id subgroup_mask_le{}; | ||||
|     Id subgroup_mask_gt{}; | ||||
|     Id subgroup_mask_ge{}; | ||||
|     Id instance_id{}; | ||||
|     Id instance_index{}; | ||||
|     Id base_instance{}; | ||||
|   | ||||
| @@ -401,6 +401,11 @@ Id EmitVoteAll(EmitContext& ctx, Id pred); | ||||
| Id EmitVoteAny(EmitContext& ctx, Id pred); | ||||
| Id EmitVoteEqual(EmitContext& ctx, Id pred); | ||||
| Id EmitSubgroupBallot(EmitContext& ctx, Id pred); | ||||
| Id EmitSubgroupEqMask(EmitContext& ctx); | ||||
| Id EmitSubgroupLtMask(EmitContext& ctx); | ||||
| Id EmitSubgroupLeMask(EmitContext& ctx); | ||||
| Id EmitSubgroupGtMask(EmitContext& ctx); | ||||
| Id EmitSubgroupGeMask(EmitContext& ctx); | ||||
| Id EmitShuffleIndex(EmitContext& ctx, IR::Inst* inst, Id value, Id index, Id clamp, | ||||
|                     Id segmentation_mask); | ||||
| Id EmitShuffleUp(EmitContext& ctx, IR::Inst* inst, Id value, Id index, Id clamp, | ||||
|   | ||||
| @@ -6,10 +6,18 @@ | ||||
|  | ||||
| namespace Shader::Backend::SPIRV { | ||||
| namespace { | ||||
| Id LargeWarpBallot(EmitContext& ctx, Id ballot) { | ||||
| Id WarpExtract(EmitContext& ctx, Id value) { | ||||
|     const Id shift{ctx.Constant(ctx.U32[1], 5)}; | ||||
|     const Id local_index{ctx.OpLoad(ctx.U32[1], ctx.subgroup_local_invocation_id)}; | ||||
|     return ctx.OpVectorExtractDynamic(ctx.U32[1], ballot, local_index); | ||||
|     return ctx.OpVectorExtractDynamic(ctx.U32[1], value, local_index); | ||||
| } | ||||
|  | ||||
| Id LoadMask(EmitContext& ctx, Id mask) { | ||||
|     const Id value{ctx.OpLoad(ctx.U32[4], mask)}; | ||||
|     if (!ctx.profile.warp_size_potentially_larger_than_guest) { | ||||
|         return ctx.OpCompositeExtract(ctx.U32[1], value, 0U); | ||||
|     } | ||||
|     return WarpExtract(ctx, value); | ||||
| } | ||||
|  | ||||
| void SetInBoundsFlag(IR::Inst* inst, Id result) { | ||||
| @@ -47,8 +55,8 @@ Id EmitVoteAll(EmitContext& ctx, Id pred) { | ||||
|         return ctx.OpSubgroupAllKHR(ctx.U1, pred); | ||||
|     } | ||||
|     const Id mask_ballot{ctx.OpSubgroupBallotKHR(ctx.U32[4], ctx.true_value)}; | ||||
|     const Id active_mask{LargeWarpBallot(ctx, mask_ballot)}; | ||||
|     const Id ballot{LargeWarpBallot(ctx, ctx.OpSubgroupBallotKHR(ctx.U32[4], pred))}; | ||||
|     const Id active_mask{WarpExtract(ctx, mask_ballot)}; | ||||
|     const Id ballot{WarpExtract(ctx, ctx.OpSubgroupBallotKHR(ctx.U32[4], pred))}; | ||||
|     const Id lhs{ctx.OpBitwiseAnd(ctx.U32[1], ballot, active_mask)}; | ||||
|     return ctx.OpIEqual(ctx.U1, lhs, active_mask); | ||||
| } | ||||
| @@ -58,8 +66,8 @@ Id EmitVoteAny(EmitContext& ctx, Id pred) { | ||||
|         return ctx.OpSubgroupAnyKHR(ctx.U1, pred); | ||||
|     } | ||||
|     const Id mask_ballot{ctx.OpSubgroupBallotKHR(ctx.U32[4], ctx.true_value)}; | ||||
|     const Id active_mask{LargeWarpBallot(ctx, mask_ballot)}; | ||||
|     const Id ballot{LargeWarpBallot(ctx, ctx.OpSubgroupBallotKHR(ctx.U32[4], pred))}; | ||||
|     const Id active_mask{WarpExtract(ctx, mask_ballot)}; | ||||
|     const Id ballot{WarpExtract(ctx, ctx.OpSubgroupBallotKHR(ctx.U32[4], pred))}; | ||||
|     const Id lhs{ctx.OpBitwiseAnd(ctx.U32[1], ballot, active_mask)}; | ||||
|     return ctx.OpINotEqual(ctx.U1, lhs, ctx.u32_zero_value); | ||||
| } | ||||
| @@ -69,8 +77,8 @@ Id EmitVoteEqual(EmitContext& ctx, Id pred) { | ||||
|         return ctx.OpSubgroupAllEqualKHR(ctx.U1, pred); | ||||
|     } | ||||
|     const Id mask_ballot{ctx.OpSubgroupBallotKHR(ctx.U32[4], ctx.true_value)}; | ||||
|     const Id active_mask{LargeWarpBallot(ctx, mask_ballot)}; | ||||
|     const Id ballot{LargeWarpBallot(ctx, ctx.OpSubgroupBallotKHR(ctx.U32[4], pred))}; | ||||
|     const Id active_mask{WarpExtract(ctx, mask_ballot)}; | ||||
|     const Id ballot{WarpExtract(ctx, ctx.OpSubgroupBallotKHR(ctx.U32[4], pred))}; | ||||
|     const Id lhs{ctx.OpBitwiseXor(ctx.U32[1], ballot, active_mask)}; | ||||
|     return ctx.OpLogicalOr(ctx.U1, ctx.OpIEqual(ctx.U1, lhs, ctx.u32_zero_value), | ||||
|                            ctx.OpIEqual(ctx.U1, lhs, active_mask)); | ||||
| @@ -81,7 +89,27 @@ Id EmitSubgroupBallot(EmitContext& ctx, Id pred) { | ||||
|     if (!ctx.profile.warp_size_potentially_larger_than_guest) { | ||||
|         return ctx.OpCompositeExtract(ctx.U32[1], ballot, 0U); | ||||
|     } | ||||
|     return LargeWarpBallot(ctx, ballot); | ||||
|     return WarpExtract(ctx, ballot); | ||||
| } | ||||
|  | ||||
| Id EmitSubgroupEqMask(EmitContext& ctx) { | ||||
|     return LoadMask(ctx, ctx.subgroup_mask_eq); | ||||
| } | ||||
|  | ||||
| Id EmitSubgroupLtMask(EmitContext& ctx) { | ||||
|     return LoadMask(ctx, ctx.subgroup_mask_lt); | ||||
| } | ||||
|  | ||||
| Id EmitSubgroupLeMask(EmitContext& ctx) { | ||||
|     return LoadMask(ctx, ctx.subgroup_mask_le); | ||||
| } | ||||
|  | ||||
| Id EmitSubgroupGtMask(EmitContext& ctx) { | ||||
|     return LoadMask(ctx, ctx.subgroup_mask_gt); | ||||
| } | ||||
|  | ||||
| Id EmitSubgroupGeMask(EmitContext& ctx) { | ||||
|     return LoadMask(ctx, ctx.subgroup_mask_ge); | ||||
| } | ||||
|  | ||||
| Id EmitShuffleIndex(EmitContext& ctx, IR::Inst* inst, Id value, Id index, Id clamp, | ||||
|   | ||||
| @@ -1628,6 +1628,26 @@ U32 IREmitter::SubgroupBallot(const U1& value) { | ||||
|     return Inst<U32>(Opcode::SubgroupBallot, value); | ||||
| } | ||||
|  | ||||
| U32 IREmitter::SubgroupEqMask() { | ||||
|     return Inst<U32>(Opcode::SubgroupEqMask); | ||||
| } | ||||
|  | ||||
| U32 IREmitter::SubgroupLtMask() { | ||||
|     return Inst<U32>(Opcode::SubgroupLtMask); | ||||
| } | ||||
|  | ||||
| U32 IREmitter::SubgroupLeMask() { | ||||
|     return Inst<U32>(Opcode::SubgroupLeMask); | ||||
| } | ||||
|  | ||||
| U32 IREmitter::SubgroupGtMask() { | ||||
|     return Inst<U32>(Opcode::SubgroupGtMask); | ||||
| } | ||||
|  | ||||
| U32 IREmitter::SubgroupGeMask() { | ||||
|     return Inst<U32>(Opcode::SubgroupGeMask); | ||||
| } | ||||
|  | ||||
| U32 IREmitter::ShuffleIndex(const IR::U32& value, const IR::U32& index, const IR::U32& clamp, | ||||
|                             const IR::U32& seg_mask) { | ||||
|     return Inst<U32>(Opcode::ShuffleIndex, value, index, clamp, seg_mask); | ||||
|   | ||||
| @@ -281,6 +281,11 @@ public: | ||||
|     [[nodiscard]] U1 VoteAny(const U1& value); | ||||
|     [[nodiscard]] U1 VoteEqual(const U1& value); | ||||
|     [[nodiscard]] U32 SubgroupBallot(const U1& value); | ||||
|     [[nodiscard]] U32 SubgroupEqMask(); | ||||
|     [[nodiscard]] U32 SubgroupLtMask(); | ||||
|     [[nodiscard]] U32 SubgroupLeMask(); | ||||
|     [[nodiscard]] U32 SubgroupGtMask(); | ||||
|     [[nodiscard]] U32 SubgroupGeMask(); | ||||
|     [[nodiscard]] U32 ShuffleIndex(const IR::U32& value, const IR::U32& index, const IR::U32& clamp, | ||||
|                                    const IR::U32& seg_mask); | ||||
|     [[nodiscard]] U32 ShuffleUp(const IR::U32& value, const IR::U32& index, const IR::U32& clamp, | ||||
|   | ||||
| @@ -417,6 +417,11 @@ OPCODE(VoteAll,                                             U1,             U1, | ||||
| OPCODE(VoteAny,                                             U1,             U1,                                                                             ) | ||||
| OPCODE(VoteEqual,                                           U1,             U1,                                                                             ) | ||||
| OPCODE(SubgroupBallot,                                      U32,            U1,                                                                             ) | ||||
| OPCODE(SubgroupEqMask,                                      U32,                                                                                            ) | ||||
| OPCODE(SubgroupLtMask,                                      U32,                                                                                            ) | ||||
| OPCODE(SubgroupLeMask,                                      U32,                                                                                            ) | ||||
| OPCODE(SubgroupGtMask,                                      U32,                                                                                            ) | ||||
| OPCODE(SubgroupGeMask,                                      U32,                                                                                            ) | ||||
| OPCODE(ShuffleIndex,                                        U32,            U32,            U32,            U32,            U32,                            ) | ||||
| OPCODE(ShuffleUp,                                           U32,            U32,            U32,            U32,            U32,                            ) | ||||
| OPCODE(ShuffleDown,                                         U32,            U32,            U32,            U32,            U32,                            ) | ||||
|   | ||||
| @@ -10,6 +10,7 @@ namespace Shader::Maxwell { | ||||
| namespace { | ||||
| enum class SpecialRegister : u64 { | ||||
|     SR_LANEID = 0, | ||||
|     SR_CLOCK = 1, | ||||
|     SR_VIRTCFG = 2, | ||||
|     SR_VIRTID = 3, | ||||
|     SR_PM0 = 4, | ||||
| @@ -20,6 +21,9 @@ enum class SpecialRegister : u64 { | ||||
|     SR_PM5 = 9, | ||||
|     SR_PM6 = 10, | ||||
|     SR_PM7 = 11, | ||||
|     SR12 = 12, | ||||
|     SR13 = 13, | ||||
|     SR14 = 14, | ||||
|     SR_ORDERING_TICKET = 15, | ||||
|     SR_PRIM_TYPE = 16, | ||||
|     SR_INVOCATION_ID = 17, | ||||
| @@ -41,44 +45,70 @@ enum class SpecialRegister : u64 { | ||||
|     SR_TID_X = 33, | ||||
|     SR_TID_Y = 34, | ||||
|     SR_TID_Z = 35, | ||||
|     SR_CTA_PARAM = 36, | ||||
|     SR_CTAID_X = 37, | ||||
|     SR_CTAID_Y = 38, | ||||
|     SR_CTAID_Z = 39, | ||||
|     SR_NTID = 49, | ||||
|     SR_CirQueueIncrMinusOne = 50, | ||||
|     SR_NLATC = 51, | ||||
|     SR_SWINLO = 57, | ||||
|     SR_SWINSZ = 58, | ||||
|     SR_SMEMSZ = 59, | ||||
|     SR_SMEMBANKS = 60, | ||||
|     SR_LWINLO = 61, | ||||
|     SR_LWINSZ = 62, | ||||
|     SR_LMEMLOSZ = 63, | ||||
|     SR_LMEMHIOFF = 64, | ||||
|     SR_EQMASK = 65, | ||||
|     SR_LTMASK = 66, | ||||
|     SR_LEMASK = 67, | ||||
|     SR_GTMASK = 68, | ||||
|     SR_GEMASK = 69, | ||||
|     SR_REGALLOC = 70, | ||||
|     SR_GLOBALERRORSTATUS = 73, | ||||
|     SR_WARPERRORSTATUS = 75, | ||||
|     SR_PM_HI0 = 81, | ||||
|     SR_PM_HI1 = 82, | ||||
|     SR_PM_HI2 = 83, | ||||
|     SR_PM_HI3 = 84, | ||||
|     SR_PM_HI4 = 85, | ||||
|     SR_PM_HI5 = 86, | ||||
|     SR_PM_HI6 = 87, | ||||
|     SR_PM_HI7 = 88, | ||||
|     SR_CLOCKLO = 89, | ||||
|     SR_CLOCKHI = 90, | ||||
|     SR_GLOBALTIMERLO = 91, | ||||
|     SR_GLOBALTIMERHI = 92, | ||||
|     SR_HWTASKID = 105, | ||||
|     SR_CIRCULARQUEUEENTRYINDEX = 106, | ||||
|     SR_CIRCULARQUEUEENTRYADDRESSLOW = 107, | ||||
|     SR_CIRCULARQUEUEENTRYADDRESSHIGH = 108, | ||||
|     SR_NTID = 40, | ||||
|     SR_CirQueueIncrMinusOne = 41, | ||||
|     SR_NLATC = 42, | ||||
|     SR43 = 43, | ||||
|     SR_SM_SPA_VERSION = 44, | ||||
|     SR_MULTIPASSSHADERINFO = 45, | ||||
|     SR_LWINHI = 46, | ||||
|     SR_SWINHI = 47, | ||||
|     SR_SWINLO = 48, | ||||
|     SR_SWINSZ = 49, | ||||
|     SR_SMEMSZ = 50, | ||||
|     SR_SMEMBANKS = 51, | ||||
|     SR_LWINLO = 52, | ||||
|     SR_LWINSZ = 53, | ||||
|     SR_LMEMLOSZ = 54, | ||||
|     SR_LMEMHIOFF = 55, | ||||
|     SR_EQMASK = 56, | ||||
|     SR_LTMASK = 57, | ||||
|     SR_LEMASK = 58, | ||||
|     SR_GTMASK = 59, | ||||
|     SR_GEMASK = 60, | ||||
|     SR_REGALLOC = 61, | ||||
|     SR_BARRIERALLOC = 62, | ||||
|     SR63 = 63, | ||||
|     SR_GLOBALERRORSTATUS = 64, | ||||
|     SR65 = 65, | ||||
|     SR_WARPERRORSTATUS = 66, | ||||
|     SR_WARPERRORSTATUSCLEAR = 67, | ||||
|     SR68 = 68, | ||||
|     SR69 = 69, | ||||
|     SR70 = 70, | ||||
|     SR71 = 71, | ||||
|     SR_PM_HI0 = 72, | ||||
|     SR_PM_HI1 = 73, | ||||
|     SR_PM_HI2 = 74, | ||||
|     SR_PM_HI3 = 75, | ||||
|     SR_PM_HI4 = 76, | ||||
|     SR_PM_HI5 = 77, | ||||
|     SR_PM_HI6 = 78, | ||||
|     SR_PM_HI7 = 79, | ||||
|     SR_CLOCKLO = 80, | ||||
|     SR_CLOCKHI = 81, | ||||
|     SR_GLOBALTIMERLO = 82, | ||||
|     SR_GLOBALTIMERHI = 83, | ||||
|     SR84 = 84, | ||||
|     SR85 = 85, | ||||
|     SR86 = 86, | ||||
|     SR87 = 87, | ||||
|     SR88 = 88, | ||||
|     SR89 = 89, | ||||
|     SR90 = 90, | ||||
|     SR91 = 91, | ||||
|     SR92 = 92, | ||||
|     SR93 = 93, | ||||
|     SR94 = 94, | ||||
|     SR95 = 95, | ||||
|     SR_HWTASKID = 96, | ||||
|     SR_CIRCULARQUEUEENTRYINDEX = 97, | ||||
|     SR_CIRCULARQUEUEENTRYADDRESSLOW = 98, | ||||
|     SR_CIRCULARQUEUEENTRYADDRESSHIGH = 99, | ||||
| }; | ||||
|  | ||||
| [[nodiscard]] IR::U32 Read(IR::IREmitter& ir, SpecialRegister special_register) { | ||||
| @@ -103,6 +133,16 @@ enum class SpecialRegister : u64 { | ||||
|         return ir.Imm32(Common::BitCast<u32>(1.0f)); | ||||
|     case SpecialRegister::SR_LANEID: | ||||
|         return ir.LaneId(); | ||||
|     case SpecialRegister::SR_EQMASK: | ||||
|         return ir.SubgroupEqMask(); | ||||
|     case SpecialRegister::SR_LTMASK: | ||||
|         return ir.SubgroupLtMask(); | ||||
|     case SpecialRegister::SR_LEMASK: | ||||
|         return ir.SubgroupLeMask(); | ||||
|     case SpecialRegister::SR_GTMASK: | ||||
|         return ir.SubgroupGtMask(); | ||||
|     case SpecialRegister::SR_GEMASK: | ||||
|         return ir.SubgroupGeMask(); | ||||
|     default: | ||||
|         throw NotImplementedException("S2R special register {}", special_register); | ||||
|     } | ||||
|   | ||||
| @@ -414,6 +414,13 @@ void VisitUsages(Info& info, IR::Inst& inst) { | ||||
|             inst.GetAssociatedPseudoOperation(IR::Opcode::GetSparseFromOp) != nullptr; | ||||
|         break; | ||||
|     } | ||||
|     case IR::Opcode::SubgroupEqMask: | ||||
|     case IR::Opcode::SubgroupLtMask: | ||||
|     case IR::Opcode::SubgroupLeMask: | ||||
|     case IR::Opcode::SubgroupGtMask: | ||||
|     case IR::Opcode::SubgroupGeMask: | ||||
|         info.uses_subgroup_mask = true; | ||||
|         break; | ||||
|     case IR::Opcode::VoteAll: | ||||
|     case IR::Opcode::VoteAny: | ||||
|     case IR::Opcode::VoteEqual: | ||||
|   | ||||
| @@ -99,6 +99,7 @@ struct Info { | ||||
|     bool uses_sparse_residency{}; | ||||
|     bool uses_demote_to_helper_invocation{}; | ||||
|     bool uses_subgroup_vote{}; | ||||
|     bool uses_subgroup_mask{}; | ||||
|     bool uses_fswzadd{}; | ||||
|  | ||||
|     IR::Type used_constant_buffer_types{}; | ||||
|   | ||||
		Reference in New Issue
	
	Block a user