diff --git a/libs/vkd3d-shader/msl.c b/libs/vkd3d-shader/msl.c index d477bfa1c..7bb8307b1 100644 --- a/libs/vkd3d-shader/msl.c +++ b/libs/vkd3d-shader/msl.c @@ -106,6 +106,68 @@ static void msl_print_register_datatype(struct vkd3d_string_buffer *buffer, } } +static bool msl_check_shader_visibility(const struct msl_generator *gen, + enum vkd3d_shader_visibility visibility) +{ + enum vkd3d_shader_type t = gen->program->shader_version.type; + + switch (visibility) + { + case VKD3D_SHADER_VISIBILITY_ALL: + return true; + case VKD3D_SHADER_VISIBILITY_VERTEX: + return t == VKD3D_SHADER_TYPE_VERTEX; + case VKD3D_SHADER_VISIBILITY_HULL: + return t == VKD3D_SHADER_TYPE_HULL; + case VKD3D_SHADER_VISIBILITY_DOMAIN: + return t == VKD3D_SHADER_TYPE_DOMAIN; + case VKD3D_SHADER_VISIBILITY_GEOMETRY: + return t == VKD3D_SHADER_TYPE_GEOMETRY; + case VKD3D_SHADER_VISIBILITY_PIXEL: + return t == VKD3D_SHADER_TYPE_PIXEL; + case VKD3D_SHADER_VISIBILITY_COMPUTE: + return t == VKD3D_SHADER_TYPE_COMPUTE; + default: + WARN("Invalid shader visibility %#x.\n", visibility); + return false; + } +} + +static const struct vkd3d_shader_descriptor_binding *msl_get_cbv_binding(const struct msl_generator *gen, + unsigned int register_space, unsigned int register_idx) +{ + const struct vkd3d_shader_interface_info *interface_info = gen->interface_info; + unsigned int i; + + if (!interface_info) + return NULL; + + for (i = 0; i < interface_info->binding_count; ++i) + { + const struct vkd3d_shader_resource_binding *binding = &interface_info->bindings[i]; + + if (binding->type != VKD3D_SHADER_DESCRIPTOR_TYPE_CBV) + continue; + if (binding->register_space != register_space) + continue; + if (binding->register_index != register_idx) + continue; + if (!msl_check_shader_visibility(gen, binding->shader_visibility)) + continue; + if (!(binding->flags & VKD3D_SHADER_BINDING_FLAG_BUFFER)) + continue; + + return &binding->binding; + } + + return NULL; +} + +static void msl_print_cbv_name(struct vkd3d_string_buffer *buffer, unsigned int binding) +{ + vkd3d_string_buffer_printf(buffer, "descriptors[%u].buf()", binding); +} + static void msl_print_register_name(struct vkd3d_string_buffer *buffer, struct msl_generator *gen, const struct vkd3d_shader_register *reg) { @@ -220,23 +282,36 @@ static void msl_print_register_name(struct vkd3d_string_buffer *buffer, break; case VKD3DSPR_CONSTBUFFER: - if (reg->idx_count != 3) { - msl_compiler_error(gen, VKD3D_SHADER_ERROR_MSL_INTERNAL, - "Internal compiler error: Unhandled constant buffer register index count %u.", reg->idx_count); - vkd3d_string_buffer_printf(buffer, "", reg->type); + const struct vkd3d_shader_descriptor_binding *binding; + + if (reg->idx_count != 3) + { + msl_compiler_error(gen, VKD3D_SHADER_ERROR_MSL_INTERNAL, + "Internal compiler error: Unhandled constant buffer register index count %u.", + reg->idx_count); + vkd3d_string_buffer_printf(buffer, "", reg->type); + break; + } + if (reg->idx[0].rel_addr || reg->idx[1].rel_addr || reg->idx[2].rel_addr) + { + msl_compiler_error(gen, VKD3D_SHADER_ERROR_MSL_INTERNAL, + "Internal compiler error: Unhandled constant buffer register indirect addressing."); + vkd3d_string_buffer_printf(buffer, "", reg->type); + break; + } + if (!(binding = msl_get_cbv_binding(gen, 0, reg->idx[1].offset))) + { + msl_compiler_error(gen, VKD3D_SHADER_ERROR_MSL_BINDING_NOT_FOUND, + "Cannot finding binding for CBV register %u.", reg->idx[0].offset); + vkd3d_string_buffer_printf(buffer, "", reg->type); + break; + } + msl_print_cbv_name(buffer, binding->binding); + vkd3d_string_buffer_printf(buffer, "[%u]", reg->idx[2].offset); + msl_print_register_datatype(buffer, gen, reg->data_type); break; } - if (reg->idx[0].rel_addr || reg->idx[2].rel_addr) - { - msl_compiler_error(gen, VKD3D_SHADER_ERROR_MSL_INTERNAL, - "Internal compiler error: Unhandled constant buffer register indirect addressing."); - vkd3d_string_buffer_printf(buffer, "", reg->type); - break; - } - vkd3d_string_buffer_printf(buffer, "descriptors.cb_%u[%u]", reg->idx[0].offset, reg->idx[2].offset); - msl_print_register_datatype(buffer, gen, reg->data_type); - break; default: msl_compiler_error(gen, VKD3D_SHADER_ERROR_MSL_INTERNAL, @@ -715,144 +790,6 @@ static void msl_handle_instruction(struct msl_generator *gen, const struct vkd3d } } -static bool msl_check_shader_visibility(const struct msl_generator *gen, - enum vkd3d_shader_visibility visibility) -{ - enum vkd3d_shader_type t = gen->program->shader_version.type; - - switch (visibility) - { - case VKD3D_SHADER_VISIBILITY_ALL: - return true; - case VKD3D_SHADER_VISIBILITY_VERTEX: - return t == VKD3D_SHADER_TYPE_VERTEX; - case VKD3D_SHADER_VISIBILITY_HULL: - return t == VKD3D_SHADER_TYPE_HULL; - case VKD3D_SHADER_VISIBILITY_DOMAIN: - return t == VKD3D_SHADER_TYPE_DOMAIN; - case VKD3D_SHADER_VISIBILITY_GEOMETRY: - return t == VKD3D_SHADER_TYPE_GEOMETRY; - case VKD3D_SHADER_VISIBILITY_PIXEL: - return t == VKD3D_SHADER_TYPE_PIXEL; - case VKD3D_SHADER_VISIBILITY_COMPUTE: - return t == VKD3D_SHADER_TYPE_COMPUTE; - default: - WARN("Invalid shader visibility %#x.\n", visibility); - return false; - } -} - -static bool msl_get_cbv_binding(const struct msl_generator *gen, - unsigned int register_space, unsigned int register_idx, unsigned int *binding_idx) -{ - const struct vkd3d_shader_interface_info *interface_info = gen->interface_info; - const struct vkd3d_shader_resource_binding *binding; - unsigned int i; - - if (!interface_info) - return false; - - for (i = 0; i < interface_info->binding_count; ++i) - { - binding = &interface_info->bindings[i]; - - if (binding->type != VKD3D_SHADER_DESCRIPTOR_TYPE_CBV) - continue; - if (binding->register_space != register_space) - continue; - if (binding->register_index != register_idx) - continue; - if (!msl_check_shader_visibility(gen, binding->shader_visibility)) - continue; - if (!(binding->flags & VKD3D_SHADER_BINDING_FLAG_BUFFER)) - continue; - *binding_idx = i; - return true; - } - - return false; -} - -static void msl_generate_cbv_declaration(struct msl_generator *gen, - const struct vkd3d_shader_descriptor_info1 *cbv) -{ - const struct vkd3d_shader_descriptor_binding *binding; - struct vkd3d_string_buffer *buffer = gen->buffer; - unsigned int binding_idx; - size_t size; - - if (cbv->count != 1) - { - msl_compiler_error(gen, VKD3D_SHADER_ERROR_MSL_BINDING_NOT_FOUND, - "Constant buffer %u has unsupported descriptor array size %u.", cbv->register_id, cbv->count); - return; - } - - if (!msl_get_cbv_binding(gen, cbv->register_space, cbv->register_index, &binding_idx)) - { - msl_compiler_error(gen, VKD3D_SHADER_ERROR_MSL_BINDING_NOT_FOUND, - "No descriptor binding specified for constant buffer %u.", cbv->register_id); - return; - } - - binding = &gen->interface_info->bindings[binding_idx].binding; - - if (binding->set != 0) - { - msl_compiler_error(gen, VKD3D_SHADER_ERROR_MSL_BINDING_NOT_FOUND, - "Unsupported binding set %u specified for constant buffer %u.", binding->set, cbv->register_id); - return; - } - - if (binding->count != 1) - { - msl_compiler_error(gen, VKD3D_SHADER_ERROR_MSL_BINDING_NOT_FOUND, - "Unsupported binding count %u specified for constant buffer %u.", binding->count, cbv->register_id); - return; - } - - size = align(cbv->buffer_size, VKD3D_VEC4_SIZE * sizeof(uint32_t)); - size /= VKD3D_VEC4_SIZE * sizeof(uint32_t); - - vkd3d_string_buffer_printf(buffer, - "constant vkd3d_vec4 *cb_%u [[id(%u)]];", cbv->register_id, binding->binding); -}; - -static void msl_generate_descriptor_struct_declarations(struct msl_generator *gen) -{ - const struct vkd3d_shader_scan_descriptor_info1 *info = &gen->program->descriptors; - const struct vkd3d_shader_descriptor_info1 *descriptor; - struct vkd3d_string_buffer *buffer = gen->buffer; - unsigned int i; - - if (!info->descriptor_count) - return; - - vkd3d_string_buffer_printf(buffer, "struct vkd3d_%s_descriptors\n{\n", gen->prefix); - - for (i = 0; i < info->descriptor_count; ++i) - { - descriptor = &info->descriptors[i]; - - msl_print_indent(buffer, 1); - switch (descriptor->type) - { - case VKD3D_SHADER_DESCRIPTOR_TYPE_CBV: - msl_generate_cbv_declaration(gen, descriptor); - break; - - default: - vkd3d_string_buffer_printf(buffer, "/* */", descriptor->type); - msl_compiler_error(gen, VKD3D_SHADER_ERROR_MSL_INTERNAL, - "Internal compiler error: Unhandled descriptor type %#x.", descriptor->type); - break; - } - vkd3d_string_buffer_printf(buffer, "\n"); - } - - vkd3d_string_buffer_printf(buffer, "};\n\n"); -} - static void msl_generate_input_struct_declarations(struct msl_generator *gen) { const struct shader_signature *signature = &gen->program->input_signature; @@ -1175,7 +1112,7 @@ static void msl_generate_entrypoint(struct msl_generator *gen) msl_print_indent(gen->buffer, 2); /* TODO: Configurable argument buffer binding location. */ vkd3d_string_buffer_printf(gen->buffer, - "constant vkd3d_%s_descriptors& descriptors [[buffer(0)]],\n", gen->prefix); + "constant descriptor *descriptors [[buffer(0)]],\n"); } msl_print_indent(gen->buffer, 2); @@ -1223,7 +1160,22 @@ static int msl_generator_generate(struct msl_generator *gen, struct vkd3d_shader vkd3d_string_buffer_printf(gen->buffer, " int4 i;\n"); vkd3d_string_buffer_printf(gen->buffer, " float4 f;\n};\n\n"); - msl_generate_descriptor_struct_declarations(gen); + if (gen->program->descriptors.descriptor_count > 0) + { + vkd3d_string_buffer_printf(gen->buffer, + "struct descriptor\n" + "{\n" + " const device void *ptr;\n" + "\n" + " template\n" + " const device T * constant &buf() constant\n" + " {\n" + " return reinterpret_cast(this->ptr);\n" + " }\n" + "};\n" + "\n"); + } + msl_generate_input_struct_declarations(gen); msl_generate_output_struct_declarations(gen); @@ -1234,7 +1186,7 @@ static int msl_generator_generate(struct msl_generator *gen, struct vkd3d_shader if (gen->write_depth) vkd3d_string_buffer_printf(gen->buffer, ", thread float& o_depth"); if (gen->program->descriptors.descriptor_count) - vkd3d_string_buffer_printf(gen->buffer, ", constant vkd3d_%s_descriptors& descriptors", gen->prefix); + vkd3d_string_buffer_printf(gen->buffer, ", constant descriptor *descriptors"); vkd3d_string_buffer_printf(gen->buffer, ")\n{\n"); ++gen->indent;