diff --git a/libs/vkd3d-shader/msl.c b/libs/vkd3d-shader/msl.c index f113721eb..cf01772ef 100644 --- a/libs/vkd3d-shader/msl.c +++ b/libs/vkd3d-shader/msl.c @@ -481,7 +481,7 @@ static enum msl_data_type msl_print_register_name(struct vkd3d_string_buffer *bu "Internal compiler error: Unhandled sample coverage mask in shader type #%x.", gen->program->shader_version.type); vkd3d_string_buffer_printf(buffer, "o_mask"); - return MSL_DATA_FLOAT; + return MSL_DATA_UNION; default: msl_compiler_error(gen, VKD3D_SHADER_ERROR_MSL_INTERNAL, @@ -1952,7 +1952,7 @@ static void msl_generate_entrypoint_epilogue(struct msl_generator *gen) } if (bitmap_is_set(gen->program->io_dcls, VKD3DSPR_SAMPLEMASK)) - vkd3d_string_buffer_printf(gen->buffer, " output.shader_out_mask = as_type(o_mask);\n"); + vkd3d_string_buffer_printf(gen->buffer, " output.shader_out_mask = o_mask.u;\n"); } static void msl_generate_entrypoint(struct msl_generator *gen) @@ -1999,7 +1999,7 @@ static void msl_generate_entrypoint(struct msl_generator *gen) vkd3d_string_buffer_printf(gen->buffer, " vkd3d_vec4 %s_out[%u];\n", gen->prefix, 32); vkd3d_string_buffer_printf(gen->buffer, " vkd3d_%s_out output;\n", gen->prefix); if (bitmap_is_set(gen->program->io_dcls, VKD3DSPR_SAMPLEMASK)) - vkd3d_string_buffer_printf(gen->buffer, " float o_mask;\n"); + vkd3d_string_buffer_printf(gen->buffer, " vkd3d_scalar o_mask;\n"); vkd3d_string_buffer_printf(gen->buffer, "\n"); msl_generate_entrypoint_prologue(gen); @@ -2081,7 +2081,7 @@ static int msl_generator_generate(struct msl_generator *gen, struct vkd3d_shader if (bitmap_is_set(gen->program->io_dcls, VKD3DSPR_DEPTHOUT)) vkd3d_string_buffer_printf(gen->buffer, ", thread float &o_depth"); if (bitmap_is_set(gen->program->io_dcls, VKD3DSPR_SAMPLEMASK)) - vkd3d_string_buffer_printf(gen->buffer, ", thread float &o_mask"); + vkd3d_string_buffer_printf(gen->buffer, ", thread vkd3d_scalar &o_mask"); if (gen->program->descriptors.descriptor_count) vkd3d_string_buffer_printf(gen->buffer, ", constant descriptor *descriptors"); vkd3d_string_buffer_printf(gen->buffer, ")\n{\n");