diff --git a/src/Discord.Net/DiscordClient.cs b/src/Discord.Net/DiscordClient.cs index 3c7d4da84..b5793fd49 100644 --- a/src/Discord.Net/DiscordClient.cs +++ b/src/Discord.Net/DiscordClient.cs @@ -95,7 +95,12 @@ namespace Discord _api = new DiscordAPIClient(_config.LogLevel, _config.APITimeout); _dataSocket = new DataWebSocket(this); _dataSocket.Connected += (s, e) => { if (_state == (int)DiscordClientState.Connecting) CompleteConnect(); }; - _dataSocket.Disconnected += async (s, e) => { RaiseDisconnected(e); if (e.WasUnexpected) await _dataSocket.Login(_token); }; + _dataSocket.Disconnected += async (s, e) => + { + RaiseDisconnected(e); + if (e.WasUnexpected) + await _dataSocket.Reconnect(_token); + }; if (_config.EnableVoice) { _voiceSocket = new VoiceWebSocket(this); diff --git a/src/Discord.Net/Net/WebSockets/DataWebSocket.cs b/src/Discord.Net/Net/WebSockets/DataWebSocket.cs index 49da7bdb0..59e8ba9b8 100644 --- a/src/Discord.Net/Net/WebSockets/DataWebSocket.cs +++ b/src/Discord.Net/Net/WebSockets/DataWebSocket.cs @@ -20,7 +20,7 @@ namespace Discord.Net.WebSockets public async Task Login(string token) { - await Connect(); + await Connect().ConfigureAwait(false); Commands.Login msg = new Commands.Login(); msg.Payload.Token = token; @@ -29,14 +29,38 @@ namespace Discord.Net.WebSockets } private async Task Redirect(string server) { - await DisconnectInternal(isUnexpected: false); - await Connect(); + await DisconnectInternal(isUnexpected: false).ConfigureAwait(false); + await Connect().ConfigureAwait(false); var resumeMsg = new Commands.Resume(); resumeMsg.Payload.SessionId = _sessionId; resumeMsg.Payload.Sequence = _lastSeq; QueueMessage(resumeMsg); } + public async Task Reconnect(string token) + { + try + { + var cancelToken = ParentCancelToken; + await Task.Delay(_client.Config.ReconnectDelay, cancelToken).ConfigureAwait(false); + while (!cancelToken.IsCancellationRequested) + { + try + { + await Login(token).ConfigureAwait(false); + break; + } + catch (OperationCanceledException) { throw; } + catch (Exception ex) + { + RaiseOnLog(LogMessageSeverity.Error, $"Reconnect failed: {ex.GetBaseException().Message}"); + //Net is down? We can keep trying to reconnect until the user runs Disconnect() + await Task.Delay(_client.Config.FailedReconnectDelay, cancelToken).ConfigureAwait(false); + } + } + } + catch (OperationCanceledException) { } + } protected override async Task ProcessMessage(string json) { diff --git a/src/Discord.Net/Net/WebSockets/VoiceWebSocket.cs b/src/Discord.Net/Net/WebSockets/VoiceWebSocket.cs index 51b3c8679..b39a9fc37 100644 --- a/src/Discord.Net/Net/WebSockets/VoiceWebSocket.cs +++ b/src/Discord.Net/Net/WebSockets/VoiceWebSocket.cs @@ -67,7 +67,7 @@ namespace Discord.Net.WebSockets _sessionId = sessionId; _token = token; - await Connect(); + await Connect().ConfigureAwait(false); } public async Task Reconnect() { @@ -85,7 +85,7 @@ namespace Discord.Net.WebSockets catch (OperationCanceledException) { throw; } catch (Exception ex) { - RaiseOnLog(LogMessageSeverity.Error, $"DataSocket reconnect failed: {ex.GetBaseException().Message}"); + RaiseOnLog(LogMessageSeverity.Error, $"Reconnect failed: {ex.GetBaseException().Message}"); //Net is down? We can keep trying to reconnect until the user runs Disconnect() await Task.Delay(_client.Config.FailedReconnectDelay, cancelToken).ConfigureAwait(false); } @@ -125,7 +125,7 @@ namespace Discord.Net.WebSockets #endif }.Concat(base.Run()).ToArray(); } - protected override Task Cleanup(bool wasUnexpected) + protected override Task Cleanup() { #if USE_THREAD _sendThread.Join(); @@ -133,7 +133,7 @@ namespace Discord.Net.WebSockets #endif ClearPCMFrames(); - if (!wasUnexpected) + if (!_wasDisconnectUnexpected) { _serverId = null; _userId = null; @@ -142,7 +142,7 @@ namespace Discord.Net.WebSockets } _udp = null; - return base.Cleanup(wasUnexpected); + return base.Cleanup(); } private async Task ReceiveVoiceAsync() diff --git a/src/Discord.Net/Net/WebSockets/WebSocket.cs b/src/Discord.Net/Net/WebSockets/WebSocket.cs index 6fc5e1731..6728dab4e 100644 --- a/src/Discord.Net/Net/WebSockets/WebSocket.cs +++ b/src/Discord.Net/Net/WebSockets/WebSocket.cs @@ -47,8 +47,9 @@ namespace Discord.Net.WebSockets public WebSocketState State => (WebSocketState)_state; protected int _state; - protected ExceptionDispatchInfo _disconnectReason; - private bool _wasDisconnectUnexpected; + protected ExceptionDispatchInfo _disconnectReason; + protected bool _wasDisconnectUnexpected; + protected WebSocketState _disconnectState; public CancellationToken ParentCancelToken { get; set; } public CancellationToken CancelToken => _cancelToken; @@ -78,9 +79,7 @@ namespace Discord.Net.WebSockets try { - await Disconnect().ConfigureAwait(false); - - _state = (int)WebSocketState.Connecting; + await Disconnect().ConfigureAwait(false); _cancelTokenSource = new CancellationTokenSource(); if (ParentCancelToken != null) @@ -91,12 +90,13 @@ namespace Discord.Net.WebSockets await _engine.Connect(Host, _cancelToken).ConfigureAwait(false); _lastHeartbeat = DateTime.UtcNow; + _state = (int)WebSocketState.Connecting; _runTask = RunTasks(); } - catch + catch (Exception ex) { - await Disconnect().ConfigureAwait(false); - throw; + await DisconnectInternal(ex, isUnexpected: false).ConfigureAwait(false); + throw; //Dont handle this exception internally, send up it upwards } } protected void CompleteConnect() @@ -108,33 +108,40 @@ namespace Discord.Net.WebSockets => Connect(_host, _cancelToken);*/ public Task Disconnect() => DisconnectInternal(new Exception("Disconnect was requested by user."), isUnexpected: false); - protected Task DisconnectInternal(Exception ex = null, bool isUnexpected = true, bool skipAwait = false) + protected async Task DisconnectInternal(Exception ex = null, bool isUnexpected = true, bool skipAwait = false) { int oldState; bool hasWriterLock; //If in either connecting or connected state, get a lock by being the first to switch to disconnecting oldState = Interlocked.CompareExchange(ref _state, (int)WebSocketState.Disconnecting, (int)WebSocketState.Connecting); - if (oldState == (int)WebSocketState.Disconnected) return TaskHelper.CompletedTask; //Already disconnected + if (oldState == (int)WebSocketState.Disconnected) return; //Already disconnected hasWriterLock = oldState == (int)WebSocketState.Connecting; //Caused state change if (!hasWriterLock) { oldState = Interlocked.CompareExchange(ref _state, (int)WebSocketState.Disconnecting, (int)WebSocketState.Connected); - if (oldState == (int)WebSocketState.Disconnected) return TaskHelper.CompletedTask; //Already disconnected + if (oldState == (int)WebSocketState.Disconnected) return; //Already disconnected hasWriterLock = oldState == (int)WebSocketState.Connected; //Caused state change } if (hasWriterLock) { _wasDisconnectUnexpected = isUnexpected; + _disconnectState = (WebSocketState)oldState; _disconnectReason = ex != null ? ExceptionDispatchInfo.Capture(ex) : null; + + if (_disconnectState == WebSocketState.Connecting) //_runTask was never made + await Cleanup(); _cancelTokenSource.Cancel(); } if (!skipAwait) - return _runTask ?? TaskHelper.CompletedTask; + { + Task task = _runTask ?? TaskHelper.CompletedTask; + await task; + } else - return TaskHelper.CompletedTask; + await TaskHelper.CompletedTask; } protected virtual async Task RunTasks() @@ -143,19 +150,19 @@ namespace Discord.Net.WebSockets Task firstTask = Task.WhenAny(tasks); Task allTasks = Task.WhenAll(tasks); + //Wait until the first task ends/errors and capture the error try { await firstTask.ConfigureAwait(false); } catch (Exception ex) { await DisconnectInternal(ex: ex, skipAwait: true).ConfigureAwait(false); } - //When the first task ends, make sure the rest do too + //Ensure all other tasks are signaled to end. await DisconnectInternal(skipAwait: true); + + //Wait for the remaining tasks to complete try { await allTasks.ConfigureAwait(false); } catch { } - - bool wasUnexpected = _wasDisconnectUnexpected; - _wasDisconnectUnexpected = false; - await Cleanup(wasUnexpected).ConfigureAwait(false); - _runTask = null; + //Clean up state variables and raise disconnect event + await Cleanup().ConfigureAwait(false); } protected virtual Task[] Run() { @@ -164,12 +171,22 @@ namespace Discord.Net.WebSockets .Concat(new Task[] { HeartbeatAsync(cancelToken) }) .ToArray(); } - protected virtual Task Cleanup(bool wasUnexpected) + protected virtual async Task Cleanup() { + var disconnectState = _disconnectState; + _disconnectState = WebSocketState.Disconnected; + var wasDisconnectUnexpected = _wasDisconnectUnexpected; + _wasDisconnectUnexpected = false; + //Dont reset disconnectReason, we may called ThrowError() later + + await _engine.Disconnect(); _cancelTokenSource = null; - _state = (int)WebSocketState.Disconnected; - RaiseDisconnected(wasUnexpected, _disconnectReason?.SourceException); - return _engine.Disconnect(); + var oldState = _state; + _state = (int)WebSocketState.Disconnected; + _runTask = null; + + if (disconnectState == WebSocketState.Connected) + RaiseDisconnected(wasDisconnectUnexpected, _disconnectReason?.SourceException); } protected abstract Task ProcessMessage(string json);