// Copyright Epic Games, Inc. All Rights Reserved. #include "StorageServerConnection.h" #include "IPAddress.h" #include "SocketSubsystem.h" #include "Sockets.h" #include "Misc/App.h" #include "Misc/Paths.h" #include "IO/IoDispatcher.h" #include "Misc/StringBuilder.h" #include "Misc/ScopeLock.h" #include "Serialization/CompactBinary.h" #include "Serialization/CompactBinarySerialization.h" #if !UE_BUILD_SHIPPING DEFINE_LOG_CATEGORY_STATIC(LogStorageServerConnection, Log, All); static TArray> GetAddressFromString(ISocketSubsystem& SocketSubsystem, TArrayView HostAddresses, const int32 Port) { TArray> InterntAddresses; for (const FString& HostAddr : HostAddresses) { TSharedPtr Addr = SocketSubsystem.GetAddressFromString(HostAddr); if (!Addr.IsValid() || !Addr->IsValid()) { FAddressInfoResult GAIRequest = SocketSubsystem.GetAddressInfo(*HostAddr, nullptr, EAddressInfoFlags::Default, NAME_None); if (GAIRequest.ReturnCode == SE_NO_ERROR && GAIRequest.Results.Num() > 0) { Addr = GAIRequest.Results[0].Address; } } if (Addr.IsValid() && Addr->IsValid()) { Addr->SetPort(Port); InterntAddresses.Emplace(MoveTemp(Addr)); } } return InterntAddresses; } FStorageServerRequest::FStorageServerRequest(FAnsiStringView Verb, FAnsiStringView Resource, FAnsiStringView Hostname) { SetIsSaving(true); HeaderBuffer << Verb << " " << Resource << " HTTP/1.1\r\n" << "Host: " << Hostname << "\r\n" << "Connection: Keep-Alive\r\n"; } FSocket* FStorageServerRequest::Send(FStorageServerConnection& Owner) { if (BodyBuffer.Num()) { HeaderBuffer.Append("Content-Length: ").Appendf("%d\r\n", BodyBuffer.Num()); } HeaderBuffer << "\r\n"; int32 BytesLeft = HeaderBuffer.Len(); auto Send = [](FSocket* Socket, const uint8* Data, int32 Length) { int32 BytesLeft = Length; while (BytesLeft > 0) { int32 BytesSent; if (!Socket->Send(Data, BytesLeft, BytesSent)) { return false; } check(BytesSent >= 0); BytesLeft -= BytesSent; Data += BytesSent; } return true; }; int32 Attempts = 0; while (Attempts++ < 10) { FSocket* Socket = Owner.AcquireSocket(); if (Send(Socket, reinterpret_cast(HeaderBuffer.GetData()), HeaderBuffer.Len()) && Send(Socket, BodyBuffer.GetData(), BodyBuffer.Num())) { return Socket; } UE_LOG(LogStorageServerConnection, Warning, TEXT("Failed to send request to storage server. Retrying...")); Owner.ReleaseSocket(Socket, false); } return nullptr; } void FStorageServerRequest::Serialize(void* V, int64 Length) { int32 Index = BodyBuffer.AddUninitialized(Length); uint8* Dest = BodyBuffer.GetData() + Index; FMemory::Memcpy(Dest, V, Length); } FStorageServerResponse::FStorageServerResponse(FStorageServerConnection& InOwner, FSocket& InSocket) : Owner(InOwner) , Socket(&InSocket) { SetIsLoading(true); uint8 Buffer[1024]; int32 TotalReadFromSocket = 0; auto ReadResponseLine = [&Buffer, &InSocket, &TotalReadFromSocket]() -> FAnsiStringView { for (;;) { int32 BytesRead; InSocket.Recv(Buffer, 1024, BytesRead, ESocketReceiveFlags::Peek); FAnsiStringView ResponseView(reinterpret_cast(Buffer), BytesRead); int32 LineEndIndex; if (ResponseView.FindChar('\r', LineEndIndex) && BytesRead >= LineEndIndex + 2) { check(ResponseView[LineEndIndex + 1] == '\n'); InSocket.Recv(Buffer, LineEndIndex + 2, BytesRead, ESocketReceiveFlags::None); check(BytesRead == LineEndIndex + 2); TotalReadFromSocket += BytesRead; return ResponseView.Left(LineEndIndex); } } }; FAnsiStringView ResponseLine = ReadResponseLine(); if (ResponseLine == "HTTP/1.1 200 OK") { bIsOk = true; } else if (ResponseLine.StartsWith("HTTP/1.1 ")) { ErrorCode = TCString::Atoi64(ResponseLine.GetData() + 9); } while (!ResponseLine.IsEmpty()) { ResponseLine = ReadResponseLine(); if (ResponseLine.StartsWith("Content-Length: ")) { ContentLength = TCString::Atoi64(ResponseLine.GetData() + 16); } } if (!bIsOk && ContentLength) { TArray ErrorBuffer; ErrorBuffer.SetNumUninitialized(ContentLength); int32 BytesRead; InSocket.Recv(ErrorBuffer.GetData(), ContentLength, BytesRead, ESocketReceiveFlags::WaitAll); ErrorMessage = FString(ContentLength, ANSI_TO_TCHAR(reinterpret_cast(ErrorBuffer.GetData()))); ContentLength = 0; } if (ContentLength == 0) { ReleaseSocket(true); } } FStorageServerChunkBatchRequest::FStorageServerChunkBatchRequest(FStorageServerConnection& InOwner, FAnsiStringView Resource, FAnsiStringView Hostname) : FStorageServerRequest("POST", Resource, Hostname) , Owner(InOwner) { uint32 Magic = 0xAAAA'77AC; uint32 ChunkCountPlaceHolder = 0; uint32 Reserved1 = 0; uint32 Reserved2 = 0; *this << Magic; ChunkCountOffset = BodyBuffer.Num(); *this << ChunkCountPlaceHolder << Reserved1 << Reserved2; } FStorageServerChunkBatchRequest& FStorageServerChunkBatchRequest::AddChunk(const FIoChunkId& ChunkId, int64 Offset, int64 Size) { uint32* ChunkCount = reinterpret_cast(BodyBuffer.GetData() + ChunkCountOffset); *this << const_cast(ChunkId) << *ChunkCount << Offset << Size; ++(*ChunkCount); return *this; } bool FStorageServerChunkBatchRequest::Issue(TFunctionRef OnResponse) { FSocket* Socket = Send(Owner); if (!Socket) { UE_LOG(LogStorageServerConnection, Fatal, TEXT("Failed to send chunk batch request to storage server.")); return false; } FStorageServerResponse Response(Owner, *Socket); if (!Response.IsOk()) { UE_LOG(LogStorageServerConnection, Fatal, TEXT("Failed to read chunk batch from storage server. '%s'"), *Response.GetErrorMessage()); return false; } uint32 Magic; uint32 ChunkCount; uint32 Reserved1; uint32 Reserved2; Response << Magic; if (Magic != 0xbada'b00f) { UE_LOG(LogStorageServerConnection, Fatal, TEXT("Invalid magic in chunk batch response from storage server.")); return false; } Response << ChunkCount; Response << Reserved1; Response << Reserved2; TArray> ChunkIndices; ChunkIndices.Reserve(ChunkCount); TArray> ChunkSizes; ChunkSizes.Reserve(ChunkCount); for (uint32 Index = 0; Index < ChunkCount; ++Index) { uint32 ChunkIndex; uint32 Flags; int64 ChunkSize; Response << ChunkIndex; Response << Flags; Response << ChunkSize; ChunkIndices.Add(ChunkIndex); ChunkSizes.Emplace(ChunkSize); } OnResponse(ChunkCount, ChunkIndices.GetData(), ChunkSizes.GetData(), Response); return true; } void FStorageServerResponse::ReleaseSocket(bool bKeepAlive) { Owner.ReleaseSocket(Socket, bKeepAlive); Socket = nullptr; } void FStorageServerResponse::Serialize(void* V, int64 Length) { if (Length == 0) { return; } if (!Socket) { UE_LOG(LogStorageServerConnection, Fatal, TEXT("Trying to read %lld bytes from released socket"), Length); return; } if (Position + Length > ContentLength) { UE_LOG(LogStorageServerConnection, Fatal, TEXT("Trying to read %lld bytes from socket with only %lld available"), Length, ContentLength - Position); return; } uint64 RemainingBytesToRead = Length; uint8* Destination = reinterpret_cast(V); while (RemainingBytesToRead) { uint64 BytesToRead32 = FMath::Min(RemainingBytesToRead, static_cast(INT32_MAX)); int32 BytesRead; if (!Socket->Recv(Destination, static_cast(BytesToRead32), BytesRead, ESocketReceiveFlags::WaitAll)) { UE_LOG(LogStorageServerConnection, Fatal, TEXT("Failed reading %d bytes from socket"), BytesToRead32); return; } RemainingBytesToRead -= BytesRead; Destination += BytesRead; Position += BytesRead; } if (Position == ContentLength) { ReleaseSocket(true); } } FStorageServerConnection::FStorageServerConnection() : SocketSubsystem(*ISocketSubsystem::Get()) { } FStorageServerConnection::~FStorageServerConnection() { for (FSocket* Socket : SocketPool) { Socket->Close(); delete Socket; } } bool FStorageServerConnection::Initialize(TArrayView InHostAddresses, int32 InPort, const TCHAR* InProjectNameOverride, const TCHAR* InPlatformNameOverride) { TArray> HostAddresses = GetAddressFromString(SocketSubsystem, InHostAddresses, InPort); if (!HostAddresses.Num()) { UE_LOG(LogStorageServerConnection, Fatal, TEXT("No valid Zen store host address specified")); return false; } OplogPath.Append("/prj/"); if (InProjectNameOverride) { OplogPath.Append(TCHAR_TO_ANSI(InProjectNameOverride)); } else { OplogPath.Append(TCHAR_TO_ANSI(FApp::GetProjectName())); } OplogPath.Append("/oplog/"); if (InPlatformNameOverride) { OplogPath.Append(TCHAR_TO_ANSI(InPlatformNameOverride)); } else { OplogPath.Append(FPlatformProperties::PlatformName()); } const int32 ServerVersion = HandshakeRequest(HostAddresses); if (ServerVersion != 1) { return false; } UE_LOG(LogStorageServerConnection, Display, TEXT("Connected to Zen storage server at '%s'"), *ServerAddr->ToString(true)); return true; } int32 FStorageServerConnection::HandshakeRequest(TArrayView> HostAddresses) { TAnsiStringBuilder<256> ResourceBuilder; ResourceBuilder.Append(OplogPath); for (const TSharedPtr& Addr : HostAddresses) { Hostname.Reset(); Hostname.Append(TCHAR_TO_ANSI(*Addr->ToString(false))); ServerAddr = Addr; UE_LOG(LogStorageServerConnection, Display, TEXT("Trying to handshake with Zen at '%s'"), *Addr->ToString(true)); FStorageServerRequest Request("GET", *ResourceBuilder, Hostname); if (FSocket* Socket = Request.Send(*this)) { FStorageServerResponse Response(*this, *Socket); if (Response.IsOk()) { FCbObject ResponseObj = Response.GetResponseObject(); // we currently don't have any concept of protocol versioning, if // we succeed in communicating with the endpoint we're good since // any breaking API change would need to be done in a backward // compatible manner return 1; } else { UE_LOG(LogStorageServerConnection, Fatal, TEXT("Failed to handshake with Zen at %s. '%s'"), *ServerAddr->ToString(true), *Response.GetErrorMessage()); } } } Hostname.Reset(); ServerAddr.Reset(); return -1; } void FStorageServerConnection::FileManifestRequest(TFunctionRef Callback) { TAnsiStringBuilder<256> ResourceBuilder; ResourceBuilder.Append(OplogPath).Append("/files?filter=client"); FStorageServerRequest Request("GET", *ResourceBuilder, Hostname); FSocket* Socket = Request.Send(*this); if (!Socket) { UE_LOG(LogStorageServerConnection, Fatal, TEXT("Failed to send file manifest request to storage server at %s."), *ServerAddr->ToString(true)); return; } FStorageServerResponse Response(*this, *Socket); if (Response.IsOk()) { FCbObject ResponseObj = Response.GetResponseObject(); for (FCbField& FileArrayEntry : ResponseObj["files"].AsArray()) { FCbObject Entry = FileArrayEntry.AsObject(); FCbObjectId Id = Entry["id"].AsObjectId(); TStringBuilder<128> WidePath; WidePath.Append(FUTF8ToTCHAR(Entry["clientpath"].AsString())); FIoChunkId ChunkId; ChunkId.Set(Id.GetView()); Callback(ChunkId, WidePath); } } else { UE_LOG(LogStorageServerConnection, Fatal, TEXT("Failed to read file manifest from storage server at %s. '%s'"), *ServerAddr->ToString(true), *Response.GetErrorMessage()); } } void FileIndexToChunkId(FAnsiStringBuilderBase& OutString, int32 FileIndex) { FIoChunkId Chunk = CreateExternalFileChunkId(0, FileIndex); OutString << Chunk; } int64 FStorageServerConnection::FileSizeRequest(int32 FileIndex) { TAnsiStringBuilder<256> ResourceBuilder; ResourceBuilder.Append(OplogPath); ResourceBuilder << "/"; FileIndexToChunkId(ResourceBuilder, FileIndex); FStorageServerRequest Request("HEAD", *ResourceBuilder, Hostname); FSocket* Socket = Request.Send(*this); if (!Socket) { UE_LOG(LogStorageServerConnection, Fatal, TEXT("Failed to send file size request to storage server at %s."), *ServerAddr->ToString(true)); return -1; } FStorageServerResponse Response(*this, *Socket); if (Response.IsOk()) { return Response.ContentLength; } else if (Response.GetErrorCode() == 404) { return -1; } else { UE_LOG(LogStorageServerConnection, Fatal, TEXT("Failed to get file size from storage server at %s. '%s'"), *ServerAddr->ToString(true), *Response.GetErrorMessage()); } return -1; } bool FStorageServerConnection::ReadFileRequest(int32 FileIndex, uint64 Offset, uint64 Size, TFunctionRef OnResponse) { TAnsiStringBuilder<256> ResourceBuilder; ResourceBuilder.Append(OplogPath); ResourceBuilder << "/"; FileIndexToChunkId(ResourceBuilder, FileIndex); ResourceBuilder.Appendf("?offset=%" UINT64_FMT "&size=%" UINT64_FMT, Offset, Size); FStorageServerRequest Request = FStorageServerRequest("GET", *ResourceBuilder, Hostname); FSocket* Socket = Request.Send(*this); if (!Socket) { UE_LOG(LogStorageServerConnection, Fatal, TEXT("Failed to send file read request to storage server at %s."), *ServerAddr->ToString(true)); return false; } FStorageServerResponse Response(*this, *Socket); if (Response.IsOk()) { OnResponse(Response); return true; } else if (Response.GetErrorCode() == 404) { return false; } else { UE_LOG(LogStorageServerConnection, Fatal, TEXT("Failed to read file from storage server at %s."), *ServerAddr->ToString(true), *Response.GetErrorMessage()); } return false; } int64 FStorageServerConnection::ChunkSizeRequest(const FIoChunkId& ChunkId) { TAnsiStringBuilder<256> ResourceBuilder; ResourceBuilder.Append(OplogPath); ResourceBuilder << "/" << ChunkId << "/info"; FStorageServerRequest Request("GET", *ResourceBuilder, Hostname); FSocket* Socket = Request.Send(*this); if (!Socket) { UE_LOG(LogStorageServerConnection, Fatal, TEXT("Failed to send chunk size request to storage server at %s."), *ServerAddr->ToString(true)); return -1; } FStorageServerResponse Response(*this, *Socket); if (Response.IsOk()) { FCbObject ResponseObj = Response.GetResponseObject(); const int64 ChunkSize = ResponseObj["size"].AsInt64(0); return ChunkSize; } else if (Response.GetErrorCode() == 404) { return -1; } else { UE_LOG(LogStorageServerConnection, Fatal, TEXT("Failed to get chunk size from storage server at %s. '%s'"), *ServerAddr->ToString(true), *Response.GetErrorMessage()); } return -1; } bool FStorageServerConnection::ReadChunkRequest(const FIoChunkId& ChunkId, uint64 Offset, uint64 Size, TFunctionRef OnResponse) { TAnsiStringBuilder<256> ResourceBuilder; ResourceBuilder.Append(OplogPath) << "/" << ChunkId; bool HaveQuery = false; auto AppendQueryDelimiter = [&] { if (HaveQuery) { ResourceBuilder.Append("&"_ASV); } else { ResourceBuilder.Append("?"_ASV); HaveQuery = true; } }; if (Offset) { AppendQueryDelimiter(); ResourceBuilder.Appendf("offset=%" UINT64_FMT, Offset); } if (Size != ~uint64(0)) { AppendQueryDelimiter(); ResourceBuilder.Appendf("size=%" UINT64_FMT, Size); } FStorageServerRequest Request("GET", *ResourceBuilder, Hostname); FSocket* Socket = Request.Send(*this); if (!Socket) { UE_LOG(LogStorageServerConnection, Fatal, TEXT("Failed to send chunk read request to storage server at %s."), *ServerAddr->ToString(true)); return false; } FStorageServerResponse Response(*this, *Socket); if (Response.IsOk()) { OnResponse(Response); return true; } else if (Response.GetErrorCode() == 404) { return false; } else { UE_LOG(LogStorageServerConnection, Fatal, TEXT("Failed to read chunk from storage server at %s. '%s'"), *ServerAddr->ToString(true), *Response.GetErrorMessage()); return false; } } FStorageServerChunkBatchRequest FStorageServerConnection::NewChunkBatchRequest() { TAnsiStringBuilder<256> ResourceBuilder; ResourceBuilder.Append(OplogPath).Append("/batch"); return FStorageServerChunkBatchRequest(*this, *ResourceBuilder, Hostname); } FSocket* FStorageServerConnection::AcquireSocket() { { FScopeLock Lock(&SocketPoolCritical); if (!SocketPool.IsEmpty()) { return SocketPool.Pop(false); } } for (int32 Attempt = 0, MaxAttempts = 10; Attempt < MaxAttempts; Attempt++) { FSocket* Socket = SocketSubsystem.CreateSocket(NAME_Stream, TEXT("StorageServer"), ServerAddr->GetProtocolType()); check(Socket); if (Socket->Connect(*ServerAddr)) { return Socket; } delete Socket; } UE_LOG(LogStorageServerConnection, Fatal, TEXT("Failed to connect to storage server at %s."), *ServerAddr->ToString(true)); return nullptr; } FString FStorageServerConnection::GetHostAddr() const { return ServerAddr.IsValid() ? ServerAddr->ToString(false) : FString(); } void FStorageServerConnection::ReleaseSocket(FSocket* Socket, bool bKeepAlive) { if (bKeepAlive) { uint32 PendingDataSize; if (!Socket->HasPendingData(PendingDataSize)) { FScopeLock Lock(&SocketPoolCritical); SocketPool.Push(Socket); return; } UE_LOG(LogStorageServerConnection, Fatal, TEXT("Socket was not fully drained")); } Socket->Close(); delete Socket; } #endif