diff --git a/libs/vkd3d-shader/hlsl.h b/libs/vkd3d-shader/hlsl.h index c6591478..0a20acd9 100644 --- a/libs/vkd3d-shader/hlsl.h +++ b/libs/vkd3d-shader/hlsl.h @@ -1642,6 +1642,7 @@ struct hlsl_reg hlsl_reg_from_deref(struct hlsl_ctx *ctx, const struct hlsl_dere bool hlsl_copy_propagation_execute(struct hlsl_ctx *ctx, struct hlsl_block *block); bool hlsl_fold_constant_exprs(struct hlsl_ctx *ctx, struct hlsl_ir_node *instr, void *context); bool hlsl_fold_constant_identities(struct hlsl_ctx *ctx, struct hlsl_ir_node *instr, void *context); +bool hlsl_normalize_binary_exprs(struct hlsl_ctx *ctx, struct hlsl_ir_node *instr, void *context); bool hlsl_fold_constant_swizzles(struct hlsl_ctx *ctx, struct hlsl_ir_node *instr, void *context); bool hlsl_transform_ir(struct hlsl_ctx *ctx, bool (*func)(struct hlsl_ctx *ctx, struct hlsl_ir_node *, void *), struct hlsl_block *block, void *context); diff --git a/libs/vkd3d-shader/hlsl_codegen.c b/libs/vkd3d-shader/hlsl_codegen.c index e6924aa7..e28f4512 100644 --- a/libs/vkd3d-shader/hlsl_codegen.c +++ b/libs/vkd3d-shader/hlsl_codegen.c @@ -6569,6 +6569,7 @@ void hlsl_run_const_passes(struct hlsl_ctx *ctx, struct hlsl_block *body) { progress = hlsl_transform_ir(ctx, hlsl_fold_constant_exprs, body, NULL); progress |= hlsl_transform_ir(ctx, hlsl_fold_constant_identities, body, NULL); + progress |= hlsl_transform_ir(ctx, hlsl_normalize_binary_exprs, body, NULL); progress |= hlsl_transform_ir(ctx, hlsl_fold_constant_swizzles, body, NULL); progress |= hlsl_copy_propagation_execute(ctx, body); progress |= hlsl_transform_ir(ctx, fold_swizzle_chains, body, NULL); diff --git a/libs/vkd3d-shader/hlsl_constant_ops.c b/libs/vkd3d-shader/hlsl_constant_ops.c index 716adb15..584038da 100644 --- a/libs/vkd3d-shader/hlsl_constant_ops.c +++ b/libs/vkd3d-shader/hlsl_constant_ops.c @@ -1544,6 +1544,149 @@ bool hlsl_fold_constant_identities(struct hlsl_ctx *ctx, struct hlsl_ir_node *in return false; } +static bool is_op_associative(enum hlsl_ir_expr_op op, enum hlsl_base_type type) +{ + switch (op) + { + case HLSL_OP2_ADD: + case HLSL_OP2_MUL: + return type == HLSL_TYPE_INT || type == HLSL_TYPE_UINT; + + case HLSL_OP2_BIT_AND: + case HLSL_OP2_BIT_OR: + case HLSL_OP2_BIT_XOR: + case HLSL_OP2_LOGIC_AND: + case HLSL_OP2_LOGIC_OR: + case HLSL_OP2_MAX: + case HLSL_OP2_MIN: + return true; + + default: + return false; + } +} + +static bool is_op_commutative(enum hlsl_ir_expr_op op) +{ + switch (op) + { + case HLSL_OP2_ADD: + case HLSL_OP2_BIT_AND: + case HLSL_OP2_BIT_OR: + case HLSL_OP2_BIT_XOR: + case HLSL_OP2_DOT: + case HLSL_OP2_LOGIC_AND: + case HLSL_OP2_LOGIC_OR: + case HLSL_OP2_MAX: + case HLSL_OP2_MIN: + case HLSL_OP2_MUL: + return true; + + default: + return false; + } +} + +bool hlsl_normalize_binary_exprs(struct hlsl_ctx *ctx, struct hlsl_ir_node *instr, void *context) +{ + struct hlsl_ir_node *arg1 , *arg2; + struct hlsl_ir_expr *expr; + enum hlsl_base_type type; + enum hlsl_ir_expr_op op; + bool progress = false; + + if (instr->type != HLSL_IR_EXPR) + return false; + expr = hlsl_ir_expr(instr); + + if (instr->data_type->class > HLSL_CLASS_VECTOR) + return false; + + arg1 = expr->operands[0].node; + arg2 = expr->operands[1].node; + type = instr->data_type->e.numeric.type; + op = expr->op; + + if (!arg1 || !arg2) + return false; + + if (is_op_commutative(op) && arg1->type == HLSL_IR_CONSTANT && arg2->type != HLSL_IR_CONSTANT) + { + /* a OP x -> x OP a */ + struct hlsl_ir_node *tmp = arg1; + + arg1 = arg2; + arg2 = tmp; + progress = true; + } + + if (is_op_associative(op, type)) + { + struct hlsl_ir_expr *e1 = arg1->type == HLSL_IR_EXPR ? hlsl_ir_expr(arg1) : NULL; + struct hlsl_ir_expr *e2 = arg2->type == HLSL_IR_EXPR ? hlsl_ir_expr(arg2) : NULL; + + if (e1 && e1->op == op && e1->operands[0].node->type != HLSL_IR_CONSTANT + && e1->operands[1].node->type == HLSL_IR_CONSTANT) + { + if (arg2->type == HLSL_IR_CONSTANT) + { + /* (x OP a) OP b -> x OP (a OP b) */ + struct hlsl_ir_node *ab; + + if (!(ab = hlsl_new_binary_expr(ctx, op, e1->operands[1].node, arg2))) + return false; + list_add_before(&instr->entry, &ab->entry); + + arg1 = e1->operands[0].node; + arg2 = ab; + progress = true; + } + else if (is_op_commutative(op)) + { + /* (x OP a) OP y -> (x OP y) OP a */ + struct hlsl_ir_node *xy; + + if (!(xy = hlsl_new_binary_expr(ctx, op, e1->operands[0].node, arg2))) + return false; + list_add_before(&instr->entry, &xy->entry); + + arg1 = xy; + arg2 = e1->operands[1].node; + progress = true; + } + } + + if (!progress && arg1->type != HLSL_IR_CONSTANT && e2 && e2->op == op + && e2->operands[0].node->type != HLSL_IR_CONSTANT && e2->operands[1].node->type == HLSL_IR_CONSTANT) + { + /* x OP (y OP a) -> (x OP y) OP a */ + struct hlsl_ir_node *xy; + + if (!(xy = hlsl_new_binary_expr(ctx, op, arg1, e2->operands[0].node))) + return false; + list_add_before(&instr->entry, &xy->entry); + + arg1 = xy; + arg2 = e2->operands[1].node; + progress = true; + } + + } + + if (progress) + { + struct hlsl_ir_node *operands[HLSL_MAX_OPERANDS] = {arg1, arg2}; + struct hlsl_ir_node *res; + + if (!(res = hlsl_new_expr(ctx, op, operands, instr->data_type, &instr->loc))) + return false; + list_add_before(&instr->entry, &res->entry); + hlsl_replace_node(instr, res); + } + + return progress; +} + bool hlsl_fold_constant_swizzles(struct hlsl_ctx *ctx, struct hlsl_ir_node *instr, void *context) { struct hlsl_constant_value value; diff --git a/tests/hlsl/arithmetic-uint.shader_test b/tests/hlsl/arithmetic-uint.shader_test index 9c3ad646..31d86f18 100644 --- a/tests/hlsl/arithmetic-uint.shader_test +++ b/tests/hlsl/arithmetic-uint.shader_test @@ -52,3 +52,19 @@ float4 main() : SV_TARGET [test] draw quad probe (0, 0) rgba (0.0, 0.0, 0.0, 0.0) + + +% Test expression normalization and simplification. + +[pixel shader] +uniform uint4 x; + +float4 main() : SV_TARGET +{ + return 6 + (2 * x * 3) - 5; +} + +[test] +uniform 0 uint4 0 1 2 3 +todo(msl) draw quad +probe (0, 0) rgba (1.0, 7.0, 13.0, 19.0)