vkd3d: Use struct d3d12_pipeline_state_desc for compute pipelines.

This commit is contained in:
Conor McCarthy 2023-11-14 13:03:38 +10:00 committed by Alexandre Julliard
parent 39afbb8e32
commit 6196199a89
Notes: Alexandre Julliard 2023-11-15 22:58:50 +01:00
Approved-by: Giovanni Mascellani (@giomasce)
Approved-by: Henri Verbeet (@hverbeet)
Approved-by: Alexandre Julliard (@julliard)
Merge-Request: https://gitlab.winehq.org/wine/vkd3d/-/merge_requests/461

View File

@ -1763,6 +1763,17 @@ static void pipeline_state_desc_from_d3d12_graphics_desc(struct d3d12_pipeline_s
desc->flags = d3d12_desc->Flags; desc->flags = d3d12_desc->Flags;
} }
static void pipeline_state_desc_from_d3d12_compute_desc(struct d3d12_pipeline_state_desc *desc,
const D3D12_COMPUTE_PIPELINE_STATE_DESC *d3d12_desc)
{
memset(desc, 0, sizeof(*desc));
desc->root_signature = d3d12_desc->pRootSignature;
desc->cs = d3d12_desc->CS;
desc->node_mask = d3d12_desc->NodeMask;
desc->cached_pso = d3d12_desc->CachedPSO;
desc->flags = d3d12_desc->Flags;
}
struct vkd3d_pipeline_key struct vkd3d_pipeline_key
{ {
D3D12_PRIMITIVE_TOPOLOGY topology; D3D12_PRIMITIVE_TOPOLOGY topology;
@ -2220,7 +2231,7 @@ static HRESULT d3d12_pipeline_state_find_and_init_uav_counters(struct d3d12_pipe
} }
static HRESULT d3d12_pipeline_state_init_compute(struct d3d12_pipeline_state *state, static HRESULT d3d12_pipeline_state_init_compute(struct d3d12_pipeline_state *state,
struct d3d12_device *device, const D3D12_COMPUTE_PIPELINE_STATE_DESC *desc) struct d3d12_device *device, const struct d3d12_pipeline_state_desc *desc)
{ {
const struct vkd3d_vk_device_procs *vk_procs = &device->vk_procs; const struct vkd3d_vk_device_procs *vk_procs = &device->vk_procs;
struct vkd3d_shader_interface_info shader_interface; struct vkd3d_shader_interface_info shader_interface;
@ -2235,14 +2246,14 @@ static HRESULT d3d12_pipeline_state_init_compute(struct d3d12_pipeline_state *st
memset(&state->uav_counters, 0, sizeof(state->uav_counters)); memset(&state->uav_counters, 0, sizeof(state->uav_counters));
if (!(root_signature = unsafe_impl_from_ID3D12RootSignature(desc->pRootSignature))) if (!(root_signature = unsafe_impl_from_ID3D12RootSignature(desc->root_signature)))
{ {
WARN("Root signature is NULL.\n"); WARN("Root signature is NULL.\n");
return E_INVALIDARG; return E_INVALIDARG;
} }
if (FAILED(hr = d3d12_pipeline_state_find_and_init_uav_counters(state, device, root_signature, if (FAILED(hr = d3d12_pipeline_state_find_and_init_uav_counters(state, device, root_signature,
&desc->CS, VK_SHADER_STAGE_COMPUTE_BIT))) &desc->cs, VK_SHADER_STAGE_COMPUTE_BIT)))
return hr; return hr;
memset(&target_info, 0, sizeof(target_info)); memset(&target_info, 0, sizeof(target_info));
@ -2283,7 +2294,7 @@ static HRESULT d3d12_pipeline_state_init_compute(struct d3d12_pipeline_state *st
vk_pipeline_layout = state->uav_counters.vk_pipeline_layout vk_pipeline_layout = state->uav_counters.vk_pipeline_layout
? state->uav_counters.vk_pipeline_layout : root_signature->vk_pipeline_layout; ? state->uav_counters.vk_pipeline_layout : root_signature->vk_pipeline_layout;
if (FAILED(hr = vkd3d_create_compute_pipeline(device, &desc->CS, &shader_interface, if (FAILED(hr = vkd3d_create_compute_pipeline(device, &desc->cs, &shader_interface,
vk_pipeline_layout, &state->u.compute.vk_pipeline))) vk_pipeline_layout, &state->u.compute.vk_pipeline)))
{ {
WARN("Failed to create Vulkan compute pipeline, hr %#x.\n", hr); WARN("Failed to create Vulkan compute pipeline, hr %#x.\n", hr);
@ -2307,13 +2318,16 @@ static HRESULT d3d12_pipeline_state_init_compute(struct d3d12_pipeline_state *st
HRESULT d3d12_pipeline_state_create_compute(struct d3d12_device *device, HRESULT d3d12_pipeline_state_create_compute(struct d3d12_device *device,
const D3D12_COMPUTE_PIPELINE_STATE_DESC *desc, struct d3d12_pipeline_state **state) const D3D12_COMPUTE_PIPELINE_STATE_DESC *desc, struct d3d12_pipeline_state **state)
{ {
struct d3d12_pipeline_state_desc pipeline_desc;
struct d3d12_pipeline_state *object; struct d3d12_pipeline_state *object;
HRESULT hr; HRESULT hr;
pipeline_state_desc_from_d3d12_compute_desc(&pipeline_desc, desc);
if (!(object = vkd3d_malloc(sizeof(*object)))) if (!(object = vkd3d_malloc(sizeof(*object))))
return E_OUTOFMEMORY; return E_OUTOFMEMORY;
if (FAILED(hr = d3d12_pipeline_state_init_compute(object, device, desc))) if (FAILED(hr = d3d12_pipeline_state_init_compute(object, device, &pipeline_desc)))
{ {
vkd3d_free(object); vkd3d_free(object);
return hr; return hr;