From c7b209124bcb4edf5c08d713b891e64e7aaee22c Mon Sep 17 00:00:00 2001 From: "Anna (navi) Figueiredo Gomes" Date: Thu, 16 Jan 2025 00:52:14 +0100 Subject: [PATCH] vkd3d-shader/hlsl: Consider conversions across all parameters in function_compare(). For example, given two arguments, half3 and float, and two functions, func(float, float) and func(float3, float3), fxc/d3dcompiler prefers to widen both arguments to float3. --- libs/vkd3d-shader/hlsl.y | 85 +++++++++++++----------- tests/hlsl/function-overload.shader_test | 24 +++---- 2 files changed, 58 insertions(+), 51 deletions(-) diff --git a/libs/vkd3d-shader/hlsl.y b/libs/vkd3d-shader/hlsl.y index da2f482b..9f2ebe7d 100644 --- a/libs/vkd3d-shader/hlsl.y +++ b/libs/vkd3d-shader/hlsl.y @@ -3047,80 +3047,87 @@ static unsigned int hlsl_base_type_width(enum hlsl_base_type t) return 0; } -static int function_parameter_compare(const struct hlsl_ir_var *candidate, - const struct hlsl_ir_var *ref, const struct hlsl_ir_node *arg) +static uint32_t get_argument_conversion_mask(const struct hlsl_ir_var *parameter, const struct hlsl_ir_node *arg) { + enum + { + COMPONENT_COUNT_WIDENING = 1u << 0, + COMPONENT_TYPE_NARROWING = 1u << 1, + COMPONENT_TYPE_MISMATCH = 1u << 2, + COMPONENT_TYPE_CLASS_MISMATCH = 1u << 3, + COMPONENT_COUNT_NARROWING = 1u << 4, + }; struct { enum hlsl_base_type type; enum hlsl_base_type class; unsigned int count, width; - } c, r, a; - int ret; + } p, a; + uint32_t mask = 0; /* TODO: Non-numeric types. */ if (!hlsl_is_numeric_type(arg->data_type)) return 0; - c.type = candidate->data_type->e.numeric.type; - c.class = hlsl_base_type_class(c.type); - c.count = hlsl_type_component_count(candidate->data_type); - c.width = hlsl_base_type_width(c.type); - - r.type = ref->data_type->e.numeric.type; - r.class = hlsl_base_type_class(r.type); - r.count = hlsl_type_component_count(ref->data_type); - r.width = hlsl_base_type_width(r.type); + p.type = parameter->data_type->e.numeric.type; + p.class = hlsl_base_type_class(p.type); + p.count = hlsl_type_component_count(parameter->data_type); + p.width = hlsl_base_type_width(p.type); a.type = arg->data_type->e.numeric.type; a.class = hlsl_base_type_class(a.type); a.count = hlsl_type_component_count(arg->data_type); a.width = hlsl_base_type_width(a.type); - /* Prefer candidates without component count narrowing. E.g., given an - * float4 argument, half4 is a better match than float2. */ - if ((ret = (a.count > r.count) - (a.count > c.count))) - return ret; + /* Component count narrowing. E.g., passing a float4 argument to a float2 + * or int2 parameter. */ + if (a.count > p.count) + mask |= COMPONENT_COUNT_NARROWING; + /* Different component type classes. E.g., passing an int argument to a + * float parameter. */ + if (a.class != p.class) + mask |= COMPONENT_TYPE_CLASS_MISMATCH; + /* Different component types. E.g., passing an int argument to an uint + * parameter. */ + if (a.type != p.type) + mask |= COMPONENT_TYPE_MISMATCH; + /* Component type narrowing. E.g., passing a float argument to a half + * parameter. */ + if (a.width > p.width) + mask |= COMPONENT_TYPE_NARROWING; + /* Component count widening. E.g., passing an int2 argument to an int4 + * parameter. */ + if (a.count < p.count) + mask |= COMPONENT_COUNT_WIDENING; - /* Prefer candidates with matching component type classes. E.g., given a - * float argument, double is a better match than int. */ - if ((ret = (a.class == c.class) - (a.class == r.class))) - return ret; - - /* Prefer candidates with matching component types. E.g., given an int - * argument, int4 is a better match than uint4. */ - if ((ret = (a.type == c.type) - (a.type == r.type))) - return ret; - - /* Prefer candidates without component type narrowing. E.g., given a float - * argument, double is a better match than half. */ - if ((ret = (a.width > r.width) - (a.width > c.width))) - return ret; - - /* Prefer candidates without component count widening. E.g. given a float - * argument, float is a better match than float2. */ - return (a.count < r.count) - (a.count < c.count); + return mask; } static int function_compare(const struct hlsl_ir_function_decl *candidate, const struct hlsl_ir_function_decl *ref, const struct parse_initializer *args) { + uint32_t candidate_mask = 0, ref_mask = 0, c, r; bool any_worse = false, any_better = false; unsigned int i; int ret; for (i = 0; i < args->args_count; ++i) { - ret = function_parameter_compare(candidate->parameters.vars[i], ref->parameters.vars[i], args->args[i]); - if (ret < 0) + candidate_mask |= (c = get_argument_conversion_mask(candidate->parameters.vars[i], args->args[i])); + ref_mask |= (r = get_argument_conversion_mask(ref->parameters.vars[i], args->args[i])); + + if (c > r) any_worse = true; - else if (ret > 0) + else if (c < r) any_better = true; } /* We consider a candidate better if at least one parameter is a better * match, and none are a worse match. */ - return any_better - any_worse; + if ((ret = any_better - any_worse)) + return ret; + /* Otherwise, consider the kind of conversions across all parameters. */ + return vkd3d_u32_compare(ref_mask, candidate_mask); } static struct hlsl_ir_function_decl *find_function_call(struct hlsl_ctx *ctx, diff --git a/tests/hlsl/function-overload.shader_test b/tests/hlsl/function-overload.shader_test index 33c36c7f..512f8cc2 100644 --- a/tests/hlsl/function-overload.shader_test +++ b/tests/hlsl/function-overload.shader_test @@ -445,7 +445,7 @@ probe (0, 0) rgba(2.0, 2.0, 2.0, 2.0) % argument, but with more components) for the first function and "double" for the % second function. -[pixel shader todo(sm<6)] +[pixel shader] float func(float2 x, double y) { return 1.0; @@ -465,10 +465,10 @@ float4 main() : sv_target } [test] -todo(sm<6) draw quad +draw quad probe (0, 0) r(2.0) -[pixel shader todo(sm<6)] +[pixel shader] float func(uint4 x, double y) { return 1.0; @@ -488,10 +488,10 @@ float4 main() : sv_target } [test] -todo(sm<6) draw quad +draw quad probe (0, 0) r(1.0) -[pixel shader todo(sm<6)] +[pixel shader] float func(uint4 x, double y) { return 1.0; @@ -511,10 +511,10 @@ float4 main() : sv_target } [test] -todo(sm<6) draw quad +draw quad probe (0, 0) r(1.0) -[pixel shader todo(sm<6)] +[pixel shader] float func(half4 x, half y, half z, half a) { return 1.0; @@ -534,7 +534,7 @@ float4 main() : sv_target } [test] -todo(sm<6) draw quad +draw quad probe (0, 0) r(1.0) [require] @@ -683,7 +683,7 @@ float4 main() : sv_target draw quad probe (0, 0) rgba(42.0, 42.0, 42.0, 42.0) -[pixel shader todo(sm<6)] +[pixel shader] float func(float x, float y) { return 1.0; @@ -702,10 +702,10 @@ float4 main() : sv_target } [test] -todo(sm<6) draw quad +draw quad probe (0, 0) r(2.0) -[pixel shader todo(sm<6)] +[pixel shader] float func(float3 a, float3 b, float3 c, float3 d) { return 1.0; @@ -724,5 +724,5 @@ float4 main() : sv_target } [test] -todo(sm<6) draw quad +draw quad probe (0, 0) r(1.0)