vkd3d-shader/hlsl: Fold some general conditional identities.

The following conditional identities are applied:

  c ? x : x -> x
  false ? x : y -> y; true ? x : y -> x
  c ? true : false -> c; c ? false : true -> !c
  !c ? x : y -> c ? y : x

Lastly, for expression chains x, y in a conditional expression
  c ? x : y,
we evaluate all conditionals in the expression chains with the
condition c, assuming c is true (for x), or false (for y).
This commit is contained in:
Shaun Ren
2025-08-01 19:24:12 -04:00
committed by Henri Verbeet
parent 245430002a
commit 320c3c9652
Notes: Henri Verbeet 2025-08-21 16:34:35 +02:00
Approved-by: Francisco Casas (@fcasas)
Approved-by: Elizabeth Figura (@zfigura)
Approved-by: Henri Verbeet (@hverbeet)
Merge-Request: https://gitlab.winehq.org/wine/vkd3d/-/merge_requests/1648
5 changed files with 352 additions and 74 deletions

View File

@@ -1809,6 +1809,76 @@ struct hlsl_ir_node *hlsl_new_null_constant(struct hlsl_ctx *ctx, const struct v
return hlsl_new_constant(ctx, ctx->builtin_types.null, &value, loc);
}
bool hlsl_constant_is_zero(struct hlsl_ir_constant *c)
{
struct hlsl_type *data_type = c->node.data_type;
unsigned int k;
for (k = 0; k < data_type->e.numeric.dimx; ++k)
{
switch (data_type->e.numeric.type)
{
case HLSL_TYPE_FLOAT:
case HLSL_TYPE_HALF:
if (c->value.u[k].f != 0.0f)
return false;
break;
case HLSL_TYPE_DOUBLE:
if (c->value.u[k].d != 0.0)
return false;
break;
case HLSL_TYPE_UINT:
case HLSL_TYPE_INT:
case HLSL_TYPE_BOOL:
case HLSL_TYPE_MIN16UINT:
if (c->value.u[k].u != 0)
return false;
break;
}
}
return true;
}
bool hlsl_constant_is_one(struct hlsl_ir_constant *c)
{
struct hlsl_type *data_type = c->node.data_type;
unsigned int k;
for (k = 0; k < data_type->e.numeric.dimx; ++k)
{
switch (data_type->e.numeric.type)
{
case HLSL_TYPE_FLOAT:
case HLSL_TYPE_HALF:
if (c->value.u[k].f != 1.0f)
return false;
break;
case HLSL_TYPE_DOUBLE:
if (c->value.u[k].d != 1.0)
return false;
break;
case HLSL_TYPE_UINT:
case HLSL_TYPE_INT:
case HLSL_TYPE_MIN16UINT:
if (c->value.u[k].u != 1)
return false;
break;
case HLSL_TYPE_BOOL:
if (c->value.u[k].u != ~0)
return false;
break;
}
}
return true;
}
static struct hlsl_ir_node *hlsl_new_expr(struct hlsl_ctx *ctx, enum hlsl_ir_expr_op op,
struct hlsl_ir_node *operands[HLSL_MAX_OPERANDS],
struct hlsl_type *data_type, const struct vkd3d_shader_location *loc)

View File

