diff --git a/libs/vkd3d-shader/spirv.c b/libs/vkd3d-shader/spirv.c index 7ac9f3cd..fd2babc3 100644 --- a/libs/vkd3d-shader/spirv.c +++ b/libs/vkd3d-shader/spirv.c @@ -1760,6 +1760,14 @@ static uint32_t vkd3d_spirv_build_op_group_nonuniform_ballot(struct vkd3d_spirv_ result_type, vkd3d_spirv_get_op_scope_subgroup(builder), val_id); } +static uint32_t vkd3d_spirv_build_op_group_nonuniform_ballot_bit_count(struct vkd3d_spirv_builder *builder, + uint32_t result_type, SpvGroupOperation group_op, uint32_t val_id) +{ + vkd3d_spirv_enable_capability(builder, SpvCapabilityGroupNonUniformBallot); + return vkd3d_spirv_build_op_tr3(builder, &builder->function_stream, SpvOpGroupNonUniformBallotBitCount, + result_type, vkd3d_spirv_get_op_scope_subgroup(builder), group_op, val_id); +} + static uint32_t vkd3d_spirv_build_op_glsl_std450_tr1(struct vkd3d_spirv_builder *builder, enum GLSLstd450 op, uint32_t result_type, uint32_t operand) { @@ -9801,18 +9809,26 @@ static void spirv_compiler_emit_wave_bool_op(struct spirv_compiler *compiler, spirv_compiler_emit_store_dst(compiler, dst, val_id); } -static void spirv_compiler_emit_wave_active_ballot(struct spirv_compiler *compiler, - const struct vkd3d_shader_instruction *instruction) +static uint32_t spirv_compiler_emit_group_nonuniform_ballot(struct spirv_compiler *compiler, + const struct vkd3d_shader_src_param *src) { struct vkd3d_spirv_builder *builder = &compiler->spirv_builder; - const struct vkd3d_shader_dst_param *dst = instruction->dst; - const struct vkd3d_shader_src_param *src = instruction->src; uint32_t type_id, val_id; type_id = vkd3d_spirv_get_type_id(builder, VKD3D_SHADER_COMPONENT_UINT, VKD3D_VEC4_SIZE); val_id = spirv_compiler_emit_load_src(compiler, src, VKD3DSP_WRITEMASK_0); val_id = vkd3d_spirv_build_op_group_nonuniform_ballot(builder, type_id, val_id); + return val_id; +} + +static void spirv_compiler_emit_wave_active_ballot(struct spirv_compiler *compiler, + const struct vkd3d_shader_instruction *instruction) +{ + const struct vkd3d_shader_dst_param *dst = instruction->dst; + uint32_t val_id; + + val_id = spirv_compiler_emit_group_nonuniform_ballot(compiler, instruction->src); spirv_compiler_emit_store_dst(compiler, dst, val_id); } @@ -9871,6 +9887,23 @@ static void spirv_compiler_emit_wave_alu_op(struct spirv_compiler *compiler, spirv_compiler_emit_store_dst(compiler, dst, val_id); } +static void spirv_compiler_emit_wave_bit_count(struct spirv_compiler *compiler, + const struct vkd3d_shader_instruction *instruction) +{ + struct vkd3d_spirv_builder *builder = &compiler->spirv_builder; + const struct vkd3d_shader_dst_param *dst = instruction->dst; + SpvGroupOperation group_op; + uint32_t type_id, val_id; + + group_op = SpvGroupOperationReduce; + + val_id = spirv_compiler_emit_group_nonuniform_ballot(compiler, instruction->src); + type_id = vkd3d_spirv_get_type_id(builder, VKD3D_SHADER_COMPONENT_UINT, 1); + val_id = vkd3d_spirv_build_op_group_nonuniform_ballot_bit_count(builder, type_id, group_op, val_id); + + spirv_compiler_emit_store_dst(compiler, dst, val_id); +} + /* This function is called after declarations are processed. */ static void spirv_compiler_emit_main_prolog(struct spirv_compiler *compiler) { @@ -10236,6 +10269,9 @@ static int spirv_compiler_handle_instruction(struct spirv_compiler *compiler, case VKD3DSIH_WAVE_OP_UMIN: spirv_compiler_emit_wave_alu_op(compiler, instruction); break; + case VKD3DSIH_WAVE_ALL_BIT_COUNT: + spirv_compiler_emit_wave_bit_count(compiler, instruction); + break; case VKD3DSIH_DCL: case VKD3DSIH_DCL_HS_MAX_TESSFACTOR: case VKD3DSIH_DCL_INPUT_CONTROL_POINT_COUNT: diff --git a/tests/hlsl/wave-ops-uint.shader_test b/tests/hlsl/wave-ops-uint.shader_test index c16adab4..2c817366 100644 --- a/tests/hlsl/wave-ops-uint.shader_test +++ b/tests/hlsl/wave-ops-uint.shader_test @@ -177,7 +177,7 @@ void main(uint id : SV_GroupIndex) } [test] -todo dispatch 4 1 1 +dispatch 4 1 1 probe uav 1 (0) rui (2) probe uav 1 (1) rui (2) probe uav 1 (2) rui (2)