diff --git a/libs/vkd3d-shader/spirv.c b/libs/vkd3d-shader/spirv.c index 67701ccf..ae0b1958 100644 --- a/libs/vkd3d-shader/spirv.c +++ b/libs/vkd3d-shader/spirv.c @@ -4726,6 +4726,15 @@ static void vkd3d_dxbc_compiler_enter_shader_phase(struct vkd3d_dxbc_compiler *c vkd3d_spirv_build_op_name(builder, phase->function_id, "%s%u", name, id); } +static const struct vkd3d_shader_phase *vkd3d_dxbc_compiler_get_current_shader_phase( + const struct vkd3d_dxbc_compiler *compiler) +{ + if (!compiler->shader_phase_count) + return NULL; + + return &compiler->shader_phases[compiler->shader_phase_count - 1]; +} + static void vkd3d_dxbc_compiler_emit_hull_shader_main(struct vkd3d_dxbc_compiler *compiler) { struct vkd3d_spirv_builder *builder = &compiler->spirv_builder; @@ -5392,9 +5401,13 @@ static void vkd3d_dxbc_compiler_emit_return(struct vkd3d_dxbc_compiler *compiler const struct vkd3d_shader_instruction *instruction) { struct vkd3d_spirv_builder *builder = &compiler->spirv_builder; + const struct vkd3d_shader_phase *phase; - if (compiler->shader_type != VKD3D_SHADER_TYPE_GEOMETRY) + if (compiler->shader_type != VKD3D_SHADER_TYPE_GEOMETRY + && (!(phase = vkd3d_dxbc_compiler_get_current_shader_phase(compiler)) + || phase->type == VKD3DSIH_HS_CONTROL_POINT_PHASE)) vkd3d_dxbc_compiler_emit_shader_epilogue_invocation(compiler); + vkd3d_spirv_build_op_return(builder); }