vkd3d-shader/msl: Generate shader descriptor structure declarations.

This commit is contained in:
Feifan He 2024-10-09 16:36:06 +08:00 committed by Henri Verbeet
parent 9cb4207c92
commit 2d7832e738
Notes: Henri Verbeet 2024-10-10 20:08:53 +02:00
Approved-by: Giovanni Mascellani (@giomasce)
Approved-by: Henri Verbeet (@hverbeet)
Merge-Request: https://gitlab.winehq.org/wine/vkd3d/-/merge_requests/1164
3 changed files with 169 additions and 6 deletions

View File

@ -39,6 +39,8 @@ struct msl_generator
struct vkd3d_shader_message_context *message_context;
unsigned int indent;
const char *prefix;
const struct vkd3d_shader_interface_info *interface_info;
const struct vkd3d_shader_scan_descriptor_info1 *descriptor_info;
};
static void VKD3D_PRINTF_FUNC(3, 4) msl_compiler_error(struct msl_generator *gen,
@ -266,6 +268,144 @@ static void msl_handle_instruction(struct msl_generator *gen, const struct vkd3d
}
}
static bool msl_check_shader_visibility(const struct msl_generator *gen,
enum vkd3d_shader_visibility visibility)
{
enum vkd3d_shader_type t = gen->program->shader_version.type;
switch (visibility)
{
case VKD3D_SHADER_VISIBILITY_ALL:
return true;
case VKD3D_SHADER_VISIBILITY_VERTEX:
return t == VKD3D_SHADER_TYPE_VERTEX;
case VKD3D_SHADER_VISIBILITY_HULL:
return t == VKD3D_SHADER_TYPE_HULL;
case VKD3D_SHADER_VISIBILITY_DOMAIN:
return t == VKD3D_SHADER_TYPE_DOMAIN;
case VKD3D_SHADER_VISIBILITY_GEOMETRY:
return t == VKD3D_SHADER_TYPE_GEOMETRY;
case VKD3D_SHADER_VISIBILITY_PIXEL:
return t == VKD3D_SHADER_TYPE_PIXEL;
case VKD3D_SHADER_VISIBILITY_COMPUTE:
return t == VKD3D_SHADER_TYPE_COMPUTE;
default:
WARN("Invalid shader visibility %#x.\n", visibility);
return false;
}
}
static bool msl_get_cbv_binding(const struct msl_generator *gen,
unsigned int register_space, unsigned int register_idx, unsigned int *binding_idx)
{
const struct vkd3d_shader_interface_info *interface_info = gen->interface_info;
const struct vkd3d_shader_resource_binding *binding;
unsigned int i;
if (!interface_info)
return false;
for (i = 0; i < interface_info->binding_count; ++i)
{
binding = &interface_info->bindings[i];
if (binding->type != VKD3D_SHADER_DESCRIPTOR_TYPE_CBV)
continue;
if (binding->register_space != register_space)
continue;
if (binding->register_index != register_idx)
continue;
if (!msl_check_shader_visibility(gen, binding->shader_visibility))
continue;
if (!(binding->flags & VKD3D_SHADER_BINDING_FLAG_BUFFER))
continue;
*binding_idx = i;
return true;
}
return false;
}
static void msl_generate_cbv_declaration(struct msl_generator *gen,
const struct vkd3d_shader_descriptor_info1 *cbv)
{
const struct vkd3d_shader_descriptor_binding *binding;
struct vkd3d_string_buffer *buffer = gen->buffer;
unsigned int binding_idx;
size_t size;
if (cbv->count != 1)
{
msl_compiler_error(gen, VKD3D_SHADER_ERROR_MSL_BINDING_NOT_FOUND,
"Constant buffer %u has unsupported descriptor array size %u.", cbv->register_id, cbv->count);
return;
}
if (!msl_get_cbv_binding(gen, cbv->register_space, cbv->register_index, &binding_idx))
{
msl_compiler_error(gen, VKD3D_SHADER_ERROR_MSL_BINDING_NOT_FOUND,
"No descriptor binding specified for constant buffer %u.", cbv->register_id);
return;
}
binding = &gen->interface_info->bindings[binding_idx].binding;
if (binding->set != 0)
{
msl_compiler_error(gen, VKD3D_SHADER_ERROR_MSL_BINDING_NOT_FOUND,
"Unsupported binding set %u specified for constant buffer %u.", binding->set, cbv->register_id);
return;
}
if (binding->count != 1)
{
msl_compiler_error(gen, VKD3D_SHADER_ERROR_MSL_BINDING_NOT_FOUND,
"Unsupported binding count %u specified for constant buffer %u.", binding->count, cbv->register_id);
return;
}
size = align(cbv->buffer_size, VKD3D_VEC4_SIZE * sizeof(uint32_t));
size /= VKD3D_VEC4_SIZE * sizeof(uint32_t);
vkd3d_string_buffer_printf(buffer,
"constant vkd3d_vec4 (&cb_%u)[%zu] [[id(%u)]];", cbv->register_id, size, binding->binding);
};
static void msl_generate_descriptor_struct_declarations(struct msl_generator *gen)
{
const struct vkd3d_shader_scan_descriptor_info1 *info = gen->descriptor_info;
const struct vkd3d_shader_descriptor_info1 *descriptor;
struct vkd3d_string_buffer *buffer = gen->buffer;
unsigned int i;
if (!info->descriptor_count)
return;
vkd3d_string_buffer_printf(buffer, "struct vkd3d_%s_descriptors\n{\n", gen->prefix);
for (i = 0; i < info->descriptor_count; ++i)
{
descriptor = &info->descriptors[i];
msl_print_indent(buffer, 1);
switch (descriptor->type)
{
case VKD3D_SHADER_DESCRIPTOR_TYPE_CBV:
msl_generate_cbv_declaration(gen, descriptor);
break;
default:
vkd3d_string_buffer_printf(buffer, "/* <unhandled descriptor type %#x> */", descriptor->type);
msl_compiler_error(gen, VKD3D_SHADER_ERROR_MSL_INTERNAL,
"Internal compiler error: Unhandled descriptor type %#x.", descriptor->type);
break;
}
vkd3d_string_buffer_printf(buffer, "\n");
}
vkd3d_string_buffer_printf(buffer, "};\n\n");
}
static void msl_generate_input_struct_declarations(struct msl_generator *gen)
{
const struct shader_signature *signature = &gen->program->input_signature;
@ -550,9 +690,15 @@ static void msl_generate_entrypoint(struct msl_generator *gen)
vkd3d_string_buffer_printf(gen->buffer, "vkd3d_%s_out shader_entry(\n", gen->prefix);
/* TODO: descriptor declaration */
if (gen->descriptor_info->descriptor_count)
{
msl_print_indent(gen->buffer, 2);
/* TODO: Configurable argument buffer binding location. */
vkd3d_string_buffer_printf(gen->buffer,
"constant vkd3d_%s_descriptors& descriptors [[buffer(0)]],\n", gen->prefix);
}
msl_print_indent(gen->buffer, 1);
msl_print_indent(gen->buffer, 2);
vkd3d_string_buffer_printf(gen->buffer, "vkd3d_%s_in input [[stage_in]])\n{\n", gen->prefix);
/* TODO: declare #maximum_register + 1 */
@ -562,7 +708,10 @@ static void msl_generate_entrypoint(struct msl_generator *gen)
msl_generate_entrypoint_prologue(gen);
vkd3d_string_buffer_printf(gen->buffer, " %s_main(%s_in, %s_out);\n", gen->prefix, gen->prefix, gen->prefix);
vkd3d_string_buffer_printf(gen->buffer, " %s_main(%s_in, %s_out", gen->prefix, gen->prefix, gen->prefix);
if (gen->descriptor_info->descriptor_count)
vkd3d_string_buffer_printf(gen->buffer, ", descriptors");
vkd3d_string_buffer_printf(gen->buffer, ");\n");
msl_generate_entrypoint_epilogue(gen);
@ -583,13 +732,17 @@ static void msl_generator_generate(struct msl_generator *gen)
vkd3d_string_buffer_printf(gen->buffer, " int4 i;\n");
vkd3d_string_buffer_printf(gen->buffer, " float4 f;\n};\n\n");
msl_generate_descriptor_struct_declarations(gen);
msl_generate_input_struct_declarations(gen);
msl_generate_output_struct_declarations(gen);
vkd3d_string_buffer_printf(gen->buffer,
"void %s_main(thread vkd3d_vec4 *v, "
"thread vkd3d_vec4 *o)\n{\n",
"thread vkd3d_vec4 *o",
gen->prefix);
if (gen->descriptor_info->descriptor_count)
vkd3d_string_buffer_printf(gen->buffer, ", constant vkd3d_%s_descriptors& descriptors", gen->prefix);
vkd3d_string_buffer_printf(gen->buffer, ")\n{\n");
++gen->indent;
@ -621,6 +774,8 @@ static void msl_generator_cleanup(struct msl_generator *gen)
}
static int msl_generator_init(struct msl_generator *gen, struct vsir_program *program,
const struct vkd3d_shader_compile_info *compile_info,
const struct vkd3d_shader_scan_descriptor_info1 *descriptor_info,
struct vkd3d_shader_message_context *message_context)
{
enum vkd3d_shader_type type = program->shader_version.type;
@ -640,11 +795,14 @@ static int msl_generator_init(struct msl_generator *gen, struct vsir_program *pr
"Internal compiler error: Unhandled shader type %#x.", type);
return VKD3D_ERROR_INVALID_SHADER;
}
gen->interface_info = vkd3d_find_struct(compile_info->next, INTERFACE_INFO);
gen->descriptor_info = descriptor_info;
return VKD3D_OK;
}
int msl_compile(struct vsir_program *program, uint64_t config_flags,
const struct vkd3d_shader_scan_descriptor_info1 *descriptor_info,
const struct vkd3d_shader_compile_info *compile_info, struct vkd3d_shader_message_context *message_context)
{
struct msl_generator generator;
@ -653,7 +811,7 @@ int msl_compile(struct vsir_program *program, uint64_t config_flags,
if ((ret = vsir_program_transform(program, config_flags, compile_info, message_context)) < 0)
return ret;
if ((ret = msl_generator_init(&generator, program, message_context)) < 0)
if ((ret = msl_generator_init(&generator, program, compile_info, descriptor_info, message_context)) < 0)
return ret;
msl_generator_generate(&generator);
msl_generator_cleanup(&generator);

View File

@ -1655,7 +1655,10 @@ int vsir_program_compile(struct vsir_program *program, uint64_t config_flags,
break;
case VKD3D_SHADER_TARGET_MSL:
ret = msl_compile(program, config_flags, compile_info, message_context);
if ((ret = vsir_program_scan(program, &scan_info, message_context, &scan_descriptor_info)) < 0)
return ret;
ret = msl_compile(program, config_flags, &scan_descriptor_info, compile_info, message_context);
vkd3d_shader_free_scan_descriptor_info1(&scan_descriptor_info);
break;
default:

View File

@ -250,6 +250,7 @@ enum vkd3d_shader_error
VKD3D_SHADER_WARNING_VSIR_DYNAMIC_DESCRIPTOR_ARRAY = 9300,
VKD3D_SHADER_ERROR_MSL_INTERNAL = 10000,
VKD3D_SHADER_ERROR_MSL_BINDING_NOT_FOUND = 10001,
};
enum vkd3d_shader_opcode
@ -1617,6 +1618,7 @@ int spirv_compile(struct vsir_program *program, uint64_t config_flags,
struct vkd3d_shader_code *out, struct vkd3d_shader_message_context *message_context);
int msl_compile(struct vsir_program *program, uint64_t config_flags,
const struct vkd3d_shader_scan_descriptor_info1 *descriptor_info,
const struct vkd3d_shader_compile_info *compile_info, struct vkd3d_shader_message_context *message_context);
enum vkd3d_md5_variant