vkd3d-shader/hlsl: Use elementwise_intrinsic_float_convert_args() in refract().

This commit is contained in:
Elizabeth Figura 2024-09-04 22:09:04 -05:00 committed by Henri Verbeet
parent 2688a24bde
commit 82773eb805
Notes: Henri Verbeet 2024-09-11 15:34:20 +02:00
Approved-by: Giovanni Mascellani (@giomasce)
Approved-by: Henri Verbeet (@hverbeet)
Merge-Request: https://gitlab.winehq.org/wine/vkd3d/-/merge_requests/1042

View File

@ -4370,13 +4370,9 @@ static bool intrinsic_reflect(struct hlsl_ctx *ctx,
static bool intrinsic_refract(struct hlsl_ctx *ctx, static bool intrinsic_refract(struct hlsl_ctx *ctx,
const struct parse_initializer *params, const struct vkd3d_shader_location *loc) const struct parse_initializer *params, const struct vkd3d_shader_location *loc)
{ {
struct hlsl_type *r_type = params->args[0]->data_type; struct hlsl_type *type, *scalar_type;
struct hlsl_type *n_type = params->args[1]->data_type;
struct hlsl_type *i_type = params->args[2]->data_type;
struct hlsl_type *res_type, *idx_type, *scal_type;
struct parse_initializer mut_params;
struct hlsl_ir_function_decl *func; struct hlsl_ir_function_decl *func;
enum hlsl_base_type base; struct hlsl_ir_node *index;
char *body; char *body;
static const char template[] = static const char template[] =
@ -4388,28 +4384,34 @@ static bool intrinsic_refract(struct hlsl_ctx *ctx,
" return t >= 0.0 ? i.x * r - (i.x * d + sqrt(t)) * n : 0;\n" " return t >= 0.0 ? i.x * r - (i.x * d + sqrt(t)) * n : 0;\n"
"}"; "}";
if (r_type->class == HLSL_CLASS_MATRIX if (params->args[0]->data_type->class == HLSL_CLASS_MATRIX
|| n_type->class == HLSL_CLASS_MATRIX || params->args[1]->data_type->class == HLSL_CLASS_MATRIX
|| i_type->class == HLSL_CLASS_MATRIX) || params->args[2]->data_type->class == HLSL_CLASS_MATRIX)
{ {
hlsl_error(ctx, loc, VKD3D_SHADER_ERROR_HLSL_INVALID_TYPE, "Matrix arguments are not supported."); hlsl_error(ctx, loc, VKD3D_SHADER_ERROR_HLSL_INVALID_TYPE, "Matrix arguments are not supported.");
return false; return false;
} }
VKD3D_ASSERT(params->args_count == 3); /* This is technically not an elementwise intrinsic, but the first two
mut_params = *params; * arguments are.
mut_params.args_count = 2; * The third argument is a scalar, but can be passed as a vector,
if (!(res_type = elementwise_intrinsic_get_common_type(ctx, &mut_params, loc))) * which should generate an implicit truncation warning.
* Cast down to scalar explicitly, then we can just use
* elementwise_intrinsic_float_convert_args().
* This may result in casting the scalar back to a vector,
* which we will only use the first component of. */
scalar_type = hlsl_get_scalar_type(ctx, params->args[2]->data_type->e.numeric.type);
if (!(index = add_implicit_conversion(ctx, params->instrs, params->args[2], scalar_type, loc)))
return false; return false;
params->args[2] = index;
base = expr_common_base_type(res_type->e.numeric.type, i_type->e.numeric.type); if (!elementwise_intrinsic_float_convert_args(ctx, params, loc))
base = base == HLSL_TYPE_HALF ? HLSL_TYPE_HALF : HLSL_TYPE_FLOAT; return false;
res_type = convert_numeric_type(ctx, res_type, base); type = params->args[0]->data_type;
idx_type = convert_numeric_type(ctx, i_type, base);
scal_type = hlsl_get_scalar_type(ctx, base);
if (!(body = hlsl_sprintf_alloc(ctx, template, res_type->name, res_type->name, if (!(body = hlsl_sprintf_alloc(ctx, template, type->name, type->name,
res_type->name, idx_type->name, scal_type->name))) type->name, type->name, scalar_type->name)))
return false; return false;
func = hlsl_compile_internal_function(ctx, "refract", body); func = hlsl_compile_internal_function(ctx, "refract", body);