diff --git a/libs/vkd3d-shader/msl.c b/libs/vkd3d-shader/msl.c index 906392870..8d724c30c 100644 --- a/libs/vkd3d-shader/msl.c +++ b/libs/vkd3d-shader/msl.c @@ -490,6 +490,10 @@ static enum msl_data_type msl_print_register_name(struct vkd3d_string_buffer *bu vkd3d_string_buffer_printf(buffer, "v_thread_id"); return MSL_DATA_UNION; + case VKD3DSPR_LOCALTHREADINDEX: + vkd3d_string_buffer_printf(buffer, "v_local_thread_index"); + return MSL_DATA_UNION; + default: msl_compiler_error(gen, VKD3D_SHADER_ERROR_MSL_INTERNAL, "Internal compiler error: Unhandled register type %#x.", reg->type); @@ -2022,6 +2026,12 @@ static void msl_generate_entrypoint_prologue(struct msl_generator *gen) msl_print_indent(gen->buffer, 1); vkd3d_string_buffer_printf(buffer, "v_thread_id.u = uint4(thread_id, 0u);\n"); } + + if (bitmap_is_set(gen->program->io_dcls, VKD3DSPR_LOCALTHREADINDEX)) + { + msl_print_indent(gen->buffer, 1); + vkd3d_string_buffer_printf(buffer, "v_local_thread_index.u = uint4(local_thread_index, 0u, 0u, 0u);\n"); + } } static void msl_generate_entrypoint_epilogue(struct msl_generator *gen) @@ -2114,6 +2124,12 @@ static void msl_generate_entrypoint(struct msl_generator *gen) vkd3d_string_buffer_printf(gen->buffer, "uint3 thread_id [[thread_position_in_grid]],\n"); } + if (bitmap_is_set(gen->program->io_dcls, VKD3DSPR_LOCALTHREADINDEX)) + { + msl_print_indent(gen->buffer, 2); + vkd3d_string_buffer_printf(gen->buffer, "uint local_thread_index [[thread_index_in_threadgroup]],\n"); + } + msl_print_indent(gen->buffer, 2); vkd3d_string_buffer_printf(gen->buffer, "vkd3d_%s_in input [[stage_in]])\n{\n", gen->prefix); @@ -2125,6 +2141,8 @@ static void msl_generate_entrypoint(struct msl_generator *gen) vkd3d_string_buffer_printf(gen->buffer, " vkd3d_scalar o_mask;\n"); if (bitmap_is_set(gen->program->io_dcls, VKD3DSPR_THREADID)) vkd3d_string_buffer_printf(gen->buffer, " vkd3d_vec4 v_thread_id;\n"); + if (bitmap_is_set(gen->program->io_dcls, VKD3DSPR_LOCALTHREADINDEX)) + vkd3d_string_buffer_printf(gen->buffer, " vkd3d_vec4 v_local_thread_index;\n"); vkd3d_string_buffer_printf(gen->buffer, "\n"); msl_generate_entrypoint_prologue(gen); @@ -2138,6 +2156,8 @@ static void msl_generate_entrypoint(struct msl_generator *gen) vkd3d_string_buffer_printf(gen->buffer, ", o_mask"); if (bitmap_is_set(gen->program->io_dcls, VKD3DSPR_THREADID)) vkd3d_string_buffer_printf(gen->buffer, ", v_thread_id"); + if (bitmap_is_set(gen->program->io_dcls, VKD3DSPR_LOCALTHREADINDEX)) + vkd3d_string_buffer_printf(gen->buffer, ", v_local_thread_index"); if (gen->program->descriptors.descriptor_count) vkd3d_string_buffer_printf(gen->buffer, ", descriptors"); vkd3d_string_buffer_printf(gen->buffer, ");\n\n"); @@ -2213,6 +2233,8 @@ static int msl_generator_generate(struct msl_generator *gen, struct vkd3d_shader vkd3d_string_buffer_printf(gen->buffer, ", thread vkd3d_scalar &o_mask"); if (bitmap_is_set(gen->program->io_dcls, VKD3DSPR_THREADID)) vkd3d_string_buffer_printf(gen->buffer, ", thread vkd3d_vec4 &v_thread_id"); + if (bitmap_is_set(gen->program->io_dcls, VKD3DSPR_LOCALTHREADINDEX)) + vkd3d_string_buffer_printf(gen->buffer, ", thread vkd3d_vec4 &v_local_thread_index"); if (gen->program->descriptors.descriptor_count) vkd3d_string_buffer_printf(gen->buffer, ", constant descriptor *descriptors"); vkd3d_string_buffer_printf(gen->buffer, ")\n{\n"); diff --git a/tests/hlsl/compute.shader_test b/tests/hlsl/compute.shader_test index fa78639c4..438785cea 100644 --- a/tests/hlsl/compute.shader_test +++ b/tests/hlsl/compute.shader_test @@ -56,7 +56,7 @@ void main(uint3 thread_id : SV_DispatchThreadId, uint group_index : SV_GroupInde } [test] -todo(glsl | msl) dispatch 13 13 1 +todo(glsl) dispatch 13 13 1 probe uav 0 (0, 0) u32(0, 0, 0, 1) probe uav 0 (14, 38) u32(10, 0, 0, 1) probe uav 0 (49, 49) u32(5, 0, 0, 1)