From b00b69234f4078a10f23194079a9016cce61fad4 Mon Sep 17 00:00:00 2001 From: RogueException Date: Sun, 29 Jan 2017 23:15:48 -0400 Subject: [PATCH] Users can no longer directly request user downloads. --- .../DiscordShardedClient.cs | 14 ++-- .../DiscordSocketClient.cs | 70 +++++++++---------- .../Entities/Guilds/SocketGuild.cs | 2 +- 3 files changed, 40 insertions(+), 46 deletions(-) diff --git a/src/Discord.Net.WebSocket/DiscordShardedClient.cs b/src/Discord.Net.WebSocket/DiscordShardedClient.cs index 3a8f90990..832e35578 100644 --- a/src/Discord.Net.WebSocket/DiscordShardedClient.cs +++ b/src/Discord.Net.WebSocket/DiscordShardedClient.cs @@ -112,12 +112,12 @@ namespace Discord.WebSocket } /// - public async Task ConnectAsync(bool waitForGuilds = true) + public async Task ConnectAsync() { await _connectionLock.WaitAsync().ConfigureAwait(false); try { - await ConnectInternalAsync(waitForGuilds).ConfigureAwait(false); + await ConnectInternalAsync().ConfigureAwait(false); } catch { @@ -126,10 +126,10 @@ namespace Discord.WebSocket } finally { _connectionLock.Release(); } } - private async Task ConnectInternalAsync(bool waitForGuilds) + private async Task ConnectInternalAsync() { await Task.WhenAll( - _shards.Select(x => x.ConnectAsync(waitForGuilds)) + _shards.Select(x => x.ConnectAsync()) ).ConfigureAwait(false); CurrentUser = _shards[0].CurrentUser; @@ -253,12 +253,6 @@ namespace Discord.WebSocket public RestVoiceRegion GetVoiceRegion(string id) => _shards[0].GetVoiceRegion(id); - /// Downloads the users list for all large guilds. - public async Task DownloadAllUsersAsync() - { - for (int i = 0; i < _shards.Length; i++) - await _shards[i].DownloadAllUsersAsync().ConfigureAwait(false); - } /// Downloads the users list for the provided guilds, if they don't have a complete list. public async Task DownloadUsersAsync(IEnumerable guilds) { diff --git a/src/Discord.Net.WebSocket/DiscordSocketClient.cs b/src/Discord.Net.WebSocket/DiscordSocketClient.cs index 6f48a23e7..0641672d2 100644 --- a/src/Discord.Net.WebSocket/DiscordSocketClient.cs +++ b/src/Discord.Net.WebSocket/DiscordSocketClient.cs @@ -42,6 +42,7 @@ namespace Discord.WebSocket private bool _canReconnect; private DateTimeOffset? _statusSince; private RestApplication _applicationInfo; + private ConcurrentHashSet _downloadUsersFor; /// Gets the shard of of this client. public int ShardId { get; } @@ -61,7 +62,7 @@ namespace Discord.WebSocket internal int ConnectionTimeout { get; private set; } internal UdpSocketProvider UdpSocketProvider { get; private set; } internal WebSocketProvider WebSocketProvider { get; private set; } - internal bool DownloadUsersOnGuildAvailable { get; private set; } + internal bool AlwaysDownloadUsers { get; private set; } internal new DiscordSocketApiClient ApiClient => base.ApiClient as DiscordSocketApiClient; public new SocketSelfUser CurrentUser { get { return base.CurrentUser as SocketSelfUser; } private set { base.CurrentUser = value; } } @@ -84,9 +85,10 @@ namespace Discord.WebSocket AudioMode = config.AudioMode; UdpSocketProvider = config.UdpSocketProvider; WebSocketProvider = config.WebSocketProvider; - DownloadUsersOnGuildAvailable = config.DownloadUsersOnGuildAvailable; + AlwaysDownloadUsers = config.AlwaysDownloadUsers; ConnectionTimeout = config.ConnectionTimeout; State = new ClientState(0, 0); + _downloadUsersFor = new ConcurrentHashSet(); _nextAudioId = 1; _gatewayLogger = LogManager.CreateLogger(ShardId == 0 && TotalShards == 1 ? "Gateway" : "Shard #" + ShardId); @@ -119,14 +121,17 @@ namespace Discord.WebSocket GuildUnavailable += async g => await _gatewayLogger.VerboseAsync($"Disconnected from {g.Name}").ConfigureAwait(false); LatencyUpdated += async (old, val) => await _gatewayLogger.VerboseAsync($"Latency = {val} ms").ConfigureAwait(false); - if (DownloadUsersOnGuildAvailable) + GuildAvailable += g => { - GuildAvailable += g => + if (ConnectionState == ConnectionState.Connected && (AlwaysDownloadUsers || _downloadUsersFor.ContainsKey(g.Id))) { - var _ = g.DownloadUsersAsync(); - return Task.Delay(0); - }; - } + if (!g.HasAllMembers) + { + var _ = g.DownloadUsersAsync(); + } + } + return Task.Delay(0); + }; _voiceRegions = ImmutableDictionary.Create(); _largeGuilds = new ConcurrentQueue(); @@ -151,10 +156,11 @@ namespace Discord.WebSocket _applicationInfo = null; _voiceRegions = ImmutableDictionary.Create(); + _downloadUsersFor.Clear(); } /// - public async Task ConnectAsync(bool waitForGuilds = true) + public async Task ConnectAsync() { await _connectionLock.WaitAsync().ConfigureAwait(false); try @@ -162,13 +168,6 @@ namespace Discord.WebSocket await ConnectInternalAsync(false).ConfigureAwait(false); } finally { _connectionLock.Release(); } - - if (waitForGuilds) - { - var downloadTask = _guildDownloadTask; - if (downloadTask != null) - await _guildDownloadTask.ConfigureAwait(false); - } } private async Task ConnectInternalAsync(bool isReconnecting) { @@ -227,6 +226,8 @@ namespace Discord.WebSocket await _gatewayLogger.DebugAsync("Raising Event").ConfigureAwait(false); ConnectionState = ConnectionState.Connected; await _gatewayLogger.InfoAsync("Connected").ConfigureAwait(false); + + await ProcessUserDownloadsAsync(_downloadUsersFor.Select(x => GetGuild(x)).Where(x => x != null).ToImmutableArray()).ConfigureAwait(false); } catch (Exception) { @@ -442,31 +443,23 @@ namespace Discord.WebSocket return null; } - /// Downloads the users list for all large guilds. - public Task DownloadAllUsersAsync() - => DownloadUsersAsync(State.Guilds.Where(x => !x.HasAllMembers)); /// Downloads the users list for the provided guilds, if they don't have a complete list. - public async Task DownloadUsersAsync(IEnumerable guilds) + public async Task DownloadUsersAsync(IEnumerable guilds) { - var cachedGuilds = guilds.ToImmutableArray(); - if (cachedGuilds.Length == 0) return; + foreach (var guild in guilds) + _downloadUsersFor.TryAdd(guild.Id); - //Wait for unsynced guilds to sync first. - var unsyncedGuilds = guilds.Select(x => x.SyncPromise).Where(x => !x.IsCompleted).ToImmutableArray(); - if (unsyncedGuilds.Length > 0) - await Task.WhenAll(unsyncedGuilds).ConfigureAwait(false); - - //Download offline members - const short batchSize = 50; - - if (cachedGuilds.Length == 1) + if (ConnectionState == ConnectionState.Connected) { - if (!cachedGuilds[0].HasAllMembers) - await ApiClient.SendRequestMembersAsync(new ulong[] { cachedGuilds[0].Id }).ConfigureAwait(false); - await cachedGuilds[0].DownloaderPromise.ConfigureAwait(false); - return; + //Race condition leads to guilds being requested twice, probably okay + await ProcessUserDownloadsAsync(guilds.Select(x => GetGuild(x.Id)).Where(x => x != null)).ConfigureAwait(false); } + } + private async Task ProcessUserDownloadsAsync(IEnumerable guilds) + { + var cachedGuilds = guilds.ToImmutableArray(); + const short batchSize = 50; ulong[] batchIds = new ulong[Math.Min(batchSize, cachedGuilds.Length)]; Task[] batchTasks = new Task[batchIds.Length]; int batchCount = (cachedGuilds.Length + (batchSize - 1)) / batchSize; @@ -795,6 +788,7 @@ namespace Discord.WebSocket { await _gatewayLogger.DebugAsync($"Received Dispatch (GUILD_DELETE)").ConfigureAwait(false); + _downloadUsersFor.TryRemove(data.Id); var guild = RemoveGuild(data.Id); if (guild != null) { @@ -1728,6 +1722,12 @@ namespace Discord.WebSocket await logger.ErrorAsync("Heartbeat Errored", ex).ConfigureAwait(false); } } + public async Task WaitForGuildsAsync() + { + var downloadTask = _guildDownloadTask; + if (downloadTask != null) + await _guildDownloadTask.ConfigureAwait(false); + } private async Task WaitForGuildsAsync(CancellationToken cancelToken, Logger logger) { //Wait for GUILD_AVAILABLEs diff --git a/src/Discord.Net.WebSocket/Entities/Guilds/SocketGuild.cs b/src/Discord.Net.WebSocket/Entities/Guilds/SocketGuild.cs index c1bec756c..22a4c2a71 100644 --- a/src/Discord.Net.WebSocket/Entities/Guilds/SocketGuild.cs +++ b/src/Discord.Net.WebSocket/Entities/Guilds/SocketGuild.cs @@ -55,7 +55,7 @@ namespace Discord.WebSocket public SocketTextChannel DefaultChannel => GetTextChannel(Id); public string IconUrl => CDN.GetGuildIconUrl(Id, IconId); public string SplashUrl => CDN.GetGuildSplashUrl(Id, SplashId); - public bool HasAllMembers => _downloaderPromise.Task.IsCompleted; + public bool HasAllMembers => MemberCount == DownloadedMemberCount;// _downloaderPromise.Task.IsCompleted; public bool IsSynced => _syncPromise.Task.IsCompleted; public Task SyncPromise => _syncPromise.Task; public Task DownloaderPromise => _downloaderPromise.Task;