diff --git a/libs/vkd3d-shader/hlsl_codegen.c b/libs/vkd3d-shader/hlsl_codegen.c index 2df88f57..283a67e6 100644 --- a/libs/vkd3d-shader/hlsl_codegen.c +++ b/libs/vkd3d-shader/hlsl_codegen.c @@ -57,13 +57,24 @@ static void replace_node(struct hlsl_ir_node *old, struct hlsl_ir_node *new) hlsl_free_instr(old); } +static bool is_vec1(const struct hlsl_type *type) +{ + return (type->type == HLSL_CLASS_SCALAR) || (type->type == HLSL_CLASS_VECTOR && type->dimx == 1); +} + static bool fold_redundant_casts(struct hlsl_ctx *ctx, struct hlsl_ir_node *instr, void *context) { if (instr->type == HLSL_IR_EXPR) { struct hlsl_ir_expr *expr = hlsl_ir_expr(instr); + const struct hlsl_type *src_type = expr->operands[0].node->data_type; + const struct hlsl_type *dst_type = expr->node.data_type; - if (expr->op == HLSL_IR_UNOP_CAST && hlsl_type_compare(expr->node.data_type, expr->operands[0].node->data_type)) + if (expr->op != HLSL_IR_UNOP_CAST) + return false; + + if (hlsl_type_compare(src_type, dst_type) + || (src_type->base_type == dst_type->base_type && is_vec1(src_type) && is_vec1(dst_type))) { replace_node(&expr->node, expr->operands[0].node); return true;