diff --git a/libs/vkd3d/device.c b/libs/vkd3d/device.c index 5173de7a..149a8201 100644 --- a/libs/vkd3d/device.c +++ b/libs/vkd3d/device.c @@ -571,7 +571,7 @@ static HRESULT STDMETHODCALLTYPE d3d12_device_CreateComputePipelineState(ID3D12D struct d3d12_pipeline_state *object; HRESULT hr; - FIXME("iface %p, desc %p, riid %s, pipeline_state %p stub!\n", + TRACE("iface %p, desc %p, riid %s, pipeline_state %p.\n", iface, desc, debugstr_guid(riid), pipeline_state); if (FAILED(hr = d3d12_pipeline_state_create_compute(device, desc, &object))) diff --git a/libs/vkd3d/state.c b/libs/vkd3d/state.c index a908c1ee..6a29444a 100644 --- a/libs/vkd3d/state.c +++ b/libs/vkd3d/state.c @@ -137,6 +137,14 @@ static const struct ID3D12RootSignatureVtbl d3d12_root_signature_vtbl = d3d12_root_signature_GetDevice, }; +static struct d3d12_root_signature *unsafe_impl_from_ID3D12RootSignature(ID3D12RootSignature *iface) +{ + if (!iface) + return NULL; + assert(iface->lpVtbl == &d3d12_root_signature_vtbl); + return impl_from_ID3D12RootSignature(iface); +} + static HRESULT d3d12_root_signature_init(struct d3d12_root_signature *root_signature, struct d3d12_device *device, const D3D12_ROOT_SIGNATURE_DESC *desc) { @@ -245,6 +253,9 @@ static ULONG STDMETHODCALLTYPE d3d12_pipeline_state_Release(ID3D12PipelineState if (!refcount) { struct d3d12_device *device = state->device; + const struct vkd3d_vk_device_procs *vk_procs = &device->vk_procs; + + VK_CALL(vkDestroyPipeline(device->vk_device, state->vk_pipeline, NULL)); vkd3d_free(state); @@ -320,25 +331,80 @@ static const struct ID3D12PipelineStateVtbl d3d12_pipeline_state_vtbl = d3d12_pipeline_state_GetCachedBlob, }; -static void d3d12_pipeline_state_init_compute(struct d3d12_pipeline_state *state, +static HRESULT d3d12_pipeline_state_init_compute(struct d3d12_pipeline_state *state, struct d3d12_device *device, const D3D12_COMPUTE_PIPELINE_STATE_DESC *desc) { + const struct vkd3d_vk_device_procs *vk_procs = &device->vk_procs; + struct d3d12_root_signature *root_signature; + VkComputePipelineCreateInfo pipeline_info; + VkShaderModuleCreateInfo shader_info; + VkShaderModule shader; + VkResult vr; + state->ID3D12PipelineState_iface.lpVtbl = &d3d12_pipeline_state_vtbl; state->refcount = 1; + if (!(root_signature = unsafe_impl_from_ID3D12RootSignature(desc->pRootSignature))) + { + WARN("Root signature is NULL.\n"); + return E_INVALIDARG; + } + + shader_info.sType = VK_STRUCTURE_TYPE_SHADER_MODULE_CREATE_INFO; + shader_info.pNext = NULL; + shader_info.flags = 0; + shader_info.codeSize = desc->CS.BytecodeLength; + shader_info.pCode = desc->CS.pShaderBytecode; + + if ((vr = VK_CALL(vkCreateShaderModule(device->vk_device, &shader_info, NULL, &shader)))) + { + WARN("Failed to create Vulkan shader module, vr %d.\n", vr); + return hresult_from_vk_result(vr); + } + + pipeline_info.sType = VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO; + pipeline_info.pNext = NULL; + pipeline_info.flags = 0; + pipeline_info.stage.sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO; + pipeline_info.stage.pNext = NULL; + pipeline_info.stage.flags = 0; + pipeline_info.stage.stage = VK_SHADER_STAGE_COMPUTE_BIT; + pipeline_info.stage.module = shader; + pipeline_info.stage.pName = "main"; + pipeline_info.stage.pSpecializationInfo = NULL; + pipeline_info.layout = root_signature->vk_pipeline_layout; + pipeline_info.basePipelineHandle = VK_NULL_HANDLE; + pipeline_info.basePipelineIndex = -1; + + vr = VK_CALL(vkCreateComputePipelines(device->vk_device, VK_NULL_HANDLE, + 1, &pipeline_info, NULL, &state->vk_pipeline)); + VK_CALL(vkDestroyShaderModule(device->vk_device, shader, NULL)); + if (vr) + { + WARN("Failed to create Vulkan compute pipeline, vr %d.\n", vr); + return hresult_from_vk_result(vr); + } + state->device = device; ID3D12Device_AddRef(&device->ID3D12Device_iface); + + return S_OK; } HRESULT d3d12_pipeline_state_create_compute(struct d3d12_device *device, const D3D12_COMPUTE_PIPELINE_STATE_DESC *desc, struct d3d12_pipeline_state **state) { struct d3d12_pipeline_state *object; + HRESULT hr; if (!(object = vkd3d_malloc(sizeof(*object)))) return E_OUTOFMEMORY; - d3d12_pipeline_state_init_compute(object, device, desc); + if (FAILED(hr = d3d12_pipeline_state_init_compute(object, device, desc))) + { + vkd3d_free(object); + return hr; + } TRACE("Created compute pipeline state %p.\n", object); diff --git a/libs/vkd3d/vkd3d_private.h b/libs/vkd3d/vkd3d_private.h index 0e47ee78..1b55a5d9 100644 --- a/libs/vkd3d/vkd3d_private.h +++ b/libs/vkd3d/vkd3d_private.h @@ -90,6 +90,8 @@ struct d3d12_pipeline_state ID3D12PipelineState ID3D12PipelineState_iface; ULONG refcount; + VkPipeline vk_pipeline; + struct d3d12_device *device; };