From c298493e4f7447bdc23752a07c49e34b83f5db65 Mon Sep 17 00:00:00 2001 From: Conor McCarthy Date: Wed, 8 May 2024 13:44:51 +1000 Subject: [PATCH] vkd3d-shader/spirv: Implement the QUAD_READ_ACROSS_* instructions. --- libs/vkd3d-shader/spirv.c | 47 +++++++++++++++++++++++++++ tests/hlsl/wave-ops-float.shader_test | 10 +++--- 2 files changed, 52 insertions(+), 5 deletions(-) diff --git a/libs/vkd3d-shader/spirv.c b/libs/vkd3d-shader/spirv.c index 01d5bc3d..2be3a3e1 100644 --- a/libs/vkd3d-shader/spirv.c +++ b/libs/vkd3d-shader/spirv.c @@ -1752,6 +1752,14 @@ static uint32_t vkd3d_spirv_get_op_scope_subgroup(struct vkd3d_spirv_builder *bu return vkd3d_spirv_build_once(builder, &builder->scope_subgroup_id, vkd3d_spirv_build_op_scope_subgroup); } +static uint32_t vkd3d_spirv_build_op_group_nonuniform_quad_swap(struct vkd3d_spirv_builder *builder, + uint32_t result_type, uint32_t val_id, uint32_t op_id) +{ + vkd3d_spirv_enable_capability(builder, SpvCapabilityGroupNonUniformQuad); + return vkd3d_spirv_build_op_tr3(builder, &builder->function_stream, SpvOpGroupNonUniformQuadSwap, result_type, + vkd3d_spirv_get_op_scope_subgroup(builder), val_id, op_id); +} + static uint32_t vkd3d_spirv_build_op_group_nonuniform_ballot(struct vkd3d_spirv_builder *builder, uint32_t result_type, uint32_t val_id) { @@ -9805,6 +9813,40 @@ static void spirv_compiler_emit_cut_stream(struct spirv_compiler *compiler, vkd3d_spirv_build_op_end_primitive(builder); } +static uint32_t map_quad_read_across_direction(enum vkd3d_shader_opcode opcode) +{ + switch (opcode) + { + case VKD3DSIH_QUAD_READ_ACROSS_X: + return 0; + case VKD3DSIH_QUAD_READ_ACROSS_Y: + return 1; + case VKD3DSIH_QUAD_READ_ACROSS_D: + return 2; + default: + vkd3d_unreachable(); + } +} + +static void spirv_compiler_emit_quad_read_across(struct spirv_compiler *compiler, + const struct vkd3d_shader_instruction *instruction) +{ + struct vkd3d_spirv_builder *builder = &compiler->spirv_builder; + const struct vkd3d_shader_dst_param *dst = instruction->dst; + const struct vkd3d_shader_src_param *src = instruction->src; + uint32_t type_id, direction_type_id, direction_id, val_id; + + type_id = vkd3d_spirv_get_type_id_for_data_type(builder, dst->reg.data_type, + vsir_write_mask_component_count(dst->write_mask)); + direction_type_id = vkd3d_spirv_get_type_id_for_data_type(builder, VKD3D_DATA_UINT, 1); + val_id = spirv_compiler_emit_load_src(compiler, src, dst->write_mask); + direction_id = map_quad_read_across_direction(instruction->opcode); + direction_id = vkd3d_spirv_get_op_constant(builder, direction_type_id, direction_id); + val_id = vkd3d_spirv_build_op_group_nonuniform_quad_swap(builder, type_id, val_id, direction_id); + + spirv_compiler_emit_store_dst(compiler, dst, val_id); +} + static SpvOp map_wave_bool_op(enum vkd3d_shader_opcode opcode) { switch (opcode) @@ -10335,6 +10377,11 @@ static int spirv_compiler_handle_instruction(struct spirv_compiler *compiler, case VKD3DSIH_CUT_STREAM: spirv_compiler_emit_cut_stream(compiler, instruction); break; + case VKD3DSIH_QUAD_READ_ACROSS_D: + case VKD3DSIH_QUAD_READ_ACROSS_X: + case VKD3DSIH_QUAD_READ_ACROSS_Y: + spirv_compiler_emit_quad_read_across(compiler, instruction); + break; case VKD3DSIH_WAVE_ACTIVE_ALL_EQUAL: case VKD3DSIH_WAVE_ALL_TRUE: case VKD3DSIH_WAVE_ANY_TRUE: diff --git a/tests/hlsl/wave-ops-float.shader_test b/tests/hlsl/wave-ops-float.shader_test index b54f5ac1..1cee8b5d 100644 --- a/tests/hlsl/wave-ops-float.shader_test +++ b/tests/hlsl/wave-ops-float.shader_test @@ -252,7 +252,7 @@ void main(uint id : SV_GroupIndex) } [test] -todo dispatch 4 1 1 +dispatch 4 1 1 probe uav 1 (0) rgba (0.5, 0.25, 1.0, 0.75) probe uav 1 (1) rgba (0.25, 0.5, 0.75, 1.0) probe uav 1 (2) rgba (1.0, 0.75, 0.25, 0.5) @@ -293,7 +293,7 @@ float4 main(float4 pos : SV_Position) : SV_Target } [test] -todo draw quad +draw quad probe rtv 0 (0, 0) rgba (0.25, 0.5, 0.75, 1.0) probe rtv 0 (1, 0) rgba (0.5, 0.25, 1.0, 0.75) probe rtv 0 (0, 1) rgba (0.75, 1.0, 0.5, 0.25) @@ -317,7 +317,7 @@ float4 main(float4 pos : SV_Position) : SV_Target } [test] -todo draw quad +draw quad probe uav 1 (0) rgba (0.75, 1.0, 0.5, 0.25) probe uav 1 (1) rgba (1.0, 0.75, 0.25, 0.5) probe uav 1 (2) rgba (0.25, 0.5, 0.75, 1.0) @@ -337,7 +337,7 @@ float4 main(float4 pos : SV_Position) : SV_Target } [test] -todo draw quad +draw quad probe uav 1 (0) rgba (1.0, 0.75, 0.25, 0.5) probe uav 1 (1) rgba (0.75, 1.0, 0.5, 0.25) probe uav 1 (2) rgba (0.5, 0.25, 1.0, 0.75) @@ -360,7 +360,7 @@ float4 main(float4 pos : SV_Position) : SV_Target } [test] -todo draw quad +draw quad probe rtv 0 (0, 0) rgba (1.0, 0.0, 0.0, 1.0) probe rtv 0 (1, 0) rgba (1.0, 0.0, 0.0, 1.0) probe rtv 0 (0, 1) rgba (0.75, 1.0, 0.5, 0.25)