diff --git a/libs/vkd3d-shader/spirv.c b/libs/vkd3d-shader/spirv.c index 1da883b3..b7b4c3ec 100644 --- a/libs/vkd3d-shader/spirv.c +++ b/libs/vkd3d-shader/spirv.c @@ -361,6 +361,7 @@ struct vkd3d_spirv_builder uint32_t type_sampler_id; uint32_t type_bool_id; uint32_t type_void_id; + uint32_t scope_subgroup_id; struct vkd3d_spirv_stream debug_stream; /* debug instructions */ struct vkd3d_spirv_stream annotation_stream; /* decoration instructions */ @@ -1741,6 +1742,16 @@ static void vkd3d_spirv_build_op_memory_barrier(struct vkd3d_spirv_builder *buil SpvOpMemoryBarrier, memory_id, memory_semantics_id); } +static uint32_t vkd3d_spirv_build_op_scope_subgroup(struct vkd3d_spirv_builder *builder) +{ + return vkd3d_spirv_get_op_constant(builder, vkd3d_spirv_get_op_type_int(builder, 32, 0), SpvScopeSubgroup); +} + +static uint32_t vkd3d_spirv_get_op_scope_subgroup(struct vkd3d_spirv_builder *builder) +{ + return vkd3d_spirv_build_once(builder, &builder->scope_subgroup_id, vkd3d_spirv_build_op_scope_subgroup); +} + static uint32_t vkd3d_spirv_build_op_glsl_std450_tr1(struct vkd3d_spirv_builder *builder, enum GLSLstd450 op, uint32_t result_type, uint32_t operand) { @@ -9747,6 +9758,37 @@ static void spirv_compiler_emit_cut_stream(struct spirv_compiler *compiler, vkd3d_spirv_build_op_end_primitive(builder); } +static SpvOp map_wave_bool_op(enum vkd3d_shader_opcode handler_idx) +{ + switch (handler_idx) + { + case VKD3DSIH_WAVE_ACTIVE_ALL_EQUAL: + return SpvOpGroupNonUniformAllEqual; + default: + vkd3d_unreachable(); + } +} + +static void spirv_compiler_emit_wave_bool_op(struct spirv_compiler *compiler, + const struct vkd3d_shader_instruction *instruction) +{ + struct vkd3d_spirv_builder *builder = &compiler->spirv_builder; + const struct vkd3d_shader_dst_param *dst = instruction->dst; + const struct vkd3d_shader_src_param *src = instruction->src; + uint32_t type_id, val_id; + SpvOp op; + + vkd3d_spirv_enable_capability(builder, SpvCapabilityGroupNonUniformVote); + + op = map_wave_bool_op(instruction->handler_idx); + type_id = vkd3d_spirv_get_op_type_bool(builder); + val_id = spirv_compiler_emit_load_src(compiler, src, dst->write_mask); + val_id = vkd3d_spirv_build_op_tr2(builder, &builder->function_stream, op, + type_id, vkd3d_spirv_get_op_scope_subgroup(builder), val_id); + + spirv_compiler_emit_store_dst(compiler, dst, val_id); +} + /* This function is called after declarations are processed. */ static void spirv_compiler_emit_main_prolog(struct spirv_compiler *compiler) { @@ -10091,6 +10133,9 @@ static int spirv_compiler_handle_instruction(struct spirv_compiler *compiler, case VKD3DSIH_CUT_STREAM: spirv_compiler_emit_cut_stream(compiler, instruction); break; + case VKD3DSIH_WAVE_ACTIVE_ALL_EQUAL: + spirv_compiler_emit_wave_bool_op(compiler, instruction); + break; case VKD3DSIH_DCL: case VKD3DSIH_DCL_HS_MAX_TESSFACTOR: case VKD3DSIH_DCL_INPUT_CONTROL_POINT_COUNT: diff --git a/tests/hlsl/wave-ops-float.shader_test b/tests/hlsl/wave-ops-float.shader_test index 74254179..09e7937e 100644 --- a/tests/hlsl/wave-ops-float.shader_test +++ b/tests/hlsl/wave-ops-float.shader_test @@ -36,7 +36,7 @@ void main(uint id : SV_GroupIndex) } [test] -todo dispatch 4 1 1 +dispatch 4 1 1 probe uav 1 (0) rgbaui (1, 0, 0, 1) probe uav 1 (1) rgbaui (1, 0, 0, 1) probe uav 1 (2) rgbaui (1, 0, 0, 1)