vkd3d-shader/hlsl: Lower return statements.

This commit is contained in:
Zebediah Figura 2021-09-13 23:08:34 -05:00 committed by Alexandre Julliard
parent 0cf39f3c63
commit 8bdee6681b
Notes: Alexandre Julliard 2023-02-13 22:20:34 +01:00
Approved-by: Giovanni Mascellani (@giomasce)
Approved-by: Francisco Casas (@fcasas)
Approved-by: Henri Verbeet (@hverbeet)
Approved-by: Alexandre Julliard (@julliard)
Merge-Request: https://gitlab.winehq.org/wine/vkd3d/-/merge_requests/85
5 changed files with 236 additions and 7 deletions

View File

@ -186,6 +186,30 @@ static inline void list_move_tail( struct list *dst, struct list *src )
list_move_before( dst, src );
}
/* move the slice of elements from begin to end inclusive to the head of dst */
static inline void list_move_slice_head( struct list *dst, struct list *begin, struct list *end )
{
struct list *dst_next = dst->next;
begin->prev->next = end->next;
end->next->prev = begin->prev;
dst->next = begin;
dst_next->prev = end;
begin->prev = dst;
end->next = dst_next;
}
/* move the slice of elements from begin to end inclusive to the tail of dst */
static inline void list_move_slice_tail( struct list *dst, struct list *begin, struct list *end )
{
struct list *dst_prev = dst->prev;
begin->prev->next = end->next;
end->next->prev = begin->prev;
dst_prev->next = begin;
dst->prev = end;
begin->prev = dst_prev;
end->next = dst;
}
/* iterate through the list */
#define LIST_FOR_EACH(cursor,list) \
for ((cursor) = (list)->next; (cursor) != (list); (cursor) = (cursor)->next)

View File

@ -1209,6 +1209,8 @@ struct hlsl_ir_function_decl *hlsl_new_func_decl(struct hlsl_ctx *ctx,
const struct hlsl_semantic *semantic, const struct vkd3d_shader_location *loc)
{
struct hlsl_ir_function_decl *decl;
struct hlsl_ir_constant *constant;
struct hlsl_ir_store *store;
if (!(decl = hlsl_alloc(ctx, sizeof(*decl))))
return NULL;
@ -1227,6 +1229,18 @@ struct hlsl_ir_function_decl *hlsl_new_func_decl(struct hlsl_ctx *ctx,
decl->return_var->semantic = *semantic;
}
if (!(decl->early_return_var = hlsl_new_synthetic_var(ctx, "early_return",
hlsl_get_scalar_type(ctx, HLSL_TYPE_BOOL), loc)))
return decl;
if (!(constant = hlsl_new_bool_constant(ctx, false, loc)))
return decl;
list_add_tail(&decl->body.instrs, &constant->node.entry);
if (!(store = hlsl_new_simple_store(ctx, decl->early_return_var, &constant->node)))
return decl;
list_add_tail(&decl->body.instrs, &store->node.entry);
return decl;
}

View File

@ -429,6 +429,11 @@ struct hlsl_ir_function_decl
* Not to be confused with the function parameters! */
unsigned int attr_count;
const struct hlsl_attribute *const *attrs;
/* Synthetic boolean variable marking whether a return statement has been
* executed. Needed to deal with return statements in non-uniform control
* flow, since some backends can't handle them. */
struct hlsl_ir_var *early_return_var;
};
struct hlsl_ir_call

View File

