vkd3d-shader/hlsl: Add a hlsl_block_add_cast() helper.

This commit is contained in:
Elizabeth Figura
2024-12-08 22:12:41 -06:00
committed by Henri Verbeet
parent 2e09257d94
commit 858b6a3e0b
Notes: Henri Verbeet 2025-02-24 16:27:47 +01:00
Approved-by: Francisco Casas (@fcasas)
Approved-by: Henri Verbeet (@hverbeet)
Merge-Request: https://gitlab.winehq.org/wine/vkd3d/-/merge_requests/1386
4 changed files with 34 additions and 99 deletions

View File

@@ -1696,6 +1696,12 @@ struct hlsl_ir_node *hlsl_new_cast(struct hlsl_ctx *ctx, struct hlsl_ir_node *no
return cast;
}
struct hlsl_ir_node *hlsl_block_add_cast(struct hlsl_ctx *ctx, struct hlsl_block *block,
struct hlsl_ir_node *arg, struct hlsl_type *type, const struct vkd3d_shader_location *loc)
{
return append_new_instr(ctx, block, hlsl_new_cast(ctx, arg, type, loc));
}
static struct hlsl_ir_node *hlsl_new_error_expr(struct hlsl_ctx *ctx)
{
static const struct vkd3d_shader_location loc = {.source_name = "<error>"};

View File

@@ -1505,6 +1505,8 @@ struct hlsl_ir_node *hlsl_add_conditional(struct hlsl_ctx *ctx, struct hlsl_bloc
void hlsl_add_function(struct hlsl_ctx *ctx, char *name, struct hlsl_ir_function_decl *decl);
void hlsl_add_var(struct hlsl_ctx *ctx, struct hlsl_ir_var *decl);
struct hlsl_ir_node *hlsl_block_add_cast(struct hlsl_ctx *ctx, struct hlsl_block *block,
struct hlsl_ir_node *arg, struct hlsl_type *type, const struct vkd3d_shader_location *loc);
struct hlsl_ir_node *hlsl_block_add_float_constant(struct hlsl_ctx *ctx, struct hlsl_block *block,
float f, const struct vkd3d_shader_location *loc);
struct hlsl_ir_node *hlsl_block_add_int_constant(struct hlsl_ctx *ctx, struct hlsl_block *block,

View File

@@ -351,7 +351,6 @@ static struct hlsl_ir_node *add_cast(struct hlsl_ctx *ctx, struct hlsl_block *bl
struct hlsl_ir_node *node, struct hlsl_type *dst_type, const struct vkd3d_shader_location *loc)
{
struct hlsl_type *src_type = node->data_type;
struct hlsl_ir_node *cast;
if (hlsl_types_are_equal(src_type, dst_type))
return node;
@@ -359,11 +358,7 @@ static struct hlsl_ir_node *add_cast(struct hlsl_ctx *ctx, struct hlsl_block *bl
if (src_type->class == HLSL_CLASS_NULL)
return node;
if (!(cast = hlsl_new_cast(ctx, node, dst_type, loc)))
return NULL;
hlsl_block_add_instr(block, cast);
return cast;
return hlsl_block_add_cast(ctx, block, node, dst_type, loc);
}
static struct hlsl_ir_node *add_implicit_conversion(struct hlsl_ctx *ctx, struct hlsl_block *block,
@@ -940,7 +935,7 @@ static bool add_array_access(struct hlsl_ctx *ctx, struct hlsl_block *block, str
struct hlsl_ir_node *index, const struct vkd3d_shader_location *loc)
{
const struct hlsl_type *expr_type = array->data_type, *index_type = index->data_type;
struct hlsl_ir_node *return_index, *cast;
struct hlsl_ir_node *return_index;
if (array->data_type->class == HLSL_CLASS_ERROR || index->data_type->class == HLSL_CLASS_ERROR)
{
@@ -981,10 +976,7 @@ static bool add_array_access(struct hlsl_ctx *ctx, struct hlsl_block *block, str
return false;
}
if (!(cast = hlsl_new_cast(ctx, index, hlsl_get_scalar_type(ctx, HLSL_TYPE_UINT), &index->loc)))
return false;
hlsl_block_add_instr(block, cast);
index = cast;
index = hlsl_block_add_cast(ctx, block, index, hlsl_get_scalar_type(ctx, HLSL_TYPE_UINT), &index->loc);
if (expr_type->class != HLSL_CLASS_ARRAY && expr_type->class != HLSL_CLASS_VECTOR && expr_type->class != HLSL_CLASS_MATRIX)
{

View File

@@ -456,9 +456,7 @@ static void prepend_input_copy(struct hlsl_ctx *ctx, struct hlsl_ir_function_dec
hlsl_block_add_instr(block, &load->node);
}
if (!(cast = hlsl_new_cast(ctx, &load->node, vector_type_dst, &var->loc)))
return;
hlsl_block_add_instr(block, cast);
cast = hlsl_block_add_cast(ctx, block, &load->node, vector_type_dst, &var->loc);
if (type->class == HLSL_CLASS_MATRIX)
{
@@ -1187,9 +1185,7 @@ static bool lower_complex_casts(struct hlsl_ctx *ctx, struct hlsl_ir_node *instr
if (!(component_load = hlsl_add_load_component(ctx, block, arg, src_idx, &arg->loc)))
return false;
if (!(cast = hlsl_new_cast(ctx, component_load, dst_comp_type, &arg->loc)))
return false;
hlsl_block_add_instr(block, cast);
cast = hlsl_block_add_cast(ctx, block, component_load, dst_comp_type, &arg->loc);
if (!hlsl_new_store_component(ctx, &store_block, &var_deref, dst_idx, cast))
return false;
@@ -1368,9 +1364,7 @@ static bool lower_broadcasts(struct hlsl_ctx *ctx, struct hlsl_ir_node *instr, s
dst_scalar_type = hlsl_get_scalar_type(ctx, dst_type->e.numeric.type);
/* We need to preserve the cast since it might be doing more than just
* turning the scalar into a vector. */
if (!(new_cast = hlsl_new_cast(ctx, cast->operands[0].node, dst_scalar_type, &cast->node.loc)))
return false;
hlsl_block_add_instr(block, new_cast);
new_cast = hlsl_block_add_cast(ctx, block, cast->operands[0].node, dst_scalar_type, &cast->node.loc);
if (dst_type->e.numeric.dimx != 1)
{
@@ -2582,9 +2576,7 @@ static bool lower_narrowing_casts(struct hlsl_ctx *ctx, struct hlsl_ir_node *ins
dst_vector_type = hlsl_get_vector_type(ctx, dst_type->e.numeric.type, src_type->e.numeric.dimx);
/* We need to preserve the cast since it might be doing more than just
* narrowing the vector. */
if (!(new_cast = hlsl_new_cast(ctx, cast->operands[0].node, dst_vector_type, &cast->node.loc)))
return false;
hlsl_block_add_instr(block, new_cast);
new_cast = hlsl_block_add_cast(ctx, block, cast->operands[0].node, dst_vector_type, &cast->node.loc);
if (!(swizzle = hlsl_new_swizzle(ctx, HLSL_SWIZZLE(X, Y, Z, W),
dst_type->e.numeric.dimx, new_cast, &cast->node.loc)))
@@ -2830,9 +2822,7 @@ static bool lower_nonconstant_vector_derefs(struct hlsl_ctx *ctx, struct hlsl_ir
return false;
hlsl_block_add_instr(block, eq);
if (!(eq = hlsl_new_cast(ctx, eq, type, &instr->loc)))
return false;
hlsl_block_add_instr(block, eq);
eq = hlsl_block_add_cast(ctx, block, eq, type, &instr->loc);
op = HLSL_OP2_DOT;
if (width == 1)
@@ -3686,9 +3676,7 @@ static bool lower_logic_not(struct hlsl_ctx *ctx, struct hlsl_ir_node *instr, st
/* If this is happens, it means we failed to cast the argument to boolean somewhere. */
VKD3D_ASSERT(arg->data_type->e.numeric.type == HLSL_TYPE_BOOL);
if (!(arg_cast = hlsl_new_cast(ctx, arg, float_type, &arg->loc)))
return false;
hlsl_block_add_instr(block, arg_cast);
arg_cast = hlsl_block_add_cast(ctx, block, arg, float_type, &arg->loc);
neg = hlsl_block_add_unary_expr(ctx, block, HLSL_OP1_NEG, arg_cast, &instr->loc);
@@ -3742,11 +3730,7 @@ static bool lower_ternary(struct hlsl_ctx *ctx, struct hlsl_ir_node *instr, stru
type = hlsl_get_numeric_type(ctx, instr->data_type->class, HLSL_TYPE_FLOAT,
instr->data_type->e.numeric.dimx, instr->data_type->e.numeric.dimy);
if (!(float_cond = hlsl_new_cast(ctx, cond, type, &instr->loc)))
return false;
hlsl_block_add_instr(block, float_cond);
float_cond = hlsl_block_add_cast(ctx, block, cond, type, &instr->loc);
neg = hlsl_block_add_unary_expr(ctx, block, HLSL_OP1_NEG, float_cond, &instr->loc);
memset(operands, 0, sizeof(operands));
@@ -3825,13 +3809,8 @@ static bool lower_comparison_operators(struct hlsl_ctx *ctx, struct hlsl_ir_node
arg2 = expr->operands[1].node;
float_type = hlsl_get_vector_type(ctx, HLSL_TYPE_FLOAT, instr->data_type->e.numeric.dimx);
if (!(arg1_cast = hlsl_new_cast(ctx, arg1, float_type, &instr->loc)))
return false;
hlsl_block_add_instr(block, arg1_cast);
if (!(arg2_cast = hlsl_new_cast(ctx, arg2, float_type, &instr->loc)))
return false;
hlsl_block_add_instr(block, arg2_cast);
arg1_cast = hlsl_block_add_cast(ctx, block, arg1, float_type, &instr->loc);
arg2_cast = hlsl_block_add_cast(ctx, block, arg2, float_type, &instr->loc);
switch (expr->op)
{
@@ -3943,14 +3922,8 @@ static bool lower_slt(struct hlsl_ctx *ctx, struct hlsl_ir_node *instr, struct h
arg2 = expr->operands[1].node;
float_type = hlsl_get_vector_type(ctx, HLSL_TYPE_FLOAT, instr->data_type->e.numeric.dimx);
if (!(arg1_cast = hlsl_new_cast(ctx, arg1, float_type, &instr->loc)))
return false;
hlsl_block_add_instr(block, arg1_cast);
if (!(arg2_cast = hlsl_new_cast(ctx, arg2, float_type, &instr->loc)))
return false;
hlsl_block_add_instr(block, arg2_cast);
arg1_cast = hlsl_block_add_cast(ctx, block, arg1, float_type, &instr->loc);
arg2_cast = hlsl_block_add_cast(ctx, block, arg2, float_type, &instr->loc);
neg = hlsl_block_add_unary_expr(ctx, block, HLSL_OP1_NEG, arg2_cast, &instr->loc);
if (!(sub = hlsl_new_binary_expr(ctx, HLSL_OP2_ADD, arg1_cast, neg)))
@@ -4004,10 +3977,7 @@ static bool lower_cmp(struct hlsl_ctx *ctx, struct hlsl_ir_node *instr, struct h
for (i = 0; i < 3; ++i)
{
args[i] = expr->operands[i].node;
if (!(args_cast[i] = hlsl_new_cast(ctx, args[i], float_type, &instr->loc)))
return false;
hlsl_block_add_instr(block, args_cast[i]);
args_cast[i] = hlsl_block_add_cast(ctx, block, args[i], float_type, &instr->loc);
}
memset(&zero_value, 0, sizeof(zero_value));
@@ -4095,10 +4065,7 @@ struct hlsl_ir_node *hlsl_add_conditional(struct hlsl_ctx *ctx, struct hlsl_bloc
{
cond_type = hlsl_get_numeric_type(ctx, cond_type->class, HLSL_TYPE_BOOL,
cond_type->e.numeric.dimx, cond_type->e.numeric.dimy);
if (!(condition = hlsl_new_cast(ctx, condition, cond_type, &condition->loc)))
return NULL;
hlsl_block_add_instr(instrs, condition);
condition = hlsl_block_add_cast(ctx, instrs, condition, cond_type, &condition->loc);
}
operands[0] = condition;
@@ -4147,27 +4114,16 @@ static bool lower_int_division(struct hlsl_ctx *ctx, struct hlsl_ir_node *instr,
hlsl_block_add_instr(block, and);
abs1 = hlsl_block_add_unary_expr(ctx, block, HLSL_OP1_ABS, arg1, &instr->loc);
if (!(cast1 = hlsl_new_cast(ctx, abs1, utype, &instr->loc)))
return false;
hlsl_block_add_instr(block, cast1);
cast1 = hlsl_block_add_cast(ctx, block, abs1, utype, &instr->loc);
abs2 = hlsl_block_add_unary_expr(ctx, block, HLSL_OP1_ABS, arg2, &instr->loc);
if (!(cast2 = hlsl_new_cast(ctx, abs2, utype, &instr->loc)))
return false;
hlsl_block_add_instr(block, cast2);
cast2 = hlsl_block_add_cast(ctx, block, abs2, utype, &instr->loc);
if (!(div = hlsl_new_binary_expr(ctx, HLSL_OP2_DIV, cast1, cast2)))
return false;
hlsl_block_add_instr(block, div);
if (!(cast3 = hlsl_new_cast(ctx, div, type, &instr->loc)))
return false;
hlsl_block_add_instr(block, cast3);
cast3 = hlsl_block_add_cast(ctx, block, div, type, &instr->loc);
neg = hlsl_block_add_unary_expr(ctx, block, HLSL_OP1_NEG, cast3, &instr->loc);
return hlsl_add_conditional(ctx, block, and, neg, cast3);
}
@@ -4203,27 +4159,16 @@ static bool lower_int_modulus(struct hlsl_ctx *ctx, struct hlsl_ir_node *instr,
hlsl_block_add_instr(block, and);
abs1 = hlsl_block_add_unary_expr(ctx, block, HLSL_OP1_ABS, arg1, &instr->loc);
if (!(cast1 = hlsl_new_cast(ctx, abs1, utype, &instr->loc)))
return false;
hlsl_block_add_instr(block, cast1);
cast1 = hlsl_block_add_cast(ctx, block, abs1, utype, &instr->loc);
abs2 = hlsl_block_add_unary_expr(ctx, block, HLSL_OP1_ABS, arg2, &instr->loc);
if (!(cast2 = hlsl_new_cast(ctx, abs2, utype, &instr->loc)))
return false;
hlsl_block_add_instr(block, cast2);
cast2 = hlsl_block_add_cast(ctx, block, abs2, utype, &instr->loc);
if (!(div = hlsl_new_binary_expr(ctx, HLSL_OP2_MOD, cast1, cast2)))
return false;
hlsl_block_add_instr(block, div);
if (!(cast3 = hlsl_new_cast(ctx, div, type, &instr->loc)))
return false;
hlsl_block_add_instr(block, cast3);
cast3 = hlsl_block_add_cast(ctx, block, div, type, &instr->loc);
neg = hlsl_block_add_unary_expr(ctx, block, HLSL_OP1_NEG, cast3, &instr->loc);
return hlsl_add_conditional(ctx, block, and, neg, cast3);
}
@@ -4389,7 +4334,7 @@ static bool lower_nonfloat_exprs(struct hlsl_ctx *ctx, struct hlsl_ir_node *inst
case HLSL_OP2_MUL:
{
struct hlsl_ir_node *operands[HLSL_MAX_OPERANDS] = {0};
struct hlsl_ir_node *arg, *arg_cast, *float_expr, *ret;
struct hlsl_ir_node *arg, *float_expr;
struct hlsl_type *float_type;
unsigned int i;
@@ -4400,11 +4345,7 @@ static bool lower_nonfloat_exprs(struct hlsl_ctx *ctx, struct hlsl_ir_node *inst
continue;
float_type = hlsl_get_vector_type(ctx, HLSL_TYPE_FLOAT, arg->data_type->e.numeric.dimx);
if (!(arg_cast = hlsl_new_cast(ctx, arg, float_type, &instr->loc)))
return false;
hlsl_block_add_instr(block, arg_cast);
operands[i] = arg_cast;
operands[i] = hlsl_block_add_cast(ctx, block, arg, float_type, &instr->loc);
}
float_type = hlsl_get_vector_type(ctx, HLSL_TYPE_FLOAT, instr->data_type->e.numeric.dimx);
@@ -4412,10 +4353,7 @@ static bool lower_nonfloat_exprs(struct hlsl_ctx *ctx, struct hlsl_ir_node *inst
return false;
hlsl_block_add_instr(block, float_expr);
if (!(ret = hlsl_new_cast(ctx, float_expr, instr->data_type, &instr->loc)))
return false;
hlsl_block_add_instr(block, ret);
hlsl_block_add_cast(ctx, block, float_expr, instr->data_type, &instr->loc);
return true;
}
default:
@@ -4497,10 +4435,7 @@ static bool lower_discard_nz(struct hlsl_ctx *ctx, struct hlsl_ir_node *instr, v
hlsl_block_init(&block);
if (!(cond_cast = hlsl_new_cast(ctx, cond, float_type, &instr->loc)))
return false;
hlsl_block_add_instr(&block, cond_cast);
cond_cast = hlsl_block_add_cast(ctx, &block, cond, float_type, &instr->loc);
abs = hlsl_block_add_unary_expr(ctx, &block, HLSL_OP1_ABS, cond_cast, &instr->loc);
neg = hlsl_block_add_unary_expr(ctx, &block, HLSL_OP1_NEG, abs, &instr->loc);