diff --git a/libs/vkd3d-shader/hlsl.c b/libs/vkd3d-shader/hlsl.c index 62335086e..d8eb18e39 100644 --- a/libs/vkd3d-shader/hlsl.c +++ b/libs/vkd3d-shader/hlsl.c @@ -1809,6 +1809,76 @@ struct hlsl_ir_node *hlsl_new_null_constant(struct hlsl_ctx *ctx, const struct v return hlsl_new_constant(ctx, ctx->builtin_types.null, &value, loc); } +bool hlsl_constant_is_zero(struct hlsl_ir_constant *c) +{ + struct hlsl_type *data_type = c->node.data_type; + unsigned int k; + + for (k = 0; k < data_type->e.numeric.dimx; ++k) + { + switch (data_type->e.numeric.type) + { + case HLSL_TYPE_FLOAT: + case HLSL_TYPE_HALF: + if (c->value.u[k].f != 0.0f) + return false; + break; + + case HLSL_TYPE_DOUBLE: + if (c->value.u[k].d != 0.0) + return false; + break; + + case HLSL_TYPE_UINT: + case HLSL_TYPE_INT: + case HLSL_TYPE_BOOL: + case HLSL_TYPE_MIN16UINT: + if (c->value.u[k].u != 0) + return false; + break; + } + } + + return true; +} + +bool hlsl_constant_is_one(struct hlsl_ir_constant *c) +{ + struct hlsl_type *data_type = c->node.data_type; + unsigned int k; + + for (k = 0; k < data_type->e.numeric.dimx; ++k) + { + switch (data_type->e.numeric.type) + { + case HLSL_TYPE_FLOAT: + case HLSL_TYPE_HALF: + if (c->value.u[k].f != 1.0f) + return false; + break; + + case HLSL_TYPE_DOUBLE: + if (c->value.u[k].d != 1.0) + return false; + break; + + case HLSL_TYPE_UINT: + case HLSL_TYPE_INT: + case HLSL_TYPE_MIN16UINT: + if (c->value.u[k].u != 1) + return false; + break; + + case HLSL_TYPE_BOOL: + if (c->value.u[k].u != ~0) + return false; + break; + } + } + + return true; +} + static struct hlsl_ir_node *hlsl_new_expr(struct hlsl_ctx *ctx, enum hlsl_ir_expr_op op, struct hlsl_ir_node *operands[HLSL_MAX_OPERANDS], struct hlsl_type *data_type, const struct vkd3d_shader_location *loc) diff --git a/libs/vkd3d-shader/hlsl.h b/libs/vkd3d-shader/hlsl.h index 0cb5d1b90..762da0f87 100644 --- a/libs/vkd3d-shader/hlsl.h +++ b/libs/vkd3d-shader/hlsl.h @@ -1699,6 +1699,9 @@ struct hlsl_type *hlsl_new_stream_output_type(struct hlsl_ctx *ctx, struct hlsl_ir_node *hlsl_new_ternary_expr(struct hlsl_ctx *ctx, enum hlsl_ir_expr_op op, struct hlsl_ir_node *arg1, struct hlsl_ir_node *arg2, struct hlsl_ir_node *arg3); +bool hlsl_constant_is_zero(struct hlsl_ir_constant *c); +bool hlsl_constant_is_one(struct hlsl_ir_constant *c); + void hlsl_init_simple_deref_from_var(struct hlsl_deref *deref, struct hlsl_ir_var *var); struct hlsl_ir_load *hlsl_new_var_load(struct hlsl_ctx *ctx, struct hlsl_ir_var *var, diff --git a/libs/vkd3d-shader/hlsl_codegen.c b/libs/vkd3d-shader/hlsl_codegen.c index 9d02f831c..d17883439 100644 --- a/libs/vkd3d-shader/hlsl_codegen.c +++ b/libs/vkd3d-shader/hlsl_codegen.c @@ -8387,6 +8387,161 @@ static bool fold_unary_identities(struct hlsl_ctx *ctx, struct hlsl_ir_node *ins return false; } +static bool nodes_are_equivalent(const struct hlsl_ir_node *c1, const struct hlsl_ir_node *c2) +{ + if (c1 == c2) + return true; + + if (c1->type == HLSL_IR_SWIZZLE && c2->type == HLSL_IR_SWIZZLE + && hlsl_types_are_equal(c1->data_type, c2->data_type)) + { + const struct hlsl_ir_swizzle *s1 = hlsl_ir_swizzle(c1), *s2 = hlsl_ir_swizzle(c2); + + VKD3D_ASSERT(c1->data_type->class <= HLSL_CLASS_VECTOR); + + if (s1->val.node == s2->val.node && s1->u.vector == s2->u.vector) + return true; + } + + return false; +} + +/* Replaces all conditionals in an expression chain of the form (cond ? x : y) + * with x or y, assuming cond = cond_value. */ +static struct hlsl_ir_node *evaluate_conditionals_recurse(struct hlsl_ctx *ctx, + struct hlsl_block *block, const struct hlsl_ir_node *cond, bool cond_value, + struct hlsl_ir_node *instr, const struct vkd3d_shader_location *loc) +{ + struct hlsl_ir_node *operands[HLSL_MAX_OPERANDS] = {0}; + struct hlsl_ir_expr *expr; + struct hlsl_ir_node *res; + bool progress = false; + unsigned int i; + + if (instr->type != HLSL_IR_EXPR) + return NULL; + expr = hlsl_ir_expr(instr); + + if (expr->op == HLSL_OP3_TERNARY && nodes_are_equivalent(cond, expr->operands[0].node)) + { + struct hlsl_ir_node *x = cond_value ? expr->operands[1].node : expr->operands[2].node; + + res = evaluate_conditionals_recurse(ctx, block, cond, cond_value, x, loc); + return res ? res : x; + } + + for (i = 0; i < HLSL_MAX_OPERANDS; ++i) + { + if (!expr->operands[i].node) + break; + + operands[i] = evaluate_conditionals_recurse(ctx, block, cond, cond_value, expr->operands[i].node, loc); + + if (operands[i]) + progress = true; + else + operands[i] = expr->operands[i].node; + } + + if (progress) + return hlsl_block_add_expr(ctx, block, expr->op, operands, expr->node.data_type, loc); + + return NULL; +} + +static bool fold_conditional_identities(struct hlsl_ctx *ctx, struct hlsl_ir_node *instr, void *context) +{ + struct hlsl_ir_node *c, *x, *y, *res_x, *res_y; + struct hlsl_ir_node *res = NULL; + struct hlsl_ir_expr *expr, *ec; + struct hlsl_block block; + + if (instr->type != HLSL_IR_EXPR) + return false; + + if (instr->data_type->class > HLSL_CLASS_VECTOR) + return false; + + expr = hlsl_ir_expr(instr); + if (expr->op != HLSL_OP3_TERNARY) + return false; + + c = expr->operands[0].node; + x = expr->operands[1].node; + y = expr->operands[2].node; + + VKD3D_ASSERT(c->data_type->e.numeric.type == HLSL_TYPE_BOOL); + + if (nodes_are_equivalent(x, y)) + { + /* c ? x : x -> x */ + hlsl_replace_node(instr, x); + return true; + } + + if (c->type == HLSL_IR_CONSTANT) + { + if (hlsl_constant_is_zero(hlsl_ir_constant(c))) + { + /* false ? x : y -> y */ + hlsl_replace_node(instr, y); + return true; + } + + if (hlsl_constant_is_one(hlsl_ir_constant(c))) + { + /* true ? x : y -> x */ + hlsl_replace_node(instr, x); + return true; + } + } + + hlsl_block_init(&block); + + if (x->type == HLSL_IR_CONSTANT && y->type == HLSL_IR_CONSTANT + && hlsl_types_are_equal(c->data_type, x->data_type) + && hlsl_types_are_equal(c->data_type, y->data_type)) + { + if (hlsl_constant_is_one(hlsl_ir_constant(x)) && hlsl_constant_is_zero(hlsl_ir_constant(y))) + { + /* c ? true : false -> c */ + res = c; + goto done; + } + + if (hlsl_constant_is_zero(hlsl_ir_constant(x)) && hlsl_constant_is_one(hlsl_ir_constant(y))) + { + /* c ? false : true -> !c */ + res = hlsl_block_add_unary_expr(ctx, &block, HLSL_OP1_LOGIC_NOT, c, &instr->loc); + goto done; + } + } + + ec = c->type == HLSL_IR_EXPR ? hlsl_ir_expr(c) : NULL; + if (ec && ec->op == HLSL_OP1_LOGIC_NOT) + { + /* !c ? x : y -> c ? y : x */ + res = hlsl_add_conditional(ctx, &block, ec->operands[0].node, y, x); + goto done; + } + + res_x = evaluate_conditionals_recurse(ctx, &block, c, true, x, &instr->loc); + res_y = evaluate_conditionals_recurse(ctx, &block, c, false, y, &instr->loc); + if (res_x || res_y) + res = hlsl_add_conditional(ctx, &block, c, res_x ? res_x : x, res_y ? res_y : y); + +done: + if (res) + { + list_move_before(&instr->entry, &block.instrs); + hlsl_replace_node(instr, res); + return true; + } + + hlsl_block_cleanup(&block); + return false; +} + static bool simplify_exprs(struct hlsl_ctx *ctx, struct hlsl_block *block) { bool progress, any_progress = false; @@ -8396,6 +8551,7 @@ static bool simplify_exprs(struct hlsl_ctx *ctx, struct hlsl_block *block) progress = hlsl_transform_ir(ctx, hlsl_fold_constant_exprs, block, NULL); progress |= hlsl_transform_ir(ctx, hlsl_normalize_binary_exprs, block, NULL); progress |= hlsl_transform_ir(ctx, fold_unary_identities, block, NULL); + progress |= hlsl_transform_ir(ctx, fold_conditional_identities, block, NULL); progress |= hlsl_transform_ir(ctx, hlsl_fold_constant_identities, block, NULL); progress |= hlsl_transform_ir(ctx, hlsl_fold_constant_swizzles, block, NULL); diff --git a/libs/vkd3d-shader/hlsl_constant_ops.c b/libs/vkd3d-shader/hlsl_constant_ops.c index 7e9410e06..4cd47a063 100644 --- a/libs/vkd3d-shader/hlsl_constant_ops.c +++ b/libs/vkd3d-shader/hlsl_constant_ops.c @@ -1393,74 +1393,6 @@ bool hlsl_fold_constant_exprs(struct hlsl_ctx *ctx, struct hlsl_ir_node *instr, return success; } -static bool constant_is_zero(struct hlsl_ir_constant *const_arg) -{ - struct hlsl_type *data_type = const_arg->node.data_type; - unsigned int k; - - for (k = 0; k < data_type->e.numeric.dimx; ++k) - { - switch (data_type->e.numeric.type) - { - case HLSL_TYPE_FLOAT: - case HLSL_TYPE_HALF: - if (const_arg->value.u[k].f != 0.0f) - return false; - break; - - case HLSL_TYPE_DOUBLE: - if (const_arg->value.u[k].d != 0.0) - return false; - break; - - case HLSL_TYPE_UINT: - case HLSL_TYPE_INT: - case HLSL_TYPE_BOOL: - case HLSL_TYPE_MIN16UINT: - if (const_arg->value.u[k].u != 0) - return false; - break; - } - } - return true; -} - -static bool constant_is_one(struct hlsl_ir_constant *const_arg) -{ - struct hlsl_type *data_type = const_arg->node.data_type; - unsigned int k; - - for (k = 0; k < data_type->e.numeric.dimx; ++k) - { - switch (data_type->e.numeric.type) - { - case HLSL_TYPE_FLOAT: - case HLSL_TYPE_HALF: - if (const_arg->value.u[k].f != 1.0f) - return false; - break; - - case HLSL_TYPE_DOUBLE: - if (const_arg->value.u[k].d != 1.0) - return false; - break; - - case HLSL_TYPE_UINT: - case HLSL_TYPE_INT: - case HLSL_TYPE_MIN16UINT: - if (const_arg->value.u[k].u != 1) - return false; - break; - - case HLSL_TYPE_BOOL: - if (const_arg->value.u[k].u != ~0) - return false; - break; - } - } - return true; -} - bool hlsl_fold_constant_identities(struct hlsl_ctx *ctx, struct hlsl_ir_node *instr, void *context) { struct hlsl_ir_constant *const_arg = NULL; @@ -1502,26 +1434,26 @@ bool hlsl_fold_constant_identities(struct hlsl_ctx *ctx, struct hlsl_ir_node *in switch (expr->op) { case HLSL_OP2_ADD: - if (constant_is_zero(const_arg)) + if (hlsl_constant_is_zero(const_arg)) res_node = mut_arg; break; case HLSL_OP2_MUL: - if (constant_is_one(const_arg)) + if (hlsl_constant_is_one(const_arg)) res_node = mut_arg; break; case HLSL_OP2_LOGIC_AND: - if (constant_is_zero(const_arg)) + if (hlsl_constant_is_zero(const_arg)) res_node = &const_arg->node; - else if (constant_is_one(const_arg)) + else if (hlsl_constant_is_one(const_arg)) res_node = mut_arg; break; case HLSL_OP2_LOGIC_OR: - if (constant_is_zero(const_arg)) + if (hlsl_constant_is_zero(const_arg)) res_node = mut_arg; - else if (constant_is_one(const_arg)) + else if (hlsl_constant_is_one(const_arg)) res_node = &const_arg->node; break; diff --git a/tests/hlsl/ternary.shader_test b/tests/hlsl/ternary.shader_test index fd1b3efd3..1d75c858c 100644 --- a/tests/hlsl/ternary.shader_test +++ b/tests/hlsl/ternary.shader_test @@ -160,6 +160,123 @@ float4 main() : sv_target draw quad probe (0, 0) rgba (1.0, 6.0, 7.0, 4.0) +[pixel shader] +static float4 cond = {1, 0, 0, 1}; +uniform float4 a; + +float4 main() : sv_target +{ + return cond ? a : a; +} + +[test] +uniform 0 float4 1.0 2.0 3.0 4.0 +draw quad +probe (0, 0) f32(1.0, 2.0, 3.0, 4.0) + +[pixel shader] +uniform float4 a; + +float4 main() : sv_target +{ + return float4(false ? a.x : a.y, true ? a.x : a.y, 0, 0); +} + +[test] +uniform 0 float4 1.0 2.0 0.0 0.0 +draw quad +probe (0, 0) f32(2.0, 1.0, 0.0, 0.0) + +[pixel shader todo(sm<4)] +uniform float c; + +float4 main() : sv_target +{ + bool cond = c >= 0; + return float4(cond ? true : false, cond ? false : true, 0, 0); +} + +[test] +uniform 0 float -1.0 +todo(sm<4) draw quad +probe (0, 0) f32(0.0, 1.0, 0.0, 0.0) + +[pixel shader] +uniform bool cond; +uniform float4 a, b; + +float4 main() : sv_target +{ + return !cond ? a : b; +} + +[test] +if(sm<4) uniform 0 float 0.0 +if(sm>=4) uniform 0 uint 0 +uniform 4 float4 1.0 2.0 3.0 4.0 +uniform 8 float4 5.0 6.0 7.0 8.0 +draw quad +probe (0, 0) f32(1.0, 2.0, 3.0, 4.0) + +[pixel shader] +uniform float4 a; + +float4 main() : sv_target +{ + bool cond1 = a.x > 0, cond2 = a.x > 2; + float4 ret; + + ret.x = cond1 ? (cond1 ? a.x : a.y) : a.z; + ret.y = cond1 ? a.x : (cond1 ? a.y : a.z); + ret.z = cond2 ? (cond2 ? a.x : a.y) : a.z; + ret.w = cond2 ? a.x : (cond2 ? a.y : a.z); + + return ret; +} + +[test] +uniform 0 float4 1.0 2.0 3.0 0.0 +draw quad +probe (0, 0) f32(1.0, 1.0, 3.0, 3.0) + +[pixel shader] +uniform float a; + +float4 main() : sv_target +{ + bool cond1 = a > 0, cond2 = a > 2; + float4 ret; + + ret.xy = cond1 ? (cond1 ? float2(0, 1) : float2(1, 0)) : float2(1, 1); + ret.zw = cond2 ? float2(0, 2) : (cond2 ? float2(2, 0) : float2(2, 2)); + + return ret; +} + +[test] +uniform 0 float 1.0 +draw quad +probe (0, 0) f32(0.0, 1.0, 2.0, 2.0) + +[pixel shader] +uniform float4 a; + +float4 main() : sv_target +{ + bool cond1 = a.x > 2, cond2 = a.x > 0, cond3 = a.y > 3; + + float b = cond1 ? a.x : a.y; + float c = cond2 ? a.z : b; + float d = cond3 ? a.w : c + 10; + float e = cond1 ? a.x : d; + + return e; +} + +[test] +uniform 0 float4 1.0 2.0 3.0 4.0 +draw quad +probe (0, 0) f32(13.0, 13.0, 13.0, 13.0) [pixel shader fail]