diff --git a/src/Discord.Net/DiscordSocketClient.Events.cs b/src/Discord.Net/DiscordSocketClient.Events.cs index 092b19674..698a940e4 100644 --- a/src/Discord.Net/DiscordSocketClient.Events.cs +++ b/src/Discord.Net/DiscordSocketClient.Events.cs @@ -167,12 +167,12 @@ namespace Discord remove { _userPresenceUpdatedEvent.Remove(value); } } private readonly AsyncEvent> _userPresenceUpdatedEvent = new AsyncEvent>(); - public event Func UserVoiceStateUpdated + public event Func UserVoiceStateUpdated { add { _userVoiceStateUpdatedEvent.Add(value); } remove { _userVoiceStateUpdatedEvent.Remove(value); } } - private readonly AsyncEvent> _userVoiceStateUpdatedEvent = new AsyncEvent>(); + private readonly AsyncEvent> _userVoiceStateUpdatedEvent = new AsyncEvent>(); public event Func CurrentUserUpdated { add { _selfUpdatedEvent.Add(value); } diff --git a/src/Discord.Net/DiscordSocketClient.cs b/src/Discord.Net/DiscordSocketClient.cs index 9ac32f122..7375d1088 100644 --- a/src/Discord.Net/DiscordSocketClient.cs +++ b/src/Discord.Net/DiscordSocketClient.cs @@ -1280,39 +1280,66 @@ namespace Discord var data = (payload as JToken).ToObject(_serializer); if (data.GuildId.HasValue) { - var guild = DataStore.GetGuild(data.GuildId.Value); - if (guild != null) + ICachedUser user; + VoiceState before, after; + if (data.GuildId != null) { - if (!guild.IsSynced) + var guild = DataStore.GetGuild(data.GuildId.Value); + if (guild != null) { - await _gatewayLogger.DebugAsync("Ignored VOICE_STATE_UPDATE, guild is not synced yet.").ConfigureAwait(false); - return; + if (!guild.IsSynced) + { + await _gatewayLogger.DebugAsync("Ignored VOICE_STATE_UPDATE, guild is not synced yet.").ConfigureAwait(false); + return; + } + + if (data.ChannelId != null) + { + before = guild.GetVoiceState(data.UserId)?.Clone() ?? new VoiceState(null, null, false, false, false); + after = guild.AddOrUpdateVoiceState(data, DataStore); + } + else + { + before = guild.RemoveVoiceState(data.UserId) ?? new VoiceState(null, null, false, false, false); + after = new VoiceState(null, data); + } + user = guild.GetUser(data.UserId); } - - VoiceState before, after; - if (data.ChannelId != null) + else { - before = guild.GetVoiceState(data.UserId)?.Clone() ?? new VoiceState(null, null, false, false, false); - after = guild.AddOrUpdateVoiceState(data, DataStore); + await _gatewayLogger.WarningAsync("VOICE_STATE_UPDATE referenced an unknown guild.").ConfigureAwait(false); + return; } - else + } + else + { + var groupChannel = DataStore.GetChannel(data.ChannelId.Value) as CachedGroupChannel; + if (groupChannel != null) { - before = guild.RemoveVoiceState(data.UserId) ?? new VoiceState(null, null, false, false, false); - after = new VoiceState(null, data); + if (data.ChannelId != null) + { + before = groupChannel.GetVoiceState(data.UserId)?.Clone() ?? new VoiceState(null, null, false, false, false); + after = groupChannel.AddOrUpdateVoiceState(data, DataStore); + } + else + { + before = groupChannel.RemoveVoiceState(data.UserId) ?? new VoiceState(null, null, false, false, false); + after = new VoiceState(null, data); + } + user = groupChannel.GetUser(data.UserId); } - - var user = guild.GetUser(data.UserId); - if (user != null) - await _userVoiceStateUpdatedEvent.InvokeAsync(user, before, after).ConfigureAwait(false); else { - await _gatewayLogger.WarningAsync("VOICE_STATE_UPDATE referenced an unknown user.").ConfigureAwait(false); + await _gatewayLogger.WarningAsync("VOICE_STATE_UPDATE referenced an unknown channel.").ConfigureAwait(false); return; } } + + if (user != null) + await _userVoiceStateUpdatedEvent.InvokeAsync(user, before, after).ConfigureAwait(false); else { - await _gatewayLogger.WarningAsync("VOICE_STATE_UPDATE referenced an unknown guild.").ConfigureAwait(false); + await _gatewayLogger.WarningAsync("VOICE_STATE_UPDATE referenced an unknown user.").ConfigureAwait(false); return; } } diff --git a/src/Discord.Net/Entities/Channels/GroupChannel.cs b/src/Discord.Net/Entities/Channels/GroupChannel.cs index 65dad771c..a11b58800 100644 --- a/src/Discord.Net/Entities/Channels/GroupChannel.cs +++ b/src/Discord.Net/Entities/Channels/GroupChannel.cs @@ -30,10 +30,10 @@ namespace Discord { Discord = discord; _users = recipients; - + Update(model, UpdateSource.Creation); } - public void Update(Model model, UpdateSource source) + public virtual void Update(Model model, UpdateSource source) { if (source == UpdateSource.Rest && IsAttached) return; @@ -41,7 +41,7 @@ namespace Discord Name = model.Name.Value; if (model.Icon.IsSpecified) _iconId = model.Icon.Value; - + if (source != UpdateSource.Creation && model.Recipients.IsSpecified) UpdateUsers(model.Recipients.Value, source); } diff --git a/src/Discord.Net/Entities/WebSocket/CachedGroupChannel.cs b/src/Discord.Net/Entities/WebSocket/CachedGroupChannel.cs index da94a5b31..5320727fa 100644 --- a/src/Discord.Net/Entities/WebSocket/CachedGroupChannel.cs +++ b/src/Discord.Net/Entities/WebSocket/CachedGroupChannel.cs @@ -6,16 +6,18 @@ using System.Linq; using System.Threading.Tasks; using MessageModel = Discord.API.Message; using Model = Discord.API.Channel; +using VoiceStateModel = Discord.API.VoiceState; namespace Discord { internal class CachedGroupChannel : GroupChannel, IGroupChannel, ICachedChannel, ICachedMessageChannel, ICachedPrivateChannel { private readonly MessageManager _messages; + private ConcurrentDictionary _voiceStates; public new DiscordSocketClient Discord => base.Discord as DiscordSocketClient; public IReadOnlyCollection Members - => _users.Select(x => x.Value).Concat(ImmutableArray.Create(Discord.CurrentUser)).Cast().ToReadOnlyCollection(_users, 1); + => _users.Select(x => x.Value).Concat(ImmutableArray.Create(Discord.CurrentUser)).Cast().ToReadOnlyCollection(() => _users.Count + 1); public new IReadOnlyCollection Recipients => _users.Cast().ToReadOnlyCollection(_users); public CachedGroupChannel(DiscordSocketClient discord, ConcurrentDictionary recipients, Model model) @@ -25,6 +27,13 @@ namespace Discord _messages = new MessageCache(Discord, this); else _messages = new MessageManager(Discord, this); + _voiceStates = new ConcurrentDictionary(1, 5); + } + public override void Update(Model model, UpdateSource source) + { + if (source == UpdateSource.Rest && IsAttached) return; + + base.Update(model, source); } protected override void UpdateUsers(API.User[] models, UpdateSource source) @@ -35,6 +44,38 @@ namespace Discord _users = users; } + public ICachedUser GetUser(ulong id) + { + IUser user; + if (_users.TryGetValue(id, out user)) + return user as ICachedUser; + if (id == Discord.CurrentUser.Id) + return Discord.CurrentUser; + return null; + } + + public VoiceState AddOrUpdateVoiceState(VoiceStateModel model, DataStore dataStore, ConcurrentDictionary voiceStates = null) + { + var voiceChannel = dataStore.GetChannel(model.ChannelId.Value) as CachedVoiceChannel; + var voiceState = new VoiceState(voiceChannel, model); + (voiceStates ?? _voiceStates)[model.UserId] = voiceState; + return voiceState; + } + public VoiceState? GetVoiceState(ulong id) + { + VoiceState voiceState; + if (_voiceStates.TryGetValue(id, out voiceState)) + return voiceState; + return null; + } + public VoiceState? RemoveVoiceState(ulong id) + { + VoiceState voiceState; + if (_voiceStates.TryRemove(id, out voiceState)) + return voiceState; + return null; + } + public override async Task GetMessageAsync(ulong id) { return await _messages.DownloadAsync(id).ConfigureAwait(false); diff --git a/src/Discord.Net/Extensions/CollectionExtensions.cs b/src/Discord.Net/Extensions/CollectionExtensions.cs index d2cef8a64..91c9f030f 100644 --- a/src/Discord.Net/Extensions/CollectionExtensions.cs +++ b/src/Discord.Net/Extensions/CollectionExtensions.cs @@ -10,8 +10,8 @@ namespace Discord.Extensions { public static IReadOnlyCollection ToReadOnlyCollection(this IReadOnlyDictionary source) => new ConcurrentDictionaryWrapper(source.Select(x => x.Value), () => source.Count); - public static IReadOnlyCollection ToReadOnlyCollection(this IEnumerable query, IReadOnlyCollection source, int countOffset = 0) - => new ConcurrentDictionaryWrapper(query, () => source.Count + countOffset); + public static IReadOnlyCollection ToReadOnlyCollection(this IEnumerable query, IReadOnlyCollection source) + => new ConcurrentDictionaryWrapper(query, () => source.Count); public static IReadOnlyCollection ToReadOnlyCollection(this IEnumerable query, Func countFunc) => new ConcurrentDictionaryWrapper(query, countFunc); }