// Copyright Epic Games, Inc. All Rights Reserved.
using System;
using System.Collections.Generic;
using System.Linq;
using System.Net;
using System.Threading;
using System.Threading.Tasks;
using EpicGames.Horde.Agents;
using EpicGames.Horde.Agents.Leases;
using EpicGames.Horde.Compute;
using Google.Protobuf.WellKnownTypes;
using Horde.Server.Agents;
using Horde.Server.Agents.Leases;
using Horde.Server.Agents.Relay;
using Horde.Server.Server;
using Horde.Server.Tasks;
using HordeCommon.Rpc.Tasks;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options;
namespace Horde.Server.Compute
{
///
/// A TCP/IP port used by a compute resource, and how it is mapped externally -> internally
///
public class ComputeResourcePort
{
///
/// Externally visible port that is mapped to agent port
/// In direct connection mode, these two are identical.
///
public int Port { get; }
///
/// Port the local process on the agent is listening on
///
public int AgentPort { get; }
///
/// Constructor
///
///
///
public ComputeResourcePort(int port, int agentPort)
{
Port = port;
AgentPort = agentPort;
}
///
protected bool Equals(ComputeResourcePort other)
{
return Port == other.Port && AgentPort == other.AgentPort;
}
///
public override bool Equals(object? obj)
{
if (ReferenceEquals(null, obj))
{
return false;
}
if (ReferenceEquals(this, obj))
{
return true;
}
return obj.GetType() == GetType() && Equals((ComputeResourcePort)obj);
}
///
public override int GetHashCode()
{
return HashCode.Combine(Port, AgentPort);
}
///
public override string ToString()
{
return $"Port={Port}, AgentPort={AgentPort}";
}
}
///
/// Information about a compute
///
public class ComputeResource
{
///
public ConnectionMode ConnectionMode { get; }
///
/// IP address of the agent
///
public IPAddress Ip { get; }
///
public string? ConnectionAddress { get; }
///
public IReadOnlyDictionary Ports { get; }
///
/// Information about the compute task
///
public ComputeTask Task { get; }
///
/// Properties of the assigned agent
///
public IReadOnlyList Properties { get; }
///
/// Agent id on the remote machine
///
public AgentId AgentId { get; }
///
/// Lease id on the remote machine
///
public LeaseId LeaseId { get; }
///
/// Constructor
///
public ComputeResource(ConnectionMode connectionMode, IPAddress ip, string? connectionAddress, IReadOnlyDictionary ports, ComputeTask task, IReadOnlyList properties, AgentId agentId, LeaseId leaseId)
{
ConnectionMode = connectionMode;
Ip = ip;
ConnectionAddress = connectionAddress;
Ports = ports;
Task = task;
Properties = properties;
AgentId = agentId;
LeaseId = leaseId;
}
}
///
/// Dispatches requests for compute resources
///
public class ComputeTaskSource : TaskSourceBase
{
class Waiter
{
public IAgent Agent { get; }
public IPAddress Ip { get; }
public int Port { get; }
public TaskCompletionSource Lease { get; } = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
public Waiter(IAgent agent, IPAddress ip, int port)
{
Agent = agent;
Ip = ip;
Port = port;
}
}
///
public override string Type => "Compute";
///
public override TaskSourceFlags Flags => TaskSourceFlags.None;
readonly AgentRelayService _agentRelay;
readonly ILeaseCollection _leaseCollection;
readonly IOptionsMonitor _globalConfig;
readonly ILogger _logger;
readonly object _lockObject = new object();
readonly Dictionary> _waiters = new Dictionary>();
///
/// Constructor
///
public ComputeTaskSource(AgentRelayService agentRelay, ILeaseCollection leaseCollection, IOptionsMonitor globalConfig, ILogger logger)
{
_agentRelay = agentRelay;
_leaseCollection = leaseCollection;
_globalConfig = globalConfig;
_logger = logger;
}
///
public override Task> AssignLeaseAsync(IAgent agent, CancellationToken cancellationToken)
{
return Task.FromResult(WaitInternalAsync(agent, cancellationToken));
}
///
public override async Task OnLeaseFinishedAsync(IAgent agent, LeaseId leaseId, ComputeTask payload, LeaseOutcome outcome, ReadOnlyMemory output, ILogger logger, CancellationToken cancellationToken)
{
await base.OnLeaseFinishedAsync(agent, leaseId, payload, outcome, output, logger, cancellationToken);
// Remove any port mapping associated with this lease ID (as of now, only compute tasks can be relayed)
await _agentRelay.RemovePortMappingAsync(leaseId);
}
async Task WaitInternalAsync(IAgent agent, CancellationToken cancellationToken)
{
string? ipStr = agent.GetPropertyValues("ComputeIp").FirstOrDefault();
if (ipStr == null || !IPAddress.TryParse(ipStr, out IPAddress? ip))
{
return null;
}
string? portStr = agent.GetPropertyValues("ComputePort").FirstOrDefault();
if (portStr == null || !Int32.TryParse(portStr, out int port))
{
return null;
}
// Add it to the wait queue
GlobalConfig globalConfig = _globalConfig.CurrentValue;
Waiter? waiter = null;
try
{
lock (_lockObject)
{
foreach (ComputeClusterConfig clusterConfig in globalConfig.Compute)
{
if (clusterConfig.Condition == null || agent.SatisfiesCondition(clusterConfig.Condition))
{
LinkedList? list;
if (!_waiters.TryGetValue(clusterConfig.Id, out list))
{
list = new LinkedList();
_waiters.Add(clusterConfig.Id, list);
}
waiter ??= new Waiter(agent, ip, port);
list.AddFirst(waiter);
}
}
}
if (waiter != null)
{
using IDisposable disposable = cancellationToken.Register(() => waiter.Lease.TrySetResult(null));
AgentLease? lease = await waiter.Lease.Task;
if (lease != null)
{
_logger.LogInformation("Created compute lease for agent {AgentId}", agent.Id);
return lease;
}
}
}
finally
{
lock (_lockObject)
{
if (waiter != null)
{
foreach (ComputeClusterConfig clusterConfig in globalConfig.Compute)
{
if (_waiters.TryGetValue(clusterConfig.Id, out LinkedList? list))
{
list.Remove(waiter);
}
}
}
}
}
return null;
}
///
public override async ValueTask GetLeaseDetailsAsync(Any payload, Dictionary details, CancellationToken cancellationToken)
{
await base.GetLeaseDetailsAsync(payload, details, cancellationToken);
ComputeTask message = payload.Unpack();
if (!String.IsNullOrEmpty(message.ParentLeaseId) && LeaseId.TryParse(message.ParentLeaseId, out LeaseId parentLeaseId))
{
ILease? lease = await _leaseCollection.GetAsync(parentLeaseId, cancellationToken);
if (lease != null)
{
details["parentLogId"] = lease.LogId.ToString() ?? String.Empty;
}
}
}
}
}