@@ -1699,6 +1699,9 @@ struct hlsl_type *hlsl_new_stream_output_type(struct hlsl_ctx *ctx,
struct hlsl_ir_node *hlsl_new_ternary_expr(struct hlsl_ctx *ctx, enum hlsl_ir_expr_op op,
struct hlsl_ir_node *arg1, struct hlsl_ir_node *arg2, struct hlsl_ir_node *arg3);
bool hlsl_constant_is_zero(struct hlsl_ir_constant *c);
bool hlsl_constant_is_one(struct hlsl_ir_constant *c);
void hlsl_init_simple_deref_from_var(struct hlsl_deref *deref, struct hlsl_ir_var *var);
struct hlsl_ir_load *hlsl_new_var_load(struct hlsl_ctx *ctx, struct hlsl_ir_var *var,

View File

@@ -8387,6 +8387,161 @@ static bool fold_unary_identities(struct hlsl_ctx *ctx, struct hlsl_ir_node *ins
return false;
}
static bool nodes_are_equivalent(const struct hlsl_ir_node *c1, const struct hlsl_ir_node *c2)
{
if (c1 == c2)
return true;
if (c1->type == HLSL_IR_SWIZZLE && c2->type == HLSL_IR_SWIZZLE
&& hlsl_types_are_equal(c1->data_type, c2->data_type))
{
const struct hlsl_ir_swizzle *s1 = hlsl_ir_swizzle(c1), *s2 = hlsl_ir_swizzle(c2);
VKD3D_ASSERT(c1->data_type->class <= HLSL_CLASS_VECTOR);
if (s1->val.node == s2->val.node && s1->u.vector == s2->u.vector)
return true;
}
return false;
}
/* Replaces all conditionals in an expression chain of the form (cond ? x : y)
* with x or y, assuming cond = cond_value. */
static struct hlsl_ir_node *evaluate_conditionals_recurse(struct hlsl_ctx *ctx,
struct hlsl_block *block, const struct hlsl_ir_node *cond, bool cond_value,
struct hlsl_ir_node *instr, const struct vkd3d_shader_location *loc)
{
struct hlsl_ir_node *operands[HLSL_MAX_OPERANDS] = {0};
struct hlsl_ir_expr *expr;
struct hlsl_ir_node *res;
bool progress = false;
unsigned int i;
if (instr->type != HLSL_IR_EXPR)
return NULL;
expr = hlsl_ir_expr(instr);
if (expr->op == HLSL_OP3_TERNARY && nodes_are_equivalent(cond, expr->operands[0].node))
{
struct hlsl_ir_node *x = cond_value ? expr->operands[1].node : expr->operands[2].node;
res = evaluate_conditionals_recurse(ctx, block, cond, cond_value, x, loc);
return res ? res : x;
}
for (i = 0; i < HLSL_MAX_OPERANDS; ++i)
{
if (!expr->operands[i].node)
break;
operands[i] = evaluate_conditionals_recurse(ctx, block, cond, cond_value, expr->operands[i].node, loc);
if (operands[i])
progress = true;
else
operands[i] = expr->operands[i].node;
}
if (progress)
return hlsl_block_add_expr(ctx, block, expr->op, operands, expr->node.data_type, loc);
return NULL;
}
static bool fold_conditional_identities(struct hlsl_ctx *ctx, struct hlsl_ir_node *instr, void *context)
{
struct hlsl_ir_node *c, *x, *y, *res_x, *res_y;
struct hlsl_ir_node *res = NULL;
struct hlsl_ir_expr *expr, *ec;
struct hlsl_block block;
if (instr->type != HLSL_IR_EXPR)
return false;
if (instr->data_type->class > HLSL_CLASS_VECTOR)
return false;
expr = hlsl_ir_expr(instr);
if (expr->op != HLSL_OP3_TERNARY)
return false;
c = expr->operands[0].node;
x = expr->operands[1].node;
y = expr->operands[2].node;
VKD3D_ASSERT(c->data_type->e.numeric.type == HLSL_TYPE_BOOL);
if (nodes_are_equivalent(x, y))
{
/* c ? x : x -> x */
hlsl_replace_node(instr, x);
return true;
}
if (c->type == HLSL_IR_CONSTANT)
{
if (hlsl_constant_is_zero(hlsl_ir_constant(c)))
{
/* false ? x : y -> y */
hlsl_replace_node(instr, y);
return true;
}
if (hlsl_constant_is_one(hlsl_ir_constant(c)))
{
/* true ? x : y -> x */
hlsl_replace_node(instr, x);
return true;
}
}
hlsl_block_init(&block);
if (x->type == HLSL_IR_CONSTANT && y->type == HLSL_IR_CONSTANT
&& hlsl_types_are_equal(c->data_type, x->data_type)
&& hlsl_types_are_equal(c->data_type, y->data_type))
{
if (hlsl_constant_is_one(hlsl_ir_constant(x)) && hlsl_constant_is_zero(hlsl_ir_constant(y)))
{
/* c ? true : false -> c */
res = c;
goto done;
}
if (hlsl_constant_is_zero(hlsl_ir_constant(x)) && hlsl_constant_is_one(hlsl_ir_constant(y)))
{
/* c ? false : true -> !c */
res = hlsl_block_add_unary_expr(ctx, &block, HLSL_OP1_LOGIC_NOT, c, &instr->loc);
goto done;
}
}
ec = c->type == HLSL_IR_EXPR ? hlsl_ir_expr(c) : NULL;
if (ec && ec->op == HLSL_OP1_LOGIC_NOT)
{
/* !c ? x : y -> c ? y : x */
res = hlsl_add_conditional(ctx, &block, ec->operands[0].node, y, x);
goto done;
}
res_x = evaluate_conditionals_recurse(ctx, &block, c, true, x, &instr->loc);
res_y = evaluate_conditionals_recurse(ctx, &block, c, false, y, &instr->loc);
if (res_x || res_y)
res = hlsl_add_conditional(ctx, &block, c, res_x ? res_x : x, res_y ? res_y : y);
done:
if (res)
{
list_move_before(&instr->entry, &block.instrs);
hlsl_replace_node(instr, res);
return true;
}
hlsl_block_cleanup(&block);
return false;
}
static bool simplify_exprs(struct hlsl_ctx *ctx, struct hlsl_block *block)
{
bool progress, any_progress = false;
@@ -8396,6 +8551,7 @@ static bool simplify_exprs(struct hlsl_ctx *ctx, struct hlsl_block *block)
progress = hlsl_transform_ir(ctx, hlsl_fold_constant_exprs, block, NULL);
progress |= hlsl_transform_ir(ctx, hlsl_normalize_binary_exprs, block, NULL);
progress |= hlsl_transform_ir(ctx, fold_unary_identities, block, NULL);
progress |= hlsl_transform_ir(ctx, fold_conditional_identities, block, NULL);
progress |= hlsl_transform_ir(ctx, hlsl_fold_constant_identities, block, NULL);
progress |= hlsl_transform_ir(ctx, hlsl_fold_constant_swizzles, block, NULL);

View File

@@ -1393,74 +1393,6 @@ bool hlsl_fold_constant_exprs(struct hlsl_ctx *ctx, struct hlsl_ir_node *instr,
return success;
}
static bool constant_is_zero(struct hlsl_ir_constant *const_arg)
{
struct hlsl_type *data_type = const_arg->node.data_type;
unsigned int k;
for (k = 0; k < data_type->e.numeric.dimx; ++k)
{
switch (data_type->e.numeric.type)
{
case HLSL_TYPE_FLOAT:
case HLSL_TYPE_HALF:
if (const_arg->value.u[k].f != 0.0f)
return false;
break;
case HLSL_TYPE_DOUBLE:
if (const_arg->value.u[k].d != 0.0)
return false;
break;
case HLSL_TYPE_UINT:
case HLSL_TYPE_INT:
case HLSL_TYPE_BOOL:
case HLSL_TYPE_MIN16UINT:
if (const_arg->value.u[k].u != 0)
return false;
break;
}
}
return true;
}
static bool constant_is_one(struct hlsl_ir_constant *const_arg)
{
struct hlsl_type *data_type = const_arg->node.data_type;
unsigned int k;
for (k = 0; k < data_type->e.numeric.dimx; ++k)
{
switch (data_type->e.numeric.type)
{
case HLSL_TYPE_FLOAT:
case HLSL_TYPE_HALF:
if (const_arg->value.u[k].f != 1.0f)
return false;
break;
case HLSL_TYPE_DOUBLE:
if (const_arg->value.u[k].d != 1.0)
return false;
break;
case HLSL_TYPE_UINT:
case HLSL_TYPE_INT:
case HLSL_TYPE_MIN16UINT:
if (const_arg->value.u[k].u != 1)
return false;
break;
case HLSL_TYPE_BOOL:
if (const_arg->value.u[k].u != ~0)
return false;
break;
}
}
return true;
}
bool hlsl_fold_constant_identities(struct hlsl_ctx *ctx, struct hlsl_ir_node *instr, void *context)
{
struct hlsl_ir_constant *const_arg = NULL;
@@ -1502,26 +1434,26 @@ bool hlsl_fold_constant_identities(struct hlsl_ctx *ctx, struct hlsl_ir_node *in
switch (expr->op)
{
case HLSL_OP2_ADD:
if (constant_is_zero(const_arg))
if (hlsl_constant_is_zero(const_arg))
res_node = mut_arg;
break;
case HLSL_OP2_MUL:
if (constant_is_one(const_arg))
if (hlsl_constant_is_one(const_arg))
res_node = mut_arg;
break;
case HLSL_OP2_LOGIC_AND:
if (constant_is_zero(const_arg))
if (hlsl_constant_is_zero(const_arg))
res_node = &const_arg->node;
else if (constant_is_one(const_arg))
else if (hlsl_constant_is_one(const_arg))
res_node = mut_arg;
break;
case HLSL_OP2_LOGIC_OR:
if (constant_is_zero(const_arg))
if (hlsl_constant_is_zero(const_arg))
res_node = mut_arg;
else if (constant_is_one(const_arg))
else if (hlsl_constant_is_one(const_arg))
res_node = &const_arg->node;
break;

View File

@@ -160,6 +160,123 @@ float4 main() : sv_target
draw quad
probe (0, 0) rgba (1.0, 6.0, 7.0, 4.0)
[pixel shader]
static float4 cond = {1, 0, 0, 1};
uniform float4 a;
float4 main() : sv_target
{
return cond ? a : a;
}
[test]
uniform 0 float4 1.0 2.0 3.0 4.0
draw quad
probe (0, 0) f32(1.0, 2.0, 3.0, 4.0)
[pixel shader]
uniform float4 a;
float4 main() : sv_target
{
return float4(false ? a.x : a.y, true ? a.x : a.y, 0, 0);
}
[test]
uniform 0 float4 1.0 2.0 0.0 0.0
draw quad
probe (0, 0) f32(2.0, 1.0, 0.0, 0.0)
[pixel shader todo(sm<4)]
uniform float c;
float4 main() : sv_target
{
bool cond = c >= 0;
return float4(cond ? true : false, cond ? false : true, 0, 0);
}
[test]
uniform 0 float -1.0
todo(sm<4) draw quad
probe (0, 0) f32(0.0, 1.0, 0.0, 0.0)
[pixel shader]
uniform bool cond;
uniform float4 a, b;
float4 main() : sv_target
{
return !cond ? a : b;
}
[test]
if(sm<4) uniform 0 float 0.0
if(sm>=4) uniform 0 uint 0
uniform 4 float4 1.0 2.0 3.0 4.0
uniform 8 float4 5.0 6.0 7.0 8.0
draw quad
probe (0, 0) f32(1.0, 2.0, 3.0, 4.0)
[pixel shader]
uniform float4 a;
float4 main() : sv_target
{
bool cond1 = a.x > 0, cond2 = a.x > 2;
float4 ret;
ret.x = cond1 ? (cond1 ? a.x : a.y) : a.z;
ret.y = cond1 ? a.x : (cond1 ? a.y : a.z);
ret.z = cond2 ? (cond2 ? a.x : a.y) : a.z;
ret.w = cond2 ? a.x : (cond2 ? a.y : a.z);
return ret;
}
[test]
uniform 0 float4 1.0 2.0 3.0 0.0
draw quad
probe (0, 0) f32(1.0, 1.0, 3.0, 3.0)
[pixel shader]
uniform float a;
float4 main() : sv_target
{
bool cond1 = a > 0, cond2 = a > 2;
float4 ret;
ret.xy = cond1 ? (cond1 ? float2(0, 1) : float2(1, 0)) : float2(1, 1);
ret.zw = cond2 ? float2(0, 2) : (cond2 ? float2(2, 0) : float2(2, 2));
return ret;
}
[test]
uniform 0 float 1.0
draw quad
probe (0, 0) f32(0.0, 1.0, 2.0, 2.0)
[pixel shader]
uniform float4 a;
float4 main() : sv_target
{
bool cond1 = a.x > 2, cond2 = a.x > 0, cond3 = a.y > 3;
float b = cond1 ? a.x : a.y;
float c = cond2 ? a.z : b;
float d = cond3 ? a.w : c + 10;
float e = cond1 ? a.x : d;
return e;
}
[test]
uniform 0 float4 1.0 2.0 3.0 4.0
draw quad
probe (0, 0) f32(13.0, 13.0, 13.0, 13.0)
[pixel shader fail]