vkd3d-shader/hlsl: Define smoothstep() in HLSL.

This commit is contained in:
Zebediah Figura 2023-08-07 17:20:10 -05:00 committed by Alexandre Julliard
parent d396c4ce27
commit 9ab77658f2
Notes: Alexandre Julliard 2023-08-30 23:19:43 +02:00
Approved-by: Giovanni Mascellani (@giomasce)
Approved-by: Henri Verbeet (@hverbeet)
Approved-by: Alexandre Julliard (@julliard)
Merge-Request: https://gitlab.winehq.org/wine/vkd3d/-/merge_requests/310
3 changed files with 79 additions and 49 deletions

View File

@ -2988,6 +2988,16 @@ void hlsl_add_function(struct hlsl_ctx *ctx, char *name, struct hlsl_ir_function
struct hlsl_ir_function *func;
struct rb_entry *func_entry;
if (ctx->internal_func_name)
{
char *internal_name;
if (!(internal_name = hlsl_strdup(ctx, ctx->internal_func_name)))
return;
vkd3d_free(name);
name = internal_name;
}
func_entry = rb_get(&ctx->functions, name);
if (func_entry)
{
@ -3519,3 +3529,44 @@ int hlsl_compile_shader(const struct vkd3d_shader_code *hlsl, const struct vkd3d
hlsl_ctx_cleanup(&ctx);
return ret;
}
struct hlsl_ir_function_decl *hlsl_compile_internal_function(struct hlsl_ctx *ctx, const char *name, const char *hlsl)
{
const struct hlsl_ir_function_decl *saved_cur_function = ctx->cur_function;
struct vkd3d_shader_code code = {.code = hlsl, .size = strlen(hlsl)};
const char *saved_internal_func_name = ctx->internal_func_name;
struct vkd3d_string_buffer *internal_name;
struct hlsl_ir_function_decl *func;
void *saved_scanner = ctx->scanner;
int ret;
TRACE("name %s, hlsl %s.\n", debugstr_a(name), debugstr_a(hlsl));
/* The actual name of the function is mangled with a unique prefix, both to
* allow defining multiple variants of a function with the same name, and to
* avoid polluting the user name space. */
if (!(internal_name = hlsl_get_string_buffer(ctx)))
return NULL;
vkd3d_string_buffer_printf(internal_name, "<%s-%u>", name, ctx->internal_name_counter++);
/* Save and restore everything that matters.
* Note that saving the scope stack is hard, and shouldn't be necessary. */
ctx->scanner = NULL;
ctx->internal_func_name = internal_name->buffer;
ctx->cur_function = NULL;
ret = hlsl_lexer_compile(ctx, &code);
ctx->scanner = saved_scanner;
ctx->internal_func_name = saved_internal_func_name;
ctx->cur_function = saved_cur_function;
if (ret)
{
ERR("Failed to compile intrinsic, error %u.\n", ret);
hlsl_release_string_buffer(ctx, internal_name);
return NULL;
}
func = hlsl_get_func_decl(ctx, internal_name->buffer);
hlsl_release_string_buffer(ctx, internal_name);
return func;
}

View File

@ -837,6 +837,12 @@ struct hlsl_ctx
* compute shader profiles. It is set using the numthreads() attribute in the entry point. */
uint32_t thread_count[3];
/* In some cases we generate opcodes by parsing an HLSL function and then
* invoking it. If not NULL, this field is the name of the function that we
* are currently parsing, "mangled" with an internal prefix to avoid
* polluting the user namespace. */
const char *internal_func_name;
/* Whether the parser is inside a state block (effects' metadata) inside a variable declaration. */
uint32_t in_state_block : 1;
/* Whether the numthreads() attribute has been provided in the entry-point function. */
@ -1263,6 +1269,8 @@ bool hlsl_sm4_register_from_semantic(struct hlsl_ctx *ctx, const struct hlsl_sem
bool output, enum vkd3d_shader_register_type *type, enum vkd3d_sm4_swizzle_type *swizzle_type, bool *has_idx);
int hlsl_sm4_write(struct hlsl_ctx *ctx, struct hlsl_ir_function_decl *entry_func, struct vkd3d_shader_code *out);
struct hlsl_ir_function_decl *hlsl_compile_internal_function(struct hlsl_ctx *ctx, const char *name, const char *hlsl);
int hlsl_lexer_compile(struct hlsl_ctx *ctx, const struct vkd3d_shader_code *hlsl);
#endif

View File

@ -3422,58 +3422,29 @@ static bool intrinsic_sin(struct hlsl_ctx *ctx,
static bool intrinsic_smoothstep(struct hlsl_ctx *ctx,
const struct parse_initializer *params, const struct vkd3d_shader_location *loc)
{
struct hlsl_ir_node *min_arg, *max_arg, *x_arg, *p, *p_num, *p_denom, *res, *one, *minus_two, *three;
struct hlsl_ir_function_decl *func;
struct hlsl_type *type;
char *body;
if (!elementwise_intrinsic_float_convert_args(ctx, params, loc))
static const char template[] =
"%s smoothstep(%s low, %s high, %s x)\n"
"{\n"
" %s p = saturate((x - low) / (high - low));\n"
" return (p * p) * (3 - 2 * p);\n"
"}";
if (!(type = elementwise_intrinsic_get_common_type(ctx, params, loc)))
return false;
type = hlsl_get_numeric_type(ctx, type->class, HLSL_TYPE_FLOAT, type->dimx, type->dimy);
if (!(body = hlsl_sprintf_alloc(ctx, template, type->name, type->name, type->name, type->name, type->name)))
return false;
func = hlsl_compile_internal_function(ctx, "smoothstep", body);
vkd3d_free(body);
if (!func)
return false;
min_arg = params->args[0];
max_arg = params->args[1];
x_arg = params->args[2];
if (!(min_arg = add_unary_arithmetic_expr(ctx, params->instrs, HLSL_OP1_NEG, min_arg, loc)))
return false;
if (!(p_num = add_binary_arithmetic_expr(ctx, params->instrs, HLSL_OP2_ADD, x_arg, min_arg, loc)))
return false;
if (!(p_denom = add_binary_arithmetic_expr(ctx, params->instrs, HLSL_OP2_ADD, max_arg, min_arg, loc)))
return false;
if (!(one = hlsl_new_float_constant(ctx, 1.0, loc)))
return false;
hlsl_block_add_instr(params->instrs, one);
if (!(p_denom = add_binary_arithmetic_expr(ctx, params->instrs, HLSL_OP2_DIV, one, p_denom, loc)))
return false;
if (!(p = add_binary_arithmetic_expr(ctx, params->instrs, HLSL_OP2_MUL, p_num, p_denom, loc)))
return false;
if (!(p = add_unary_arithmetic_expr(ctx, params->instrs, HLSL_OP1_SAT, p, loc)))
return false;
if (!(minus_two = hlsl_new_float_constant(ctx, -2.0, loc)))
return false;
hlsl_block_add_instr(params->instrs, minus_two);
if (!(three = hlsl_new_float_constant(ctx, 3.0, loc)))
return false;
hlsl_block_add_instr(params->instrs, three);
if (!(res = add_binary_arithmetic_expr(ctx, params->instrs, HLSL_OP2_MUL, minus_two, p, loc)))
return false;
if (!(res = add_binary_arithmetic_expr(ctx, params->instrs, HLSL_OP2_ADD, three, res, loc)))
return false;
if (!(p = add_binary_arithmetic_expr(ctx, params->instrs, HLSL_OP2_MUL, p, p, loc)))
return false;
if (!(res = add_binary_arithmetic_expr(ctx, params->instrs, HLSL_OP2_MUL, p, res, loc)))
return false;
return true;
return add_user_call(ctx, func, params, loc);
}
static bool intrinsic_sqrt(struct hlsl_ctx *ctx,