From 1a20395e6fb303b4a81602abac0378f986dfeb0e Mon Sep 17 00:00:00 2001 From: Henri Verbeet Date: Wed, 28 May 2025 22:52:18 +0200 Subject: [PATCH] vkd3d-shader/msl: Implement VSIR_OP_STORE_UAV_TYPED. --- libs/vkd3d-shader/msl.c | 138 ++++++++++++++++++++++++++++ tests/hlsl/uav-rwbuffer.shader_test | 4 +- tests/shader_runner_metal.m | 45 +++++++-- 3 files changed, 178 insertions(+), 9 deletions(-) diff --git a/libs/vkd3d-shader/msl.c b/libs/vkd3d-shader/msl.c index ec4dbacd3..9f6127f46 100644 --- a/libs/vkd3d-shader/msl.c +++ b/libs/vkd3d-shader/msl.c @@ -299,6 +299,41 @@ static const struct vkd3d_shader_descriptor_binding *msl_get_srv_binding(const s return NULL; } +static const struct vkd3d_shader_descriptor_binding *msl_get_uav_binding(const struct msl_generator *gen, + unsigned int register_space, unsigned int register_idx, enum vkd3d_shader_resource_type resource_type) +{ + const struct vkd3d_shader_interface_info *interface_info = gen->interface_info; + const struct vkd3d_shader_resource_binding *binding; + enum vkd3d_shader_binding_flag resource_type_flag; + unsigned int i; + + if (!interface_info) + return NULL; + + resource_type_flag = resource_type == VKD3D_SHADER_RESOURCE_BUFFER + ? VKD3D_SHADER_BINDING_FLAG_BUFFER : VKD3D_SHADER_BINDING_FLAG_IMAGE; + + for (i = 0; i < interface_info->binding_count; ++i) + { + binding = &interface_info->bindings[i]; + + if (binding->type != VKD3D_SHADER_DESCRIPTOR_TYPE_UAV) + continue; + if (binding->register_space != register_space) + continue; + if (binding->register_index != register_idx) + continue; + if (!msl_check_shader_visibility(gen, binding->shader_visibility)) + continue; + if (!(binding->flags & resource_type_flag)) + continue; + + return &binding->binding; + } + + return NULL; +} + static void msl_print_cbv_name(struct vkd3d_string_buffer *buffer, unsigned int binding) { vkd3d_string_buffer_printf(buffer, "descriptors[%u].buf()", binding); @@ -319,6 +354,16 @@ static void msl_print_srv_name(struct vkd3d_string_buffer *buffer, struct msl_ge vkd3d_string_buffer_printf(buffer, ">>()"); } +static void msl_print_uav_name(struct vkd3d_string_buffer *buffer, struct msl_generator *gen, unsigned int binding, + const struct msl_resource_type_info *resource_type_info, enum vkd3d_data_type resource_data_type) +{ + vkd3d_string_buffer_printf(buffer, "descriptors[%u].astype_suffix, + resource_type_info->array ? "_array" : ""); + msl_print_resource_datatype(gen, buffer, resource_data_type); + vkd3d_string_buffer_printf(buffer, ", access::read_write>>()"); +} + static enum msl_data_type msl_print_register_name(struct vkd3d_string_buffer *buffer, struct msl_generator *gen, const struct vkd3d_shader_register *reg) { @@ -1204,6 +1249,96 @@ static void msl_sample(struct msl_generator *gen, const struct vkd3d_shader_inst msl_dst_cleanup(&dst, &gen->string_buffers); } +static void msl_store_uav_typed(struct msl_generator *gen, const struct vkd3d_shader_instruction *ins) +{ + const struct msl_resource_type_info *resource_type_info; + const struct vkd3d_shader_descriptor_binding *binding; + const struct vkd3d_shader_descriptor_info1 *d; + enum vkd3d_shader_resource_type resource_type; + unsigned int uav_id, uav_idx, uav_space; + struct vkd3d_string_buffer *image_data; + enum vkd3d_data_type data_type; + unsigned int uav_binding; + uint32_t coord_mask; + + if (ins->dst[0].reg.idx[0].rel_addr || ins->dst[0].reg.idx[1].rel_addr) + msl_compiler_error(gen, VKD3D_SHADER_ERROR_MSL_UNSUPPORTED, + "Descriptor indexing is not supported."); + + uav_id = ins->dst[0].reg.idx[0].offset; + uav_idx = ins->dst[0].reg.idx[1].offset; + if ((d = vkd3d_shader_find_descriptor(&gen->program->descriptors, + VKD3D_SHADER_DESCRIPTOR_TYPE_UAV, uav_id))) + { + uav_space = d->register_space; + resource_type = d->resource_type; + data_type = d->resource_data_type; + } + else + { + msl_compiler_error(gen, VKD3D_SHADER_ERROR_MSL_INTERNAL, + "Internal compiler error: Undeclared UAV descriptor %u.", uav_id); + uav_space = 0; + resource_type = VKD3D_SHADER_RESOURCE_TEXTURE_2D; + data_type = VKD3D_DATA_FLOAT; + } + + if (!(resource_type_info = msl_get_resource_type_info(resource_type))) + { + msl_compiler_error(gen, VKD3D_SHADER_ERROR_MSL_INTERNAL, + "Internal compiler error: Unhandled resource type %#x.", resource_type); + resource_type_info = msl_get_resource_type_info(VKD3D_SHADER_RESOURCE_TEXTURE_2D); + } + coord_mask = vkd3d_write_mask_from_component_count(resource_type_info->coord_size); + + if ((binding = msl_get_uav_binding(gen, uav_space, uav_idx, resource_type))) + { + uav_binding = binding->binding; + } + else + { + msl_compiler_error(gen, VKD3D_SHADER_ERROR_MSL_BINDING_NOT_FOUND, + "No descriptor binding specified for UAV %u (index %u, space %u).", + uav_id, uav_idx, uav_space); + uav_binding = 0; + } + + image_data = vkd3d_string_buffer_get(&gen->string_buffers); + + if (ins->src[1].reg.dimension == VSIR_DIMENSION_SCALAR) + { + switch (data_type) + { + case VKD3D_DATA_UINT: + vkd3d_string_buffer_printf(image_data, "uint4("); + break; + case VKD3D_DATA_INT: + vkd3d_string_buffer_printf(image_data, "int4("); + break; + default: + msl_compiler_error(gen, VKD3D_SHADER_ERROR_MSL_INTERNAL, + "Internal compiler error: Unhandled data type %#x.", data_type); + /* fall through */ + case VKD3D_DATA_FLOAT: + case VKD3D_DATA_UNORM: + case VKD3D_DATA_SNORM: + vkd3d_string_buffer_printf(image_data, "float4("); + break; + } + } + msl_print_src_with_type(image_data, gen, &ins->src[1], VKD3DSP_WRITEMASK_ALL, data_type); + if (ins->src[1].reg.dimension == VSIR_DIMENSION_SCALAR) + vkd3d_string_buffer_printf(image_data, ", 0, 0, 0)"); + + msl_print_indent(gen->buffer, gen->indent); + msl_print_uav_name(gen->buffer, gen, uav_binding, resource_type_info, data_type); + vkd3d_string_buffer_printf(gen->buffer, ".write(%s, ", image_data->buffer); + msl_print_src_with_type(gen->buffer, gen, &ins->src[0], coord_mask, VKD3D_DATA_UINT); + vkd3d_string_buffer_printf(gen->buffer, ");\n"); + + vkd3d_string_buffer_release(&gen->string_buffers, image_data); +} + static void msl_unary_op(struct msl_generator *gen, const struct vkd3d_shader_instruction *ins, const char *op) { struct msl_src src; @@ -1458,6 +1593,9 @@ static void msl_handle_instruction(struct msl_generator *gen, const struct vkd3d case VSIR_OP_SQRT: msl_intrinsic(gen, ins, "sqrt"); break; + case VSIR_OP_STORE_UAV_TYPED: + msl_store_uav_typed(gen, ins); + break; case VSIR_OP_SWITCH: msl_switch(gen, ins); break; diff --git a/tests/hlsl/uav-rwbuffer.shader_test b/tests/hlsl/uav-rwbuffer.shader_test index dc5d0d703..0b8cd804a 100644 --- a/tests/hlsl/uav-rwbuffer.shader_test +++ b/tests/hlsl/uav-rwbuffer.shader_test @@ -184,7 +184,7 @@ float4 main() : sv_target } [test] -todo(msl) draw quad +draw quad probe uav 1 (0) i32(11, -12, 13, -14) probe uav 1 (1) i32(-15, 16, -17, 18) @@ -203,7 +203,7 @@ float4 main() : sv_target1 } [test] -todo(msl) draw quad +draw quad probe uav 2 (0) f32(11.1, 12.2, 13.3, 14.4) [require] diff --git a/tests/shader_runner_metal.m b/tests/shader_runner_metal.m index b3a20f2aa..751d50243 100644 --- a/tests/shader_runner_metal.m +++ b/tests/shader_runner_metal.m @@ -293,6 +293,9 @@ static void init_resource_texture(struct metal_runner *runner, break; case RESOURCE_TYPE_UAV: + desc.usage = MTLTextureUsageShaderRead | MTLTextureUsageShaderWrite; + break; + case RESOURCE_TYPE_VERTEX_BUFFER: break; } @@ -357,15 +360,13 @@ static struct resource *metal_runner_create_resource(struct shader_runner *r, co case RESOURCE_TYPE_RENDER_TARGET: case RESOURCE_TYPE_DEPTH_STENCIL: case RESOURCE_TYPE_TEXTURE: + case RESOURCE_TYPE_UAV: init_resource_texture(runner, resource, params); break; case RESOURCE_TYPE_VERTEX_BUFFER: init_resource_buffer(runner, resource, params); break; - - case RESOURCE_TYPE_UAV: - break; } return &resource->r; @@ -444,9 +445,24 @@ static bool compile_shader(struct metal_runner *runner, enum shader_type type, s ++interface_info.binding_count; break; + case RESOURCE_TYPE_UAV: + binding = &bindings[interface_info.binding_count]; + binding->type = VKD3D_SHADER_DESCRIPTOR_TYPE_UAV; + binding->register_space = 0; + binding->register_index = resource->r.desc.slot; + binding->shader_visibility = VKD3D_SHADER_VISIBILITY_ALL; + if (resource->r.desc.dimension == RESOURCE_DIMENSION_BUFFER) + binding->flags = VKD3D_SHADER_BINDING_FLAG_BUFFER; + else + binding->flags = VKD3D_SHADER_BINDING_FLAG_IMAGE; + binding->binding.set = 0; + binding->binding.binding = interface_info.binding_count; + binding->binding.count = 1; + ++interface_info.binding_count; + break; + case RESOURCE_TYPE_RENDER_TARGET: case RESOURCE_TYPE_DEPTH_STENCIL: - case RESOURCE_TYPE_UAV: case RESOURCE_TYPE_VERTEX_BUFFER: break; @@ -538,9 +554,17 @@ static bool encode_argument_buffer(struct metal_runner *runner, [argument_descriptors addObject:arg_desc]; break; + case RESOURCE_TYPE_UAV: + arg_desc = [MTLArgumentDescriptor argumentDescriptor]; + arg_desc.dataType = MTLDataTypeTexture; + arg_desc.index = [argument_descriptors count]; + arg_desc.access = MTLBindingAccessReadWrite; + arg_desc.textureType = [resource->texture textureType]; + [argument_descriptors addObject:arg_desc]; + break; + case RESOURCE_TYPE_RENDER_TARGET: case RESOURCE_TYPE_DEPTH_STENCIL: - case RESOURCE_TYPE_UAV: case RESOURCE_TYPE_VERTEX_BUFFER: break; } @@ -589,9 +613,15 @@ static bool encode_argument_buffer(struct metal_runner *runner, stages:MTLRenderStageVertex | MTLRenderStageFragment]; break; + case RESOURCE_TYPE_UAV: + [encoder setTexture:resource->texture atIndex:index++]; + [command_encoder useResource:resource->texture + usage:MTLResourceUsageRead | MTLResourceUsageWrite + stages:MTLRenderStageVertex | MTLRenderStageFragment]; + break; + case RESOURCE_TYPE_RENDER_TARGET: case RESOURCE_TYPE_DEPTH_STENCIL: - case RESOURCE_TYPE_UAV: case RESOURCE_TYPE_VERTEX_BUFFER: break; } @@ -908,7 +938,8 @@ static struct resource_readback *metal_runner_get_resource_readback(struct shade id src_texture; unsigned int layer, level; - if (resource->r.desc.dimension != RESOURCE_DIMENSION_2D) + if (resource->r.desc.dimension != RESOURCE_DIMENSION_BUFFER + && resource->r.desc.dimension != RESOURCE_DIMENSION_2D) fatal_error("Unhandled resource dimension %#x.\n", resource->r.desc.dimension); rb = malloc(sizeof(*rb));