emit_spirv_warp: Fix ballot related ops for 64-thread warp sizes
This commit is contained in:
		| @@ -7,8 +7,13 @@ | |||||||
|  |  | ||||||
| namespace Shader::Backend::SPIRV { | namespace Shader::Backend::SPIRV { | ||||||
| namespace { | namespace { | ||||||
|  | Id GetThreadId(EmitContext& ctx) { | ||||||
|  |     return ctx.OpLoad(ctx.U32[1], ctx.subgroup_local_invocation_id); | ||||||
|  | } | ||||||
|  |  | ||||||
| Id WarpExtract(EmitContext& ctx, Id value) { | Id WarpExtract(EmitContext& ctx, Id value) { | ||||||
|     const Id local_index{ctx.OpLoad(ctx.U32[1], ctx.subgroup_local_invocation_id)}; |     const Id thread_id{GetThreadId(ctx)}; | ||||||
|  |     const Id local_index{ctx.OpShiftRightArithmetic(ctx.U32[1], thread_id, ctx.Const(5U))}; | ||||||
|     return ctx.OpVectorExtractDynamic(ctx.U32[1], value, local_index); |     return ctx.OpVectorExtractDynamic(ctx.U32[1], value, local_index); | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -51,11 +56,7 @@ Id SelectValue(EmitContext& ctx, Id in_range, Id value, Id src_thread_id) { | |||||||
| } // Anonymous namespace | } // Anonymous namespace | ||||||
|  |  | ||||||
| Id EmitLaneId(EmitContext& ctx) { | Id EmitLaneId(EmitContext& ctx) { | ||||||
|     const Id id{ctx.OpLoad(ctx.U32[1], ctx.subgroup_local_invocation_id)}; |     return GetThreadId(ctx); | ||||||
|     if (!ctx.profile.warp_size_potentially_larger_than_guest) { |  | ||||||
|         return id; |  | ||||||
|     } |  | ||||||
|     return ctx.OpBitwiseAnd(ctx.U32[1], id, ctx.Const(31U)); |  | ||||||
| } | } | ||||||
|  |  | ||||||
| Id EmitVoteAll(EmitContext& ctx, Id pred) { | Id EmitVoteAll(EmitContext& ctx, Id pred) { | ||||||
| @@ -123,7 +124,7 @@ Id EmitSubgroupGeMask(EmitContext& ctx) { | |||||||
| Id EmitShuffleIndex(EmitContext& ctx, IR::Inst* inst, Id value, Id index, Id clamp, | Id EmitShuffleIndex(EmitContext& ctx, IR::Inst* inst, Id value, Id index, Id clamp, | ||||||
|                     Id segmentation_mask) { |                     Id segmentation_mask) { | ||||||
|     const Id not_seg_mask{ctx.OpNot(ctx.U32[1], segmentation_mask)}; |     const Id not_seg_mask{ctx.OpNot(ctx.U32[1], segmentation_mask)}; | ||||||
|     const Id thread_id{ctx.OpLoad(ctx.U32[1], ctx.subgroup_local_invocation_id)}; |     const Id thread_id{GetThreadId(ctx)}; | ||||||
|     const Id min_thread_id{ComputeMinThreadId(ctx, thread_id, segmentation_mask)}; |     const Id min_thread_id{ComputeMinThreadId(ctx, thread_id, segmentation_mask)}; | ||||||
|     const Id max_thread_id{ComputeMaxThreadId(ctx, min_thread_id, clamp, not_seg_mask)}; |     const Id max_thread_id{ComputeMaxThreadId(ctx, min_thread_id, clamp, not_seg_mask)}; | ||||||
|  |  | ||||||
| @@ -137,7 +138,7 @@ Id EmitShuffleIndex(EmitContext& ctx, IR::Inst* inst, Id value, Id index, Id cla | |||||||
|  |  | ||||||
| Id EmitShuffleUp(EmitContext& ctx, IR::Inst* inst, Id value, Id index, Id clamp, | Id EmitShuffleUp(EmitContext& ctx, IR::Inst* inst, Id value, Id index, Id clamp, | ||||||
|                  Id segmentation_mask) { |                  Id segmentation_mask) { | ||||||
|     const Id thread_id{ctx.OpLoad(ctx.U32[1], ctx.subgroup_local_invocation_id)}; |     const Id thread_id{GetThreadId(ctx)}; | ||||||
|     const Id max_thread_id{GetMaxThreadId(ctx, thread_id, clamp, segmentation_mask)}; |     const Id max_thread_id{GetMaxThreadId(ctx, thread_id, clamp, segmentation_mask)}; | ||||||
|     const Id src_thread_id{ctx.OpISub(ctx.U32[1], thread_id, index)}; |     const Id src_thread_id{ctx.OpISub(ctx.U32[1], thread_id, index)}; | ||||||
|     const Id in_range{ctx.OpSGreaterThanEqual(ctx.U1, src_thread_id, max_thread_id)}; |     const Id in_range{ctx.OpSGreaterThanEqual(ctx.U1, src_thread_id, max_thread_id)}; | ||||||
| @@ -148,7 +149,7 @@ Id EmitShuffleUp(EmitContext& ctx, IR::Inst* inst, Id value, Id index, Id clamp, | |||||||
|  |  | ||||||
| Id EmitShuffleDown(EmitContext& ctx, IR::Inst* inst, Id value, Id index, Id clamp, | Id EmitShuffleDown(EmitContext& ctx, IR::Inst* inst, Id value, Id index, Id clamp, | ||||||
|                    Id segmentation_mask) { |                    Id segmentation_mask) { | ||||||
|     const Id thread_id{ctx.OpLoad(ctx.U32[1], ctx.subgroup_local_invocation_id)}; |     const Id thread_id{GetThreadId(ctx)}; | ||||||
|     const Id max_thread_id{GetMaxThreadId(ctx, thread_id, clamp, segmentation_mask)}; |     const Id max_thread_id{GetMaxThreadId(ctx, thread_id, clamp, segmentation_mask)}; | ||||||
|     const Id src_thread_id{ctx.OpIAdd(ctx.U32[1], thread_id, index)}; |     const Id src_thread_id{ctx.OpIAdd(ctx.U32[1], thread_id, index)}; | ||||||
|     const Id in_range{ctx.OpSLessThanEqual(ctx.U1, src_thread_id, max_thread_id)}; |     const Id in_range{ctx.OpSLessThanEqual(ctx.U1, src_thread_id, max_thread_id)}; | ||||||
| @@ -159,7 +160,7 @@ Id EmitShuffleDown(EmitContext& ctx, IR::Inst* inst, Id value, Id index, Id clam | |||||||
|  |  | ||||||
| Id EmitShuffleButterfly(EmitContext& ctx, IR::Inst* inst, Id value, Id index, Id clamp, | Id EmitShuffleButterfly(EmitContext& ctx, IR::Inst* inst, Id value, Id index, Id clamp, | ||||||
|                         Id segmentation_mask) { |                         Id segmentation_mask) { | ||||||
|     const Id thread_id{ctx.OpLoad(ctx.U32[1], ctx.subgroup_local_invocation_id)}; |     const Id thread_id{GetThreadId(ctx)}; | ||||||
|     const Id max_thread_id{GetMaxThreadId(ctx, thread_id, clamp, segmentation_mask)}; |     const Id max_thread_id{GetMaxThreadId(ctx, thread_id, clamp, segmentation_mask)}; | ||||||
|     const Id src_thread_id{ctx.OpBitwiseXor(ctx.U32[1], thread_id, index)}; |     const Id src_thread_id{ctx.OpBitwiseXor(ctx.U32[1], thread_id, index)}; | ||||||
|     const Id in_range{ctx.OpSLessThanEqual(ctx.U1, src_thread_id, max_thread_id)}; |     const Id in_range{ctx.OpSLessThanEqual(ctx.U1, src_thread_id, max_thread_id)}; | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user