shader: Add subgroup masks
This commit is contained in:
		| @@ -390,8 +390,16 @@ void EmitContext::DefineInputs(const Info& info) { | ||||
|     if (info.uses_local_invocation_id) { | ||||
|         local_invocation_id = DefineInput(*this, U32[3], spv::BuiltIn::LocalInvocationId); | ||||
|     } | ||||
|     if (info.uses_subgroup_mask) { | ||||
|         subgroup_mask_eq = DefineInput(*this, U32[4], spv::BuiltIn::SubgroupEqMaskKHR); | ||||
|         subgroup_mask_lt = DefineInput(*this, U32[4], spv::BuiltIn::SubgroupLtMaskKHR); | ||||
|         subgroup_mask_le = DefineInput(*this, U32[4], spv::BuiltIn::SubgroupLeMaskKHR); | ||||
|         subgroup_mask_gt = DefineInput(*this, U32[4], spv::BuiltIn::SubgroupGtMaskKHR); | ||||
|         subgroup_mask_ge = DefineInput(*this, U32[4], spv::BuiltIn::SubgroupGeMaskKHR); | ||||
|     } | ||||
|     if (info.uses_subgroup_invocation_id || | ||||
|         (profile.warp_size_potentially_larger_than_guest && info.uses_subgroup_vote)) { | ||||
|         (profile.warp_size_potentially_larger_than_guest && | ||||
|          (info.uses_subgroup_vote || info.uses_subgroup_mask))) { | ||||
|         subgroup_local_invocation_id = | ||||
|             DefineInput(*this, U32[1], spv::BuiltIn::SubgroupLocalInvocationId); | ||||
|     } | ||||
|   | ||||
| @@ -97,6 +97,11 @@ public: | ||||
|     Id workgroup_id{}; | ||||
|     Id local_invocation_id{}; | ||||
|     Id subgroup_local_invocation_id{}; | ||||
|     Id subgroup_mask_eq{}; | ||||
|     Id subgroup_mask_lt{}; | ||||
|     Id subgroup_mask_le{}; | ||||
|     Id subgroup_mask_gt{}; | ||||
|     Id subgroup_mask_ge{}; | ||||
|     Id instance_id{}; | ||||
|     Id instance_index{}; | ||||
|     Id base_instance{}; | ||||
|   | ||||
| @@ -401,6 +401,11 @@ Id EmitVoteAll(EmitContext& ctx, Id pred); | ||||
| Id EmitVoteAny(EmitContext& ctx, Id pred); | ||||
| Id EmitVoteEqual(EmitContext& ctx, Id pred); | ||||
| Id EmitSubgroupBallot(EmitContext& ctx, Id pred); | ||||
| Id EmitSubgroupEqMask(EmitContext& ctx); | ||||
| Id EmitSubgroupLtMask(EmitContext& ctx); | ||||
| Id EmitSubgroupLeMask(EmitContext& ctx); | ||||
| Id EmitSubgroupGtMask(EmitContext& ctx); | ||||
| Id EmitSubgroupGeMask(EmitContext& ctx); | ||||
| Id EmitShuffleIndex(EmitContext& ctx, IR::Inst* inst, Id value, Id index, Id clamp, | ||||
|                     Id segmentation_mask); | ||||
| Id EmitShuffleUp(EmitContext& ctx, IR::Inst* inst, Id value, Id index, Id clamp, | ||||
|   | ||||
| @@ -6,10 +6,18 @@ | ||||
|  | ||||
| namespace Shader::Backend::SPIRV { | ||||
| namespace { | ||||
| Id LargeWarpBallot(EmitContext& ctx, Id ballot) { | ||||
| Id WarpExtract(EmitContext& ctx, Id value) { | ||||
|     const Id shift{ctx.Constant(ctx.U32[1], 5)}; | ||||
|     const Id local_index{ctx.OpLoad(ctx.U32[1], ctx.subgroup_local_invocation_id)}; | ||||
|     return ctx.OpVectorExtractDynamic(ctx.U32[1], ballot, local_index); | ||||
|     return ctx.OpVectorExtractDynamic(ctx.U32[1], value, local_index); | ||||
| } | ||||
|  | ||||
| Id LoadMask(EmitContext& ctx, Id mask) { | ||||
|     const Id value{ctx.OpLoad(ctx.U32[4], mask)}; | ||||
|     if (!ctx.profile.warp_size_potentially_larger_than_guest) { | ||||
|         return ctx.OpCompositeExtract(ctx.U32[1], value, 0U); | ||||
|     } | ||||
|     return WarpExtract(ctx, value); | ||||
| } | ||||
|  | ||||
| void SetInBoundsFlag(IR::Inst* inst, Id result) { | ||||
| @@ -47,8 +55,8 @@ Id EmitVoteAll(EmitContext& ctx, Id pred) { | ||||
|         return ctx.OpSubgroupAllKHR(ctx.U1, pred); | ||||
|     } | ||||
|     const Id mask_ballot{ctx.OpSubgroupBallotKHR(ctx.U32[4], ctx.true_value)}; | ||||
|     const Id active_mask{LargeWarpBallot(ctx, mask_ballot)}; | ||||
|     const Id ballot{LargeWarpBallot(ctx, ctx.OpSubgroupBallotKHR(ctx.U32[4], pred))}; | ||||
|     const Id active_mask{WarpExtract(ctx, mask_ballot)}; | ||||
|     const Id ballot{WarpExtract(ctx, ctx.OpSubgroupBallotKHR(ctx.U32[4], pred))}; | ||||
|     const Id lhs{ctx.OpBitwiseAnd(ctx.U32[1], ballot, active_mask)}; | ||||
|     return ctx.OpIEqual(ctx.U1, lhs, active_mask); | ||||
| } | ||||
| @@ -58,8 +66,8 @@ Id EmitVoteAny(EmitContext& ctx, Id pred) { | ||||
|         return ctx.OpSubgroupAnyKHR(ctx.U1, pred); | ||||
|     } | ||||
|     const Id mask_ballot{ctx.OpSubgroupBallotKHR(ctx.U32[4], ctx.true_value)}; | ||||
|     const Id active_mask{LargeWarpBallot(ctx, mask_ballot)}; | ||||
|     const Id ballot{LargeWarpBallot(ctx, ctx.OpSubgroupBallotKHR(ctx.U32[4], pred))}; | ||||
|     const Id active_mask{WarpExtract(ctx, mask_ballot)}; | ||||
|     const Id ballot{WarpExtract(ctx, ctx.OpSubgroupBallotKHR(ctx.U32[4], pred))}; | ||||
|     const Id lhs{ctx.OpBitwiseAnd(ctx.U32[1], ballot, active_mask)}; | ||||
|     return ctx.OpINotEqual(ctx.U1, lhs, ctx.u32_zero_value); | ||||
| } | ||||
| @@ -69,8 +77,8 @@ Id EmitVoteEqual(EmitContext& ctx, Id pred) { | ||||
|         return ctx.OpSubgroupAllEqualKHR(ctx.U1, pred); | ||||
|     } | ||||
|     const Id mask_ballot{ctx.OpSubgroupBallotKHR(ctx.U32[4], ctx.true_value)}; | ||||
|     const Id active_mask{LargeWarpBallot(ctx, mask_ballot)}; | ||||
|     const Id ballot{LargeWarpBallot(ctx, ctx.OpSubgroupBallotKHR(ctx.U32[4], pred))}; | ||||
|     const Id active_mask{WarpExtract(ctx, mask_ballot)}; | ||||
|     const Id ballot{WarpExtract(ctx, ctx.OpSubgroupBallotKHR(ctx.U32[4], pred))}; | ||||
|     const Id lhs{ctx.OpBitwiseXor(ctx.U32[1], ballot, active_mask)}; | ||||
|     return ctx.OpLogicalOr(ctx.U1, ctx.OpIEqual(ctx.U1, lhs, ctx.u32_zero_value), | ||||
|                            ctx.OpIEqual(ctx.U1, lhs, active_mask)); | ||||
| @@ -81,7 +89,27 @@ Id EmitSubgroupBallot(EmitContext& ctx, Id pred) { | ||||
|     if (!ctx.profile.warp_size_potentially_larger_than_guest) { | ||||
|         return ctx.OpCompositeExtract(ctx.U32[1], ballot, 0U); | ||||
|     } | ||||
|     return LargeWarpBallot(ctx, ballot); | ||||
|     return WarpExtract(ctx, ballot); | ||||
| } | ||||
|  | ||||
| Id EmitSubgroupEqMask(EmitContext& ctx) { | ||||
|     return LoadMask(ctx, ctx.subgroup_mask_eq); | ||||
| } | ||||
|  | ||||
| Id EmitSubgroupLtMask(EmitContext& ctx) { | ||||
|     return LoadMask(ctx, ctx.subgroup_mask_lt); | ||||
| } | ||||
|  | ||||
| Id EmitSubgroupLeMask(EmitContext& ctx) { | ||||
|     return LoadMask(ctx, ctx.subgroup_mask_le); | ||||
| } | ||||
|  | ||||
| Id EmitSubgroupGtMask(EmitContext& ctx) { | ||||
|     return LoadMask(ctx, ctx.subgroup_mask_gt); | ||||
| } | ||||
|  | ||||
| Id EmitSubgroupGeMask(EmitContext& ctx) { | ||||
|     return LoadMask(ctx, ctx.subgroup_mask_ge); | ||||
| } | ||||
|  | ||||
| Id EmitShuffleIndex(EmitContext& ctx, IR::Inst* inst, Id value, Id index, Id clamp, | ||||
|   | ||||
		Reference in New Issue
	
	Block a user