diff --git a/libs/vkd3d-shader/hlsl_codegen.c b/libs/vkd3d-shader/hlsl_codegen.c index 013667e2..541a45ac 100644 --- a/libs/vkd3d-shader/hlsl_codegen.c +++ b/libs/vkd3d-shader/hlsl_codegen.c @@ -2169,6 +2169,118 @@ static bool remove_trivial_conditional_branches(struct hlsl_ctx *ctx, struct hls return true; } +static bool normalize_switch_cases(struct hlsl_ctx *ctx, struct hlsl_ir_node *instr, void *context) +{ + struct hlsl_ir_switch_case *c, *def = NULL; + bool missing_terminal_break = false; + struct hlsl_ir_node *node; + struct hlsl_ir_jump *jump; + struct hlsl_ir_switch *s; + + if (instr->type != HLSL_IR_SWITCH) + return false; + s = hlsl_ir_switch(instr); + + LIST_FOR_EACH_ENTRY(c, &s->cases, struct hlsl_ir_switch_case, entry) + { + bool terminal_break = false; + + if (list_empty(&c->body.instrs)) + { + terminal_break = !!list_next(&s->cases, &c->entry); + } + else + { + node = LIST_ENTRY(list_tail(&c->body.instrs), struct hlsl_ir_node, entry); + if (node->type == HLSL_IR_JUMP) + { + jump = hlsl_ir_jump(node); + terminal_break = jump->type == HLSL_IR_JUMP_BREAK; + } + } + + missing_terminal_break |= !terminal_break; + + if (!terminal_break) + { + if (c->is_default) + { + hlsl_error(ctx, &c->loc, VKD3D_SHADER_ERROR_HLSL_INVALID_SYNTAX, + "The 'default' case block is not terminated with 'break' or 'return'."); + } + else + { + hlsl_error(ctx, &c->loc, VKD3D_SHADER_ERROR_HLSL_INVALID_SYNTAX, + "Switch case block '%u' is not terminated with 'break' or 'return'.", c->value); + } + } + } + + if (missing_terminal_break) + return true; + + LIST_FOR_EACH_ENTRY(c, &s->cases, struct hlsl_ir_switch_case, entry) + { + if (c->is_default) + { + def = c; + + /* Remove preceding empty cases. */ + while (list_prev(&s->cases, &def->entry)) + { + c = LIST_ENTRY(list_prev(&s->cases, &def->entry), struct hlsl_ir_switch_case, entry); + if (!list_empty(&c->body.instrs)) + break; + hlsl_free_ir_switch_case(c); + } + + if (list_empty(&def->body.instrs)) + { + /* Remove following empty cases. */ + while (list_next(&s->cases, &def->entry)) + { + c = LIST_ENTRY(list_next(&s->cases, &def->entry), struct hlsl_ir_switch_case, entry); + if (!list_empty(&c->body.instrs)) + break; + hlsl_free_ir_switch_case(c); + } + + /* Merge with the next case. */ + if (list_next(&s->cases, &def->entry)) + { + c = LIST_ENTRY(list_next(&s->cases, &def->entry), struct hlsl_ir_switch_case, entry); + c->is_default = true; + hlsl_free_ir_switch_case(def); + def = c; + } + } + + break; + } + } + + if (def) + { + list_remove(&def->entry); + } + else + { + struct hlsl_ir_node *jump; + + if (!(def = hlsl_new_switch_case(ctx, 0, true, NULL, &s->node.loc))) + return true; + if (!(jump = hlsl_new_jump(ctx, HLSL_IR_JUMP_BREAK, NULL, &s->node.loc))) + { + hlsl_free_ir_switch_case(def); + return true; + } + hlsl_block_add_instr(&def->body, jump); + } + list_add_tail(&s->cases, &def->entry); + + return true; +} + static bool lower_nonconstant_vector_derefs(struct hlsl_ctx *ctx, struct hlsl_ir_node *instr, struct hlsl_block *block) { struct hlsl_ir_node *idx; @@ -4629,6 +4741,7 @@ int hlsl_emit_bytecode(struct hlsl_ctx *ctx, struct hlsl_ir_function_decl *entry } while (progress); remove_unreachable_code(ctx, body); + hlsl_transform_ir(ctx, normalize_switch_cases, body, NULL); lower_ir(ctx, lower_nonconstant_vector_derefs, body); lower_ir(ctx, lower_casts_to_bool, body); diff --git a/tests/hlsl/switch.shader_test b/tests/hlsl/switch.shader_test index 720672a7..aab12485 100644 --- a/tests/hlsl/switch.shader_test +++ b/tests/hlsl/switch.shader_test @@ -211,7 +211,7 @@ float4 main() : sv_target } % unterminated cases -[pixel shader fail(sm<6) todo] +[pixel shader fail(sm<6)] uint4 v; float4 main() : sv_target @@ -230,7 +230,7 @@ float4 main() : sv_target return c; } -[pixel shader fail todo] +[pixel shader fail] uint4 v; float4 main() : sv_target @@ -246,7 +246,7 @@ float4 main() : sv_target return 0.0; } -[pixel shader fail todo] +[pixel shader fail] uint4 v; float4 main() : sv_target @@ -262,7 +262,7 @@ float4 main() : sv_target return 0.0; } -[pixel shader fail(sm<6) todo] +[pixel shader fail(sm<6)] uint4 v; float4 main() : sv_target @@ -279,7 +279,7 @@ float4 main() : sv_target return 0.0; } -[pixel shader fail(sm<6) todo] +[pixel shader fail(sm<6)] uint4 v; float4 main() : sv_target @@ -296,7 +296,7 @@ float4 main() : sv_target return 0.0; } -[pixel shader fail(sm<6) todo] +[pixel shader fail(sm<6)] uint4 v; float4 main() : sv_target