diff --git a/libs/vkd3d-shader/hlsl_codegen.c b/libs/vkd3d-shader/hlsl_codegen.c index 9eb65dc0..09c21f16 100644 --- a/libs/vkd3d-shader/hlsl_codegen.c +++ b/libs/vkd3d-shader/hlsl_codegen.c @@ -3020,6 +3020,127 @@ static bool lower_ternary(struct hlsl_ctx *ctx, struct hlsl_ir_node *instr, stru return true; } +static bool lower_comparison_operators(struct hlsl_ctx *ctx, struct hlsl_ir_node *instr, + struct hlsl_block *block) +{ + struct hlsl_ir_node *arg1, *arg1_cast, *arg2, *arg2_cast, *slt, *res, *ret; + struct hlsl_ir_node *operands[HLSL_MAX_OPERANDS]; + struct hlsl_type *float_type; + struct hlsl_ir_expr *expr; + bool negate = false; + + if (instr->type != HLSL_IR_EXPR) + return false; + expr = hlsl_ir_expr(instr); + if (expr->op != HLSL_OP2_EQUAL && expr->op != HLSL_OP2_NEQUAL && expr->op != HLSL_OP2_LESS + && expr->op != HLSL_OP2_GEQUAL) + return false; + + arg1 = expr->operands[0].node; + arg2 = expr->operands[1].node; + float_type = hlsl_get_vector_type(ctx, HLSL_TYPE_FLOAT, instr->data_type->dimx); + + if (!(arg1_cast = hlsl_new_cast(ctx, arg1, float_type, &instr->loc))) + return false; + hlsl_block_add_instr(block, arg1_cast); + + if (!(arg2_cast = hlsl_new_cast(ctx, arg2, float_type, &instr->loc))) + return false; + hlsl_block_add_instr(block, arg2_cast); + + switch (expr->op) + { + case HLSL_OP2_EQUAL: + case HLSL_OP2_NEQUAL: + { + struct hlsl_ir_node *neg, *sub, *abs, *abs_neg; + + if (!(neg = hlsl_new_unary_expr(ctx, HLSL_OP1_NEG, arg2_cast, &instr->loc))) + return false; + hlsl_block_add_instr(block, neg); + + if (!(sub = hlsl_new_binary_expr(ctx, HLSL_OP2_ADD, arg1_cast, neg))) + return false; + hlsl_block_add_instr(block, sub); + + if (ctx->profile->major_version >= 3) + { + if (!(abs = hlsl_new_unary_expr(ctx, HLSL_OP1_ABS, sub, &instr->loc))) + return false; + hlsl_block_add_instr(block, abs); + } + else + { + /* Use MUL as a precarious ABS. */ + if (!(abs = hlsl_new_binary_expr(ctx, HLSL_OP2_MUL, sub, sub))) + return false; + hlsl_block_add_instr(block, abs); + } + + if (!(abs_neg = hlsl_new_unary_expr(ctx, HLSL_OP1_NEG, abs, &instr->loc))) + return false; + hlsl_block_add_instr(block, abs_neg); + + if (!(slt = hlsl_new_binary_expr(ctx, HLSL_OP2_SLT, abs_neg, abs))) + return false; + hlsl_block_add_instr(block, slt); + + negate = (expr->op == HLSL_OP2_EQUAL); + break; + } + + case HLSL_OP2_GEQUAL: + case HLSL_OP2_LESS: + { + if (!(slt = hlsl_new_binary_expr(ctx, HLSL_OP2_SLT, arg1_cast, arg2_cast))) + return false; + hlsl_block_add_instr(block, slt); + + negate = (expr->op == HLSL_OP2_GEQUAL); + break; + } + + default: + vkd3d_unreachable(); + } + + if (negate) + { + struct hlsl_constant_value one_value; + struct hlsl_ir_node *one, *slt_neg; + + one_value.u[0].f = 1.0; + one_value.u[1].f = 1.0; + one_value.u[2].f = 1.0; + one_value.u[3].f = 1.0; + if (!(one = hlsl_new_constant(ctx, float_type, &one_value, &instr->loc))) + return false; + hlsl_block_add_instr(block, one); + + if (!(slt_neg = hlsl_new_unary_expr(ctx, HLSL_OP1_NEG, slt, &instr->loc))) + return false; + hlsl_block_add_instr(block, slt_neg); + + if (!(res = hlsl_new_binary_expr(ctx, HLSL_OP2_ADD, one, slt_neg))) + return false; + hlsl_block_add_instr(block, res); + } + else + { + res = slt; + } + + /* We need a REINTERPRET so that the HLSL IR code is valid. SLT and its arguments must be FLOAT, + * and casts to BOOL have already been lowered to "!= 0". */ + memset(operands, 0, sizeof(operands)); + operands[0] = res; + if (!(ret = hlsl_new_expr(ctx, HLSL_OP1_REINTERPRET, operands, instr->data_type, &instr->loc))) + return false; + hlsl_block_add_instr(block, ret); + + return true; +} + static bool lower_casts_to_bool(struct hlsl_ctx *ctx, struct hlsl_ir_node *instr, struct hlsl_block *block) { struct hlsl_type *type = instr->data_type, *arg_type; @@ -5209,6 +5330,7 @@ int hlsl_emit_bytecode(struct hlsl_ctx *ctx, struct hlsl_ir_function_decl *entry lower_ir(ctx, lower_round, body); lower_ir(ctx, lower_ceil, body); lower_ir(ctx, lower_floor, body); + lower_ir(ctx, lower_comparison_operators, body); } if (profile->major_version < 2) diff --git a/tests/hlsl/vertex-shader-ops.shader_test b/tests/hlsl/vertex-shader-ops.shader_test index 34d49672..38f3db65 100644 --- a/tests/hlsl/vertex-shader-ops.shader_test +++ b/tests/hlsl/vertex-shader-ops.shader_test @@ -15,7 +15,7 @@ float4 main(in float4 res : COLOR1) : sv_target % Check that -0.0f is not less than 0.0f -[vertex shader todo(sm<4)] +[vertex shader] float a; void main(out float4 res : COLOR1, in float4 pos : position, out float4 out_pos : sv_position) @@ -28,11 +28,11 @@ void main(out float4 res : COLOR1, in float4 pos : position, out float4 out_pos [test] if(sm<4) uniform 0 float 0.0 if(sm>=4) uniform 0 float4 0.0 0.0 0.0 0.0 -todo(sm<4) draw quad +draw quad probe all rgba (0.0, 0.0, 0.0, 0.0) -[vertex shader todo(sm<4)] +[vertex shader] int a, b; void main(out float4 res : COLOR1, in float4 pos : position, out float4 out_pos : sv_position) @@ -49,12 +49,12 @@ void main(out float4 res : COLOR1, in float4 pos : position, out float4 out_pos if(sm<4) uniform 0 float 3 if(sm<4) uniform 4 float 4 if(sm>=4) uniform 0 int4 3 4 0 0 -todo(sm<4) draw quad +draw quad probe all rgba (0.0, 1.0, 0.0, 1.0) if(sm<4) uniform 0 float -2 if(sm<4) uniform 4 float -2 if(sm>=4) uniform 0 int4 -2 -2 0 0 -todo(sm<4) draw quad +draw quad probe all rgba (1.0, 0.0, 0.0, 1.0)