diff --git a/libs/vkd3d-shader/hlsl.y b/libs/vkd3d-shader/hlsl.y index 3c20ebd2..eabf072b 100644 --- a/libs/vkd3d-shader/hlsl.y +++ b/libs/vkd3d-shader/hlsl.y @@ -4370,13 +4370,9 @@ static bool intrinsic_reflect(struct hlsl_ctx *ctx, static bool intrinsic_refract(struct hlsl_ctx *ctx, const struct parse_initializer *params, const struct vkd3d_shader_location *loc) { - struct hlsl_type *r_type = params->args[0]->data_type; - struct hlsl_type *n_type = params->args[1]->data_type; - struct hlsl_type *i_type = params->args[2]->data_type; - struct hlsl_type *res_type, *idx_type, *scal_type; - struct parse_initializer mut_params; + struct hlsl_type *type, *scalar_type; struct hlsl_ir_function_decl *func; - enum hlsl_base_type base; + struct hlsl_ir_node *index; char *body; static const char template[] = @@ -4388,28 +4384,34 @@ static bool intrinsic_refract(struct hlsl_ctx *ctx, " return t >= 0.0 ? i.x * r - (i.x * d + sqrt(t)) * n : 0;\n" "}"; - if (r_type->class == HLSL_CLASS_MATRIX - || n_type->class == HLSL_CLASS_MATRIX - || i_type->class == HLSL_CLASS_MATRIX) + if (params->args[0]->data_type->class == HLSL_CLASS_MATRIX + || params->args[1]->data_type->class == HLSL_CLASS_MATRIX + || params->args[2]->data_type->class == HLSL_CLASS_MATRIX) { hlsl_error(ctx, loc, VKD3D_SHADER_ERROR_HLSL_INVALID_TYPE, "Matrix arguments are not supported."); return false; } - VKD3D_ASSERT(params->args_count == 3); - mut_params = *params; - mut_params.args_count = 2; - if (!(res_type = elementwise_intrinsic_get_common_type(ctx, &mut_params, loc))) + /* This is technically not an elementwise intrinsic, but the first two + * arguments are. + * The third argument is a scalar, but can be passed as a vector, + * which should generate an implicit truncation warning. + * Cast down to scalar explicitly, then we can just use + * elementwise_intrinsic_float_convert_args(). + * This may result in casting the scalar back to a vector, + * which we will only use the first component of. */ + + scalar_type = hlsl_get_scalar_type(ctx, params->args[2]->data_type->e.numeric.type); + if (!(index = add_implicit_conversion(ctx, params->instrs, params->args[2], scalar_type, loc))) return false; + params->args[2] = index; - base = expr_common_base_type(res_type->e.numeric.type, i_type->e.numeric.type); - base = base == HLSL_TYPE_HALF ? HLSL_TYPE_HALF : HLSL_TYPE_FLOAT; - res_type = convert_numeric_type(ctx, res_type, base); - idx_type = convert_numeric_type(ctx, i_type, base); - scal_type = hlsl_get_scalar_type(ctx, base); + if (!elementwise_intrinsic_float_convert_args(ctx, params, loc)) + return false; + type = params->args[0]->data_type; - if (!(body = hlsl_sprintf_alloc(ctx, template, res_type->name, res_type->name, - res_type->name, idx_type->name, scal_type->name))) + if (!(body = hlsl_sprintf_alloc(ctx, template, type->name, type->name, + type->name, type->name, scalar_type->name))) return false; func = hlsl_compile_internal_function(ctx, "refract", body);