diff --git a/libs/vkd3d-shader/hlsl_codegen.c b/libs/vkd3d-shader/hlsl_codegen.c index abcead82..3df2f74b 100644 --- a/libs/vkd3d-shader/hlsl_codegen.c +++ b/libs/vkd3d-shader/hlsl_codegen.c @@ -3335,8 +3335,10 @@ static bool lower_casts_to_int(struct hlsl_ctx *ctx, struct hlsl_ir_node *instr, /* Lower DIV to RCP + MUL. */ static bool lower_division(struct hlsl_ctx *ctx, struct hlsl_ir_node *instr, struct hlsl_block *block) { + struct hlsl_ir_node *rcp, *ret, *operands[2]; + struct hlsl_type *float_type; struct hlsl_ir_expr *expr; - struct hlsl_ir_node *rcp; + bool is_float; if (instr->type != HLSL_IR_EXPR) return false; @@ -3344,8 +3346,22 @@ static bool lower_division(struct hlsl_ctx *ctx, struct hlsl_ir_node *instr, str if (expr->op != HLSL_OP2_DIV) return false; - rcp = hlsl_block_add_unary_expr(ctx, block, HLSL_OP1_RCP, expr->operands[1].node, &instr->loc); - hlsl_block_add_binary_expr(ctx, block, HLSL_OP2_MUL, expr->operands[0].node, rcp); + is_float = instr->data_type->e.numeric.type == HLSL_TYPE_FLOAT + || instr->data_type->e.numeric.type == HLSL_TYPE_HALF; + float_type = hlsl_get_vector_type(ctx, HLSL_TYPE_FLOAT, instr->data_type->e.numeric.dimx); + + for (unsigned int i = 0; i < 2; ++i) + { + operands[i] = expr->operands[i].node; + if (!is_float) + operands[i] = hlsl_block_add_cast(ctx, block, operands[i], float_type, &instr->loc); + } + + rcp = hlsl_block_add_unary_expr(ctx, block, HLSL_OP1_RCP, operands[1], &instr->loc); + ret = hlsl_block_add_binary_expr(ctx, block, HLSL_OP2_MUL, operands[0], rcp); + if (!is_float) + ret = hlsl_block_add_cast(ctx, block, ret, instr->data_type, &instr->loc); + return true; } @@ -4175,48 +4191,6 @@ static bool lower_float_modulus(struct hlsl_ctx *ctx, struct hlsl_ir_node *instr return true; } -static bool lower_nonfloat_exprs(struct hlsl_ctx *ctx, struct hlsl_ir_node *instr, struct hlsl_block *block) -{ - struct hlsl_ir_expr *expr; - - if (instr->type != HLSL_IR_EXPR) - return false; - expr = hlsl_ir_expr(instr); - if (expr->op == HLSL_OP1_CAST || instr->data_type->e.numeric.type == HLSL_TYPE_FLOAT) - return false; - - switch (expr->op) - { - case HLSL_OP2_DIV: - { - struct hlsl_ir_node *operands[HLSL_MAX_OPERANDS] = {0}; - struct hlsl_ir_node *arg, *float_expr; - struct hlsl_type *float_type; - unsigned int i; - - for (i = 0; i < HLSL_MAX_OPERANDS; ++i) - { - arg = expr->operands[i].node; - if (!arg) - continue; - - float_type = hlsl_get_vector_type(ctx, HLSL_TYPE_FLOAT, arg->data_type->e.numeric.dimx); - operands[i] = hlsl_block_add_cast(ctx, block, arg, float_type, &instr->loc); - } - - float_type = hlsl_get_vector_type(ctx, HLSL_TYPE_FLOAT, instr->data_type->e.numeric.dimx); - if (!(float_expr = hlsl_new_expr(ctx, expr->op, operands, float_type, &instr->loc))) - return false; - hlsl_block_add_instr(block, float_expr); - - hlsl_block_add_cast(ctx, block, float_expr, instr->data_type, &instr->loc); - return true; - } - default: - return false; - } -} - static bool lower_discard_neg(struct hlsl_ctx *ctx, struct hlsl_ir_node *instr, void *context) { struct hlsl_ir_node *zero, *bool_false, *or, *cmp, *load; @@ -12273,14 +12247,12 @@ static void process_entry_function(struct hlsl_ctx *ctx, while (lower_ir(ctx, lower_nonconstant_array_loads, body)); lower_ir(ctx, lower_ternary, body); - - lower_ir(ctx, lower_nonfloat_exprs, body); + lower_ir(ctx, lower_division, body); /* Constants casted to float must be folded, and new casts to bool also need to be lowered. */ hlsl_transform_ir(ctx, hlsl_fold_constant_exprs, body, NULL); lower_ir(ctx, lower_casts_to_bool, body); lower_ir(ctx, lower_casts_to_int, body); - lower_ir(ctx, lower_division, body); lower_ir(ctx, lower_sqrt, body); lower_ir(ctx, lower_dot, body); lower_ir(ctx, lower_round, body);