vkd3d-shader/hlsl: Use lower_ir() for more passes.

This commit is contained in:
Zebediah Figura 2023-06-25 18:46:10 -05:00 committed by Alexandre Julliard
parent 976fd67f51
commit 65bf6e997c
Notes: Alexandre Julliard 2023-09-25 22:28:00 +02:00
Approved-by: Giovanni Mascellani (@giomasce)
Approved-by: Francisco Casas (@fcasas)
Approved-by: Henri Verbeet (@hverbeet)
Approved-by: Alexandre Julliard (@julliard)
Merge-Request: https://gitlab.winehq.org/wine/vkd3d/-/merge_requests/334

View File

@ -989,7 +989,7 @@ static bool lower_matrix_swizzles(struct hlsl_ctx *ctx, struct hlsl_ir_node *ins
* For the latter case, this pass takes care of lowering hlsl_ir_indexes into individual
* hlsl_ir_loads, or individual hlsl_ir_resource_loads, in case the indexing is a
* resource access. */
static bool lower_index_loads(struct hlsl_ctx *ctx, struct hlsl_ir_node *instr, void *context)
static bool lower_index_loads(struct hlsl_ctx *ctx, struct hlsl_ir_node *instr, struct hlsl_block *block)
{
struct hlsl_ir_node *val, *store;
struct hlsl_deref var_deref;
@ -1023,8 +1023,7 @@ static bool lower_index_loads(struct hlsl_ctx *ctx, struct hlsl_ir_node *instr,
if (!(load = hlsl_new_resource_load(ctx, &params, &instr->loc)))
return false;
list_add_before(&instr->entry, &load->entry);
hlsl_replace_node(instr, load);
hlsl_block_add_instr(block, load);
return true;
}
@ -1034,7 +1033,7 @@ static bool lower_index_loads(struct hlsl_ctx *ctx, struct hlsl_ir_node *instr,
if (!(store = hlsl_new_simple_store(ctx, var, val)))
return false;
list_add_before(&instr->entry, &store->entry);
hlsl_block_add_instr(block, store);
if (hlsl_index_is_noncontiguous(index))
{
@ -1054,38 +1053,36 @@ static bool lower_index_loads(struct hlsl_ctx *ctx, struct hlsl_ir_node *instr,
if (!(c = hlsl_new_uint_constant(ctx, i, &instr->loc)))
return false;
list_add_before(&instr->entry, &c->entry);
hlsl_block_add_instr(block, c);
if (!(load = hlsl_new_load_index(ctx, &var_deref, c, &instr->loc)))
return false;
list_add_before(&instr->entry, &load->node.entry);
hlsl_block_add_instr(block, &load->node);
if (!(load = hlsl_new_load_index(ctx, &load->src, index->idx.node, &instr->loc)))
return false;
list_add_before(&instr->entry, &load->node.entry);
hlsl_block_add_instr(block, &load->node);
if (!(store = hlsl_new_store_index(ctx, &row_deref, c, &load->node, 0, &instr->loc)))
return false;
list_add_before(&instr->entry, &store->entry);
hlsl_block_add_instr(block, store);
}
if (!(load = hlsl_new_var_load(ctx, var, &instr->loc)))
return false;
list_add_before(&instr->entry, &load->node.entry);
hlsl_replace_node(instr, &load->node);
hlsl_block_add_instr(block, &load->node);
}
else
{
if (!(load = hlsl_new_load_index(ctx, &var_deref, index->idx.node, &instr->loc)))
return false;
list_add_before(&instr->entry, &load->node.entry);
hlsl_replace_node(instr, &load->node);
hlsl_block_add_instr(block, &load->node);
}
return true;
}
/* Lower casts from vec1 to vecN to swizzles. */
static bool lower_broadcasts(struct hlsl_ctx *ctx, struct hlsl_ir_node *instr, void *context)
static bool lower_broadcasts(struct hlsl_ctx *ctx, struct hlsl_ir_node *instr, struct hlsl_block *block)
{
const struct hlsl_type *src_type, *dst_type;
struct hlsl_type *dst_scalar_type;
@ -1101,25 +1098,22 @@ static bool lower_broadcasts(struct hlsl_ctx *ctx, struct hlsl_ir_node *instr, v
if (src_type->class <= HLSL_CLASS_VECTOR && dst_type->class <= HLSL_CLASS_VECTOR && src_type->dimx == 1)
{
struct hlsl_ir_node *replacement, *new_cast, *swizzle;
struct hlsl_ir_node *new_cast, *swizzle;
dst_scalar_type = hlsl_get_scalar_type(ctx, dst_type->base_type);
/* We need to preserve the cast since it might be doing more than just
* turning the scalar into a vector. */
if (!(new_cast = hlsl_new_cast(ctx, cast->operands[0].node, dst_scalar_type, &cast->node.loc)))
return false;
list_add_after(&cast->node.entry, &new_cast->entry);
replacement = new_cast;
hlsl_block_add_instr(block, new_cast);
if (dst_type->dimx != 1)
{
if (!(swizzle = hlsl_new_swizzle(ctx, HLSL_SWIZZLE(X, X, X, X), dst_type->dimx, replacement, &cast->node.loc)))
if (!(swizzle = hlsl_new_swizzle(ctx, HLSL_SWIZZLE(X, X, X, X), dst_type->dimx, new_cast, &cast->node.loc)))
return false;
list_add_after(&new_cast->entry, &swizzle->entry);
replacement = swizzle;
hlsl_block_add_instr(block, swizzle);
}
hlsl_replace_node(&cast->node, replacement);
return true;
}
@ -1981,7 +1975,7 @@ static bool split_matrix_copies(struct hlsl_ctx *ctx, struct hlsl_ir_node *instr
return true;
}
static bool lower_narrowing_casts(struct hlsl_ctx *ctx, struct hlsl_ir_node *instr, void *context)
static bool lower_narrowing_casts(struct hlsl_ctx *ctx, struct hlsl_ir_node *instr, struct hlsl_block *block)
{
const struct hlsl_type *src_type, *dst_type;
struct hlsl_type *dst_vector_type;
@ -2004,12 +1998,12 @@ static bool lower_narrowing_casts(struct hlsl_ctx *ctx, struct hlsl_ir_node *ins
* narrowing the vector. */
if (!(new_cast = hlsl_new_cast(ctx, cast->operands[0].node, dst_vector_type, &cast->node.loc)))
return false;
list_add_after(&cast->node.entry, &new_cast->entry);
hlsl_block_add_instr(block, new_cast);
if (!(swizzle = hlsl_new_swizzle(ctx, HLSL_SWIZZLE(X, Y, Z, W), dst_type->dimx, new_cast, &cast->node.loc)))
return false;
list_add_after(&new_cast->entry, &swizzle->entry);
hlsl_block_add_instr(block, swizzle);
hlsl_replace_node(&cast->node, swizzle);
return true;
}
@ -2068,7 +2062,7 @@ static bool remove_trivial_swizzles(struct hlsl_ctx *ctx, struct hlsl_ir_node *i
return true;
}
static bool lower_nonconstant_vector_derefs(struct hlsl_ctx *ctx, struct hlsl_ir_node *instr, void *context)
static bool lower_nonconstant_vector_derefs(struct hlsl_ctx *ctx, struct hlsl_ir_node *instr, struct hlsl_block *block)
{
struct hlsl_ir_node *idx;
struct hlsl_deref *deref;
@ -2099,11 +2093,11 @@ static bool lower_nonconstant_vector_derefs(struct hlsl_ctx *ctx, struct hlsl_ir
if (!(vector_load = hlsl_new_load_parent(ctx, deref, &instr->loc)))
return false;
list_add_before(&instr->entry, &vector_load->node.entry);
hlsl_block_add_instr(block, &vector_load->node);
if (!(swizzle = hlsl_new_swizzle(ctx, HLSL_SWIZZLE(X, X, X, X), type->dimx, idx, &instr->loc)))
return false;
list_add_before(&instr->entry, &swizzle->entry);
hlsl_block_add_instr(block, swizzle);
value.u[0].u = 0;
value.u[1].u = 1;
@ -2111,18 +2105,18 @@ static bool lower_nonconstant_vector_derefs(struct hlsl_ctx *ctx, struct hlsl_ir
value.u[3].u = 3;
if (!(c = hlsl_new_constant(ctx, hlsl_get_vector_type(ctx, HLSL_TYPE_UINT, type->dimx), &value, &instr->loc)))
return false;
list_add_before(&instr->entry, &c->entry);
hlsl_block_add_instr(block, c);
operands[0] = swizzle;
operands[1] = c;
if (!(eq = hlsl_new_expr(ctx, HLSL_OP2_EQUAL, operands,
hlsl_get_vector_type(ctx, HLSL_TYPE_BOOL, type->dimx), &instr->loc)))
return false;
list_add_before(&instr->entry, &eq->entry);
hlsl_block_add_instr(block, eq);
if (!(eq = hlsl_new_cast(ctx, eq, type, &instr->loc)))
return false;
list_add_before(&instr->entry, &eq->entry);
hlsl_block_add_instr(block, eq);
op = HLSL_OP2_DOT;
if (type->dimx == 1)
@ -2134,8 +2128,7 @@ static bool lower_nonconstant_vector_derefs(struct hlsl_ctx *ctx, struct hlsl_ir
operands[1] = eq;
if (!(dot = hlsl_new_expr(ctx, op, operands, instr->data_type, &instr->loc)))
return false;
list_add_before(&instr->entry, &dot->entry);
hlsl_replace_node(instr, dot);
hlsl_block_add_instr(block, dot);
return true;
}
@ -2317,7 +2310,7 @@ static bool lower_sqrt(struct hlsl_ctx *ctx, struct hlsl_ir_node *instr, void *c
}
/* Lower DP2 to MUL + ADD */
static bool lower_dot(struct hlsl_ctx *ctx, struct hlsl_ir_node *instr, void *context)
static bool lower_dot(struct hlsl_ctx *ctx, struct hlsl_ir_node *instr, struct hlsl_block *block)
{
struct hlsl_ir_node *arg1, *arg2, *mul, *replacement, *zero, *add_x, *add_y;
struct hlsl_ir_expr *expr;
@ -2338,7 +2331,7 @@ static bool lower_dot(struct hlsl_ctx *ctx, struct hlsl_ir_node *instr, void *co
if (!(zero = hlsl_new_float_constant(ctx, 0.0f, &expr->node.loc)))
return false;
list_add_before(&instr->entry, &zero->entry);
hlsl_block_add_instr(block, zero);
operands[0] = arg1;
operands[1] = arg2;
@ -2351,27 +2344,26 @@ static bool lower_dot(struct hlsl_ctx *ctx, struct hlsl_ir_node *instr, void *co
{
if (!(mul = hlsl_new_binary_expr(ctx, HLSL_OP2_MUL, expr->operands[0].node, expr->operands[1].node)))
return false;
list_add_before(&instr->entry, &mul->entry);
hlsl_block_add_instr(block, mul);
if (!(add_x = hlsl_new_swizzle(ctx, HLSL_SWIZZLE(X, X, X, X), instr->data_type->dimx, mul, &expr->node.loc)))
return false;
list_add_before(&instr->entry, &add_x->entry);
hlsl_block_add_instr(block, add_x);
if (!(add_y = hlsl_new_swizzle(ctx, HLSL_SWIZZLE(Y, Y, Y, Y), instr->data_type->dimx, mul, &expr->node.loc)))
return false;
list_add_before(&instr->entry, &add_y->entry);
hlsl_block_add_instr(block, add_y);
if (!(replacement = hlsl_new_binary_expr(ctx, HLSL_OP2_ADD, add_x, add_y)))
return false;
}
list_add_before(&instr->entry, &replacement->entry);
hlsl_block_add_instr(block, replacement);
hlsl_replace_node(instr, replacement);
return true;
}
/* Lower ABS to MAX */
static bool lower_abs(struct hlsl_ctx *ctx, struct hlsl_ir_node *instr, void *context)
static bool lower_abs(struct hlsl_ctx *ctx, struct hlsl_ir_node *instr, struct hlsl_block *block)
{
struct hlsl_ir_node *arg, *neg, *replacement;
struct hlsl_ir_expr *expr;
@ -2385,18 +2377,17 @@ static bool lower_abs(struct hlsl_ctx *ctx, struct hlsl_ir_node *instr, void *co
if (!(neg = hlsl_new_unary_expr(ctx, HLSL_OP1_NEG, arg, &instr->loc)))
return false;
list_add_before(&instr->entry, &neg->entry);
hlsl_block_add_instr(block, neg);
if (!(replacement = hlsl_new_binary_expr(ctx, HLSL_OP2_MAX, neg, arg)))
return false;
list_add_before(&instr->entry, &replacement->entry);
hlsl_block_add_instr(block, replacement);
hlsl_replace_node(instr, replacement);
return true;
}
/* Lower ROUND using FRC, ROUND(x) -> ((x + 0.5) - FRC(x + 0.5)). */
static bool lower_round(struct hlsl_ctx *ctx, struct hlsl_ir_node *instr, void *context)
static bool lower_round(struct hlsl_ctx *ctx, struct hlsl_ir_node *instr, struct hlsl_block *block)
{
struct hlsl_ir_node *arg, *neg, *sum, *frc, *half, *replacement;
struct hlsl_type *type = instr->data_type;
@ -2417,31 +2408,29 @@ static bool lower_round(struct hlsl_ctx *ctx, struct hlsl_ir_node *instr, void *
half_value.u[i].f = 0.5f;
if (!(half = hlsl_new_constant(ctx, type, &half_value, &expr->node.loc)))
return false;
list_add_before(&instr->entry, &half->entry);
hlsl_block_add_instr(block, half);
if (!(sum = hlsl_new_binary_expr(ctx, HLSL_OP2_ADD, arg, half)))
return false;
list_add_before(&instr->entry, &sum->entry);
hlsl_block_add_instr(block, sum);
if (!(frc = hlsl_new_unary_expr(ctx, HLSL_OP1_FRACT, sum, &instr->loc)))
return false;
list_add_before(&instr->entry, &frc->entry);
hlsl_block_add_instr(block, frc);
if (!(neg = hlsl_new_unary_expr(ctx, HLSL_OP1_NEG, frc, &instr->loc)))
return false;
list_add_before(&instr->entry, &neg->entry);
hlsl_block_add_instr(block, neg);
if (!(replacement = hlsl_new_binary_expr(ctx, HLSL_OP2_ADD, sum, neg)))
return false;
list_add_before(&instr->entry, &replacement->entry);
hlsl_block_add_instr(block, replacement);
hlsl_replace_node(instr, replacement);
return true;
}
/* Use 'movc' for the ternary operator. */
static bool lower_ternary(struct hlsl_ctx *ctx, struct hlsl_ir_node *instr, void *context)
static bool lower_ternary(struct hlsl_ctx *ctx, struct hlsl_ir_node *instr, struct hlsl_block *block)
{
struct hlsl_ir_node *operands[HLSL_MAX_OPERANDS], *replacement;
struct hlsl_ir_node *zero, *cond, *first, *second;
@ -2464,7 +2453,7 @@ static bool lower_ternary(struct hlsl_ctx *ctx, struct hlsl_ir_node *instr, void
{
if (!(zero = hlsl_new_constant(ctx, cond->data_type, &zero_value, &instr->loc)))
return false;
list_add_tail(&instr->entry, &zero->entry);
hlsl_block_add_instr(block, zero);
memset(operands, 0, sizeof(operands));
operands[0] = zero;
@ -2473,7 +2462,7 @@ static bool lower_ternary(struct hlsl_ctx *ctx, struct hlsl_ir_node *instr, void
type = hlsl_get_numeric_type(ctx, type->class, HLSL_TYPE_BOOL, type->dimx, type->dimy);
if (!(cond = hlsl_new_expr(ctx, HLSL_OP2_NEQUAL, operands, type, &instr->loc)))
return false;
list_add_before(&instr->entry, &cond->entry);
hlsl_block_add_instr(block, cond);
}
memset(operands, 0, sizeof(operands));
@ -2482,9 +2471,7 @@ static bool lower_ternary(struct hlsl_ctx *ctx, struct hlsl_ir_node *instr, void
operands[2] = second;
if (!(replacement = hlsl_new_expr(ctx, HLSL_OP3_MOVC, operands, first->data_type, &instr->loc)))
return false;
list_add_before(&instr->entry, &replacement->entry);
hlsl_replace_node(instr, replacement);
hlsl_block_add_instr(block, replacement);
return true;
}
@ -2695,7 +2682,7 @@ static bool lower_int_abs(struct hlsl_ctx *ctx, struct hlsl_ir_node *instr, void
return true;
}
static bool lower_int_dot(struct hlsl_ctx *ctx, struct hlsl_ir_node *instr, void *context)
static bool lower_int_dot(struct hlsl_ctx *ctx, struct hlsl_ir_node *instr, struct hlsl_block *block)
{
struct hlsl_ir_node *arg1, *arg2, *mult, *comps[4] = {0}, *res;
struct hlsl_type *type = instr->data_type;
@ -2721,7 +2708,7 @@ static bool lower_int_dot(struct hlsl_ctx *ctx, struct hlsl_ir_node *instr, void
if (!(mult = hlsl_new_binary_expr(ctx, is_bool ? HLSL_OP2_LOGIC_AND : HLSL_OP2_MUL, arg1, arg2)))
return false;
list_add_before(&instr->entry, &mult->entry);
hlsl_block_add_instr(block, mult);
for (i = 0; i < dimx; ++i)
{
@ -2729,7 +2716,7 @@ static bool lower_int_dot(struct hlsl_ctx *ctx, struct hlsl_ir_node *instr, void
if (!(comps[i] = hlsl_new_swizzle(ctx, s, 1, mult, &instr->loc)))
return false;
list_add_before(&instr->entry, &comps[i]->entry);
hlsl_block_add_instr(block, comps[i]);
}
res = comps[0];
@ -2737,10 +2724,9 @@ static bool lower_int_dot(struct hlsl_ctx *ctx, struct hlsl_ir_node *instr, void
{
if (!(res = hlsl_new_binary_expr(ctx, is_bool ? HLSL_OP2_LOGIC_OR : HLSL_OP2_ADD, res, comps[i])))
return false;
list_add_before(&instr->entry, &res->entry);
hlsl_block_add_instr(block, res);
}
hlsl_replace_node(instr, res);
return true;
}
@ -4328,7 +4314,7 @@ int hlsl_emit_bytecode(struct hlsl_ctx *ctx, struct hlsl_ir_function_decl *entry
while (hlsl_transform_ir(ctx, lower_calls, body, NULL));
lower_ir(ctx, lower_matrix_swizzles, body);
hlsl_transform_ir(ctx, lower_index_loads, body, NULL);
lower_ir(ctx, lower_index_loads, body);
LIST_FOR_EACH_ENTRY(var, &ctx->globals->vars, struct hlsl_ir_var, scope_entry)
{
@ -4391,7 +4377,7 @@ int hlsl_emit_bytecode(struct hlsl_ctx *ctx, struct hlsl_ir_function_decl *entry
{
hlsl_transform_ir(ctx, lower_discard_neg, body, NULL);
}
hlsl_transform_ir(ctx, lower_broadcasts, body, NULL);
lower_ir(ctx, lower_broadcasts, body);
while (hlsl_transform_ir(ctx, fold_redundant_casts, body, NULL));
do
{
@ -4401,9 +4387,9 @@ int hlsl_emit_bytecode(struct hlsl_ctx *ctx, struct hlsl_ir_function_decl *entry
while (progress);
hlsl_transform_ir(ctx, split_matrix_copies, body, NULL);
hlsl_transform_ir(ctx, lower_narrowing_casts, body, NULL);
lower_ir(ctx, lower_narrowing_casts, body);
hlsl_transform_ir(ctx, lower_casts_to_bool, body, NULL);
hlsl_transform_ir(ctx, lower_int_dot, body, NULL);
lower_ir(ctx, lower_int_dot, body);
lower_ir(ctx, lower_int_division, body);
lower_ir(ctx, lower_int_modulus, body);
hlsl_transform_ir(ctx, lower_int_abs, body, NULL);
@ -4419,9 +4405,9 @@ int hlsl_emit_bytecode(struct hlsl_ctx *ctx, struct hlsl_ir_function_decl *entry
}
while (progress);
hlsl_transform_ir(ctx, lower_nonconstant_vector_derefs, body, NULL);
lower_ir(ctx, lower_nonconstant_vector_derefs, body);
hlsl_transform_ir(ctx, lower_casts_to_bool, body, NULL);
hlsl_transform_ir(ctx, lower_int_dot, body, NULL);
lower_ir(ctx, lower_int_dot, body);
hlsl_transform_ir(ctx, validate_static_object_references, body, NULL);
hlsl_transform_ir(ctx, track_object_components_sampler_dim, body, NULL);
@ -4431,18 +4417,18 @@ int hlsl_emit_bytecode(struct hlsl_ctx *ctx, struct hlsl_ir_function_decl *entry
sort_synthetic_separated_samplers_first(ctx);
if (profile->major_version >= 4)
hlsl_transform_ir(ctx, lower_ternary, body, NULL);
lower_ir(ctx, lower_ternary, body);
if (profile->major_version < 4)
{
hlsl_transform_ir(ctx, lower_division, body, NULL);
hlsl_transform_ir(ctx, lower_sqrt, body, NULL);
hlsl_transform_ir(ctx, lower_dot, body, NULL);
hlsl_transform_ir(ctx, lower_round, body, NULL);
lower_ir(ctx, lower_dot, body);
lower_ir(ctx, lower_round, body);
}
if (profile->major_version < 2)
{
hlsl_transform_ir(ctx, lower_abs, body, NULL);
lower_ir(ctx, lower_abs, body);
}
/* TODO: move forward, remove when no longer needed */