From e16176672aa95026c06d36e16414e867aec797df Mon Sep 17 00:00:00 2001 From: Elizabeth Figura Date: Fri, 28 Mar 2025 15:35:44 -0500 Subject: [PATCH] vkd3d-shader/hlsl: Validate "numthreads" attribute values. --- libs/vkd3d-shader/hlsl_codegen.c | 21 +++++++++++---------- tests/hlsl/numthreads.shader_test | 8 ++++---- 2 files changed, 15 insertions(+), 14 deletions(-) diff --git a/libs/vkd3d-shader/hlsl_codegen.c b/libs/vkd3d-shader/hlsl_codegen.c index 4e7ea6bf6..d4b0e2cde 100644 --- a/libs/vkd3d-shader/hlsl_codegen.c +++ b/libs/vkd3d-shader/hlsl_codegen.c @@ -7154,7 +7154,7 @@ struct hlsl_reg hlsl_reg_from_deref(struct hlsl_ctx *ctx, const struct hlsl_dere } static bool get_integral_argument_value(struct hlsl_ctx *ctx, const struct hlsl_attribute *attr, - unsigned int i, enum hlsl_base_type *base_type, int *value) + unsigned int i, int *value) { const struct hlsl_ir_node *instr = attr->args[i].node; const struct hlsl_type *type = instr->data_type; @@ -7178,7 +7178,6 @@ static bool get_integral_argument_value(struct hlsl_ctx *ctx, const struct hlsl_ return false; } - *base_type = type->e.numeric.type; *value = hlsl_ir_constant(instr)->value.u[0].i; return true; } @@ -7205,6 +7204,7 @@ static const char *get_string_argument_value(struct hlsl_ctx *ctx, const struct static void parse_numthreads_attribute(struct hlsl_ctx *ctx, const struct hlsl_attribute *attr) { + static const unsigned int limits[3] = {1024, 1024, 64}; unsigned int i; ctx->found_numthreads = 1; @@ -7218,18 +7218,21 @@ static void parse_numthreads_attribute(struct hlsl_ctx *ctx, const struct hlsl_a for (i = 0; i < attr->args_count; ++i) { - enum hlsl_base_type base_type; int value; - if (!get_integral_argument_value(ctx, attr, i, &base_type, &value)) + if (!get_integral_argument_value(ctx, attr, i, &value)) return; - if ((base_type == HLSL_TYPE_INT && value <= 0) || (base_type == HLSL_TYPE_UINT && !value)) + if (value < 1 || value > limits[i]) hlsl_error(ctx, &attr->args[i].node->loc, VKD3D_SHADER_ERROR_HLSL_INVALID_THREAD_COUNT, - "Thread count must be a positive integer."); + "Dimension %u of the thread count must be between 1 and %u.", i, limits[i]); ctx->thread_count[i] = value; } + + if (ctx->thread_count[0] * ctx->thread_count[1] * ctx->thread_count[2] > 1024) + hlsl_error(ctx, &attr->loc, VKD3D_SHADER_ERROR_HLSL_INVALID_THREAD_COUNT, + "Product of thread count parameters cannot exceed 1024."); } static void parse_domain_attribute(struct hlsl_ctx *ctx, const struct hlsl_attribute *attr) @@ -7260,7 +7263,6 @@ static void parse_domain_attribute(struct hlsl_ctx *ctx, const struct hlsl_attri static void parse_outputcontrolpoints_attribute(struct hlsl_ctx *ctx, const struct hlsl_attribute *attr) { - enum hlsl_base_type base_type; int value; if (attr->args_count != 1) @@ -7270,7 +7272,7 @@ static void parse_outputcontrolpoints_attribute(struct hlsl_ctx *ctx, const stru return; } - if (!get_integral_argument_value(ctx, attr, 0, &base_type, &value)) + if (!get_integral_argument_value(ctx, attr, 0, &value)) return; if (value < 0 || value > 32) @@ -7373,7 +7375,6 @@ static void parse_patchconstantfunc_attribute(struct hlsl_ctx *ctx, const struct static void parse_maxvertexcount_attribute(struct hlsl_ctx *ctx, const struct hlsl_attribute *attr) { - enum hlsl_base_type base_type; int value; if (attr->args_count != 1) @@ -7383,7 +7384,7 @@ static void parse_maxvertexcount_attribute(struct hlsl_ctx *ctx, const struct hl return; } - if (!get_integral_argument_value(ctx, attr, 0, &base_type, &value)) + if (!get_integral_argument_value(ctx, attr, 0, &value)) return; if (value < 1 || value > 1024) diff --git a/tests/hlsl/numthreads.shader_test b/tests/hlsl/numthreads.shader_test index a812adf01..001158c5c 100644 --- a/tests/hlsl/numthreads.shader_test +++ b/tests/hlsl/numthreads.shader_test @@ -177,19 +177,19 @@ size (2d, 2, 2) % The product must not exceed 1024, and the third dimension cannot exceed 64. -[compute shader fail todo] +[compute shader fail] [numthreads(1025,1,1)] void main() {} -[compute shader fail todo] +[compute shader fail] [numthreads(1,1025,1)] void main() {} -[compute shader fail todo] +[compute shader fail] [numthreads(1,1,65)] void main() {} -[compute shader fail todo] +[compute shader fail] [numthreads(41,25,25)] void main() {}