vkd3d-shader/spirv: Implement the WAVE_ACTIVE_BIT_* instructions.

This commit is contained in:
Conor McCarthy 2024-04-23 21:14:22 +10:00 committed by Alexandre Julliard
parent af208135f3
commit fef5760af0
Notes: Alexandre Julliard 2024-05-06 22:37:06 +02:00
Approved-by: Giovanni Mascellani (@giomasce)
Approved-by: Henri Verbeet (@hverbeet)
Approved-by: Alexandre Julliard (@julliard)
Merge-Request: https://gitlab.winehq.org/wine/vkd3d/-/merge_requests/827
2 changed files with 47 additions and 3 deletions

View File

@ -9816,6 +9816,45 @@ static void spirv_compiler_emit_wave_active_ballot(struct spirv_compiler *compil
spirv_compiler_emit_store_dst(compiler, dst, val_id);
}
static SpvOp map_wave_alu_op(enum vkd3d_shader_opcode handler_idx, bool is_float)
{
switch (handler_idx)
{
case VKD3DSIH_WAVE_ACTIVE_BIT_AND:
return SpvOpGroupNonUniformBitwiseAnd;
case VKD3DSIH_WAVE_ACTIVE_BIT_OR:
return SpvOpGroupNonUniformBitwiseOr;
case VKD3DSIH_WAVE_ACTIVE_BIT_XOR:
return SpvOpGroupNonUniformBitwiseXor;
default:
vkd3d_unreachable();
}
}
static void spirv_compiler_emit_wave_alu_op(struct spirv_compiler *compiler,
const struct vkd3d_shader_instruction *instruction)
{
struct vkd3d_spirv_builder *builder = &compiler->spirv_builder;
const struct vkd3d_shader_dst_param *dst = instruction->dst;
const struct vkd3d_shader_src_param *src = instruction->src;
uint32_t type_id, val_id;
SpvOp op;
op = map_wave_alu_op(instruction->handler_idx, data_type_is_floating_point(src->reg.data_type));
type_id = vkd3d_spirv_get_type_id_for_data_type(builder, dst->reg.data_type,
vsir_write_mask_component_count(dst->write_mask));
val_id = spirv_compiler_emit_load_src(compiler, &src[0], dst->write_mask);
vkd3d_spirv_enable_capability(builder, SpvCapabilityGroupNonUniformArithmetic);
val_id = vkd3d_spirv_build_op_tr3(builder, &builder->function_stream, op, type_id,
vkd3d_spirv_get_op_scope_subgroup(builder),
SpvGroupOperationReduce,
val_id);
spirv_compiler_emit_store_dst(compiler, dst, val_id);
}
/* This function is called after declarations are processed. */
static void spirv_compiler_emit_main_prolog(struct spirv_compiler *compiler)
{
@ -10168,6 +10207,11 @@ static int spirv_compiler_handle_instruction(struct spirv_compiler *compiler,
case VKD3DSIH_WAVE_ACTIVE_BALLOT:
spirv_compiler_emit_wave_active_ballot(compiler, instruction);
break;
case VKD3DSIH_WAVE_ACTIVE_BIT_AND:
case VKD3DSIH_WAVE_ACTIVE_BIT_OR:
case VKD3DSIH_WAVE_ACTIVE_BIT_XOR:
spirv_compiler_emit_wave_alu_op(compiler, instruction);
break;
case VKD3DSIH_DCL:
case VKD3DSIH_DCL_HS_MAX_TESSFACTOR:
case VKD3DSIH_DCL_INPUT_CONTROL_POINT_COUNT:

View File

@ -232,7 +232,7 @@ void main(uint id : SV_GroupIndex)
}
[test]
todo dispatch 4 1 1
dispatch 4 1 1
probe uav 1 (0) rui (8)
probe uav 1 (1) rui (8)
probe uav 1 (2) rui (8)
@ -250,7 +250,7 @@ void main(uint id : SV_GroupIndex)
}
[test]
todo dispatch 4 1 1
dispatch 4 1 1
probe uav 1 (0) rui (15)
probe uav 1 (1) rui (15)
probe uav 1 (2) rui (15)
@ -268,7 +268,7 @@ void main(uint id : SV_GroupIndex)
}
[test]
todo dispatch 4 1 1
dispatch 4 1 1
probe uav 1 (0) rui (5)
probe uav 1 (1) rui (5)
probe uav 1 (2) rui (5)