diff --git a/src/Discord.Net.Core/Utils/AsyncEvent.cs b/src/Discord.Net.Core/Utils/AsyncEvent.cs index 12a1fba9c..731489dea 100644 --- a/src/Discord.Net.Core/Utils/AsyncEvent.cs +++ b/src/Discord.Net.Core/Utils/AsyncEvent.cs @@ -11,6 +11,7 @@ namespace Discord private readonly object _subLock = new object(); internal ImmutableArray _subscriptions; + public bool HasSubscribers => _subscriptions.Length != 0; public IReadOnlyList Subscriptions => _subscriptions; public AsyncEvent() diff --git a/src/Discord.Net.WebSocket/DiscordSocketClient.cs b/src/Discord.Net.WebSocket/DiscordSocketClient.cs index 093339028..76e943ce4 100644 --- a/src/Discord.Net.WebSocket/DiscordSocketClient.cs +++ b/src/Discord.Net.WebSocket/DiscordSocketClient.cs @@ -21,6 +21,8 @@ namespace Discord.WebSocket { public partial class DiscordSocketClient : BaseDiscordClient, IDiscordClient { + private const int HandlerTimeoutMillis = 3000; + private readonly ConcurrentQueue _largeGuilds; private readonly JsonSerializer _serializer; private readonly SemaphoreSlim _connectionGroupLock; @@ -57,6 +59,7 @@ namespace Discord.WebSocket internal UdpSocketProvider UdpSocketProvider { get; private set; } internal WebSocketProvider WebSocketProvider { get; private set; } internal bool AlwaysDownloadUsers { get; private set; } + internal bool EnableHandlerTimeouts { get; private set; } internal new DiscordSocketApiClient ApiClient => base.ApiClient as DiscordSocketApiClient; public new SocketSelfUser CurrentUser { get { return base.CurrentUser as SocketSelfUser; } private set { base.CurrentUser = value; } } @@ -83,6 +86,7 @@ namespace Discord.WebSocket UdpSocketProvider = config.UdpSocketProvider; WebSocketProvider = config.WebSocketProvider; AlwaysDownloadUsers = config.AlwaysDownloadUsers; + EnableHandlerTimeouts = config.EnableHandlerTimeouts; State = new ClientState(0, 0); _heartbeatTimes = new ConcurrentQueue(); @@ -90,8 +94,8 @@ namespace Discord.WebSocket _gatewayLogger = LogManager.CreateLogger(ShardId == 0 && TotalShards == 1 ? "Gateway" : $"Shard #{ShardId}"); _connection = new ConnectionManager(_stateLock, _gatewayLogger, config.ConnectionTimeout, OnConnectingAsync, OnDisconnectingAsync, x => ApiClient.Disconnected += x); - _connection.Connected += () => _connectedEvent.InvokeAsync(); - _connection.Disconnected += (ex, recon) => _disconnectedEvent.InvokeAsync(ex); + _connection.Connected += () => TimedInvokeAsync(_connectedEvent, nameof(Connected)); + _connection.Disconnected += (ex, recon) => TimedInvokeAsync(_disconnectedEvent, nameof(Disconnected), ex); _nextAudioId = 1; _connectionGroupLock = groupLock; @@ -422,7 +426,7 @@ namespace Discord.WebSocket int before = Latency; Latency = latency; - await _latencyUpdatedEvent.InvokeAsync(before, latency).ConfigureAwait(false); + await TimedInvokeAsync(_latencyUpdatedEvent, nameof(LatencyUpdated), before, latency).ConfigureAwait(false); } } break; @@ -500,7 +504,7 @@ namespace Discord.WebSocket else if (_connection.CancelToken.IsCancellationRequested) return; - await _readyEvent.InvokeAsync().ConfigureAwait(false); + await TimedInvokeAsync(_readyEvent, nameof(Ready)).ConfigureAwait(false); await _gatewayLogger.InfoAsync("Ready").ConfigureAwait(false); }); var _ = _connection.CompleteAsync(); @@ -559,7 +563,7 @@ namespace Discord.WebSocket { if (ApiClient.AuthTokenType == TokenType.User) await SyncGuildsAsync().ConfigureAwait(false); - await _joinedGuildEvent.InvokeAsync(guild).ConfigureAwait(false); + await TimedInvokeAsync(_joinedGuildEvent, nameof(JoinedGuild), guild).ConfigureAwait(false); } else { @@ -579,7 +583,7 @@ namespace Discord.WebSocket { var before = guild.Clone(); guild.Update(State, data); - await _guildUpdatedEvent.InvokeAsync(before, guild).ConfigureAwait(false); + await TimedInvokeAsync(_guildUpdatedEvent, nameof(GuildUpdated), before, guild).ConfigureAwait(false); } else { @@ -598,7 +602,7 @@ namespace Discord.WebSocket { var before = guild.Clone(); guild.Update(State, data); - await _guildUpdatedEvent.InvokeAsync(before, guild).ConfigureAwait(false); + await TimedInvokeAsync(_guildUpdatedEvent, nameof(GuildUpdated), before, guild).ConfigureAwait(false); } else { @@ -620,7 +624,7 @@ namespace Discord.WebSocket _unavailableGuilds--; _lastGuildAvailableTime = Environment.TickCount; await GuildAvailableAsync(guild).ConfigureAwait(false); - await _guildUpdatedEvent.InvokeAsync(before, guild).ConfigureAwait(false); + await TimedInvokeAsync(_guildUpdatedEvent, nameof(GuildUpdated), before, guild).ConfigureAwait(false); } else { @@ -657,7 +661,7 @@ namespace Discord.WebSocket if (guild != null) { await GuildUnavailableAsync(guild).ConfigureAwait(false); - await _leftGuildEvent.InvokeAsync(guild).ConfigureAwait(false); + await TimedInvokeAsync(_leftGuildEvent, nameof(LeftGuild), guild).ConfigureAwait(false); } else { @@ -698,7 +702,7 @@ namespace Discord.WebSocket channel = AddPrivateChannel(data, State) as SocketChannel; if (channel != null) - await _channelCreatedEvent.InvokeAsync(channel).ConfigureAwait(false); + await TimedInvokeAsync(_channelCreatedEvent, nameof(ChannelCreated), channel).ConfigureAwait(false); } break; case "CHANNEL_UPDATE": @@ -718,7 +722,7 @@ namespace Discord.WebSocket return; } - await _channelUpdatedEvent.InvokeAsync(before, channel).ConfigureAwait(false); + await TimedInvokeAsync(_channelUpdatedEvent, nameof(ChannelUpdated), before, channel).ConfigureAwait(false); } else { @@ -756,7 +760,7 @@ namespace Discord.WebSocket channel = RemovePrivateChannel(data.Id) as SocketChannel; if (channel != null) - await _channelDestroyedEvent.InvokeAsync(channel).ConfigureAwait(false); + await TimedInvokeAsync(_channelDestroyedEvent, nameof(ChannelDestroyed), channel).ConfigureAwait(false); else { await _gatewayLogger.WarningAsync("CHANNEL_DELETE referenced an unknown channel.").ConfigureAwait(false); @@ -783,7 +787,7 @@ namespace Discord.WebSocket return; } - await _userJoinedEvent.InvokeAsync(user).ConfigureAwait(false); + await TimedInvokeAsync(_userJoinedEvent, nameof(UserJoined), user).ConfigureAwait(false); } else { @@ -812,7 +816,7 @@ namespace Discord.WebSocket { var before = user.Clone(); user.Update(State, data); - await _guildMemberUpdatedEvent.InvokeAsync(before, user).ConfigureAwait(false); + await TimedInvokeAsync(_guildMemberUpdatedEvent, nameof(GuildMemberUpdated), before, user).ConfigureAwait(false); } else { @@ -851,7 +855,7 @@ namespace Discord.WebSocket } if (user != null) - await _userLeftEvent.InvokeAsync(user).ConfigureAwait(false); + await TimedInvokeAsync(_userLeftEvent, nameof(UserLeft), user).ConfigureAwait(false); else { if (!guild.HasAllMembers) @@ -885,7 +889,7 @@ namespace Discord.WebSocket if (guild.DownloadedMemberCount >= guild.MemberCount) //Finished downloading for there { guild.CompleteDownloadUsers(); - await _guildMembersDownloadedEvent.InvokeAsync(guild).ConfigureAwait(false); + await TimedInvokeAsync(_guildMembersDownloadedEvent, nameof(GuildMembersDownloaded), guild).ConfigureAwait(false); } } else @@ -904,7 +908,7 @@ namespace Discord.WebSocket if (channel != null) { var user = channel.AddUser(data.User); - await _recipientAddedEvent.InvokeAsync(user).ConfigureAwait(false); + await TimedInvokeAsync(_recipientAddedEvent, nameof(RecipientAdded), user).ConfigureAwait(false); } else { @@ -923,7 +927,7 @@ namespace Discord.WebSocket { var user = channel.RemoveUser(data.User.Id); if (user != null) - await _recipientRemovedEvent.InvokeAsync(user).ConfigureAwait(false); + await TimedInvokeAsync(_recipientRemovedEvent, nameof(RecipientRemoved), user).ConfigureAwait(false); else { await _gatewayLogger.WarningAsync("CHANNEL_RECIPIENT_REMOVE referenced an unknown user.").ConfigureAwait(false); @@ -954,7 +958,7 @@ namespace Discord.WebSocket await _gatewayLogger.DebugAsync("Ignored GUILD_ROLE_CREATE, guild is not synced yet.").ConfigureAwait(false); return; } - await _roleCreatedEvent.InvokeAsync(role).ConfigureAwait(false); + await TimedInvokeAsync(_roleCreatedEvent, nameof(RoleCreated), role).ConfigureAwait(false); } else { @@ -983,7 +987,7 @@ namespace Discord.WebSocket return; } - await _roleUpdatedEvent.InvokeAsync(before, role).ConfigureAwait(false); + await TimedInvokeAsync(_roleUpdatedEvent, nameof(RoleUpdated), before, role).ConfigureAwait(false); } else { @@ -1015,7 +1019,7 @@ namespace Discord.WebSocket return; } - await _roleDeletedEvent.InvokeAsync(role).ConfigureAwait(false); + await TimedInvokeAsync(_roleDeletedEvent, nameof(RoleDeleted), role).ConfigureAwait(false); } else { @@ -1049,7 +1053,7 @@ namespace Discord.WebSocket SocketUser user = guild.GetUser(data.User.Id); if (user == null) user = SocketUnknownUser.Create(this, State, data.User); - await _userBannedEvent.InvokeAsync(user, guild).ConfigureAwait(false); + await TimedInvokeAsync(_userBannedEvent, nameof(UserBanned), user, guild).ConfigureAwait(false); } else { @@ -1075,7 +1079,7 @@ namespace Discord.WebSocket SocketUser user = State.GetUser(data.User.Id); if (user == null) user = SocketUnknownUser.Create(this, State, data.User); - await _userUnbannedEvent.InvokeAsync(user, guild).ConfigureAwait(false); + await TimedInvokeAsync(_userUnbannedEvent, nameof(UserUnbanned), user, guild).ConfigureAwait(false); } else { @@ -1116,7 +1120,7 @@ namespace Discord.WebSocket { var msg = SocketMessage.Create(this, State, author, channel, data); SocketChannelHelper.AddMessage(channel, this, msg); - await _messageReceivedEvent.InvokeAsync(msg).ConfigureAwait(false); + await TimedInvokeAsync(_messageReceivedEvent, nameof(MessageReceived), msg).ConfigureAwait(false); } else { @@ -1170,7 +1174,7 @@ namespace Discord.WebSocket } var cacheableBefore = new Cacheable(before, data.Id, isCached , async () => await channel.GetMessageAsync(data.Id)); - await _messageUpdatedEvent.InvokeAsync(cacheableBefore, after, channel).ConfigureAwait(false); + await TimedInvokeAsync(_messageUpdatedEvent, nameof(MessageUpdated), cacheableBefore, after, channel).ConfigureAwait(false); } else { @@ -1197,7 +1201,7 @@ namespace Discord.WebSocket bool isCached = msg != null; var cacheable = new Cacheable(msg, data.Id, isCached, async () => await channel.GetMessageAsync(data.Id)); - await _messageDeletedEvent.InvokeAsync(cacheable, channel).ConfigureAwait(false); + await TimedInvokeAsync(_messageDeletedEvent, nameof(MessageDeleted), cacheable, channel).ConfigureAwait(false); } else { @@ -1222,7 +1226,7 @@ namespace Discord.WebSocket cachedMsg?.AddReaction(reaction); - await _reactionAddedEvent.InvokeAsync(cacheable, channel, reaction).ConfigureAwait(false); + await TimedInvokeAsync(_reactionAddedEvent, nameof(ReactionAdded), cacheable, channel, reaction).ConfigureAwait(false); } else { @@ -1247,7 +1251,7 @@ namespace Discord.WebSocket cachedMsg?.RemoveReaction(reaction); - await _reactionRemovedEvent.InvokeAsync(cacheable, channel, reaction).ConfigureAwait(false); + await TimedInvokeAsync(_reactionRemovedEvent, nameof(ReactionRemoved), cacheable, channel, reaction).ConfigureAwait(false); } else { @@ -1270,7 +1274,7 @@ namespace Discord.WebSocket cachedMsg?.ClearReactions(); - await _reactionsClearedEvent.InvokeAsync(cacheable, channel).ConfigureAwait(false); + await TimedInvokeAsync(_reactionsClearedEvent, nameof(ReactionsCleared), cacheable, channel).ConfigureAwait(false); } else { @@ -1298,7 +1302,7 @@ namespace Discord.WebSocket var msg = SocketChannelHelper.RemoveMessage(channel, this, id); bool isCached = msg != null; var cacheable = new Cacheable(msg, id, isCached, async () => await channel.GetMessageAsync(id)); - await _messageDeletedEvent.InvokeAsync(cacheable, channel).ConfigureAwait(false); + await TimedInvokeAsync(_messageDeletedEvent, nameof(MessageDeleted), cacheable, channel).ConfigureAwait(false); } } else @@ -1339,7 +1343,7 @@ namespace Discord.WebSocket var before = globalUser.Clone(); globalUser.Update(State, data); - await _userUpdatedEvent.InvokeAsync(before, globalUser).ConfigureAwait(false); + await TimedInvokeAsync(_userUpdatedEvent, nameof(UserUpdated), before, globalUser).ConfigureAwait(false); } break; case "TYPING_START": @@ -1358,7 +1362,7 @@ namespace Discord.WebSocket var user = (channel as SocketChannel).GetUser(data.UserId); if (user != null) - await _userIsTypingEvent.InvokeAsync(user, channel).ConfigureAwait(false); + await TimedInvokeAsync(_userIsTypingEvent, nameof(UserIsTyping), user, channel).ConfigureAwait(false); } } break; @@ -1373,7 +1377,7 @@ namespace Discord.WebSocket { var before = CurrentUser.Clone(); CurrentUser.Update(State, data); - await _selfUpdatedEvent.InvokeAsync(before, CurrentUser).ConfigureAwait(false); + await TimedInvokeAsync(_selfUpdatedEvent, nameof(CurrentUserUpdated), before, CurrentUser).ConfigureAwait(false); } else { @@ -1449,7 +1453,7 @@ namespace Discord.WebSocket } if (user != null) - await _userVoiceStateUpdatedEvent.InvokeAsync(user, before, after).ConfigureAwait(false); + await TimedInvokeAsync(_userVoiceStateUpdatedEvent, nameof(UserVoiceStateUpdated), user, before, after).ConfigureAwait(false); else { await _gatewayLogger.WarningAsync("VOICE_STATE_UPDATE referenced an unknown user.").ConfigureAwait(false); @@ -1631,7 +1635,7 @@ namespace Discord.WebSocket if (!guild.IsConnected) { guild.IsConnected = true; - await _guildAvailableEvent.InvokeAsync(guild).ConfigureAwait(false); + await TimedInvokeAsync(_guildAvailableEvent, nameof(GuildAvailable), guild).ConfigureAwait(false); } } private async Task GuildUnavailableAsync(SocketGuild guild) @@ -1639,7 +1643,85 @@ namespace Discord.WebSocket if (guild.IsConnected) { guild.IsConnected = false; - await _guildUnavailableEvent.InvokeAsync(guild).ConfigureAwait(false); + await TimedInvokeAsync(_guildUnavailableEvent, nameof(GuildUnavailable), guild).ConfigureAwait(false); + } + } + + private async Task TimedInvokeAsync(AsyncEvent> eventHandler, string name) + { + if (eventHandler.HasSubscribers) + { + if (EnableHandlerTimeouts) + await TimeoutWrap(name, () => eventHandler.InvokeAsync()).ConfigureAwait(false); + else + await eventHandler.InvokeAsync().ConfigureAwait(false); + } + } + private async Task TimedInvokeAsync(AsyncEvent> eventHandler, string name, T arg) + { + if (eventHandler.HasSubscribers) + { + if (EnableHandlerTimeouts) + await TimeoutWrap(name, () => eventHandler.InvokeAsync(arg)).ConfigureAwait(false); + else + await eventHandler.InvokeAsync(arg).ConfigureAwait(false); + } + } + private async Task TimedInvokeAsync(AsyncEvent> eventHandler, string name, T1 arg1, T2 arg2) + { + if (eventHandler.HasSubscribers) + { + if (EnableHandlerTimeouts) + await TimeoutWrap(name, () => eventHandler.InvokeAsync(arg1, arg2)).ConfigureAwait(false); + else + await eventHandler.InvokeAsync(arg1, arg2).ConfigureAwait(false); + } + } + private async Task TimedInvokeAsync(AsyncEvent> eventHandler, string name, T1 arg1, T2 arg2, T3 arg3) + { + if (eventHandler.HasSubscribers) + { + if (EnableHandlerTimeouts) + await TimeoutWrap(name, () => eventHandler.InvokeAsync(arg1, arg2, arg3)).ConfigureAwait(false); + else + await eventHandler.InvokeAsync(arg1, arg2, arg3).ConfigureAwait(false); + } + } + private async Task TimedInvokeAsync(AsyncEvent> eventHandler, string name, T1 arg1, T2 arg2, T3 arg3, T4 arg4) + { + if (eventHandler.HasSubscribers) + { + if (EnableHandlerTimeouts) + await TimeoutWrap(name, () => eventHandler.InvokeAsync(arg1, arg2, arg3, arg4)).ConfigureAwait(false); + else + await eventHandler.InvokeAsync(arg1, arg2, arg3, arg4).ConfigureAwait(false); + } + } + private async Task TimedInvokeAsync(AsyncEvent> eventHandler, string name, T1 arg1, T2 arg2, T3 arg3, T4 arg4, T5 arg5) + { + if (eventHandler.HasSubscribers) + { + if (EnableHandlerTimeouts) + await TimeoutWrap(name, () => eventHandler.InvokeAsync(arg1, arg2, arg3, arg4, arg5)).ConfigureAwait(false); + else + await eventHandler.InvokeAsync(arg1, arg2, arg3, arg4, arg5).ConfigureAwait(false); + } + } + private async Task TimeoutWrap(string name, Func action) + { + try + { + var timeoutTask = Task.Delay(HandlerTimeoutMillis); + var handlersTask = action(); + if (await Task.WhenAny(timeoutTask, handlersTask).ConfigureAwait(false) == timeoutTask) + { + await _gatewayLogger.WarningAsync($"A {name} handler is blocking the gateway task.").ConfigureAwait(false); + await handlersTask.ConfigureAwait(false); //Ensure the handler completes + } + } + catch (Exception ex) + { + await _gatewayLogger.WarningAsync($"A {name} handler has thrown an unhandled exception.", ex).ConfigureAwait(false); } } diff --git a/src/Discord.Net.WebSocket/DiscordSocketConfig.cs b/src/Discord.Net.WebSocket/DiscordSocketConfig.cs index 9ef030d72..add42ce80 100644 --- a/src/Discord.Net.WebSocket/DiscordSocketConfig.cs +++ b/src/Discord.Net.WebSocket/DiscordSocketConfig.cs @@ -33,6 +33,8 @@ namespace Discord.WebSocket /// Gets or sets whether or not all users should be downloaded as guilds come available. public bool AlwaysDownloadUsers { get; set; } = false; + /// Gets or sets whether or not warnings should be logged if an event handler is taking too long to execute. + public bool EnableHandlerTimeouts { get; set; } = true; public DiscordSocketConfig() {