vkd3d-shader/msl: Handle SV_VERTEX_ID inputs.

This commit is contained in:
Henri Verbeet
2025-05-19 00:40:37 +02:00
parent e948098ae3
commit cf312e14a9
Notes: Henri Verbeet 2025-06-05 16:18:57 +02:00
Approved-by: Giovanni Mascellani (@giomasce)
Approved-by: Henri Verbeet (@hverbeet)
Merge-Request: https://gitlab.winehq.org/wine/vkd3d/-/merge_requests/1537
4 changed files with 41 additions and 11 deletions

View File

@@ -48,6 +48,7 @@ struct msl_generator
const char *prefix;
bool failed;
bool read_vertex_id;
bool write_depth;
const struct vkd3d_shader_interface_info *interface_info;
@@ -1059,6 +1060,13 @@ static void msl_generate_input_struct_declarations(struct msl_generator *gen)
vkd3d_string_buffer_printf(buffer, "float4 position [[position]];\n");
continue;
case VKD3D_SHADER_SV_VERTEX_ID:
if (type != VKD3D_SHADER_TYPE_VERTEX)
msl_compiler_error(gen, VKD3D_SHADER_ERROR_MSL_INTERNAL,
"Internal compiler error: Unhandled SV_VERTEX_ID in shader type #%x.", type);
gen->read_vertex_id = true;
continue;
case VKD3D_SHADER_SV_IS_FRONT_FACE:
if (type != VKD3D_SHADER_TYPE_PIXEL)
msl_compiler_error(gen, VKD3D_SHADER_ERROR_MSL_INTERNAL,
@@ -1107,15 +1115,15 @@ static void msl_generate_input_struct_declarations(struct msl_generator *gen)
break;
}
vkd3d_string_buffer_printf(buffer, "shader_in_%u ", i);
vkd3d_string_buffer_printf(buffer, "shader_in_%u [[", i);
switch (type)
{
case VKD3D_SHADER_TYPE_VERTEX:
vkd3d_string_buffer_printf(gen->buffer, "[[attribute(%u)]]", e->target_location);
vkd3d_string_buffer_printf(gen->buffer, "attribute(%u)", e->target_location);
break;
case VKD3D_SHADER_TYPE_PIXEL:
vkd3d_string_buffer_printf(gen->buffer, "[[user(locn%u)]]", e->target_location);
vkd3d_string_buffer_printf(gen->buffer, "user(locn%u)", e->target_location);
break;
default:
msl_compiler_error(gen, VKD3D_SHADER_ERROR_MSL_INTERNAL,
@@ -1129,13 +1137,16 @@ static void msl_generate_input_struct_declarations(struct msl_generator *gen)
case VKD3DSIM_LINEAR:
case VKD3DSIM_NONE:
break;
case VKD3DSIM_CONSTANT:
vkd3d_string_buffer_printf(gen->buffer, ", flat");
break;
default:
msl_compiler_error(gen, VKD3D_SHADER_ERROR_MSL_INTERNAL,
"Internal compiler error: Unhandled interpolation mode %#x.", e->interpolation_mode);
break;
}
vkd3d_string_buffer_printf(buffer, ";\n");
vkd3d_string_buffer_printf(buffer, "]];\n");
}
vkd3d_string_buffer_printf(buffer, "};\n\n");
@@ -1289,6 +1300,12 @@ static void msl_generate_entrypoint_prologue(struct msl_generator *gen)
vkd3d_string_buffer_printf(buffer, " = float4(input.position.xyz, 1.0f / input.position.w)");
break;
case VKD3D_SHADER_SV_VERTEX_ID:
msl_print_register_datatype(buffer, gen, VKD3D_DATA_UINT);
msl_print_write_mask(buffer, e->mask);
vkd3d_string_buffer_printf(buffer, " = uint4(vertex_id, 0u, 0u, 0u)");
break;
case VKD3D_SHADER_SV_IS_FRONT_FACE:
msl_print_register_datatype(buffer, gen, VKD3D_DATA_UINT);
msl_print_write_mask(buffer, e->mask);
@@ -1374,6 +1391,12 @@ static void msl_generate_entrypoint(struct msl_generator *gen)
"constant descriptor *descriptors [[buffer(0)]],\n");
}
if (gen->read_vertex_id)
{
msl_print_indent(gen->buffer, 2);
vkd3d_string_buffer_printf(gen->buffer, "uint vertex_id [[vertex_id]],\n");
}
msl_print_indent(gen->buffer, 2);
vkd3d_string_buffer_printf(gen->buffer, "vkd3d_%s_in input [[stage_in]])\n{\n", gen->prefix);
@@ -1388,6 +1411,8 @@ static void msl_generate_entrypoint(struct msl_generator *gen)
msl_generate_entrypoint_prologue(gen);
vkd3d_string_buffer_printf(gen->buffer, " %s_main(%s_in, %s_out", gen->prefix, gen->prefix, gen->prefix);
if (gen->read_vertex_id)
vkd3d_string_buffer_printf(gen->buffer, ", vertex_id");
if (gen->write_depth)
vkd3d_string_buffer_printf(gen->buffer, ", shader_out_depth");
if (gen->program->descriptors.descriptor_count)
@@ -1449,6 +1474,9 @@ static int msl_generator_generate(struct msl_generator *gen, struct vkd3d_shader
"static void %s_main(thread vkd3d_vec4 *v, "
"thread vkd3d_vec4 *o",
gen->prefix);
if (gen->read_vertex_id)
vkd3d_string_buffer_printf(gen->buffer, ", uint vertex_id");
if (gen->write_depth)
vkd3d_string_buffer_printf(gen->buffer, ", thread float& o_depth");
if (gen->program->descriptors.descriptor_count)