From 203265cb652177b2c577d62f673efc8ee68f5f07 Mon Sep 17 00:00:00 2001 From: RogueException Date: Wed, 25 Jan 2017 12:21:58 -0400 Subject: [PATCH] Cleaned up ShardedClient, delayed connections --- ...ception.cs => WebSocketClosedException.cs} | 0 .../DiscordShardedClient.cs | 47 +++++--- .../DiscordSocketClient.cs | 106 ++++++++++-------- 3 files changed, 95 insertions(+), 58 deletions(-) rename src/Discord.Net.Core/Net/{WebSocketException.cs => WebSocketClosedException.cs} (100%) diff --git a/src/Discord.Net.Core/Net/WebSocketException.cs b/src/Discord.Net.Core/Net/WebSocketClosedException.cs similarity index 100% rename from src/Discord.Net.Core/Net/WebSocketException.cs rename to src/Discord.Net.Core/Net/WebSocketClosedException.cs diff --git a/src/Discord.Net.WebSocket/DiscordShardedClient.cs b/src/Discord.Net.WebSocket/DiscordShardedClient.cs index 4bda2b479..a32c46f10 100644 --- a/src/Discord.Net.WebSocket/DiscordShardedClient.cs +++ b/src/Discord.Net.WebSocket/DiscordShardedClient.cs @@ -5,12 +5,14 @@ using System.Collections.Generic; using System.IO; using System.Linq; using System.Threading.Tasks; +using System.Threading; namespace Discord.WebSocket { public partial class DiscordShardedClient : BaseDiscordClient, IDiscordClient { private readonly DiscordSocketConfig _baseConfig; + private readonly SemaphoreSlim _connectionGroupLock; private int[] _shardIds; private Dictionary _shardIdsToIndex; private DiscordSocketClient[] _shards; @@ -18,9 +20,9 @@ namespace Discord.WebSocket private bool _automaticShards; /// Gets the estimated round-trip latency, in milliseconds, to the gateway server. - public int Latency { get; private set; } - internal UserStatus Status => _shards[0].Status; - internal Game? Game => _shards[0].Game; + public int Latency => GetLatency(); + public UserStatus Status => _shards[0].Status; + public Game? Game => _shards[0].Game; internal new DiscordSocketApiClient ApiClient => base.ApiClient as DiscordSocketApiClient; public new SocketSelfUser CurrentUser { get { return base.CurrentUser as SocketSelfUser; } private set { base.CurrentUser = value; } } @@ -48,6 +50,7 @@ namespace Discord.WebSocket _shardIdsToIndex = new Dictionary(); config.DisplayInitialLog = false; _baseConfig = config; + _connectionGroupLock = new SemaphoreSlim(1, 1); if (config.TotalShards == null) _automaticShards = true; @@ -61,7 +64,7 @@ namespace Discord.WebSocket _shardIdsToIndex.Add(_shardIds[i], i); var newConfig = config.Clone(); newConfig.ShardId = _shardIds[i]; - _shards[i] = new DiscordSocketClient(newConfig); + _shards[i] = new DiscordSocketClient(newConfig, _connectionGroupLock); RegisterEvents(_shards[i]); } } @@ -83,7 +86,7 @@ namespace Discord.WebSocket var newConfig = _baseConfig.Clone(); newConfig.ShardId = _shardIds[i]; newConfig.TotalShards = _totalShards; - _shards[i] = new DiscordSocketClient(newConfig); + _shards[i] = new DiscordSocketClient(newConfig, _connectionGroupLock); RegisterEvents(_shards[i]); } } @@ -125,12 +128,11 @@ namespace Discord.WebSocket } private async Task ConnectInternalAsync(bool waitForGuilds) { - for (int i = 0; i < _shards.Length; i++) - { - await _shards[i].ConnectAsync(waitForGuilds).ConfigureAwait(false); - if (i == 0) - CurrentUser = _shards[i].CurrentUser; - } + await Task.WhenAll( + _shards.Select(x => x.ConnectAsync(waitForGuilds)) + ).ConfigureAwait(false); + + CurrentUser = _shards[0].CurrentUser; } /// public async Task DisconnectAsync() @@ -156,11 +158,11 @@ namespace Discord.WebSocket } private int GetShardIdFor(ulong guildId) => (int)((guildId >> 22) % (uint)_totalShards); - private int GetShardIdFor(IGuild guild) + public int GetShardIdFor(IGuild guild) => GetShardIdFor(guild.Id); private DiscordSocketClient GetShardFor(ulong guildId) => GetShard(GetShardIdFor(guildId)); - private DiscordSocketClient GetShardFor(IGuild guild) + public DiscordSocketClient GetShardFor(IGuild guild) => GetShardFor(guild.Id); /// @@ -269,6 +271,14 @@ namespace Discord.WebSocket } } + private int GetLatency() + { + int total = 0; + for (int i = 0; i < _shards.Length; i++) + total += _shards[i].Latency; + return (int)Math.Round(total / (double)_shards.Length); + } + public async Task SetStatusAsync(UserStatus status) { for (int i = 0; i < _shards.Length; i++) @@ -283,6 +293,17 @@ namespace Discord.WebSocket private void RegisterEvents(DiscordSocketClient client) { client.Log += (msg) => _logEvent.InvokeAsync(msg); + client.LoggedOut += () => + { + var state = LoginState; + if (state == LoginState.LoggedIn || state == LoginState.LoggingIn) + { + //Should only happen if token is changed + var _ = LogoutAsync(); //Signal the logout, fire and forget + } + return Task.Delay(0); + }; + client.ChannelCreated += (channel) => _channelCreatedEvent.InvokeAsync(channel); client.ChannelDestroyed += (channel) => _channelDestroyedEvent.InvokeAsync(channel); client.ChannelUpdated += (oldChannel, newChannel) => _channelUpdatedEvent.InvokeAsync(oldChannel, newChannel); diff --git a/src/Discord.Net.WebSocket/DiscordSocketClient.cs b/src/Discord.Net.WebSocket/DiscordSocketClient.cs index 58c27dccc..7591717b6 100644 --- a/src/Discord.Net.WebSocket/DiscordSocketClient.cs +++ b/src/Discord.Net.WebSocket/DiscordSocketClient.cs @@ -17,6 +17,7 @@ using System.Linq; using System.Threading; using System.Threading.Tasks; using GameModel = Discord.API.Game; +using Discord.Net; namespace Discord.WebSocket { @@ -25,6 +26,7 @@ namespace Discord.WebSocket private readonly ConcurrentQueue _largeGuilds; private readonly Logger _gatewayLogger; private readonly JsonSerializer _serializer; + private readonly SemaphoreSlim _connectionGroupLock; private string _sessionId; private int _lastSeq; @@ -69,8 +71,9 @@ namespace Discord.WebSocket /// Creates a new REST/WebSocket discord client. public DiscordSocketClient() : this(new DiscordSocketConfig()) { } /// Creates a new REST/WebSocket discord client. - public DiscordSocketClient(DiscordSocketConfig config) : this(config, CreateApiClient(config)) { } - private DiscordSocketClient(DiscordSocketConfig config, API.DiscordSocketApiClient client) + public DiscordSocketClient(DiscordSocketConfig config) : this(config, CreateApiClient(config), null) { } + internal DiscordSocketClient(DiscordSocketConfig config, SemaphoreSlim groupLock) : this(config, CreateApiClient(config), groupLock) { } + private DiscordSocketClient(DiscordSocketConfig config, API.DiscordSocketApiClient client, SemaphoreSlim groupLock) : base(config, client) { ShardId = config.ShardId ?? 0; @@ -86,6 +89,7 @@ namespace Discord.WebSocket _nextAudioId = 1; _gatewayLogger = LogManager.CreateLogger(ShardId == 0 && TotalShards == 1 ? "Gateway" : "Shard #" + ShardId); + _connectionGroupLock = groupLock; _serializer = new JsonSerializer { ContractResolver = new DiscordContractResolver() }; _serializer.Error += (s, e) => @@ -171,53 +175,65 @@ namespace Discord.WebSocket if (state == ConnectionState.Connecting || state == ConnectionState.Connected) await DisconnectInternalAsync(null, isReconnecting).ConfigureAwait(false); - ConnectionState = ConnectionState.Connecting; - await _gatewayLogger.InfoAsync("Connecting").ConfigureAwait(false); - + if (_connectionGroupLock != null) + await _connectionGroupLock.WaitAsync().ConfigureAwait(false); try { - var connectTask = new TaskCompletionSource(); - _connectTask = connectTask; - _cancelToken = new CancellationTokenSource(); - - //Abort connection on timeout - var _ = Task.Run(async () => + _canReconnect = true; + ConnectionState = ConnectionState.Connecting; + await _gatewayLogger.InfoAsync("Connecting").ConfigureAwait(false); + + try { - await Task.Delay(ConnectionTimeout).ConfigureAwait(false); - connectTask.TrySetException(new TimeoutException()); - }); + var connectTask = new TaskCompletionSource(); + _connectTask = connectTask; + _cancelToken = new CancellationTokenSource(); - await _gatewayLogger.DebugAsync("Connecting ApiClient").ConfigureAwait(false); - await ApiClient.ConnectAsync().ConfigureAwait(false); - await _gatewayLogger.DebugAsync("Raising Event").ConfigureAwait(false); - await _connectedEvent.InvokeAsync().ConfigureAwait(false); + //Abort connection on timeout + var _ = Task.Run(async () => + { + await Task.Delay(ConnectionTimeout).ConfigureAwait(false); + connectTask.TrySetException(new TimeoutException()); + }); - if (_sessionId != null) - { - await _gatewayLogger.DebugAsync("Resuming").ConfigureAwait(false); - await ApiClient.SendResumeAsync(_sessionId, _lastSeq).ConfigureAwait(false); - } - else - { - await _gatewayLogger.DebugAsync("Identifying").ConfigureAwait(false); - await ApiClient.SendIdentifyAsync(shardID: ShardId, totalShards: TotalShards).ConfigureAwait(false); - } + await _gatewayLogger.DebugAsync("Connecting ApiClient").ConfigureAwait(false); + await ApiClient.ConnectAsync().ConfigureAwait(false); + await _gatewayLogger.DebugAsync("Raising Event").ConfigureAwait(false); + await _connectedEvent.InvokeAsync().ConfigureAwait(false); - await _connectTask.Task.ConfigureAwait(false); + if (_sessionId != null) + { + await _gatewayLogger.DebugAsync("Resuming").ConfigureAwait(false); + await ApiClient.SendResumeAsync(_sessionId, _lastSeq).ConfigureAwait(false); + } + else + { + await _gatewayLogger.DebugAsync("Identifying").ConfigureAwait(false); + await ApiClient.SendIdentifyAsync(shardID: ShardId, totalShards: TotalShards).ConfigureAwait(false); + } - await _gatewayLogger.DebugAsync("Sending Status").ConfigureAwait(false); - await SendStatusAsync().ConfigureAwait(false); + await _connectTask.Task.ConfigureAwait(false); - await _gatewayLogger.DebugAsync("Raising Event").ConfigureAwait(false); - if (!isReconnecting) - _canReconnect = true; - ConnectionState = ConnectionState.Connected; - await _gatewayLogger.InfoAsync("Connected").ConfigureAwait(false); + await _gatewayLogger.DebugAsync("Sending Status").ConfigureAwait(false); + await SendStatusAsync().ConfigureAwait(false); + + await _gatewayLogger.DebugAsync("Raising Event").ConfigureAwait(false); + ConnectionState = ConnectionState.Connected; + await _gatewayLogger.InfoAsync("Connected").ConfigureAwait(false); + } + catch (Exception) + { + await DisconnectInternalAsync(null, isReconnecting).ConfigureAwait(false); + throw; + } } - catch (Exception) + finally { - await DisconnectInternalAsync(null, isReconnecting).ConfigureAwait(false); - throw; + if (_connectionGroupLock != null) + { + await Task.Delay(5000).ConfigureAwait(false); + _connectionGroupLock.Release(); + } } } /// @@ -290,13 +306,12 @@ namespace Discord.WebSocket private async Task StartReconnectAsync(Exception ex) { - if (ex == null) - { - if (_connectTask?.TrySetCanceled() ?? false) return; - } - else + if ((ex as WebSocketClosedException).CloseCode == 4004) //Bad Token { - if (_connectTask?.TrySetException(ex) ?? false) return; + _canReconnect = false; + _connectTask?.TrySetException(ex); + await LogoutAsync().ConfigureAwait(false); + return; } await _connectionLock.WaitAsync().ConfigureAwait(false); @@ -608,6 +623,7 @@ namespace Discord.WebSocket } catch (Exception ex) { + _canReconnect = false; _connectTask.TrySetException(new Exception("Processing READY failed", ex)); return; }