diff --git a/libs/vkd3d-shader/hlsl_constant_ops.c b/libs/vkd3d-shader/hlsl_constant_ops.c index 8d112fb5..538f0f46 100644 --- a/libs/vkd3d-shader/hlsl_constant_ops.c +++ b/libs/vkd3d-shader/hlsl_constant_ops.c @@ -1588,7 +1588,7 @@ static bool is_op_left_distributive(enum hlsl_ir_expr_op opl, enum hlsl_ir_expr_ } /* Attempt to collect together the expression (x OPL a) OPR (x OPL b) -> x OPL (a OPR b). */ -static struct hlsl_ir_node *collect_exprs(struct hlsl_ctx *ctx, struct hlsl_ir_node *instr, +static struct hlsl_ir_node *collect_exprs(struct hlsl_ctx *ctx, struct hlsl_block *block, struct hlsl_ir_node *instr, enum hlsl_ir_expr_op opr, struct hlsl_ir_node *node1, struct hlsl_ir_node *node2) { enum hlsl_base_type type = instr->data_type->e.numeric.type; @@ -1612,14 +1612,14 @@ static struct hlsl_ir_node *collect_exprs(struct hlsl_ctx *ctx, struct hlsl_ir_n if (!(ab = hlsl_new_binary_expr(ctx, opr, e1->operands[1].node, e2->operands[1].node))) return NULL; - list_add_before(&instr->entry, &ab->entry); + hlsl_block_add_instr(block, ab); operands[0] = e1->operands[0].node; operands[1] = ab; if (!(res = hlsl_new_expr(ctx, opl, operands, instr->data_type, &instr->loc))) return NULL; - list_add_before(&instr->entry, &res->entry); + hlsl_block_add_instr(block, res); return res; } @@ -1629,6 +1629,7 @@ bool hlsl_normalize_binary_exprs(struct hlsl_ctx *ctx, struct hlsl_ir_node *inst struct hlsl_ir_expr *expr; enum hlsl_base_type type; enum hlsl_ir_expr_op op; + struct hlsl_block block; bool progress = false; if (instr->type != HLSL_IR_EXPR) @@ -1638,6 +1639,8 @@ bool hlsl_normalize_binary_exprs(struct hlsl_ctx *ctx, struct hlsl_ir_node *inst if (instr->data_type->class > HLSL_CLASS_VECTOR) return false; + hlsl_block_init(&block); + arg1 = expr->operands[0].node; arg2 = expr->operands[1].node; type = instr->data_type->e.numeric.type; @@ -1646,9 +1649,10 @@ bool hlsl_normalize_binary_exprs(struct hlsl_ctx *ctx, struct hlsl_ir_node *inst if (!arg1 || !arg2) return false; - if ((tmp = collect_exprs(ctx, instr, op, arg1, arg2))) + if ((tmp = collect_exprs(ctx, &block, instr, op, arg1, arg2))) { /* (x OPL a) OPR (x OPL b) -> x OPL (a OPR b) */ + list_move_before(&instr->entry, &block.instrs); hlsl_replace_node(instr, tmp); return true; } @@ -1676,8 +1680,8 @@ bool hlsl_normalize_binary_exprs(struct hlsl_ctx *ctx, struct hlsl_ir_node *inst struct hlsl_ir_node *ab; if (!(ab = hlsl_new_binary_expr(ctx, op, e1->operands[1].node, arg2))) - return false; - list_add_before(&instr->entry, &ab->entry); + goto fail; + hlsl_block_add_instr(&block, ab); arg1 = e1->operands[0].node; arg2 = ab; @@ -1689,8 +1693,8 @@ bool hlsl_normalize_binary_exprs(struct hlsl_ctx *ctx, struct hlsl_ir_node *inst struct hlsl_ir_node *xy; if (!(xy = hlsl_new_binary_expr(ctx, op, e1->operands[0].node, arg2))) - return false; - list_add_before(&instr->entry, &xy->entry); + goto fail; + hlsl_block_add_instr(&block, xy); arg1 = xy; arg2 = e1->operands[1].node; @@ -1705,15 +1709,15 @@ bool hlsl_normalize_binary_exprs(struct hlsl_ctx *ctx, struct hlsl_ir_node *inst struct hlsl_ir_node *xy; if (!(xy = hlsl_new_binary_expr(ctx, op, arg1, e2->operands[0].node))) - return false; - list_add_before(&instr->entry, &xy->entry); + goto fail; + hlsl_block_add_instr(&block, xy); arg1 = xy; arg2 = e2->operands[1].node; progress = true; } - if (!progress && e1 && (tmp = collect_exprs(ctx, instr, op, e1->operands[1].node, arg2))) + if (!progress && e1 && (tmp = collect_exprs(ctx, &block, instr, op, e1->operands[1].node, arg2))) { /* (y OPR (x OPL a)) OPR (x OPL b) -> y OPR (x OPL (a OPR b)) */ arg1 = e1->operands[0].node; @@ -1722,7 +1726,7 @@ bool hlsl_normalize_binary_exprs(struct hlsl_ctx *ctx, struct hlsl_ir_node *inst } if (!progress && is_op_commutative(op) && e1 - && (tmp = collect_exprs(ctx, instr, op, e1->operands[0].node, arg2))) + && (tmp = collect_exprs(ctx, &block, instr, op, e1->operands[0].node, arg2))) { /* ((x OPL a) OPR y) OPR (x OPL b) -> (x OPL (a OPR b)) OPR y */ arg1 = tmp; @@ -1730,7 +1734,7 @@ bool hlsl_normalize_binary_exprs(struct hlsl_ctx *ctx, struct hlsl_ir_node *inst progress = true; } - if (!progress && e2 && (tmp = collect_exprs(ctx, instr, op, arg1, e2->operands[0].node))) + if (!progress && e2 && (tmp = collect_exprs(ctx, &block, instr, op, arg1, e2->operands[0].node))) { /* (x OPL a) OPR ((x OPL b) OPR y) -> (x OPL (a OPR b)) OPR y */ arg1 = tmp; @@ -1739,7 +1743,7 @@ bool hlsl_normalize_binary_exprs(struct hlsl_ctx *ctx, struct hlsl_ir_node *inst } if (!progress && is_op_commutative(op) && e2 - && (tmp = collect_exprs(ctx, instr, op, arg1, e2->operands[1].node))) + && (tmp = collect_exprs(ctx, &block, instr, op, arg1, e2->operands[1].node))) { /* (x OPL a) OPR (y OPR (x OPL b)) -> (x OPL (a OPR b)) OPR y */ arg1 = tmp; @@ -1754,12 +1758,18 @@ bool hlsl_normalize_binary_exprs(struct hlsl_ctx *ctx, struct hlsl_ir_node *inst struct hlsl_ir_node *res; if (!(res = hlsl_new_expr(ctx, op, operands, instr->data_type, &instr->loc))) - return false; - list_add_before(&instr->entry, &res->entry); + goto fail; + hlsl_block_add_instr(&block, res); + + list_move_before(&instr->entry, &block.instrs); hlsl_replace_node(instr, res); } return progress; + +fail: + hlsl_block_cleanup(&block); + return false; } bool hlsl_fold_constant_swizzles(struct hlsl_ctx *ctx, struct hlsl_ir_node *instr, void *context)