diff --git a/src/Discord.Net/API/DiscordAPIClient.cs b/src/Discord.Net/API/DiscordAPIClient.cs index 119de17d0..82fc625c1 100644 --- a/src/Discord.Net/API/DiscordAPIClient.cs +++ b/src/Discord.Net/API/DiscordAPIClient.cs @@ -7,7 +7,6 @@ using Discord.Net.Queue; using Discord.Net.Rest; using Discord.Net.WebSockets; using Newtonsoft.Json; -using Newtonsoft.Json.Linq; using System; using System.Collections.Generic; using System.Collections.Immutable; @@ -28,6 +27,7 @@ namespace Discord.API public event Func SentRequest; public event Func SentGatewayMessage; public event Func ReceivedGatewayEvent; + public event Func Disconnected; private readonly RequestQueue _requestQueue; private readonly JsonSerializer _serializer; @@ -75,6 +75,11 @@ namespace Discord.API var msg = JsonConvert.DeserializeObject(text); await ReceivedGatewayEvent.RaiseAsync((GatewayOpCode)msg.Operation, msg.Sequence, msg.Type, msg.Payload).ConfigureAwait(false); }; + _gatewayClient.Closed += async ex => + { + await DisconnectAsync().ConfigureAwait(false); + await Disconnected.RaiseAsync(ex).ConfigureAwait(false); + }; } _serializer = serializer ?? new JsonSerializer { ContractResolver = new DiscordContractResolver() }; @@ -363,6 +368,15 @@ namespace Discord.API }; await SendGatewayAsync(GatewayOpCode.Identify, msg, options: options).ConfigureAwait(false); } + public async Task SendResumeAsync(string sessionId, int lastSeq, RequestOptions options = null) + { + var msg = new ResumeParams() + { + SessionId = sessionId, + Sequence = lastSeq + }; + await SendGatewayAsync(GatewayOpCode.Resume, msg, options: options).ConfigureAwait(false); + } public async Task SendHeartbeatAsync(int lastSeq, RequestOptions options = null) { await SendGatewayAsync(GatewayOpCode.Heartbeat, lastSeq, options: options).ConfigureAwait(false); diff --git a/src/Discord.Net/API/Gateway/GatewayOpCode.cs b/src/Discord.Net/API/Gateway/GatewayOpCode.cs index ac1a21e1d..8b983383f 100644 --- a/src/Discord.Net/API/Gateway/GatewayOpCode.cs +++ b/src/Discord.Net/API/Gateway/GatewayOpCode.cs @@ -2,7 +2,7 @@ { public enum GatewayOpCode : byte { - /// C←S - Used to send most events. + /// S→C - Used to send most events. Dispatch = 0, /// C↔S - Used to keep the connection alive and measure latency. Heartbeat = 1, @@ -16,7 +16,7 @@ VoiceServerPing = 5, /// C→S - Used to resume a connection after a redirect occurs. Resume = 6, - /// C←S - Used to notify a client that they must reconnect to another gateway. + /// S→C - Used to notify a client that they must reconnect to another gateway. Reconnect = 7, /// C→S - Used to request all members that were withheld by large_threshold RequestGuildMembers = 8, diff --git a/src/Discord.Net/API/Gateway/ResumeParams.cs b/src/Discord.Net/API/Gateway/ResumeParams.cs index ba4489336..b10e312f2 100644 --- a/src/Discord.Net/API/Gateway/ResumeParams.cs +++ b/src/Discord.Net/API/Gateway/ResumeParams.cs @@ -7,6 +7,6 @@ namespace Discord.API.Gateway [JsonProperty("session_id")] public string SessionId { get; set; } [JsonProperty("seq")] - public uint Sequence { get; set; } + public int Sequence { get; set; } } } diff --git a/src/Discord.Net/DiscordSocketClient.cs b/src/Discord.Net/DiscordSocketClient.cs index 351dad850..923290821 100644 --- a/src/Discord.Net/DiscordSocketClient.cs +++ b/src/Discord.Net/DiscordSocketClient.cs @@ -55,8 +55,9 @@ namespace Discord private ImmutableDictionary _voiceRegions; private TaskCompletionSource _connectTask; private CancellationTokenSource _heartbeatCancelToken; - private Task _heartbeatTask; + private Task _heartbeatTask, _reconnectTask; private long _heartbeatTime; + private bool _isReconnecting; /// Gets the shard if of this client. public int ShardId { get; } @@ -64,9 +65,9 @@ namespace Discord public ConnectionState ConnectionState { get; private set; } /// Gets the estimated round-trip latency, in milliseconds, to the gateway server. public int Latency { get; private set; } + internal IWebSocketClient GatewaySocket { get; private set; } internal int MessageCacheSize { get; private set; } - //internal bool UsePermissionCache { get; private set; } internal DataStore DataStore { get; private set; } internal CachedSelfUser CurrentUser => _currentUser as CachedSelfUser; @@ -104,7 +105,6 @@ namespace Discord _dataStoreProvider = config.DataStoreProvider; MessageCacheSize = config.MessageCacheSize; - //UsePermissionCache = config.UsePermissionsCache; _enablePreUpdateEvents = config.EnablePreUpdateEvents; _largeThreshold = config.LargeThreshold; @@ -122,6 +122,16 @@ namespace Discord ApiClient.SentGatewayMessage += async opCode => await _gatewayLogger.DebugAsync($"Sent {(GatewayOpCode)opCode}").ConfigureAwait(false); ApiClient.ReceivedGatewayEvent += ProcessMessageAsync; + ApiClient.Disconnected += async ex => + { + if (ex != null) + { + await _gatewayLogger.WarningAsync($"Connection Closed: {ex.Message}").ConfigureAwait(false); + await StartReconnectAsync().ConfigureAwait(false); + } + else + await _gatewayLogger.WarningAsync($"Connection Closed").ConfigureAwait(false); + }; GatewaySocket = config.WebSocketProvider(); _voiceRegions = ImmutableDictionary.Create(); @@ -147,6 +157,7 @@ namespace Discord await _connectionLock.WaitAsync().ConfigureAwait(false); try { + _isReconnecting = false; await ConnectInternalAsync().ConfigureAwait(false); } finally { _connectionLock.Release(); } @@ -157,6 +168,7 @@ namespace Discord throw new InvalidOperationException("You must log in before connecting."); ConnectionState = ConnectionState.Connecting; + await _gatewayLogger.InfoAsync("Connecting"); try { _connectTask = new TaskCompletionSource(); @@ -165,6 +177,7 @@ namespace Discord await _connectTask.Task.ConfigureAwait(false); ConnectionState = ConnectionState.Connected; + await _gatewayLogger.InfoAsync("Connected"); } catch (Exception) { @@ -180,6 +193,7 @@ namespace Discord await _connectionLock.WaitAsync().ConfigureAwait(false); try { + _isReconnecting = false; await DisconnectInternalAsync().ConfigureAwait(false); } finally { _connectionLock.Release(); } @@ -190,15 +204,62 @@ namespace Discord if (ConnectionState == ConnectionState.Disconnected) return; ConnectionState = ConnectionState.Disconnecting; + await _gatewayLogger.InfoAsync("Disconnecting"); + try { _heartbeatCancelToken.Cancel(); } catch { } await ApiClient.DisconnectAsync().ConfigureAwait(false); await _heartbeatTask.ConfigureAwait(false); while (_largeGuilds.TryDequeue(out guildId)) { } ConnectionState = ConnectionState.Disconnected; + await _gatewayLogger.InfoAsync("Disconnected").ConfigureAwait(false); await Disconnected.RaiseAsync().ConfigureAwait(false); } + private async Task StartReconnectAsync() + { + //TODO: Is this thread-safe? + while (true) + { + if (_reconnectTask != null) return; + + await _connectionLock.WaitAsync().ConfigureAwait(false); + try + { + if (_reconnectTask != null) return; + _isReconnecting = true; + _reconnectTask = ReconnectInternalAsync(); + } + finally { _connectionLock.Release(); } + } + } + private async Task ReconnectInternalAsync() + { + int nextReconnectDelay = 1000; + while (_isReconnecting) + { + try + { + await Task.Delay(nextReconnectDelay).ConfigureAwait(false); + nextReconnectDelay *= 2; + if (nextReconnectDelay > 30000) + nextReconnectDelay = 30000; + + await _connectionLock.WaitAsync().ConfigureAwait(false); + try + { + await ConnectInternalAsync().ConfigureAwait(false); + } + finally { _connectionLock.Release(); } + return; + } + catch (Exception ex) + { + await _gatewayLogger.WarningAsync("Reconnect failed", ex).ConfigureAwait(false); + } + } + _reconnectTask = null; + } /// public override Task GetVoiceRegionAsync(string id) @@ -332,7 +393,10 @@ namespace Discord await _gatewayLogger.DebugAsync("Received Hello").ConfigureAwait(false); var data = (payload as JToken).ToObject(_serializer); - await ApiClient.SendIdentifyAsync().ConfigureAwait(false); + if (_sessionId != null) + await ApiClient.SendResumeAsync(_sessionId, _lastSeq).ConfigureAwait(false); + else + await ApiClient.SendIdentifyAsync().ConfigureAwait(false); _heartbeatTask = RunHeartbeatAsync(data.HeartbeatInterval, _heartbeatCancelToken.Token); } break; @@ -354,6 +418,24 @@ namespace Discord await LatencyUpdated.RaiseAsync(latency).ConfigureAwait(false); } break; + case GatewayOpCode.InvalidSession: + { + await _gatewayLogger.DebugAsync("Received InvalidSession").ConfigureAwait(false); + await _gatewayLogger.WarningAsync("Failed to resume previous session"); + + _sessionId = null; + _lastSeq = 0; + await ApiClient.SendIdentifyAsync().ConfigureAwait(false); + } + break; + case GatewayOpCode.Reconnect: + { + await _gatewayLogger.DebugAsync("Received Reconnect").ConfigureAwait(false); + await _gatewayLogger.WarningAsync("Server requested a reconnect"); + + await StartReconnectAsync().ConfigureAwait(false); + } + break; case GatewayOpCode.Dispatch: switch (type) { @@ -380,6 +462,7 @@ namespace Discord await Ready.RaiseAsync().ConfigureAwait(false); _connectTask.TrySetResult(true); //Signal the .Connect() call to complete + await _gatewayLogger.InfoAsync("Ready"); } break; @@ -410,7 +493,11 @@ namespace Discord } } - await GuildAvailable.RaiseAsync(guild).ConfigureAwait(false); + if (data.Unavailable != true) + { + await _gatewayLogger.InfoAsync($"Connected to {data.Name}").ConfigureAwait(false); + await GuildAvailable.RaiseAsync(guild).ConfigureAwait(false); + } } break; case "GUILD_UPDATE": @@ -442,11 +529,17 @@ namespace Discord var guild = RemoveGuild(data.Id); if (guild != null) { + foreach (var member in guild.Members) + member.User.RemoveRef(); + await GuildUnavailable.RaiseAsync(guild).ConfigureAwait(false); + await _gatewayLogger.InfoAsync($"Disconnected from {data.Name}").ConfigureAwait(false); if (data.Unavailable != true) + { await LeftGuild.RaiseAsync(guild).ConfigureAwait(false); - foreach (var member in guild.Members) - member.User.RemoveRef(); + await _gatewayLogger.InfoAsync($"Left {data.Name}").ConfigureAwait(false); + } + } else { @@ -987,11 +1080,16 @@ namespace Discord var state = ConnectionState; while (state == ConnectionState.Connecting || state == ConnectionState.Connected) { - //if (_heartbeatTime != 0) //TODO: Connection lost, reconnect + await Task.Delay(intervalMillis, cancelToken).ConfigureAwait(false); + if (_heartbeatTime != 0) //Server never responded to our last heartbeat + { + await _gatewayLogger.WarningAsync("Server missed last heartbeat").ConfigureAwait(false); + await StartReconnectAsync().ConfigureAwait(false); + return; + } _heartbeatTime = Environment.TickCount; await ApiClient.SendHeartbeatAsync(_lastSeq).ConfigureAwait(false); - await Task.Delay(intervalMillis, cancelToken).ConfigureAwait(false); } } catch (OperationCanceledException) { } diff --git a/src/Discord.Net/Net/WebSocketException.cs b/src/Discord.Net/Net/WebSocketException.cs new file mode 100644 index 000000000..d647b6c8c --- /dev/null +++ b/src/Discord.Net/Net/WebSocketException.cs @@ -0,0 +1,16 @@ +using System; +namespace Discord.Net +{ + public class WebSocketClosedException : Exception + { + public int CloseCode { get; } + public string Reason { get; } + + public WebSocketClosedException(int closeCode, string reason = null) + : base($"The server sent close {closeCode}{(reason != null ? $": \"{reason}\"" : "")}") + { + CloseCode = closeCode; + Reason = reason; + } + } +} diff --git a/src/Discord.Net/Net/WebSockets/DefaultWebsocketClient.cs b/src/Discord.Net/Net/WebSockets/DefaultWebsocketClient.cs index f8f8731d9..28d108cb3 100644 --- a/src/Discord.Net/Net/WebSockets/DefaultWebsocketClient.cs +++ b/src/Discord.Net/Net/WebSockets/DefaultWebsocketClient.cs @@ -1,5 +1,6 @@ using Discord.Extensions; using System; +using System.Collections.Generic; using System.ComponentModel; using System.IO; using System.Net.WebSockets; @@ -17,9 +18,11 @@ namespace Discord.Net.WebSockets public event Func BinaryMessage; public event Func TextMessage; - - private readonly ClientWebSocket _client; + public event Func Closed; + private readonly SemaphoreSlim _sendLock; + private readonly Dictionary _headers; + private ClientWebSocket _client; private Task _task; private CancellationTokenSource _cancelTokenSource; private CancellationToken _cancelToken, _parentToken; @@ -27,14 +30,11 @@ namespace Discord.Net.WebSockets public DefaultWebSocketClient() { - _client = new ClientWebSocket(); - _client.Options.Proxy = null; - _client.Options.KeepAliveInterval = TimeSpan.Zero; - _sendLock = new SemaphoreSlim(1, 1); _cancelTokenSource = new CancellationTokenSource(); _cancelToken = CancellationToken.None; _parentToken = CancellationToken.None; + _headers = new Dictionary(); } private void Dispose(bool disposing) { @@ -58,6 +58,15 @@ namespace Discord.Net.WebSockets _cancelTokenSource = new CancellationTokenSource(); _cancelToken = CancellationTokenSource.CreateLinkedTokenSource(_parentToken, _cancelTokenSource.Token).Token; + _client = new ClientWebSocket(); + _client.Options.Proxy = null; + _client.Options.KeepAliveInterval = TimeSpan.Zero; + foreach (var header in _headers) + { + if (header.Value != null) + _client.Options.SetRequestHeader(header.Key, header.Value); + } + await _client.ConnectAsync(new Uri(host), _cancelToken).ConfigureAwait(false); _task = RunAsync(_cancelToken); } @@ -66,7 +75,7 @@ namespace Discord.Net.WebSockets //Assume locked _cancelTokenSource.Cancel(); - if (_client.State == WebSocketState.Open) + if (_client != null && _client.State == WebSocketState.Open) { try { @@ -82,7 +91,7 @@ namespace Discord.Net.WebSockets public void SetHeader(string key, string value) { - _client.Options.SetRequestHeader(key, value); + _headers[key] = value; } public void SetCancelToken(CancellationToken cancelToken) { @@ -148,28 +157,36 @@ namespace Discord.Net.WebSockets throw new Exception("Connection timed out."); } - if (result.MessageType == WebSocketMessageType.Close) - throw new WebSocketException((int)result.CloseStatus.Value, result.CloseStatusDescription); - else + if (result.Count > 0) stream.Write(buffer.Array, 0, result.Count); - } while (result == null || !result.EndOfMessage); var array = stream.ToArray(); - if (result.MessageType == WebSocketMessageType.Binary) - await BinaryMessage.RaiseAsync(array, 0, array.Length).ConfigureAwait(false); - else if (result.MessageType == WebSocketMessageType.Text) - { - string text = Encoding.UTF8.GetString(array, 0, array.Length); - await TextMessage.RaiseAsync(text).ConfigureAwait(false); - } - stream.Position = 0; stream.SetLength(0); + + switch (result.MessageType) + { + case WebSocketMessageType.Binary: + await BinaryMessage(array, 0, array.Length).ConfigureAwait(false); + break; + case WebSocketMessageType.Text: + string text = Encoding.UTF8.GetString(array, 0, array.Length); + await TextMessage(text).ConfigureAwait(false); + break; + case WebSocketMessageType.Close: + var _ = Closed(new WebSocketClosedException((int)result.CloseStatus, result.CloseStatusDescription)); + return; + } } } catch (OperationCanceledException) { } + catch (Exception ex) + { + //This cannot be awaited otherwise we'll deadlock when DiscordApiClient waits for this task to complete. + var _ = Closed(ex); + } } } } diff --git a/src/Discord.Net/Net/WebSockets/IWebSocketClient.cs b/src/Discord.Net/Net/WebSockets/IWebSocketClient.cs index 583aaa06d..7eccaabf2 100644 --- a/src/Discord.Net/Net/WebSockets/IWebSocketClient.cs +++ b/src/Discord.Net/Net/WebSockets/IWebSocketClient.cs @@ -8,6 +8,7 @@ namespace Discord.Net.WebSockets { event Func BinaryMessage; event Func TextMessage; + event Func Closed; void SetHeader(string key, string value); void SetCancelToken(CancellationToken cancelToken);