diff --git a/src/Discord.Net.WebSocket/DiscordShardedClient.cs b/src/Discord.Net.WebSocket/DiscordShardedClient.cs
index 3a8f90990..832e35578 100644
--- a/src/Discord.Net.WebSocket/DiscordShardedClient.cs
+++ b/src/Discord.Net.WebSocket/DiscordShardedClient.cs
@@ -112,12 +112,12 @@ namespace Discord.WebSocket
}
///
- public async Task ConnectAsync(bool waitForGuilds = true)
+ public async Task ConnectAsync()
{
await _connectionLock.WaitAsync().ConfigureAwait(false);
try
{
- await ConnectInternalAsync(waitForGuilds).ConfigureAwait(false);
+ await ConnectInternalAsync().ConfigureAwait(false);
}
catch
{
@@ -126,10 +126,10 @@ namespace Discord.WebSocket
}
finally { _connectionLock.Release(); }
}
- private async Task ConnectInternalAsync(bool waitForGuilds)
+ private async Task ConnectInternalAsync()
{
await Task.WhenAll(
- _shards.Select(x => x.ConnectAsync(waitForGuilds))
+ _shards.Select(x => x.ConnectAsync())
).ConfigureAwait(false);
CurrentUser = _shards[0].CurrentUser;
@@ -253,12 +253,6 @@ namespace Discord.WebSocket
public RestVoiceRegion GetVoiceRegion(string id)
=> _shards[0].GetVoiceRegion(id);
- /// Downloads the users list for all large guilds.
- public async Task DownloadAllUsersAsync()
- {
- for (int i = 0; i < _shards.Length; i++)
- await _shards[i].DownloadAllUsersAsync().ConfigureAwait(false);
- }
/// Downloads the users list for the provided guilds, if they don't have a complete list.
public async Task DownloadUsersAsync(IEnumerable guilds)
{
diff --git a/src/Discord.Net.WebSocket/DiscordSocketClient.cs b/src/Discord.Net.WebSocket/DiscordSocketClient.cs
index 6f48a23e7..0641672d2 100644
--- a/src/Discord.Net.WebSocket/DiscordSocketClient.cs
+++ b/src/Discord.Net.WebSocket/DiscordSocketClient.cs
@@ -42,6 +42,7 @@ namespace Discord.WebSocket
private bool _canReconnect;
private DateTimeOffset? _statusSince;
private RestApplication _applicationInfo;
+ private ConcurrentHashSet _downloadUsersFor;
/// Gets the shard of of this client.
public int ShardId { get; }
@@ -61,7 +62,7 @@ namespace Discord.WebSocket
internal int ConnectionTimeout { get; private set; }
internal UdpSocketProvider UdpSocketProvider { get; private set; }
internal WebSocketProvider WebSocketProvider { get; private set; }
- internal bool DownloadUsersOnGuildAvailable { get; private set; }
+ internal bool AlwaysDownloadUsers { get; private set; }
internal new DiscordSocketApiClient ApiClient => base.ApiClient as DiscordSocketApiClient;
public new SocketSelfUser CurrentUser { get { return base.CurrentUser as SocketSelfUser; } private set { base.CurrentUser = value; } }
@@ -84,9 +85,10 @@ namespace Discord.WebSocket
AudioMode = config.AudioMode;
UdpSocketProvider = config.UdpSocketProvider;
WebSocketProvider = config.WebSocketProvider;
- DownloadUsersOnGuildAvailable = config.DownloadUsersOnGuildAvailable;
+ AlwaysDownloadUsers = config.AlwaysDownloadUsers;
ConnectionTimeout = config.ConnectionTimeout;
State = new ClientState(0, 0);
+ _downloadUsersFor = new ConcurrentHashSet();
_nextAudioId = 1;
_gatewayLogger = LogManager.CreateLogger(ShardId == 0 && TotalShards == 1 ? "Gateway" : "Shard #" + ShardId);
@@ -119,14 +121,17 @@ namespace Discord.WebSocket
GuildUnavailable += async g => await _gatewayLogger.VerboseAsync($"Disconnected from {g.Name}").ConfigureAwait(false);
LatencyUpdated += async (old, val) => await _gatewayLogger.VerboseAsync($"Latency = {val} ms").ConfigureAwait(false);
- if (DownloadUsersOnGuildAvailable)
+ GuildAvailable += g =>
{
- GuildAvailable += g =>
+ if (ConnectionState == ConnectionState.Connected && (AlwaysDownloadUsers || _downloadUsersFor.ContainsKey(g.Id)))
{
- var _ = g.DownloadUsersAsync();
- return Task.Delay(0);
- };
- }
+ if (!g.HasAllMembers)
+ {
+ var _ = g.DownloadUsersAsync();
+ }
+ }
+ return Task.Delay(0);
+ };
_voiceRegions = ImmutableDictionary.Create();
_largeGuilds = new ConcurrentQueue();
@@ -151,10 +156,11 @@ namespace Discord.WebSocket
_applicationInfo = null;
_voiceRegions = ImmutableDictionary.Create();
+ _downloadUsersFor.Clear();
}
///
- public async Task ConnectAsync(bool waitForGuilds = true)
+ public async Task ConnectAsync()
{
await _connectionLock.WaitAsync().ConfigureAwait(false);
try
@@ -162,13 +168,6 @@ namespace Discord.WebSocket
await ConnectInternalAsync(false).ConfigureAwait(false);
}
finally { _connectionLock.Release(); }
-
- if (waitForGuilds)
- {
- var downloadTask = _guildDownloadTask;
- if (downloadTask != null)
- await _guildDownloadTask.ConfigureAwait(false);
- }
}
private async Task ConnectInternalAsync(bool isReconnecting)
{
@@ -227,6 +226,8 @@ namespace Discord.WebSocket
await _gatewayLogger.DebugAsync("Raising Event").ConfigureAwait(false);
ConnectionState = ConnectionState.Connected;
await _gatewayLogger.InfoAsync("Connected").ConfigureAwait(false);
+
+ await ProcessUserDownloadsAsync(_downloadUsersFor.Select(x => GetGuild(x)).Where(x => x != null).ToImmutableArray()).ConfigureAwait(false);
}
catch (Exception)
{
@@ -442,31 +443,23 @@ namespace Discord.WebSocket
return null;
}
- /// Downloads the users list for all large guilds.
- public Task DownloadAllUsersAsync()
- => DownloadUsersAsync(State.Guilds.Where(x => !x.HasAllMembers));
/// Downloads the users list for the provided guilds, if they don't have a complete list.
- public async Task DownloadUsersAsync(IEnumerable guilds)
+ public async Task DownloadUsersAsync(IEnumerable guilds)
{
- var cachedGuilds = guilds.ToImmutableArray();
- if (cachedGuilds.Length == 0) return;
+ foreach (var guild in guilds)
+ _downloadUsersFor.TryAdd(guild.Id);
- //Wait for unsynced guilds to sync first.
- var unsyncedGuilds = guilds.Select(x => x.SyncPromise).Where(x => !x.IsCompleted).ToImmutableArray();
- if (unsyncedGuilds.Length > 0)
- await Task.WhenAll(unsyncedGuilds).ConfigureAwait(false);
-
- //Download offline members
- const short batchSize = 50;
-
- if (cachedGuilds.Length == 1)
+ if (ConnectionState == ConnectionState.Connected)
{
- if (!cachedGuilds[0].HasAllMembers)
- await ApiClient.SendRequestMembersAsync(new ulong[] { cachedGuilds[0].Id }).ConfigureAwait(false);
- await cachedGuilds[0].DownloaderPromise.ConfigureAwait(false);
- return;
+ //Race condition leads to guilds being requested twice, probably okay
+ await ProcessUserDownloadsAsync(guilds.Select(x => GetGuild(x.Id)).Where(x => x != null)).ConfigureAwait(false);
}
+ }
+ private async Task ProcessUserDownloadsAsync(IEnumerable guilds)
+ {
+ var cachedGuilds = guilds.ToImmutableArray();
+ const short batchSize = 50;
ulong[] batchIds = new ulong[Math.Min(batchSize, cachedGuilds.Length)];
Task[] batchTasks = new Task[batchIds.Length];
int batchCount = (cachedGuilds.Length + (batchSize - 1)) / batchSize;
@@ -795,6 +788,7 @@ namespace Discord.WebSocket
{
await _gatewayLogger.DebugAsync($"Received Dispatch (GUILD_DELETE)").ConfigureAwait(false);
+ _downloadUsersFor.TryRemove(data.Id);
var guild = RemoveGuild(data.Id);
if (guild != null)
{
@@ -1728,6 +1722,12 @@ namespace Discord.WebSocket
await logger.ErrorAsync("Heartbeat Errored", ex).ConfigureAwait(false);
}
}
+ public async Task WaitForGuildsAsync()
+ {
+ var downloadTask = _guildDownloadTask;
+ if (downloadTask != null)
+ await _guildDownloadTask.ConfigureAwait(false);
+ }
private async Task WaitForGuildsAsync(CancellationToken cancelToken, Logger logger)
{
//Wait for GUILD_AVAILABLEs
diff --git a/src/Discord.Net.WebSocket/Entities/Guilds/SocketGuild.cs b/src/Discord.Net.WebSocket/Entities/Guilds/SocketGuild.cs
index c1bec756c..22a4c2a71 100644
--- a/src/Discord.Net.WebSocket/Entities/Guilds/SocketGuild.cs
+++ b/src/Discord.Net.WebSocket/Entities/Guilds/SocketGuild.cs
@@ -55,7 +55,7 @@ namespace Discord.WebSocket
public SocketTextChannel DefaultChannel => GetTextChannel(Id);
public string IconUrl => CDN.GetGuildIconUrl(Id, IconId);
public string SplashUrl => CDN.GetGuildSplashUrl(Id, SplashId);
- public bool HasAllMembers => _downloaderPromise.Task.IsCompleted;
+ public bool HasAllMembers => MemberCount == DownloadedMemberCount;// _downloaderPromise.Task.IsCompleted;
public bool IsSynced => _syncPromise.Task.IsCompleted;
public Task SyncPromise => _syncPromise.Task;
public Task DownloaderPromise => _downloaderPromise.Task;