vkd3d-shader/hlsl: Consider conversions across all parameters in function_compare().

For example, given two arguments, half3 and float, and two functions,
func(float, float) and func(float3, float3), fxc/d3dcompiler prefers to
widen both arguments to float3.
This commit is contained in:
Anna (navi) Figueiredo Gomes 2025-01-16 00:52:14 +01:00 committed by Henri Verbeet
parent de57afd672
commit c7b209124b
Notes: Henri Verbeet 2025-01-22 15:03:48 +01:00
Approved-by: Henri Verbeet (@hverbeet)
Merge-Request: https://gitlab.winehq.org/wine/vkd3d/-/merge_requests/1341
2 changed files with 58 additions and 51 deletions

View File

@ -3047,80 +3047,87 @@ static unsigned int hlsl_base_type_width(enum hlsl_base_type t)
return 0; return 0;
} }
static int function_parameter_compare(const struct hlsl_ir_var *candidate, static uint32_t get_argument_conversion_mask(const struct hlsl_ir_var *parameter, const struct hlsl_ir_node *arg)
const struct hlsl_ir_var *ref, const struct hlsl_ir_node *arg)
{ {
enum
{
COMPONENT_COUNT_WIDENING = 1u << 0,
COMPONENT_TYPE_NARROWING = 1u << 1,
COMPONENT_TYPE_MISMATCH = 1u << 2,
COMPONENT_TYPE_CLASS_MISMATCH = 1u << 3,
COMPONENT_COUNT_NARROWING = 1u << 4,
};
struct struct
{ {
enum hlsl_base_type type; enum hlsl_base_type type;
enum hlsl_base_type class; enum hlsl_base_type class;
unsigned int count, width; unsigned int count, width;
} c, r, a; } p, a;
int ret; uint32_t mask = 0;
/* TODO: Non-numeric types. */ /* TODO: Non-numeric types. */
if (!hlsl_is_numeric_type(arg->data_type)) if (!hlsl_is_numeric_type(arg->data_type))
return 0; return 0;
c.type = candidate->data_type->e.numeric.type; p.type = parameter->data_type->e.numeric.type;
c.class = hlsl_base_type_class(c.type); p.class = hlsl_base_type_class(p.type);
c.count = hlsl_type_component_count(candidate->data_type); p.count = hlsl_type_component_count(parameter->data_type);
c.width = hlsl_base_type_width(c.type); p.width = hlsl_base_type_width(p.type);
r.type = ref->data_type->e.numeric.type;
r.class = hlsl_base_type_class(r.type);
r.count = hlsl_type_component_count(ref->data_type);
r.width = hlsl_base_type_width(r.type);
a.type = arg->data_type->e.numeric.type; a.type = arg->data_type->e.numeric.type;
a.class = hlsl_base_type_class(a.type); a.class = hlsl_base_type_class(a.type);
a.count = hlsl_type_component_count(arg->data_type); a.count = hlsl_type_component_count(arg->data_type);
a.width = hlsl_base_type_width(a.type); a.width = hlsl_base_type_width(a.type);
/* Prefer candidates without component count narrowing. E.g., given an /* Component count narrowing. E.g., passing a float4 argument to a float2
* float4 argument, half4 is a better match than float2. */ * or int2 parameter. */
if ((ret = (a.count > r.count) - (a.count > c.count))) if (a.count > p.count)
return ret; mask |= COMPONENT_COUNT_NARROWING;
/* Different component type classes. E.g., passing an int argument to a
* float parameter. */
if (a.class != p.class)
mask |= COMPONENT_TYPE_CLASS_MISMATCH;
/* Different component types. E.g., passing an int argument to an uint
* parameter. */
if (a.type != p.type)
mask |= COMPONENT_TYPE_MISMATCH;
/* Component type narrowing. E.g., passing a float argument to a half
* parameter. */
if (a.width > p.width)
mask |= COMPONENT_TYPE_NARROWING;
/* Component count widening. E.g., passing an int2 argument to an int4
* parameter. */
if (a.count < p.count)
mask |= COMPONENT_COUNT_WIDENING;
/* Prefer candidates with matching component type classes. E.g., given a return mask;
* float argument, double is a better match than int. */
if ((ret = (a.class == c.class) - (a.class == r.class)))
return ret;
/* Prefer candidates with matching component types. E.g., given an int
* argument, int4 is a better match than uint4. */
if ((ret = (a.type == c.type) - (a.type == r.type)))
return ret;
/* Prefer candidates without component type narrowing. E.g., given a float
* argument, double is a better match than half. */
if ((ret = (a.width > r.width) - (a.width > c.width)))
return ret;
/* Prefer candidates without component count widening. E.g. given a float
* argument, float is a better match than float2. */
return (a.count < r.count) - (a.count < c.count);
} }
static int function_compare(const struct hlsl_ir_function_decl *candidate, static int function_compare(const struct hlsl_ir_function_decl *candidate,
const struct hlsl_ir_function_decl *ref, const struct parse_initializer *args) const struct hlsl_ir_function_decl *ref, const struct parse_initializer *args)
{ {
uint32_t candidate_mask = 0, ref_mask = 0, c, r;
bool any_worse = false, any_better = false; bool any_worse = false, any_better = false;
unsigned int i; unsigned int i;
int ret; int ret;
for (i = 0; i < args->args_count; ++i) for (i = 0; i < args->args_count; ++i)
{ {
ret = function_parameter_compare(candidate->parameters.vars[i], ref->parameters.vars[i], args->args[i]); candidate_mask |= (c = get_argument_conversion_mask(candidate->parameters.vars[i], args->args[i]));
if (ret < 0) ref_mask |= (r = get_argument_conversion_mask(ref->parameters.vars[i], args->args[i]));
if (c > r)
any_worse = true; any_worse = true;
else if (ret > 0) else if (c < r)
any_better = true; any_better = true;
} }
/* We consider a candidate better if at least one parameter is a better /* We consider a candidate better if at least one parameter is a better
* match, and none are a worse match. */ * match, and none are a worse match. */
return any_better - any_worse; if ((ret = any_better - any_worse))
return ret;
/* Otherwise, consider the kind of conversions across all parameters. */
return vkd3d_u32_compare(ref_mask, candidate_mask);
} }
static struct hlsl_ir_function_decl *find_function_call(struct hlsl_ctx *ctx, static struct hlsl_ir_function_decl *find_function_call(struct hlsl_ctx *ctx,

View File

@ -445,7 +445,7 @@ probe (0, 0) rgba(2.0, 2.0, 2.0, 2.0)
% argument, but with more components) for the first function and "double" for the % argument, but with more components) for the first function and "double" for the
% second function. % second function.
[pixel shader todo(sm<6)] [pixel shader]
float func(float2 x, double y) float func(float2 x, double y)
{ {
return 1.0; return 1.0;
@ -465,10 +465,10 @@ float4 main() : sv_target
} }
[test] [test]
todo(sm<6) draw quad draw quad
probe (0, 0) r(2.0) probe (0, 0) r(2.0)
[pixel shader todo(sm<6)] [pixel shader]
float func(uint4 x, double y) float func(uint4 x, double y)
{ {
return 1.0; return 1.0;
@ -488,10 +488,10 @@ float4 main() : sv_target
} }
[test] [test]
todo(sm<6) draw quad draw quad
probe (0, 0) r(1.0) probe (0, 0) r(1.0)
[pixel shader todo(sm<6)] [pixel shader]
float func(uint4 x, double y) float func(uint4 x, double y)
{ {
return 1.0; return 1.0;
@ -511,10 +511,10 @@ float4 main() : sv_target
} }
[test] [test]
todo(sm<6) draw quad draw quad
probe (0, 0) r(1.0) probe (0, 0) r(1.0)
[pixel shader todo(sm<6)] [pixel shader]
float func(half4 x, half y, half z, half a) float func(half4 x, half y, half z, half a)
{ {
return 1.0; return 1.0;
@ -534,7 +534,7 @@ float4 main() : sv_target
} }
[test] [test]
todo(sm<6) draw quad draw quad
probe (0, 0) r(1.0) probe (0, 0) r(1.0)
[require] [require]
@ -683,7 +683,7 @@ float4 main() : sv_target
draw quad draw quad
probe (0, 0) rgba(42.0, 42.0, 42.0, 42.0) probe (0, 0) rgba(42.0, 42.0, 42.0, 42.0)
[pixel shader todo(sm<6)] [pixel shader]
float func(float x, float y) float func(float x, float y)
{ {
return 1.0; return 1.0;
@ -702,10 +702,10 @@ float4 main() : sv_target
} }
[test] [test]
todo(sm<6) draw quad draw quad
probe (0, 0) r(2.0) probe (0, 0) r(2.0)
[pixel shader todo(sm<6)] [pixel shader]
float func(float3 a, float3 b, float3 c, float3 d) float func(float3 a, float3 b, float3 c, float3 d)
{ {
return 1.0; return 1.0;
@ -724,5 +724,5 @@ float4 main() : sv_target
} }
[test] [test]
todo(sm<6) draw quad draw quad
probe (0, 0) r(1.0) probe (0, 0) r(1.0)