diff --git a/src/Discord.Net/API/DiscordAPIClient.cs b/src/Discord.Net/API/DiscordAPIClient.cs index 128c15474..c2a439b06 100644 --- a/src/Discord.Net/API/DiscordAPIClient.cs +++ b/src/Discord.Net/API/DiscordAPIClient.cs @@ -409,6 +409,10 @@ namespace Discord.API }; await SendGatewayAsync(GatewayOpCode.VoiceStateUpdate, payload, options: options).ConfigureAwait(false); } + public async Task SendGuildSyncAsync(IEnumerable guildIds, RequestOptions options = null) + { + await SendGatewayAsync(GatewayOpCode.GuildSync, guildIds, options: options).ConfigureAwait(false); + } //Channels public async Task GetChannelAsync(ulong channelId, RequestOptions options = null) diff --git a/src/Discord.Net/API/Gateway/GatewayOpCode.cs b/src/Discord.Net/API/Gateway/GatewayOpCode.cs index dc69d073c..3c2a3382b 100644 --- a/src/Discord.Net/API/Gateway/GatewayOpCode.cs +++ b/src/Discord.Net/API/Gateway/GatewayOpCode.cs @@ -18,13 +18,15 @@ Resume = 6, /// C←S - Used to notify a client that they must reconnect to another gateway. Reconnect = 7, - /// C→S - Used to request all members that were withheld by large_threshold + /// C→S - Used to request members that were withheld by large_threshold RequestGuildMembers = 8, /// C←S - Used to notify the client that their session has expired and cannot be resumed. InvalidSession = 9, /// C←S - Used to provide information to the client immediately on connection. Hello = 10, /// C←S - Used to reply to a client's heartbeat. - HeartbeatAck = 11 + HeartbeatAck = 11, + /// C→S - Used to request presence updates from particular guilds. + GuildSync = 12 } } diff --git a/src/Discord.Net/API/Gateway/GuildSyncEvent.cs b/src/Discord.Net/API/Gateway/GuildSyncEvent.cs new file mode 100644 index 000000000..ff290f23a --- /dev/null +++ b/src/Discord.Net/API/Gateway/GuildSyncEvent.cs @@ -0,0 +1,17 @@ +using Newtonsoft.Json; + +namespace Discord.API.Gateway +{ + public class GuildSyncEvent + { + [JsonProperty("id")] + public ulong Id { get; set; } + [JsonProperty("large")] + public bool Large { get; set; } + + [JsonProperty("presences")] + public Presence[] Presences { get; set; } + [JsonProperty("members")] + public GuildMember[] Members { get; set; } + } +} diff --git a/src/Discord.Net/DiscordSocketClient.cs b/src/Discord.Net/DiscordSocketClient.cs index 82289fe95..3b2ac8137 100644 --- a/src/Discord.Net/DiscordSocketClient.cs +++ b/src/Discord.Net/DiscordSocketClient.cs @@ -29,6 +29,8 @@ namespace Discord private int _lastSeq; private ImmutableDictionary _voiceRegions; private TaskCompletionSource _connectTask; + private ConcurrentHashSet _syncedGuilds; + private SemaphoreSlim _syncedGuildsLock; private CancellationTokenSource _cancelToken; private Task _heartbeatTask, _guildDownloadTask, _reconnectTask; private long _heartbeatTime; @@ -102,6 +104,8 @@ namespace Discord _voiceRegions = ImmutableDictionary.Create(); _largeGuilds = new ConcurrentQueue(); + _syncedGuilds = new ConcurrentHashSet(); + _syncedGuildsLock = new SemaphoreSlim(1, 1); } protected override async Task OnLoginAsync() @@ -295,7 +299,7 @@ namespace Discord { return Task.FromResult>(Guilds); } - internal CachedGuild AddGuild(API.Gateway.ExtendedGuild model, DataStore dataStore) + internal CachedGuild AddGuild(ExtendedGuild model, DataStore dataStore) { var guild = new CachedGuild(this, model, dataStore); dataStore.AddGuild(guild); @@ -305,6 +309,7 @@ namespace Discord } internal CachedGuild RemoveGuild(ulong id) { + _syncedGuilds.TryRemove(id); var guild = DataStore.RemoveGuild(id); foreach (var channel in guild.Channels) guild.RemoveChannel(channel.Id); @@ -363,18 +368,47 @@ namespace Discord } /// Downloads the users list for all large guilds. - public Task DownloadAllUsersAsync() + public Task DownloadAllUsersAsync() => DownloadUsersAsync(DataStore.Guilds.Where(x => !x.HasAllMembers)); /// Downloads the users list for the provided guilds, if they don't have a complete list. - public async Task DownloadUsersAsync(IEnumerable guilds) + public Task DownloadUsersAsync(IEnumerable guilds) + => DownloadUsersAsync(guilds.Select(x => x as CachedGuild).Where(x => x != null)); + public Task DownloadUsersAsync(params IGuild[] guilds) + => DownloadUsersAsync(guilds.Select(x => x as CachedGuild).Where(x => x != null)); + private async Task DownloadUsersAsync(IEnumerable guilds) { + var cachedGuilds = guilds.ToArray(); + if (cachedGuilds.Length == 0) return; + + //Sync guilds + if (ApiClient.AuthTokenType == TokenType.User) + { + await _syncedGuildsLock.WaitAsync().ConfigureAwait(false); + try + { + foreach (var guild in cachedGuilds) + _syncedGuilds.TryAdd(guild.Id); + await ApiClient.SendGuildSyncAsync(_syncedGuilds).ConfigureAwait(false); + await Task.WhenAll(cachedGuilds.Select(x => x.SyncPromise)); + + //Reduce the list only to those with members left to download + cachedGuilds = cachedGuilds.Where(x => !x.HasAllMembers).ToArray(); + if (cachedGuilds.Length == 0) return; + } + finally + { + _syncedGuildsLock.Release(); + } + } + + //Download offline members const short batchSize = 50; - var cachedGuilds = guilds.Select(x => x as CachedGuild).ToArray(); - if (cachedGuilds.Length == 0) - return; - else if (cachedGuilds.Length == 1) + + if (cachedGuilds.Length == 1) { - await cachedGuilds[0].DownloadUsersAsync().ConfigureAwait(false); + if (!cachedGuilds[0].HasAllMembers) + await ApiClient.SendRequestMembersAsync(new ulong[] { cachedGuilds[0].Id }).ConfigureAwait(false); + await cachedGuilds[0].DownloaderPromise.ConfigureAwait(false); return; } @@ -502,6 +536,15 @@ namespace Discord _currentUser = currentUser; _unavailableGuilds = unavailableGuilds; _lastGuildAvailableTime = Environment.TickCount; + await _syncedGuildsLock.WaitAsync().ConfigureAwait(false); + try + { + _syncedGuilds = new ConcurrentHashSet(); + } + finally + { + _syncedGuildsLock.Release(); + } DataStore = dataStore; _guildDownloadTask = WaitForGuildsAsync(_cancelToken.Token); @@ -513,9 +556,11 @@ namespace Discord } break; case "RESUMED": - await _gatewayLogger.DebugAsync("Received Dispatch (RESUMED)").ConfigureAwait(false); + { + await _gatewayLogger.DebugAsync("Received Dispatch (RESUMED)").ConfigureAwait(false); - await _gatewayLogger.InfoAsync("Resume").ConfigureAwait(false); + await _gatewayLogger.InfoAsync("Resumed previous session").ConfigureAwait(false); + } return; //Guilds @@ -579,9 +624,9 @@ namespace Discord } } break; - case "GUILD_EMOJI_UPDATE": //TODO: Add + case "GUILD_EMOJIS_UPDATE": //TODO: Add { - await _gatewayLogger.DebugAsync("Received Dispatch (GUILD_EMOJI_UPDATE)").ConfigureAwait(false); + await _gatewayLogger.DebugAsync("Received Dispatch (GUILD_EMOJIS_UPDATE)").ConfigureAwait(false); var data = (payload as JToken).ToObject(_serializer); var guild = DataStore.GetGuild(data.GuildId); @@ -593,13 +638,33 @@ namespace Discord } else { - await _gatewayLogger.WarningAsync("GUILD_EMOJI_UPDATE referenced an unknown guild.").ConfigureAwait(false); + await _gatewayLogger.WarningAsync("GUILD_EMOJIS_UPDATE referenced an unknown guild.").ConfigureAwait(false); return; } } return; case "GUILD_INTEGRATIONS_UPDATE": - await _gatewayLogger.DebugAsync("Ignored Dispatch (GUILD_INTEGRATIONS_UPDATE)").ConfigureAwait(false); + { + await _gatewayLogger.DebugAsync("Ignored Dispatch (GUILD_INTEGRATIONS_UPDATE)").ConfigureAwait(false); + } + return; + case "GUILD_SYNC": + { + await _gatewayLogger.DebugAsync("Received Dispatch (GUILD_SYNC)").ConfigureAwait(false); + var data = (payload as JToken).ToObject(_serializer); + var guild = DataStore.GetGuild(data.Id); + if (guild != null) + { + var before = guild.Clone(); + guild.Update(data, UpdateSource.WebSocket, DataStore); + await _guildUpdatedEvent.InvokeAsync(before, guild).ConfigureAwait(false); + } + else + { + await _gatewayLogger.WarningAsync("GUILD_SYNC referenced an unknown guild.").ConfigureAwait(false); + return; + } + } return; case "GUILD_DELETE": { diff --git a/src/Discord.Net/Entities/WebSocket/CachedGuild.cs b/src/Discord.Net/Entities/WebSocket/CachedGuild.cs index e1641be03..88db46763 100644 --- a/src/Discord.Net/Entities/WebSocket/CachedGuild.cs +++ b/src/Discord.Net/Entities/WebSocket/CachedGuild.cs @@ -10,6 +10,7 @@ using System.Threading.Tasks; using ChannelModel = Discord.API.Channel; using EmojiUpdateModel = Discord.API.Gateway.GuildEmojiUpdateEvent; using ExtendedModel = Discord.API.Gateway.ExtendedGuild; +using GuildSyncModel = Discord.API.Gateway.GuildSyncEvent; using MemberModel = Discord.API.GuildMember; using Model = Discord.API.Guild; using PresenceModel = Discord.API.Presence; @@ -18,10 +19,16 @@ using VoiceStateModel = Discord.API.VoiceState; namespace Discord { + internal enum MemberDownloadState + { + Incomplete, + Synced, + Complete + } internal class CachedGuild : Guild, ICachedEntity, IGuild, IUserGuild { private readonly SemaphoreSlim _audioLock; - private TaskCompletionSource _downloaderPromise; + private TaskCompletionSource _syncPromise, _downloaderPromise; private ConcurrentHashSet _channels; private ConcurrentDictionary _members; private ConcurrentDictionary _voiceStates; @@ -29,9 +36,11 @@ namespace Discord public bool Available { get; private set; } public int MemberCount { get; private set; } public int DownloadedMemberCount { get; private set; } - public AudioClient AudioClient { get; private set; } + public AudioClient AudioClient { get; private set; } + public MemberDownloadState MemberDownloadState { get; private set; } public bool HasAllMembers => _downloaderPromise.Task.IsCompleted; + public Task SyncPromise => _syncPromise.Task; public Task DownloaderPromise => _downloaderPromise.Task; public new DiscordSocketClient Discord => base.Discord as DiscordSocketClient; @@ -51,6 +60,7 @@ namespace Discord public CachedGuild(DiscordSocketClient discord, ExtendedModel model, DataStore dataStore) : base(discord, model) { _audioLock = new SemaphoreSlim(1, 1); + _syncPromise = new TaskCompletionSource(); _downloaderPromise = new TaskCompletionSource(); Update(model, UpdateSource.Creation, dataStore); } @@ -91,9 +101,12 @@ namespace Discord DownloadedMemberCount = 0; for (int i = 0; i < model.Members.Length; i++) AddUser(model.Members[i], dataStore, members); - _downloaderPromise = new TaskCompletionSource(); - if (!model.Large) - _downloaderPromise.SetResult(true); + if (Discord.ApiClient.AuthTokenType != TokenType.User) + { + _syncPromise.TrySetResult(true); + if (!model.Large) + _downloaderPromise.TrySetResult(true); + } for (int i = 0; i < model.Presences.Length; i++) AddOrUpdateUser(model.Presences[i], dataStore, members); @@ -107,6 +120,24 @@ namespace Discord } _voiceStates = voiceStates; } + public void Update(GuildSyncModel model, UpdateSource source, DataStore dataStore) + { + if (source == UpdateSource.Rest && IsAttached) return; + + var members = new ConcurrentDictionary(1, (int)(model.Presences.Length * 1.05)); + { + DownloadedMemberCount = 0; + for (int i = 0; i < model.Members.Length; i++) + AddUser(model.Members[i], dataStore, members); + _syncPromise.TrySetResult(true); + if (!model.Large) + _downloaderPromise.TrySetResult(true); + + for (int i = 0; i < model.Presences.Length; i++) + AddOrUpdateUser(model.Presences[i], dataStore, members); + } + _members = members; + } public void Update(EmojiUpdateModel model, UpdateSource source) { @@ -208,9 +239,7 @@ namespace Discord } public override async Task DownloadUsersAsync() { - if (!HasAllMembers) - await Discord.ApiClient.SendRequestMembersAsync(new ulong[] { Id }).ConfigureAwait(false); - await _downloaderPromise.Task.ConfigureAwait(false); + await Discord.DownloadUsersAsync(new [] { this }); } public void CompleteDownloadMembers() {