From 5bbd9bba8254e3b4ba21cf359955fd091a94d8a6 Mon Sep 17 00:00:00 2001 From: ObsidianMinor Date: Sat, 6 Jan 2018 21:43:11 -0600 Subject: [PATCH] Renamed existing Flatten method to FlattenAsync and added new Flatten method. Also fixed ClientHelper using incorrect guild batch count. (#744) --- .../Readers/UserTypeReader.cs | 14 ++--- .../Extensions/AsyncEnumerableExtensions.cs | 54 ++++++++++++++++++- src/Discord.Net.Rest/ClientHelper.cs | 4 +- .../Entities/Guilds/RestGuild.cs | 2 +- 4 files changed, 63 insertions(+), 11 deletions(-) diff --git a/src/Discord.Net.Commands/Readers/UserTypeReader.cs b/src/Discord.Net.Commands/Readers/UserTypeReader.cs index ca337aaf6..8fc330d4c 100644 --- a/src/Discord.Net.Commands/Readers/UserTypeReader.cs +++ b/src/Discord.Net.Commands/Readers/UserTypeReader.cs @@ -13,7 +13,7 @@ namespace Discord.Commands public override async Task ReadAsync(ICommandContext context, string input, IServiceProvider services) { var results = new Dictionary(); - IReadOnlyCollection channelUsers = (await context.Channel.GetUsersAsync(CacheMode.CacheOnly).Flatten().ConfigureAwait(false)).ToArray(); //TODO: must be a better way? + IAsyncEnumerable channelUsers = context.Channel.GetUsersAsync(CacheMode.CacheOnly).Flatten(); // it's better IReadOnlyCollection guildUsers = ImmutableArray.Create(); ulong id; @@ -45,7 +45,7 @@ namespace Discord.Commands string username = input.Substring(0, index); if (ushort.TryParse(input.Substring(index + 1), out ushort discriminator)) { - var channelUser = channelUsers.FirstOrDefault(x => x.DiscriminatorValue == discriminator && + var channelUser = await channelUsers.FirstOrDefault(x => x.DiscriminatorValue == discriminator && string.Equals(username, x.Username, StringComparison.OrdinalIgnoreCase)); AddResult(results, channelUser as T, channelUser?.Username == username ? 0.85f : 0.75f); @@ -57,8 +57,9 @@ namespace Discord.Commands //By Username (0.5-0.6) { - foreach (var channelUser in channelUsers.Where(x => string.Equals(input, x.Username, StringComparison.OrdinalIgnoreCase))) - AddResult(results, channelUser as T, channelUser.Username == input ? 0.65f : 0.55f); + await channelUsers + .Where(x => string.Equals(input, x.Username, StringComparison.OrdinalIgnoreCase)) + .ForEachAsync(channelUser => AddResult(results, channelUser as T, channelUser.Username == input ? 0.65f : 0.55f)); foreach (var guildUser in guildUsers.Where(x => string.Equals(input, x.Username, StringComparison.OrdinalIgnoreCase))) AddResult(results, guildUser as T, guildUser.Username == input ? 0.60f : 0.50f); @@ -66,8 +67,9 @@ namespace Discord.Commands //By Nickname (0.5-0.6) { - foreach (var channelUser in channelUsers.Where(x => string.Equals(input, (x as IGuildUser)?.Nickname, StringComparison.OrdinalIgnoreCase))) - AddResult(results, channelUser as T, (channelUser as IGuildUser).Nickname == input ? 0.65f : 0.55f); + await channelUsers + .Where(x => string.Equals(input, (x as IGuildUser)?.Nickname, StringComparison.OrdinalIgnoreCase)) + .ForEachAsync(channelUser => AddResult(results, channelUser as T, (channelUser as IGuildUser).Nickname == input ? 0.65f : 0.55f)); foreach (var guildUser in guildUsers.Where(x => string.Equals(input, (x as IGuildUser).Nickname, StringComparison.OrdinalIgnoreCase))) AddResult(results, guildUser as T, (guildUser as IGuildUser).Nickname == input ? 0.60f : 0.50f); diff --git a/src/Discord.Net.Core/Extensions/AsyncEnumerableExtensions.cs b/src/Discord.Net.Core/Extensions/AsyncEnumerableExtensions.cs index f52edd719..345154f1d 100644 --- a/src/Discord.Net.Core/Extensions/AsyncEnumerableExtensions.cs +++ b/src/Discord.Net.Core/Extensions/AsyncEnumerableExtensions.cs @@ -1,14 +1,64 @@ using System.Collections.Generic; using System.Linq; +using System.Threading; using System.Threading.Tasks; namespace Discord { public static class AsyncEnumerableExtensions { - public static async Task> Flatten(this IAsyncEnumerable> source) + /// + /// Flattens the specified pages into one asynchronously + /// + /// + /// + /// + public static async Task> FlattenAsync(this IAsyncEnumerable> source) { - return (await source.ToArray().ConfigureAwait(false)).SelectMany(x => x); + return await source.Flatten().ToArray().ConfigureAwait(false); + } + + public static IAsyncEnumerable Flatten(this IAsyncEnumerable> source) + { + return new PagedCollectionEnumerator(source); + } + + internal class PagedCollectionEnumerator : IAsyncEnumerator, IAsyncEnumerable + { + readonly IAsyncEnumerator> _source; + IEnumerator _enumerator; + + public IAsyncEnumerator GetEnumerator() => this; + + internal PagedCollectionEnumerator(IAsyncEnumerable> source) + { + _source = source.GetEnumerator(); + } + + public T Current => _enumerator.Current; + + public void Dispose() + { + _enumerator?.Dispose(); + _source.Dispose(); + } + + public async Task MoveNext(CancellationToken cancellationToken) + { + cancellationToken.ThrowIfCancellationRequested(); + + if(!_enumerator?.MoveNext() ?? true) + { + if (!await _source.MoveNext(cancellationToken).ConfigureAwait(false)) + return false; + + _enumerator?.Dispose(); + _enumerator = _source.Current.GetEnumerator(); + return _enumerator.MoveNext(); + } + + return true; + } } } } diff --git a/src/Discord.Net.Rest/ClientHelper.cs b/src/Discord.Net.Rest/ClientHelper.cs index 26d8c720e..5c9e26433 100644 --- a/src/Discord.Net.Rest/ClientHelper.cs +++ b/src/Discord.Net.Rest/ClientHelper.cs @@ -79,7 +79,7 @@ namespace Discord.Rest ulong? fromGuildId, int? limit, RequestOptions options) { return new PagedAsyncEnumerable( - DiscordConfig.MaxUsersPerBatch, + DiscordConfig.MaxGuildsPerBatch, async (info, ct) => { var args = new GetGuildSummariesParams @@ -106,7 +106,7 @@ namespace Discord.Rest } public static async Task> GetGuildsAsync(BaseDiscordClient client, RequestOptions options) { - var summaryModels = await GetGuildSummariesAsync(client, null, null, options).Flatten(); + var summaryModels = await GetGuildSummariesAsync(client, null, null, options).FlattenAsync().ConfigureAwait(false); var guilds = ImmutableArray.CreateBuilder(); foreach (var summaryModel in summaryModels) { diff --git a/src/Discord.Net.Rest/Entities/Guilds/RestGuild.cs b/src/Discord.Net.Rest/Entities/Guilds/RestGuild.cs index 76ddc07ca..5d12731a6 100644 --- a/src/Discord.Net.Rest/Entities/Guilds/RestGuild.cs +++ b/src/Discord.Net.Rest/Entities/Guilds/RestGuild.cs @@ -413,7 +413,7 @@ namespace Discord.Rest async Task> IGuild.GetUsersAsync(CacheMode mode, RequestOptions options) { if (mode == CacheMode.AllowDownload) - return (await GetUsersAsync(options).Flatten().ConfigureAwait(false)).ToImmutableArray(); + return (await GetUsersAsync(options).FlattenAsync().ConfigureAwait(false)).ToImmutableArray(); else return ImmutableArray.Create(); }