@ -499,6 +499,185 @@ static bool find_recursive_calls(struct hlsl_ctx *ctx, struct hlsl_ir_node *inst
return false;
}
static void insert_early_return_break(struct hlsl_ctx *ctx,
struct hlsl_ir_function_decl *func, struct hlsl_ir_node *cf_instr)
{
struct hlsl_ir_jump *jump;
struct hlsl_ir_load *load;
struct hlsl_ir_if *iff;
if (!(load = hlsl_new_var_load(ctx, func->early_return_var, cf_instr->loc)))
return;
list_add_after(&cf_instr->entry, &load->node.entry);
if (!(iff = hlsl_new_if(ctx, &load->node, cf_instr->loc)))
return;
list_add_after(&load->node.entry, &iff->node.entry);
if (!(jump = hlsl_new_jump(ctx, HLSL_IR_JUMP_BREAK, cf_instr->loc)))
return;
list_add_tail(&iff->then_instrs.instrs, &jump->node.entry);
}
/* Remove HLSL_IR_JUMP_RETURN calls by altering subsequent control flow. */
static void lower_return(struct hlsl_ctx *ctx, struct hlsl_ir_function_decl *func,
struct hlsl_block *block, bool in_loop)
{
struct hlsl_ir_node *return_instr = NULL, *cf_instr = NULL;
struct hlsl_ir_node *instr, *next;
/* SM1 has no function calls. SM4 does, but native d3dcompiler inlines
* everything anyway. We are safest following suit.
*
* The basic idea is to keep track of whether the function has executed an
* early return in a synthesized boolean variable (func->early_return_var)
* and guard all code after the return on that variable being false. In the
* case of loops we also replace the return with a break.
*
* The following algorithm loops over instructions in a block, recursing
* into inferior CF blocks, until it hits one of the following two things:
*
* - A return statement. In this case, we remove everything after the return
* statement in this block. We have to stop and do this in a separate
* loop, because instructions must be deleted in reverse order (due to
* def-use chains.)
*
* If we're inside of a loop CF block, we can instead just turn the
* return into a break, which offers the right semanticsexcept that it
* won't break out of nested loops.
*
* - A CF block which might contain a return statement. After calling
* lower_return() on the CF block body, we stop, pull out everything after
* the CF instruction, shove it into an if block, and then lower that if
* block.
*
* (We could return a "did we make progress" boolean like transform_ir()
* and run this pass multiple times, but we already know the only block
* that still needs to be addressed, so there's not much point.)
*
* If we're inside of a loop CF block, we again do things differently. We
* already turned any returns into breaks. If the block we just processed
* was conditional, then "break" did our work for us. If it was a loop,
* we need to propagate that break to the outer loop.
*/
LIST_FOR_EACH_ENTRY_SAFE(instr, next, &block->instrs, struct hlsl_ir_node, entry)
{
if (instr->type == HLSL_IR_CALL)
{
struct hlsl_ir_call *call = hlsl_ir_call(instr);
lower_return(ctx, call->decl, &call->decl->body, false);
}
else if (instr->type == HLSL_IR_IF)
{
struct hlsl_ir_if *iff = hlsl_ir_if(instr);
lower_return(ctx, func, &iff->then_instrs, in_loop);
lower_return(ctx, func, &iff->else_instrs, in_loop);
/* If we're in a loop, we don't need to do anything here. We
* turned the return into a break, and that will already skip
* anything that comes after this "if" block. */
if (!in_loop)
{
cf_instr = instr;
break;
}
}
else if (instr->type == HLSL_IR_LOOP)
{
lower_return(ctx, func, &hlsl_ir_loop(instr)->body, true);
if (in_loop)
{
/* "instr" is a nested loop. "return" breaks out of all
* loops, so break out of this one too now. */
insert_early_return_break(ctx, func, instr);
}
else
{
cf_instr = instr;
break;
}
}
else if (instr->type == HLSL_IR_JUMP)
{
struct hlsl_ir_jump *jump = hlsl_ir_jump(instr);
struct hlsl_ir_constant *constant;
struct hlsl_ir_store *store;
if (jump->type == HLSL_IR_JUMP_RETURN)
{
if (!(constant = hlsl_new_bool_constant(ctx, true, &jump->node.loc)))
return;
list_add_before(&jump->node.entry, &constant->node.entry);
if (!(store = hlsl_new_simple_store(ctx, func->early_return_var, &constant->node)))
return;
list_add_after(&constant->node.entry, &store->node.entry);
if (in_loop)
{
jump->type = HLSL_IR_JUMP_BREAK;
}
else
{
return_instr = instr;
break;
}
}
}
}
if (return_instr)
{
/* If we're in a loop, we should have used "break" instead. */
assert(!in_loop);
/* Iterate in reverse, to avoid use-after-free when unlinking sources from
* the "uses" list. */
LIST_FOR_EACH_ENTRY_SAFE_REV(instr, next, &block->instrs, struct hlsl_ir_node, entry)
{
list_remove(&instr->entry);
hlsl_free_instr(instr);
/* Yes, we just freed it, but we're comparing pointers. */
if (instr == return_instr)
break;
}
}
else if (cf_instr)
{
struct list *tail = list_tail(&block->instrs);
struct hlsl_ir_load *load;
struct hlsl_ir_node *not;
struct hlsl_ir_if *iff;
/* If we're in a loop, we should have used "break" instead. */
assert(!in_loop);
if (tail == &cf_instr->entry)
return;
if (!(load = hlsl_new_var_load(ctx, func->early_return_var, cf_instr->loc)))
return;
list_add_tail(&block->instrs, &load->node.entry);
if (!(not = hlsl_new_unary_expr(ctx, HLSL_OP1_LOGIC_NOT, &load->node, cf_instr->loc)))
return;
list_add_tail(&block->instrs, &not->entry);
if (!(iff = hlsl_new_if(ctx, not, cf_instr->loc)))
return;
list_add_tail(&block->instrs, &iff->node.entry);
list_move_slice_tail(&iff->then_instrs.instrs, list_next(&block->instrs, &cf_instr->entry), tail);
lower_return(ctx, func, &iff->then_instrs, in_loop);
}
}
/* Lower casts from vec1 to vecN to swizzles. */
static bool lower_broadcasts(struct hlsl_ctx *ctx, struct hlsl_ir_node *instr, void *context)
{
@ -2951,6 +3130,13 @@ int hlsl_emit_bytecode(struct hlsl_ctx *ctx, struct hlsl_ir_function_decl *entry
transform_ir(ctx, find_recursive_calls, body, &recursive_call_ctx);
vkd3d_free(recursive_call_ctx.backtrace);
/* Avoid going into an infinite loop when processing call instructions.
* lower_return() recurses into inferior calls. */
if (ctx->result)
return ctx->result;
lower_return(ctx, entry_func, body, false);
LIST_FOR_EACH_ENTRY(var, &ctx->globals->vars, struct hlsl_ir_var, scope_entry)
{
if (var->storage_modifiers & HLSL_STORAGE_UNIFORM)

View File

@ -10,7 +10,7 @@ float4 main() : sv_target
[test]
draw quad
todo probe all rgba (0.1, 0.2, 0.3, 0.4)
probe all rgba (0.1, 0.2, 0.3, 0.4)
[pixel shader]
@ -23,7 +23,7 @@ void main(out float4 ret : sv_target)
[test]
draw quad
todo probe all rgba (0.1, 0.2, 0.3, 0.4)
probe all rgba (0.1, 0.2, 0.3, 0.4)
[pixel shader]
@ -39,7 +39,7 @@ float4 main() : sv_target
[test]
uniform 0 float 0.2
draw quad
todo probe all rgba (0.1, 0.2, 0.3, 0.4)
probe all rgba (0.1, 0.2, 0.3, 0.4)
uniform 0 float 0.8
draw quad
probe all rgba (0.5, 0.6, 0.7, 0.8)
@ -69,7 +69,7 @@ draw quad
probe all rgba (0.3, 0.4, 0.5, 0.6)
uniform 0 float 0.8
draw quad
todo probe all rgba (0.1, 0.2, 0.3, 0.4)
probe all rgba (0.1, 0.2, 0.3, 0.4)
[pixel shader]
@ -93,10 +93,10 @@ void main(out float4 ret : sv_target)
[test]
uniform 0 float 0.1
draw quad
todo probe all rgba (0.1, 0.2, 0.3, 0.4) 1
probe all rgba (0.1, 0.2, 0.3, 0.4) 1
uniform 0 float 0.5
draw quad
todo probe all rgba (0.2, 0.3, 0.4, 0.5) 1
probe all rgba (0.2, 0.3, 0.4, 0.5) 1
uniform 0 float 0.9
draw quad
probe all rgba (0.5, 0.6, 0.7, 0.8) 1
@ -120,7 +120,7 @@ void main(out float4 ret : sv_target)
[test]
uniform 0 float 0.1
draw quad
todo probe all rgba (0.1, 0.2, 0.3, 0.4) 1
probe all rgba (0.1, 0.2, 0.3, 0.4) 1
uniform 0 float 0.5
draw quad
probe all rgba (0.5, 0.6, 0.7, 0.8) 1