From 9507520e40223356541e7e4b37eabdfd2c16d171 Mon Sep 17 00:00:00 2001 From: Pieter-Jan Briers Date: Wed, 26 Jul 2023 02:03:41 +0200 Subject: [PATCH] Better synchronous IAsyncEnumerable handling. (#18296) --- Content.Server/Database/ServerDbManager.cs | 53 ++++++++-------------- 1 file changed, 19 insertions(+), 34 deletions(-) diff --git a/Content.Server/Database/ServerDbManager.cs b/Content.Server/Database/ServerDbManager.cs index 42d0bc153b..e8bb38edb4 100644 --- a/Content.Server/Database/ServerDbManager.cs +++ b/Content.Server/Database/ServerDbManager.cs @@ -890,32 +890,10 @@ namespace Content.Server.Database private IAsyncEnumerable RunDbCommand(Func> command) { var enumerable = command(); - if (!_synchronous) - return enumerable; + if (_synchronous) + return new SyncAsyncEnumerable(enumerable); - // IAsyncEnumerable must be drained synchronously and returned as a fake async enumerable. - // If we were to let it go through like normal, it'd do a bunch of bad async stuff and break everything. - - var results = new List(); - var enumerator = enumerable.GetAsyncEnumerator(); - - while (true) - { - var result = enumerator.MoveNextAsync(); - if (!result.IsCompleted) - { - throw new InvalidOperationException( - "Database async enumerable is running asynchronously. " + - $"This should be impossible when the database is set to synchronous. Count: {results.Count}"); - } - - if (!result.Result) - break; - - results.Add(enumerator.Current); - } - - return new FakeAsyncEnumerable(results); + return enumerable; } private DbContextOptions CreatePostgresOptions() @@ -1048,38 +1026,45 @@ namespace Content.Server.Database public sealed record PlayTimeUpdate(NetUserId User, string Tracker, TimeSpan Time); - internal sealed class FakeAsyncEnumerable : IAsyncEnumerable + internal sealed class SyncAsyncEnumerable : IAsyncEnumerable { - private readonly IEnumerable _enumerable; + private readonly IAsyncEnumerable _enumerable; - public FakeAsyncEnumerable(IEnumerable enumerable) + public SyncAsyncEnumerable(IAsyncEnumerable enumerable) { _enumerable = enumerable; } public IAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToken = default) { - return new Enumerator(_enumerable.GetEnumerator()); + return new Enumerator(_enumerable.GetAsyncEnumerator(cancellationToken)); } private sealed class Enumerator : IAsyncEnumerator { - private readonly IEnumerator _enumerator; + private readonly IAsyncEnumerator _enumerator; - public Enumerator(IEnumerator enumerator) + public Enumerator(IAsyncEnumerator enumerator) { _enumerator = enumerator; } public ValueTask DisposeAsync() { - _enumerator.Dispose(); - return ValueTask.CompletedTask; + var task = _enumerator.DisposeAsync(); + if (!task.IsCompleted) + throw new InvalidOperationException("DisposeAsync did not complete synchronously."); + + return task; } public ValueTask MoveNextAsync() { - return new ValueTask(_enumerator.MoveNext()); + var task = _enumerator.MoveNextAsync(); + if (!task.IsCompleted) + throw new InvalidOperationException("MoveNextAsync did not complete synchronously."); + + return task; } public T Current => _enumerator.Current;