diff --git a/libs/vkd3d-shader/spirv.c b/libs/vkd3d-shader/spirv.c index 70ae9b1e..af2100da 100644 --- a/libs/vkd3d-shader/spirv.c +++ b/libs/vkd3d-shader/spirv.c @@ -3391,10 +3391,19 @@ static uint32_t spirv_compiler_get_buffer_parameter(struct spirv_compiler *compi } static uint32_t spirv_compiler_emit_shader_parameter(struct spirv_compiler *compiler, - enum vkd3d_shader_parameter_name name) + enum vkd3d_shader_parameter_name name, enum vkd3d_data_type type) { const struct vkd3d_shader_parameter1 *parameter; - enum vkd3d_data_type type = VKD3D_DATA_UINT; + + static const struct + { + enum vkd3d_data_type type; + } + type_map[] = + { + [VKD3D_SHADER_PARAMETER_DATA_TYPE_FLOAT32] = {VKD3D_DATA_FLOAT}, + [VKD3D_SHADER_PARAMETER_DATA_TYPE_UINT32] = {VKD3D_DATA_UINT}, + }; if (!(parameter = vsir_program_get_parameter(compiler->program, name))) { @@ -3402,6 +3411,9 @@ static uint32_t spirv_compiler_emit_shader_parameter(struct spirv_compiler *comp goto default_parameter; } + if (type_map[parameter->data_type].type != type) + ERR("Expected data type %#x for parameter %#x, got %#x.\n", type, name, parameter->data_type); + if (parameter->type == VKD3D_SHADER_PARAMETER_TYPE_IMMEDIATE_CONSTANT) { if (parameter->data_type == VKD3D_SHADER_PARAMETER_DATA_TYPE_FLOAT32) @@ -3410,11 +3422,6 @@ static uint32_t spirv_compiler_emit_shader_parameter(struct spirv_compiler *comp return spirv_compiler_get_constant_uint(compiler, parameter->u.immediate_constant.u.u32); } - if (parameter->data_type == VKD3D_SHADER_PARAMETER_DATA_TYPE_FLOAT32) - type = VKD3D_DATA_FLOAT; - else - type = VKD3D_DATA_UINT; - if (parameter->type == VKD3D_SHADER_PARAMETER_TYPE_SPECIALIZATION_CONSTANT) return spirv_compiler_get_spec_constant(compiler, name, parameter->u.specialization_constant.id, type); if (parameter->type == VKD3D_SHADER_PARAMETER_TYPE_BUFFER) @@ -4200,7 +4207,7 @@ static uint32_t spirv_compiler_emit_load_reg(struct spirv_compiler *compiler, else if (reg->type == VKD3DSPR_UNDEF) return spirv_compiler_emit_load_undef(compiler, reg, write_mask); else if (reg->type == VKD3DSPR_PARAMETER) - return spirv_compiler_emit_shader_parameter(compiler, reg->idx[0].offset); + return spirv_compiler_emit_shader_parameter(compiler, reg->idx[0].offset, reg->data_type); component_count = vsir_write_mask_component_count(write_mask); component_type = vkd3d_component_type_from_data_type(reg->data_type); @@ -9541,7 +9548,7 @@ static uint32_t spirv_compiler_emit_query_sample_count(struct spirv_compiler *co if (src->reg.type == VKD3DSPR_RASTERIZER) { val_id = spirv_compiler_emit_shader_parameter(compiler, - VKD3D_SHADER_PARAMETER_NAME_RASTERIZER_SAMPLE_COUNT); + VKD3D_SHADER_PARAMETER_NAME_RASTERIZER_SAMPLE_COUNT, VKD3D_DATA_UINT); } else {