From ad8e41f8f21602291d7ececaaa37769b1f1763e0 Mon Sep 17 00:00:00 2001 From: Giovanni Mascellani Date: Sat, 27 Sep 2025 16:43:16 +0200 Subject: [PATCH] vkd3d-shader/msl: Implement VKD3DSPR_THREADID. --- libs/vkd3d-shader/msl.c | 22 ++++++++++++++++++++++ tests/hlsl/compute.shader_test | 2 +- tests/hlsl/numthreads.shader_test | 2 +- 3 files changed, 24 insertions(+), 2 deletions(-) diff --git a/libs/vkd3d-shader/msl.c b/libs/vkd3d-shader/msl.c index 2acb31ed6..f7d703e1b 100644 --- a/libs/vkd3d-shader/msl.c +++ b/libs/vkd3d-shader/msl.c @@ -486,6 +486,10 @@ static enum msl_data_type msl_print_register_name(struct vkd3d_string_buffer *bu vkd3d_string_buffer_printf(buffer, "o_mask"); return MSL_DATA_UNION; + case VKD3DSPR_THREADID: + vkd3d_string_buffer_printf(buffer, "v_thread_id"); + return MSL_DATA_UNION; + default: msl_compiler_error(gen, VKD3D_SHADER_ERROR_MSL_INTERNAL, "Internal compiler error: Unhandled register type %#x.", reg->type); @@ -1997,6 +2001,12 @@ static void msl_generate_entrypoint_prologue(struct msl_generator *gen) msl_print_write_mask(buffer, e->mask); vkd3d_string_buffer_printf(buffer, ";\n"); } + + if (bitmap_is_set(gen->program->io_dcls, VKD3DSPR_THREADID)) + { + msl_print_indent(gen->buffer, 1); + vkd3d_string_buffer_printf(buffer, "v_thread_id.u = uint4(thread_id, 0u);\n"); + } } static void msl_generate_entrypoint_epilogue(struct msl_generator *gen) @@ -2083,6 +2093,12 @@ static void msl_generate_entrypoint(struct msl_generator *gen) vkd3d_string_buffer_printf(gen->buffer, "uint vertex_id [[vertex_id]],\n"); } + if (bitmap_is_set(gen->program->io_dcls, VKD3DSPR_THREADID)) + { + msl_print_indent(gen->buffer, 2); + vkd3d_string_buffer_printf(gen->buffer, "uint3 thread_id [[thread_position_in_grid]],\n"); + } + msl_print_indent(gen->buffer, 2); vkd3d_string_buffer_printf(gen->buffer, "vkd3d_%s_in input [[stage_in]])\n{\n", gen->prefix); @@ -2092,6 +2108,8 @@ static void msl_generate_entrypoint(struct msl_generator *gen) vkd3d_string_buffer_printf(gen->buffer, " vkd3d_%s_out output;\n", gen->prefix); if (bitmap_is_set(gen->program->io_dcls, VKD3DSPR_SAMPLEMASK)) vkd3d_string_buffer_printf(gen->buffer, " vkd3d_scalar o_mask;\n"); + if (bitmap_is_set(gen->program->io_dcls, VKD3DSPR_THREADID)) + vkd3d_string_buffer_printf(gen->buffer, " vkd3d_vec4 v_thread_id;\n"); vkd3d_string_buffer_printf(gen->buffer, "\n"); msl_generate_entrypoint_prologue(gen); @@ -2103,6 +2121,8 @@ static void msl_generate_entrypoint(struct msl_generator *gen) vkd3d_string_buffer_printf(gen->buffer, ", output.shader_out_depth"); if (bitmap_is_set(gen->program->io_dcls, VKD3DSPR_SAMPLEMASK)) vkd3d_string_buffer_printf(gen->buffer, ", o_mask"); + if (bitmap_is_set(gen->program->io_dcls, VKD3DSPR_THREADID)) + vkd3d_string_buffer_printf(gen->buffer, ", v_thread_id"); if (gen->program->descriptors.descriptor_count) vkd3d_string_buffer_printf(gen->buffer, ", descriptors"); vkd3d_string_buffer_printf(gen->buffer, ");\n\n"); @@ -2176,6 +2196,8 @@ static int msl_generator_generate(struct msl_generator *gen, struct vkd3d_shader vkd3d_string_buffer_printf(gen->buffer, ", thread float &o_depth"); if (bitmap_is_set(gen->program->io_dcls, VKD3DSPR_SAMPLEMASK)) vkd3d_string_buffer_printf(gen->buffer, ", thread vkd3d_scalar &o_mask"); + if (bitmap_is_set(gen->program->io_dcls, VKD3DSPR_THREADID)) + vkd3d_string_buffer_printf(gen->buffer, ", thread vkd3d_vec4 &v_thread_id"); if (gen->program->descriptors.descriptor_count) vkd3d_string_buffer_printf(gen->buffer, ", constant descriptor *descriptors"); vkd3d_string_buffer_printf(gen->buffer, ")\n{\n"); diff --git a/tests/hlsl/compute.shader_test b/tests/hlsl/compute.shader_test index 8b28c4a44..aa3a69ab6 100644 --- a/tests/hlsl/compute.shader_test +++ b/tests/hlsl/compute.shader_test @@ -40,7 +40,7 @@ void main(uint3 thread_id : SV_DispatchThreadId) } [test] -todo(msl) dispatch 13 13 1 +dispatch 13 13 1 probe uav 0 (0, 0) u32(0, 0, 0, 1) probe uav 0 (14, 38) u32(14, 38, 0, 1) probe uav 0 (49, 49) u32(49, 49, 0, 1) diff --git a/tests/hlsl/numthreads.shader_test b/tests/hlsl/numthreads.shader_test index 942a7a5e5..5e906b953 100644 --- a/tests/hlsl/numthreads.shader_test +++ b/tests/hlsl/numthreads.shader_test @@ -219,7 +219,7 @@ void main(uint2 id : sv_dispatchthreadid) } [test] -todo(msl) dispatch 1 1 1 +dispatch 1 1 1 probe uav 0 (0, 0) f32(2.0) if(sm<6) probe uav 0 (0, 1) f32(1.0) if(sm<6) probe uav 0 (1, 0) f32(2.0)