diff --git a/libs/vkd3d-shader/spirv.c b/libs/vkd3d-shader/spirv.c index c1079166..6b75cfb1 100644 --- a/libs/vkd3d-shader/spirv.c +++ b/libs/vkd3d-shader/spirv.c @@ -1775,6 +1775,22 @@ static uint32_t vkd3d_spirv_build_op_group_nonuniform_elect(struct vkd3d_spirv_b vkd3d_spirv_get_op_type_bool(builder), vkd3d_spirv_get_op_scope_subgroup(builder)); } +static uint32_t vkd3d_spirv_build_op_group_nonuniform_broadcast(struct vkd3d_spirv_builder *builder, + uint32_t result_type, uint32_t val_id, uint32_t lane_id) +{ + vkd3d_spirv_enable_capability(builder, SpvCapabilityGroupNonUniformBallot); + return vkd3d_spirv_build_op_tr3(builder, &builder->function_stream, SpvOpGroupNonUniformBroadcast, result_type, + vkd3d_spirv_get_op_scope_subgroup(builder), val_id, lane_id); +} + +static uint32_t vkd3d_spirv_build_op_group_nonuniform_shuffle(struct vkd3d_spirv_builder *builder, + uint32_t result_type, uint32_t val_id, uint32_t lane_id) +{ + vkd3d_spirv_enable_capability(builder, SpvCapabilityGroupNonUniformShuffle); + return vkd3d_spirv_build_op_tr3(builder, &builder->function_stream, SpvOpGroupNonUniformShuffle, result_type, + vkd3d_spirv_get_op_scope_subgroup(builder), val_id, lane_id); +} + static uint32_t vkd3d_spirv_build_op_glsl_std450_tr1(struct vkd3d_spirv_builder *builder, enum GLSLstd450 op, uint32_t result_type, uint32_t operand) { @@ -9925,6 +9941,34 @@ static void spirv_compiler_emit_wave_is_first_lane(struct spirv_compiler *compil spirv_compiler_emit_store_dst(compiler, dst, val_id); } +static void spirv_compiler_emit_wave_read_lane_at(struct spirv_compiler *compiler, + const struct vkd3d_shader_instruction *instruction) +{ + struct vkd3d_spirv_builder *builder = &compiler->spirv_builder; + const struct vkd3d_shader_dst_param *dst = instruction->dst; + const struct vkd3d_shader_src_param *src = instruction->src; + uint32_t type_id, lane_id, val_id; + + type_id = vkd3d_spirv_get_type_id_for_data_type(builder, dst->reg.data_type, + vsir_write_mask_component_count(dst->write_mask)); + val_id = spirv_compiler_emit_load_src(compiler, &src[0], dst->write_mask); + lane_id = spirv_compiler_emit_load_src(compiler, &src[1], VKD3DSP_WRITEMASK_0); + + /* TODO: detect values loaded from a const buffer? */ + if (register_is_constant_or_undef(&src[1].reg)) + { + /* Uniform lane_id only. */ + val_id = vkd3d_spirv_build_op_group_nonuniform_broadcast(builder, type_id, val_id, lane_id); + } + else + { + /* WaveReadLaneAt supports non-uniform lane ids, so if lane_id is not constant it may not be uniform. */ + val_id = vkd3d_spirv_build_op_group_nonuniform_shuffle(builder, type_id, val_id, lane_id); + } + + spirv_compiler_emit_store_dst(compiler, dst, val_id); +} + /* This function is called after declarations are processed. */ static void spirv_compiler_emit_main_prolog(struct spirv_compiler *compiler) { @@ -10297,6 +10341,9 @@ static int spirv_compiler_handle_instruction(struct spirv_compiler *compiler, case VKD3DSIH_WAVE_IS_FIRST_LANE: spirv_compiler_emit_wave_is_first_lane(compiler, instruction); break; + case VKD3DSIH_WAVE_READ_LANE_AT: + spirv_compiler_emit_wave_read_lane_at(compiler, instruction); + break; case VKD3DSIH_DCL: case VKD3DSIH_DCL_HS_MAX_TESSFACTOR: case VKD3DSIH_DCL_INPUT_CONTROL_POINT_COUNT: diff --git a/tests/hlsl/wave-ops-uint.shader_test b/tests/hlsl/wave-ops-uint.shader_test index 2e0bff40..49dbd1c2 100644 --- a/tests/hlsl/wave-ops-uint.shader_test +++ b/tests/hlsl/wave-ops-uint.shader_test @@ -158,7 +158,7 @@ void main(uint id : SV_GroupIndex) } [test] -todo dispatch 4 1 1 +dispatch 4 1 1 probe uav 1 (0) rui (18) probe uav 1 (1) rui (18) probe uav 1 (2) rui (18) @@ -178,7 +178,7 @@ void main(uint id : SV_GroupIndex) } [test] -todo dispatch 4 1 1 +dispatch 4 1 1 probe uav 1 (0) rui (18) probe uav 1 (1) rui (18) probe uav 1 (2) rui (23)