diff --git a/libs/vkd3d-shader/msl.c b/libs/vkd3d-shader/msl.c index bfc01395..ac019b09 100644 --- a/libs/vkd3d-shader/msl.c +++ b/libs/vkd3d-shader/msl.c @@ -39,6 +39,8 @@ struct msl_generator struct vkd3d_shader_message_context *message_context; unsigned int indent; const char *prefix; + const struct vkd3d_shader_interface_info *interface_info; + const struct vkd3d_shader_scan_descriptor_info1 *descriptor_info; }; static void VKD3D_PRINTF_FUNC(3, 4) msl_compiler_error(struct msl_generator *gen, @@ -266,6 +268,144 @@ static void msl_handle_instruction(struct msl_generator *gen, const struct vkd3d } } +static bool msl_check_shader_visibility(const struct msl_generator *gen, + enum vkd3d_shader_visibility visibility) +{ + enum vkd3d_shader_type t = gen->program->shader_version.type; + + switch (visibility) + { + case VKD3D_SHADER_VISIBILITY_ALL: + return true; + case VKD3D_SHADER_VISIBILITY_VERTEX: + return t == VKD3D_SHADER_TYPE_VERTEX; + case VKD3D_SHADER_VISIBILITY_HULL: + return t == VKD3D_SHADER_TYPE_HULL; + case VKD3D_SHADER_VISIBILITY_DOMAIN: + return t == VKD3D_SHADER_TYPE_DOMAIN; + case VKD3D_SHADER_VISIBILITY_GEOMETRY: + return t == VKD3D_SHADER_TYPE_GEOMETRY; + case VKD3D_SHADER_VISIBILITY_PIXEL: + return t == VKD3D_SHADER_TYPE_PIXEL; + case VKD3D_SHADER_VISIBILITY_COMPUTE: + return t == VKD3D_SHADER_TYPE_COMPUTE; + default: + WARN("Invalid shader visibility %#x.\n", visibility); + return false; + } +} + +static bool msl_get_cbv_binding(const struct msl_generator *gen, + unsigned int register_space, unsigned int register_idx, unsigned int *binding_idx) +{ + const struct vkd3d_shader_interface_info *interface_info = gen->interface_info; + const struct vkd3d_shader_resource_binding *binding; + unsigned int i; + + if (!interface_info) + return false; + + for (i = 0; i < interface_info->binding_count; ++i) + { + binding = &interface_info->bindings[i]; + + if (binding->type != VKD3D_SHADER_DESCRIPTOR_TYPE_CBV) + continue; + if (binding->register_space != register_space) + continue; + if (binding->register_index != register_idx) + continue; + if (!msl_check_shader_visibility(gen, binding->shader_visibility)) + continue; + if (!(binding->flags & VKD3D_SHADER_BINDING_FLAG_BUFFER)) + continue; + *binding_idx = i; + return true; + } + + return false; +} + +static void msl_generate_cbv_declaration(struct msl_generator *gen, + const struct vkd3d_shader_descriptor_info1 *cbv) +{ + const struct vkd3d_shader_descriptor_binding *binding; + struct vkd3d_string_buffer *buffer = gen->buffer; + unsigned int binding_idx; + size_t size; + + if (cbv->count != 1) + { + msl_compiler_error(gen, VKD3D_SHADER_ERROR_MSL_BINDING_NOT_FOUND, + "Constant buffer %u has unsupported descriptor array size %u.", cbv->register_id, cbv->count); + return; + } + + if (!msl_get_cbv_binding(gen, cbv->register_space, cbv->register_index, &binding_idx)) + { + msl_compiler_error(gen, VKD3D_SHADER_ERROR_MSL_BINDING_NOT_FOUND, + "No descriptor binding specified for constant buffer %u.", cbv->register_id); + return; + } + + binding = &gen->interface_info->bindings[binding_idx].binding; + + if (binding->set != 0) + { + msl_compiler_error(gen, VKD3D_SHADER_ERROR_MSL_BINDING_NOT_FOUND, + "Unsupported binding set %u specified for constant buffer %u.", binding->set, cbv->register_id); + return; + } + + if (binding->count != 1) + { + msl_compiler_error(gen, VKD3D_SHADER_ERROR_MSL_BINDING_NOT_FOUND, + "Unsupported binding count %u specified for constant buffer %u.", binding->count, cbv->register_id); + return; + } + + size = align(cbv->buffer_size, VKD3D_VEC4_SIZE * sizeof(uint32_t)); + size /= VKD3D_VEC4_SIZE * sizeof(uint32_t); + + vkd3d_string_buffer_printf(buffer, + "constant vkd3d_vec4 (&cb_%u)[%zu] [[id(%u)]];", cbv->register_id, size, binding->binding); +}; + +static void msl_generate_descriptor_struct_declarations(struct msl_generator *gen) +{ + const struct vkd3d_shader_scan_descriptor_info1 *info = gen->descriptor_info; + const struct vkd3d_shader_descriptor_info1 *descriptor; + struct vkd3d_string_buffer *buffer = gen->buffer; + unsigned int i; + + if (!info->descriptor_count) + return; + + vkd3d_string_buffer_printf(buffer, "struct vkd3d_%s_descriptors\n{\n", gen->prefix); + + for (i = 0; i < info->descriptor_count; ++i) + { + descriptor = &info->descriptors[i]; + + msl_print_indent(buffer, 1); + switch (descriptor->type) + { + case VKD3D_SHADER_DESCRIPTOR_TYPE_CBV: + msl_generate_cbv_declaration(gen, descriptor); + break; + + default: + vkd3d_string_buffer_printf(buffer, "/* */", descriptor->type); + msl_compiler_error(gen, VKD3D_SHADER_ERROR_MSL_INTERNAL, + "Internal compiler error: Unhandled descriptor type %#x.", descriptor->type); + break; + } + vkd3d_string_buffer_printf(buffer, "\n"); + } + + vkd3d_string_buffer_printf(buffer, "};\n\n"); +} + static void msl_generate_input_struct_declarations(struct msl_generator *gen) { const struct shader_signature *signature = &gen->program->input_signature; @@ -550,9 +690,15 @@ static void msl_generate_entrypoint(struct msl_generator *gen) vkd3d_string_buffer_printf(gen->buffer, "vkd3d_%s_out shader_entry(\n", gen->prefix); - /* TODO: descriptor declaration */ + if (gen->descriptor_info->descriptor_count) + { + msl_print_indent(gen->buffer, 2); + /* TODO: Configurable argument buffer binding location. */ + vkd3d_string_buffer_printf(gen->buffer, + "constant vkd3d_%s_descriptors& descriptors [[buffer(0)]],\n", gen->prefix); + } - msl_print_indent(gen->buffer, 1); + msl_print_indent(gen->buffer, 2); vkd3d_string_buffer_printf(gen->buffer, "vkd3d_%s_in input [[stage_in]])\n{\n", gen->prefix); /* TODO: declare #maximum_register + 1 */ @@ -562,7 +708,10 @@ static void msl_generate_entrypoint(struct msl_generator *gen) msl_generate_entrypoint_prologue(gen); - vkd3d_string_buffer_printf(gen->buffer, " %s_main(%s_in, %s_out);\n", gen->prefix, gen->prefix, gen->prefix); + vkd3d_string_buffer_printf(gen->buffer, " %s_main(%s_in, %s_out", gen->prefix, gen->prefix, gen->prefix); + if (gen->descriptor_info->descriptor_count) + vkd3d_string_buffer_printf(gen->buffer, ", descriptors"); + vkd3d_string_buffer_printf(gen->buffer, ");\n"); msl_generate_entrypoint_epilogue(gen); @@ -583,13 +732,17 @@ static void msl_generator_generate(struct msl_generator *gen) vkd3d_string_buffer_printf(gen->buffer, " int4 i;\n"); vkd3d_string_buffer_printf(gen->buffer, " float4 f;\n};\n\n"); + msl_generate_descriptor_struct_declarations(gen); msl_generate_input_struct_declarations(gen); msl_generate_output_struct_declarations(gen); vkd3d_string_buffer_printf(gen->buffer, "void %s_main(thread vkd3d_vec4 *v, " - "thread vkd3d_vec4 *o)\n{\n", + "thread vkd3d_vec4 *o", gen->prefix); + if (gen->descriptor_info->descriptor_count) + vkd3d_string_buffer_printf(gen->buffer, ", constant vkd3d_%s_descriptors& descriptors", gen->prefix); + vkd3d_string_buffer_printf(gen->buffer, ")\n{\n"); ++gen->indent; @@ -621,6 +774,8 @@ static void msl_generator_cleanup(struct msl_generator *gen) } static int msl_generator_init(struct msl_generator *gen, struct vsir_program *program, + const struct vkd3d_shader_compile_info *compile_info, + const struct vkd3d_shader_scan_descriptor_info1 *descriptor_info, struct vkd3d_shader_message_context *message_context) { enum vkd3d_shader_type type = program->shader_version.type; @@ -640,11 +795,14 @@ static int msl_generator_init(struct msl_generator *gen, struct vsir_program *pr "Internal compiler error: Unhandled shader type %#x.", type); return VKD3D_ERROR_INVALID_SHADER; } + gen->interface_info = vkd3d_find_struct(compile_info->next, INTERFACE_INFO); + gen->descriptor_info = descriptor_info; return VKD3D_OK; } int msl_compile(struct vsir_program *program, uint64_t config_flags, + const struct vkd3d_shader_scan_descriptor_info1 *descriptor_info, const struct vkd3d_shader_compile_info *compile_info, struct vkd3d_shader_message_context *message_context) { struct msl_generator generator; @@ -653,7 +811,7 @@ int msl_compile(struct vsir_program *program, uint64_t config_flags, if ((ret = vsir_program_transform(program, config_flags, compile_info, message_context)) < 0) return ret; - if ((ret = msl_generator_init(&generator, program, message_context)) < 0) + if ((ret = msl_generator_init(&generator, program, compile_info, descriptor_info, message_context)) < 0) return ret; msl_generator_generate(&generator); msl_generator_cleanup(&generator); diff --git a/libs/vkd3d-shader/vkd3d_shader_main.c b/libs/vkd3d-shader/vkd3d_shader_main.c index a996aaa1..885aeb18 100644 --- a/libs/vkd3d-shader/vkd3d_shader_main.c +++ b/libs/vkd3d-shader/vkd3d_shader_main.c @@ -1655,7 +1655,10 @@ int vsir_program_compile(struct vsir_program *program, uint64_t config_flags, break; case VKD3D_SHADER_TARGET_MSL: - ret = msl_compile(program, config_flags, compile_info, message_context); + if ((ret = vsir_program_scan(program, &scan_info, message_context, &scan_descriptor_info)) < 0) + return ret; + ret = msl_compile(program, config_flags, &scan_descriptor_info, compile_info, message_context); + vkd3d_shader_free_scan_descriptor_info1(&scan_descriptor_info); break; default: diff --git a/libs/vkd3d-shader/vkd3d_shader_private.h b/libs/vkd3d-shader/vkd3d_shader_private.h index de31d554..bf55fa4b 100644 --- a/libs/vkd3d-shader/vkd3d_shader_private.h +++ b/libs/vkd3d-shader/vkd3d_shader_private.h @@ -250,6 +250,7 @@ enum vkd3d_shader_error VKD3D_SHADER_WARNING_VSIR_DYNAMIC_DESCRIPTOR_ARRAY = 9300, VKD3D_SHADER_ERROR_MSL_INTERNAL = 10000, + VKD3D_SHADER_ERROR_MSL_BINDING_NOT_FOUND = 10001, }; enum vkd3d_shader_opcode @@ -1617,6 +1618,7 @@ int spirv_compile(struct vsir_program *program, uint64_t config_flags, struct vkd3d_shader_code *out, struct vkd3d_shader_message_context *message_context); int msl_compile(struct vsir_program *program, uint64_t config_flags, + const struct vkd3d_shader_scan_descriptor_info1 *descriptor_info, const struct vkd3d_shader_compile_info *compile_info, struct vkd3d_shader_message_context *message_context); enum vkd3d_md5_variant