emit_spirv_warp: Fix ballot related ops for 64-thread warp sizes
This commit is contained in:
		| @@ -7,8 +7,13 @@ | ||||
|  | ||||
| namespace Shader::Backend::SPIRV { | ||||
| namespace { | ||||
| Id GetThreadId(EmitContext& ctx) { | ||||
|     return ctx.OpLoad(ctx.U32[1], ctx.subgroup_local_invocation_id); | ||||
| } | ||||
|  | ||||
| Id WarpExtract(EmitContext& ctx, Id value) { | ||||
|     const Id local_index{ctx.OpLoad(ctx.U32[1], ctx.subgroup_local_invocation_id)}; | ||||
|     const Id thread_id{GetThreadId(ctx)}; | ||||
|     const Id local_index{ctx.OpShiftRightArithmetic(ctx.U32[1], thread_id, ctx.Const(5U))}; | ||||
|     return ctx.OpVectorExtractDynamic(ctx.U32[1], value, local_index); | ||||
| } | ||||
|  | ||||
| @@ -51,11 +56,7 @@ Id SelectValue(EmitContext& ctx, Id in_range, Id value, Id src_thread_id) { | ||||
| } // Anonymous namespace | ||||
|  | ||||
| Id EmitLaneId(EmitContext& ctx) { | ||||
|     const Id id{ctx.OpLoad(ctx.U32[1], ctx.subgroup_local_invocation_id)}; | ||||
|     if (!ctx.profile.warp_size_potentially_larger_than_guest) { | ||||
|         return id; | ||||
|     } | ||||
|     return ctx.OpBitwiseAnd(ctx.U32[1], id, ctx.Const(31U)); | ||||
|     return GetThreadId(ctx); | ||||
| } | ||||
|  | ||||
| Id EmitVoteAll(EmitContext& ctx, Id pred) { | ||||
| @@ -123,7 +124,7 @@ Id EmitSubgroupGeMask(EmitContext& ctx) { | ||||
| Id EmitShuffleIndex(EmitContext& ctx, IR::Inst* inst, Id value, Id index, Id clamp, | ||||
|                     Id segmentation_mask) { | ||||
|     const Id not_seg_mask{ctx.OpNot(ctx.U32[1], segmentation_mask)}; | ||||
|     const Id thread_id{ctx.OpLoad(ctx.U32[1], ctx.subgroup_local_invocation_id)}; | ||||
|     const Id thread_id{GetThreadId(ctx)}; | ||||
|     const Id min_thread_id{ComputeMinThreadId(ctx, thread_id, segmentation_mask)}; | ||||
|     const Id max_thread_id{ComputeMaxThreadId(ctx, min_thread_id, clamp, not_seg_mask)}; | ||||
|  | ||||
| @@ -137,7 +138,7 @@ Id EmitShuffleIndex(EmitContext& ctx, IR::Inst* inst, Id value, Id index, Id cla | ||||
|  | ||||
| Id EmitShuffleUp(EmitContext& ctx, IR::Inst* inst, Id value, Id index, Id clamp, | ||||
|                  Id segmentation_mask) { | ||||
|     const Id thread_id{ctx.OpLoad(ctx.U32[1], ctx.subgroup_local_invocation_id)}; | ||||
|     const Id thread_id{GetThreadId(ctx)}; | ||||
|     const Id max_thread_id{GetMaxThreadId(ctx, thread_id, clamp, segmentation_mask)}; | ||||
|     const Id src_thread_id{ctx.OpISub(ctx.U32[1], thread_id, index)}; | ||||
|     const Id in_range{ctx.OpSGreaterThanEqual(ctx.U1, src_thread_id, max_thread_id)}; | ||||
| @@ -148,7 +149,7 @@ Id EmitShuffleUp(EmitContext& ctx, IR::Inst* inst, Id value, Id index, Id clamp, | ||||
|  | ||||
| Id EmitShuffleDown(EmitContext& ctx, IR::Inst* inst, Id value, Id index, Id clamp, | ||||
|                    Id segmentation_mask) { | ||||
|     const Id thread_id{ctx.OpLoad(ctx.U32[1], ctx.subgroup_local_invocation_id)}; | ||||
|     const Id thread_id{GetThreadId(ctx)}; | ||||
|     const Id max_thread_id{GetMaxThreadId(ctx, thread_id, clamp, segmentation_mask)}; | ||||
|     const Id src_thread_id{ctx.OpIAdd(ctx.U32[1], thread_id, index)}; | ||||
|     const Id in_range{ctx.OpSLessThanEqual(ctx.U1, src_thread_id, max_thread_id)}; | ||||
| @@ -159,7 +160,7 @@ Id EmitShuffleDown(EmitContext& ctx, IR::Inst* inst, Id value, Id index, Id clam | ||||
|  | ||||
| Id EmitShuffleButterfly(EmitContext& ctx, IR::Inst* inst, Id value, Id index, Id clamp, | ||||
|                         Id segmentation_mask) { | ||||
|     const Id thread_id{ctx.OpLoad(ctx.U32[1], ctx.subgroup_local_invocation_id)}; | ||||
|     const Id thread_id{GetThreadId(ctx)}; | ||||
|     const Id max_thread_id{GetMaxThreadId(ctx, thread_id, clamp, segmentation_mask)}; | ||||
|     const Id src_thread_id{ctx.OpBitwiseXor(ctx.U32[1], thread_id, index)}; | ||||
|     const Id in_range{ctx.OpSLessThanEqual(ctx.U1, src_thread_id, max_thread_id)}; | ||||
|   | ||||
		Reference in New Issue
	
	Block a user