// Copyright Epic Games, Inc. All Rights Reserved.
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text.Json.Serialization;
using System.Threading.Tasks;
using EpicGames.Redis;
using Horde.Build.Server;
using MongoDB.Bson;
using MongoDB.Bson.IO;
using MongoDB.Bson.Serialization;
using MongoDB.Bson.Serialization.Attributes;
using MongoDB.Bson.Serialization.Conventions;
using MongoDB.Driver;
using StackExchange.Redis;
namespace Horde.Build.Utilities
{
///
/// Exception indicating that a particular document couldn't be found
///
public sealed class DocumentNotFoundException : Exception
{
///
/// The document identifier
///
public TId Id { get; }
///
/// Constructor
///
///
public DocumentNotFoundException(TId id)
: base($"Unable to find '{id}'")
{
Id = id;
}
}
///
/// Base class for a versioned MongoDB document
///
/// Type of the unique identifier for each document
/// Type of the latest document
public abstract class VersionedDocument
where TId : notnull
where TLatest : VersionedDocument
{
///
/// Unique id for this document
///
[JsonPropertyOrder(0)]
public TId Id { get; set; }
///
/// Last time that the document was updated. This field is checked and updated as part of updates to ensure atomicity.
///
[BsonElement("_u")]
[JsonPropertyOrder(1000)]
public DateTime LastUpdateTime { get; set; }
///
/// Constructor
///
/// Unique id for the document
protected VersionedDocument(TId id)
{
Id = id;
LastUpdateTime = DateTime.MinValue;
}
///
/// Perform any transformations necessary to upgrade this document to the latest version
///
/// Upgraded copy of the document
public abstract TLatest UpgradeToLatest();
}
[BsonDiscriminator()]
class VersionedDocumentDiscriminator : IDiscriminatorConvention
{
public string ElementName => "_v";
readonly IReadOnlyDictionary _versionToType;
readonly IReadOnlyDictionary _typeToVersion;
public VersionedDocumentDiscriminator(IReadOnlyDictionary types)
{
_versionToType = types;
_typeToVersion = _versionToType.ToDictionary(x => x.Value, x => x.Key);
}
public Type GetActualType(IBsonReader bsonReader, Type nominalType)
{
BsonReaderBookmark bookmark = bsonReader.GetBookmark();
bsonReader.ReadStartDocument();
for (; ; )
{
string name = bsonReader.ReadName();
if (name.Equals("_v", StringComparison.Ordinal))
{
int version = bsonReader.ReadInt32();
bsonReader.ReturnToBookmark(bookmark);
return _versionToType[version];
}
bsonReader.SkipValue();
}
}
public BsonValue GetDiscriminator(Type nominalType, Type actualType)
{
return _typeToVersion[actualType];
}
}
///
/// Collection of types derived from
///
public sealed class VersionedCollection
where TId : notnull
where TLatest : VersionedDocument
{
private static readonly RedisKey s_timeSuffix = new RedisKey("/ts");
private static FilterDefinitionBuilder> FilterBuilder { get; } = Builders>.Filter;
private readonly RedisConnectionPool _redisConnectionPool;
private readonly RedisKey _baseKey;
private static readonly HashSet s_registeredTypes = new HashSet();
///
/// Collection of versioned documents
///
public IMongoCollection> BaseCollection { get; }
///
/// Collection of versioned documents, filtered to the latest revision
///
public IMongoCollection LatestCollection { get; }
///
/// Constructor
///
/// The database service
/// Name of the collection of documents to manage
/// Instance of the redis service
/// Prefix for key types
/// Types to serialize from the collection
public VersionedCollection(MongoService mongoService, string collectionName, RedisService redisService, RedisKey baseKey, IReadOnlyDictionary types)
{
_redisConnectionPool = redisService.ConnectionPool;
_baseKey = baseKey;
lock (s_registeredTypes)
{
Type type = typeof(VersionedDocument);
if (s_registeredTypes.Add(type))
{
VersionedDocumentDiscriminator discriminator = new VersionedDocumentDiscriminator(types);
BsonSerializer.RegisterDiscriminatorConvention(type, discriminator);
}
}
BaseCollection = mongoService.Database.GetCollection>(collectionName);
LatestCollection = BaseCollection.OfType();
}
private RedisKey GetDocKey(TId id) => _baseKey.Append(id.ToString());
private static FilterDefinition> GetFilter(TId id)
{
return Builders>.Filter.Eq(x => x.Id, id);
}
private static FilterDefinition> GetFilter(VersionedDocument doc)
{
FilterDefinitionBuilder> builder = Builders>.Filter;
return builder.Eq(x => x.Id, doc.Id) & builder.Eq(x => x.LastUpdateTime, doc.LastUpdateTime);
}
private static FilterDefinition GetFilter(TLatest doc)
{
FilterDefinitionBuilder builder = Builders.Filter;
return builder.Eq(x => x.Id, doc.Id) & builder.Eq(x => x.LastUpdateTime, doc.LastUpdateTime);
}
///
/// Adds a new document to the collection
///
/// The document to add
/// True if the document was added, false if it already exists
public async Task AddAsync(TLatest doc)
{
doc.LastUpdateTime = MongoExtensions.RoundToBsonDateTime(DateTime.UtcNow);
if (!await BaseCollection.InsertOneIgnoreDuplicatesAsync(doc))
{
return false;
}
AddCachedValue(GetDocKey(doc.Id), doc);
return true;
}
private void AddCachedValue(RedisKey docKey, TLatest doc)
{
KeyValuePair[] pairs = new KeyValuePair[2];
pairs[0] = new KeyValuePair(docKey, doc.ToBson(typeof(VersionedDocument)));
pairs[1] = new KeyValuePair(docKey.Append(s_timeSuffix), doc.LastUpdateTime.Ticks);
_redisConnectionPool.GetDatabase().StringSetAsync(pairs, When.NotExists, flags: CommandFlags.FireAndForget);
}
private async ValueTask?> GetCachedValueAsync(RedisKey docKey)
{
RedisValue cacheValue = await _redisConnectionPool.GetDatabase().StringGetAsync(docKey);
if (!cacheValue.IsNull)
{
try
{
return BsonSerializer.Deserialize>((byte[])cacheValue!);
}
catch
{
await DeleteCachedValueAsync(docKey);
}
}
return null;
}
private async ValueTask UpdateCachedValueAsync(RedisKey docKey, VersionedDocument prevDoc, TLatest doc)
{
ITransaction transaction = _redisConnectionPool.GetDatabase().CreateTransaction();
transaction.AddCondition(Condition.StringEqual(docKey.Append(s_timeSuffix), prevDoc.LastUpdateTime.Ticks));
KeyValuePair[] pairs = new KeyValuePair[2];
pairs[0] = new KeyValuePair(docKey, doc.ToBson(typeof(VersionedDocument)));
pairs[1] = new KeyValuePair(docKey.Append(s_timeSuffix), doc.LastUpdateTime.Ticks);
_ = transaction.StringSetAsync(pairs, flags: CommandFlags.FireAndForget);
if (!await transaction.ExecuteAsync())
{
await DeleteCachedValueAsync(docKey);
return false;
}
return true;
}
private async ValueTask DeleteCachedValueAsync(RedisKey docKey)
{
await _redisConnectionPool.GetDatabase().KeyDeleteAsync(new[] { docKey, docKey.Append(s_timeSuffix) });
}
///
/// Gets a document with the given id
///
/// The document id to look for
/// The matching document, or null if it does not exist
public async Task GetAsync(TId id)
{
RedisKey docKey = GetDocKey(id);
for (; ; )
{
// Attempt to get the cached value for this key
VersionedDocument? cachedDoc = await GetCachedValueAsync(docKey);
if (cachedDoc == null)
{
// Read the value from the database
VersionedDocument? doc = await BaseCollection.Find(FilterBuilder.Eq(x => x.Id, id)).FirstOrDefaultAsync();
if (doc == null)
{
return null;
}
if (doc is TLatest latestDoc)
{
AddCachedValue(docKey, latestDoc);
return latestDoc;
}
TLatest upgradedDoc = doc.UpgradeToLatest();
if (await ReplaceAsync(doc, upgradedDoc))
{
return upgradedDoc;
}
}
else
{
// Parse the cached value, and make sure it's the latest version
if (cachedDoc is TLatest latestCachedDoc)
{
return latestCachedDoc;
}
TLatest upgradedDoc = cachedDoc.UpgradeToLatest();
if (await ReplaceAsync(cachedDoc, upgradedDoc))
{
return upgradedDoc;
}
}
}
}
///
/// Gets an existing document or creates one with the given callback
///
/// The document id to look for
/// Factory method used to create a new document if need be
/// The existing document, or the document that was inserted
public async Task FindOrAddAsync(TId id, Func factory)
{
TLatest? newDoc = null;
for (; ; )
{
// Try to get an existing document
TLatest? latest = await GetAsync(id);
if (latest != null)
{
return latest;
}
// Create a new document and try to add it
newDoc ??= factory();
if (await AddAsync(newDoc))
{
return newDoc;
}
}
}
///
/// Attempt to update the given document. Fails if the document has been modified from the version presented.
///
/// The current document version
/// Update to be applied
/// True if the document was updated
public async Task UpdateAsync(TLatest doc, UpdateDefinition update)
{
update = update.Set(x => x.LastUpdateTime, new DateTime(Math.Max(doc.LastUpdateTime.Ticks + 1, DateTime.UtcNow.Ticks)));
TLatest? newDoc = await LatestCollection.FindOneAndUpdateAsync(GetFilter(doc), update, new FindOneAndUpdateOptions { ReturnDocument = ReturnDocument.After });
if (newDoc != null)
{
await UpdateCachedValueAsync(GetDocKey(doc.Id), doc, newDoc);
}
return newDoc;
}
///
/// Replaces an existing document with a new one
///
/// The old document
/// The new document
/// True if the document was replaced
public async Task ReplaceAsync(VersionedDocument oldDoc, TLatest newDoc)
{
if (!oldDoc.Id.Equals(newDoc.Id))
{
throw new InvalidOperationException("Id for new document must match old document");
}
newDoc.LastUpdateTime = new DateTime(Math.Max(oldDoc.LastUpdateTime.Ticks + 1, DateTime.UtcNow.Ticks));
ReplaceOneResult result = await BaseCollection.ReplaceOneAsync(GetFilter(oldDoc), newDoc);
if (result.ModifiedCount == 0)
{
return false;
}
await UpdateCachedValueAsync(GetDocKey(oldDoc.Id), oldDoc, newDoc);
return true;
}
///
/// Delete a document with the given identifier
///
/// Id of the document to delete
/// True if the document was deleted, or false if it could not be found
public async Task DeleteAsync(TLatest doc)
{
RedisKey docKey = GetDocKey(doc.Id);
DeleteResult result = await BaseCollection.DeleteOneAsync(GetFilter((VersionedDocument)doc));
if (result.DeletedCount > 0)
{
await DeleteCachedValueAsync(docKey);
return true;
}
return false;
}
///
/// Delete a document with the given identifier
///
/// Id of the document to delete
/// True if the document was deleted, or false if it could not be found
public async Task DeleteAsync(TId id)
{
RedisKey docKey = GetDocKey(id);
DeleteResult result = await BaseCollection.DeleteOneAsync(GetFilter(id));
await DeleteCachedValueAsync(docKey);
return result.DeletedCount > 0;
}
}
}