diff --git a/libs/vkd3d-shader/msl.c b/libs/vkd3d-shader/msl.c index f28c6bd3d..d6fcbe304 100644 --- a/libs/vkd3d-shader/msl.c +++ b/libs/vkd3d-shader/msl.c @@ -18,6 +18,13 @@ #include "vkd3d_shader_private.h" +enum msl_data_type +{ + MSL_DATA_FLOAT, + MSL_DATA_UINT, + MSL_DATA_UNION, +}; + struct msl_src { struct vkd3d_string_buffer *str; @@ -267,15 +274,14 @@ static void msl_print_srv_name(struct vkd3d_string_buffer *buffer, struct msl_ge vkd3d_string_buffer_printf(buffer, ">>()"); } -static void msl_print_register_name(struct vkd3d_string_buffer *buffer, +static enum msl_data_type msl_print_register_name(struct vkd3d_string_buffer *buffer, struct msl_generator *gen, const struct vkd3d_shader_register *reg) { switch (reg->type) { case VKD3DSPR_TEMP: vkd3d_string_buffer_printf(buffer, "r[%u]", reg->idx[0].offset); - msl_print_register_datatype(buffer, gen, reg->data_type); - break; + return MSL_DATA_UNION; case VKD3DSPR_INPUT: if (reg->idx_count != 1) @@ -283,18 +289,17 @@ static void msl_print_register_name(struct vkd3d_string_buffer *buffer, msl_compiler_error(gen, VKD3D_SHADER_ERROR_MSL_INTERNAL, "Internal compiler error: Unhandled input register index count %u.", reg->idx_count); vkd3d_string_buffer_printf(buffer, "", reg->type); - break; + return MSL_DATA_UNION; } if (reg->idx[0].rel_addr) { msl_compiler_error(gen, VKD3D_SHADER_ERROR_MSL_INTERNAL, "Internal compiler error: Unhandled input register indirect addressing."); vkd3d_string_buffer_printf(buffer, "", reg->type); - break; + return MSL_DATA_UNION; } vkd3d_string_buffer_printf(buffer, "v[%u]", reg->idx[0].offset); - msl_print_register_datatype(buffer, gen, reg->data_type); - break; + return MSL_DATA_UNION; case VKD3DSPR_OUTPUT: if (reg->idx_count != 1) @@ -302,18 +307,17 @@ static void msl_print_register_name(struct vkd3d_string_buffer *buffer, msl_compiler_error(gen, VKD3D_SHADER_ERROR_MSL_INTERNAL, "Internal compiler error: Unhandled output register index count %u.", reg->idx_count); vkd3d_string_buffer_printf(buffer, "", reg->type); - break; + return MSL_DATA_UNION; } if (reg->idx[0].rel_addr) { msl_compiler_error(gen, VKD3D_SHADER_ERROR_MSL_INTERNAL, "Internal compiler error: Unhandled output register indirect addressing."); vkd3d_string_buffer_printf(buffer, "", reg->type); - break; + return MSL_DATA_UNION; } vkd3d_string_buffer_printf(buffer, "o[%u]", reg->idx[0].offset); - msl_print_register_datatype(buffer, gen, reg->data_type); - break; + return MSL_DATA_UNION; case VKD3DSPR_DEPTHOUT: if (gen->program->shader_version.type != VKD3D_SHADER_TYPE_PIXEL) @@ -321,64 +325,27 @@ static void msl_print_register_name(struct vkd3d_string_buffer *buffer, "Internal compiler error: Unhandled depth output in shader type #%x.", gen->program->shader_version.type); vkd3d_string_buffer_printf(buffer, "o_depth"); - break; + return MSL_DATA_FLOAT; case VKD3DSPR_IMMCONST: switch (reg->dimension) { case VSIR_DIMENSION_SCALAR: - switch (reg->data_type) - { - case VKD3D_DATA_INT: - vkd3d_string_buffer_printf(buffer, "as_type(%#xu)", reg->u.immconst_u32[0]); - break; - case VKD3D_DATA_UINT: - vkd3d_string_buffer_printf(buffer, "%#xu", reg->u.immconst_u32[0]); - break; - case VKD3D_DATA_FLOAT: - vkd3d_string_buffer_printf(buffer, "as_type(%#xu)", reg->u.immconst_u32[0]); - break; - default: - msl_compiler_error(gen, VKD3D_SHADER_ERROR_MSL_INTERNAL, - "Internal compiler error: Unhandled immconst datatype %#x.", reg->data_type); - vkd3d_string_buffer_printf(buffer, "", reg->data_type); - break; - } - break; + vkd3d_string_buffer_printf(buffer, "%#xu", reg->u.immconst_u32[0]); + return MSL_DATA_UINT; case VSIR_DIMENSION_VEC4: - switch (reg->data_type) - { - case VKD3D_DATA_INT: - vkd3d_string_buffer_printf(buffer, "as_type(uint4(%#xu, %#xu, %#xu, %#xu))", - reg->u.immconst_u32[0], reg->u.immconst_u32[1], - reg->u.immconst_u32[2], reg->u.immconst_u32[3]); - break; - case VKD3D_DATA_UINT: - vkd3d_string_buffer_printf(buffer, "uint4(%#xu, %#xu, %#xu, %#xu)", - reg->u.immconst_u32[0], reg->u.immconst_u32[1], - reg->u.immconst_u32[2], reg->u.immconst_u32[3]); - break; - case VKD3D_DATA_FLOAT: - vkd3d_string_buffer_printf(buffer, "as_type(uint4(%#xu, %#xu, %#xu, %#xu))", - reg->u.immconst_u32[0], reg->u.immconst_u32[1], - reg->u.immconst_u32[2], reg->u.immconst_u32[3]); - break; - default: - msl_compiler_error(gen, VKD3D_SHADER_ERROR_MSL_INTERNAL, - "Internal compiler error: Unhandled immconst datatype %#x.", reg->data_type); - vkd3d_string_buffer_printf(buffer, "", reg->data_type); - break; - } - break; + vkd3d_string_buffer_printf(buffer, "uint4(%#xu, %#xu, %#xu, %#xu)", + reg->u.immconst_u32[0], reg->u.immconst_u32[1], + reg->u.immconst_u32[2], reg->u.immconst_u32[3]); + return MSL_DATA_UINT; default: vkd3d_string_buffer_printf(buffer, "", reg->dimension); msl_compiler_error(gen, VKD3D_SHADER_ERROR_MSL_INTERNAL, "Internal compiler error: Unhandled dimension %#x.", reg->dimension); - break; + return MSL_DATA_UINT; } - break; case VKD3DSPR_CONSTBUFFER: { @@ -390,33 +357,32 @@ static void msl_print_register_name(struct vkd3d_string_buffer *buffer, "Internal compiler error: Unhandled constant buffer register index count %u.", reg->idx_count); vkd3d_string_buffer_printf(buffer, "", reg->type); - break; + return MSL_DATA_UNION; } if (reg->idx[0].rel_addr || reg->idx[1].rel_addr || reg->idx[2].rel_addr) { msl_compiler_error(gen, VKD3D_SHADER_ERROR_MSL_INTERNAL, "Internal compiler error: Unhandled constant buffer register indirect addressing."); vkd3d_string_buffer_printf(buffer, "", reg->type); - break; + return MSL_DATA_UNION; } if (!(binding = msl_get_cbv_binding(gen, 0, reg->idx[1].offset))) { msl_compiler_error(gen, VKD3D_SHADER_ERROR_MSL_BINDING_NOT_FOUND, "Cannot finding binding for CBV register %u.", reg->idx[0].offset); vkd3d_string_buffer_printf(buffer, "", reg->type); - break; + return MSL_DATA_UNION; } msl_print_cbv_name(buffer, binding->binding); vkd3d_string_buffer_printf(buffer, "[%u]", reg->idx[2].offset); - msl_print_register_datatype(buffer, gen, reg->data_type); - break; + return MSL_DATA_UNION; } default: msl_compiler_error(gen, VKD3D_SHADER_ERROR_MSL_INTERNAL, "Internal compiler error: Unhandled register type %#x.", reg->type); vkd3d_string_buffer_printf(buffer, "", reg->type); - break; + return MSL_DATA_UINT; } } @@ -451,11 +417,52 @@ static void msl_src_cleanup(struct msl_src *src, struct vkd3d_string_buffer_cach vkd3d_string_buffer_release(cache, src->str); } -static void msl_print_src(struct vkd3d_string_buffer *buffer, struct msl_generator *gen, - const struct vkd3d_shader_src_param *vsir_src, uint32_t mask) +static void msl_print_bitcast(struct vkd3d_string_buffer *dst, struct msl_generator *gen, const char *src, + enum vkd3d_data_type dst_data_type, enum msl_data_type src_data_type, enum vsir_dimension dimension) +{ + bool write_cast = false; + + if (dst_data_type == VKD3D_DATA_UNORM || dst_data_type == VKD3D_DATA_SNORM) + dst_data_type = VKD3D_DATA_FLOAT; + + switch (src_data_type) + { + case MSL_DATA_FLOAT: + write_cast = dst_data_type != VKD3D_DATA_FLOAT; + break; + + case MSL_DATA_UINT: + write_cast = dst_data_type != VKD3D_DATA_UINT; + break; + + case MSL_DATA_UNION: + break; + } + + if (write_cast) + { + vkd3d_string_buffer_printf(dst, "as_type<"); + msl_print_resource_datatype(gen, dst, dst_data_type); + vkd3d_string_buffer_printf(dst, "%s>(", dimension == VSIR_DIMENSION_VEC4 ? "4" : ""); + } + + vkd3d_string_buffer_printf(dst, "%s", src); + + if (write_cast) + vkd3d_string_buffer_printf(dst, ")"); + + if (src_data_type == MSL_DATA_UNION) + msl_print_register_datatype(dst, gen, dst_data_type); +} + +static void msl_print_src_with_type(struct vkd3d_string_buffer *buffer, struct msl_generator *gen, + const struct vkd3d_shader_src_param *vsir_src, uint32_t mask, enum vkd3d_data_type data_type) { const struct vkd3d_shader_register *reg = &vsir_src->reg; - struct vkd3d_string_buffer *str; + struct vkd3d_string_buffer *register_name, *str; + enum msl_data_type src_data_type; + + register_name = vkd3d_string_buffer_get(&gen->string_buffers); if (reg->non_uniform) msl_compiler_error(gen, VKD3D_SHADER_ERROR_MSL_INTERNAL, @@ -466,7 +473,8 @@ static void msl_print_src(struct vkd3d_string_buffer *buffer, struct msl_generat else str = vkd3d_string_buffer_get(&gen->string_buffers); - msl_print_register_name(str, gen, reg); + src_data_type = msl_print_register_name(register_name, gen, reg); + msl_print_bitcast(str, gen, register_name->buffer, data_type, src_data_type, reg->dimension); if (reg->dimension == VSIR_DIMENSION_VEC4) msl_print_swizzle(str, vsir_src->swizzle, mask); @@ -496,7 +504,7 @@ static void msl_src_init(struct msl_src *msl_src, struct msl_generator *gen, const struct vkd3d_shader_src_param *vsir_src, uint32_t mask) { msl_src->str = vkd3d_string_buffer_get(&gen->string_buffers); - msl_print_src(msl_src->str, gen, vsir_src, mask); + msl_print_src_with_type(msl_src->str, gen, vsir_src, mask, vsir_src->reg.data_type); } static void msl_dst_cleanup(struct msl_dst *dst, struct vkd3d_string_buffer_cache *cache) @@ -509,6 +517,7 @@ static uint32_t msl_dst_init(struct msl_dst *msl_dst, struct msl_generator *gen, const struct vkd3d_shader_instruction *ins, const struct vkd3d_shader_dst_param *vsir_dst) { uint32_t write_mask = vsir_dst->write_mask; + enum msl_data_type dst_data_type; if (ins->flags & VKD3DSI_PRECISE_XYZW) msl_compiler_error(gen, VKD3D_SHADER_ERROR_MSL_INTERNAL, @@ -521,7 +530,9 @@ static uint32_t msl_dst_init(struct msl_dst *msl_dst, struct msl_generator *gen, msl_dst->register_name = vkd3d_string_buffer_get(&gen->string_buffers); msl_dst->mask = vkd3d_string_buffer_get(&gen->string_buffers); - msl_print_register_name(msl_dst->register_name, gen, &vsir_dst->reg); + dst_data_type = msl_print_register_name(msl_dst->register_name, gen, &vsir_dst->reg); + if (dst_data_type == MSL_DATA_UNION) + msl_print_register_datatype(msl_dst->mask, gen, vsir_dst->reg.data_type); if (vsir_dst->reg.dimension == VSIR_DIMENSION_VEC4) msl_print_write_mask(msl_dst->mask, write_mask);