diff --git a/include/vkd3d_shader.h b/include/vkd3d_shader.h index d82869e79..a1f85dbbd 100644 --- a/include/vkd3d_shader.h +++ b/include/vkd3d_shader.h @@ -120,6 +120,11 @@ enum vkd3d_shader_structure_type * \since 1.15 */ VKD3D_SHADER_STRUCTURE_TYPE_SCAN_HULL_SHADER_TESSELLATION_INFO, + /** + * The structure is a vkd3d_shader_scan_thread_group_size_info structure. + * \since 1.18 + */ + VKD3D_SHADER_STRUCTURE_TYPE_SCAN_THREAD_GROUP_SIZE_INFO, VKD3D_FORCE_32_BIT_ENUM(VKD3D_SHADER_STRUCTURE_TYPE), }; @@ -2282,6 +2287,24 @@ struct vkd3d_shader_scan_hull_shader_tessellation_info enum vkd3d_shader_tessellator_partitioning partitioning; }; +/** + * A chained structure describing the thread group size in a compute shader. + * + * This structure extends vkd3d_shader_compile_info. + * + * \since 1.18 + */ +struct vkd3d_shader_scan_thread_group_size_info +{ + /** Must be set to VKD3D_SHADER_STRUCTURE_TYPE_SCAN_THREAD_GROUP_SIZE_INFO. */ + enum vkd3d_shader_structure_type type; + /** Optional pointer to a structure containing further parameters. */ + const void *next; + + /** The thread group size in the x/y/z direction. */ + unsigned int x, y, z; +}; + /** * Data type of a shader varying, returned as part of struct * vkd3d_shader_signature_element. diff --git a/libs/vkd3d-shader/vkd3d_shader_main.c b/libs/vkd3d-shader/vkd3d_shader_main.c index 5a200f577..54ea24a34 100644 --- a/libs/vkd3d-shader/vkd3d_shader_main.c +++ b/libs/vkd3d-shader/vkd3d_shader_main.c @@ -1685,6 +1685,7 @@ static int vsir_program_scan(struct vsir_program *program, const struct vkd3d_sh struct vsir_program_iterator it = vsir_program_iterator(&program->instructions); struct vkd3d_shader_scan_combined_resource_sampler_info *combined_sampler_info; struct vkd3d_shader_scan_hull_shader_tessellation_info *tessellation_info; + struct vkd3d_shader_scan_thread_group_size_info *thread_group_size_info; struct vkd3d_shader_scan_descriptor_info *descriptor_info; struct vkd3d_shader_scan_signature_info *signature_info; struct vkd3d_shader_scan_context context; @@ -1706,6 +1707,7 @@ static int vsir_program_scan(struct vsir_program *program, const struct vkd3d_sh } tessellation_info = vkd3d_find_struct(compile_info->next, SCAN_HULL_SHADER_TESSELLATION_INFO); + thread_group_size_info = vkd3d_find_struct(compile_info->next, SCAN_THREAD_GROUP_SIZE_INFO); vkd3d_shader_scan_context_init(&context, &program->shader_version, compile_info, add_descriptor_info ? &program->descriptors : NULL, combined_sampler_info, message_context); @@ -1758,6 +1760,13 @@ static int vsir_program_scan(struct vsir_program *program, const struct vkd3d_sh tessellation_info->partitioning = context.partitioning; } + if (!ret && thread_group_size_info) + { + thread_group_size_info->x = program->thread_group_size.x; + thread_group_size_info->y = program->thread_group_size.y; + thread_group_size_info->z = program->thread_group_size.z; + } + if (ret < 0) { if (combined_sampler_info) diff --git a/tests/vkd3d_shader_api.c b/tests/vkd3d_shader_api.c index b379c34d6..dce69b0a1 100644 --- a/tests/vkd3d_shader_api.c +++ b/tests/vkd3d_shader_api.c @@ -2180,6 +2180,78 @@ static void test_parameters(void) #endif } +static void test_scan_thread_group_size(void) +{ + struct vkd3d_shader_hlsl_source_info hlsl_info = {.type = VKD3D_SHADER_STRUCTURE_TYPE_HLSL_SOURCE_INFO}; + struct vkd3d_shader_compile_info info = {.type = VKD3D_SHADER_STRUCTURE_TYPE_COMPILE_INFO}; + struct vkd3d_shader_scan_thread_group_size_info thread_group_size_info; + struct vkd3d_shader_code out; + unsigned int i; + int rc; + + static const char cs1_source[] = + "[numthreads(1, 2, 4)]\n" + "void main() {}\n"; + static const char cs2_source[] = + "[numthreads(2, 4, 8)]\n" + "void main() {}\n"; + static const char ps1_source[] = + "float4 main() : sv_target { return float4(0.0, 1.0, 0.0, 1.0); }\n"; + + static const struct + { + const char *src; + const char *profile; + enum vkd3d_shader_target_type target_type; + unsigned int x, y, z; + } + tests[] = + { + {cs1_source, "cs_4_0", VKD3D_SHADER_TARGET_DXBC_TPF, 1, 2, 4}, + {cs2_source, "cs_5_0", VKD3D_SHADER_TARGET_SPIRV_BINARY, 2, 4, 8}, + {ps1_source, "ps_4_0", VKD3D_SHADER_TARGET_DXBC_TPF, 0, 0, 0}, + {ps1_source, "ps_3_0", VKD3D_SHADER_TARGET_D3D_BYTECODE, 0, 0, 0}, + }; + + memset(&thread_group_size_info, 0, sizeof(thread_group_size_info)); + thread_group_size_info.type = VKD3D_SHADER_STRUCTURE_TYPE_SCAN_THREAD_GROUP_SIZE_INFO; + + for (i = 0; i < ARRAY_SIZE(tests); ++i) + { + vkd3d_test_push_context("Test %u", i); + + hlsl_info.next = &thread_group_size_info; + hlsl_info.profile = tests[i].profile; + + info.next = &hlsl_info; + info.source.code = tests[i].src; + info.source.size = strlen(tests[i].src); + info.source_type = VKD3D_SHADER_SOURCE_HLSL; + info.target_type = tests[i].target_type; + info.log_level = VKD3D_SHADER_LOG_INFO; + + rc = vkd3d_shader_scan(&info, NULL); + ok(rc == VKD3D_OK, "Got rc %d.\n", rc); + + ok(thread_group_size_info.x == tests[i].x, "Got x %u.\n", thread_group_size_info.x); + ok(thread_group_size_info.x == tests[i].x, "Got y %u.\n", thread_group_size_info.y); + ok(thread_group_size_info.x == tests[i].x, "Got z %u.\n", thread_group_size_info.z); + + memset(&thread_group_size_info, 0, sizeof(thread_group_size_info)); + thread_group_size_info.type = VKD3D_SHADER_STRUCTURE_TYPE_SCAN_THREAD_GROUP_SIZE_INFO; + rc = vkd3d_shader_compile(&info, &out, NULL); + ok(rc == VKD3D_OK, "Got rc %d.\n", rc); + + ok(thread_group_size_info.x == tests[i].x, "Got x %u.\n", thread_group_size_info.x); + ok(thread_group_size_info.x == tests[i].x, "Got y %u.\n", thread_group_size_info.y); + ok(thread_group_size_info.x == tests[i].x, "Got z %u.\n", thread_group_size_info.z); + + vkd3d_shader_free_shader_code(&out); + + vkd3d_test_pop_context(); + } +} + START_TEST(vkd3d_shader_api) { setlocale(LC_ALL, ""); @@ -2196,4 +2268,5 @@ START_TEST(vkd3d_shader_api) run_test(test_emit_signature); run_test(test_warning_options); run_test(test_parameters); + run_test(test_scan_thread_group_size); }