diff --git a/src/Discord.Net.WebSocket/API/Gateway/ReadyEvent.cs b/src/Discord.Net.WebSocket/API/Gateway/ReadyEvent.cs index e0b5fc0b5..fb6670a90 100644 --- a/src/Discord.Net.WebSocket/API/Gateway/ReadyEvent.cs +++ b/src/Discord.Net.WebSocket/API/Gateway/ReadyEvent.cs @@ -20,6 +20,8 @@ namespace Discord.API.Gateway public User User { get; set; } [JsonProperty("session_id")] public string SessionId { get; set; } + [JsonProperty("resume_gateway_url")] + public string ResumeGatewayUrl { get; set; } [JsonProperty("read_state")] public ReadState[] ReadStates { get; set; } [JsonProperty("guilds")] diff --git a/src/Discord.Net.WebSocket/DiscordSocketApiClient.cs b/src/Discord.Net.WebSocket/DiscordSocketApiClient.cs index cca2de203..cf4a43967 100644 --- a/src/Discord.Net.WebSocket/DiscordSocketApiClient.cs +++ b/src/Discord.Net.WebSocket/DiscordSocketApiClient.cs @@ -28,6 +28,7 @@ namespace Discord.API private readonly bool _isExplicitUrl; private CancellationTokenSource _connectCancelToken; private string _gatewayUrl; + private string _resumeGatewayUrl; //Store our decompression streams for zlib shared state private MemoryStream _compressed; @@ -37,6 +38,14 @@ namespace Discord.API public ConnectionState ConnectionState { get; private set; } + /// + /// Sets the gateway URL used for resumes. + /// + public string ResumeGatewayUrl + { + set => _resumeGatewayUrl = FormatGatewayUrl(value); + } + public DiscordSocketApiClient(RestClientProvider restClientProvider, WebSocketProvider webSocketProvider, string userAgent, string url = null, RetryMode defaultRetryMode = RetryMode.AlwaysRetry, JsonSerializer serializer = null, bool useSystemClock = true, Func defaultRatelimitCallback = null) @@ -157,6 +166,17 @@ namespace Discord.API #endif } + /// + /// Appends necessary query parameters to the specified gateway URL. + /// + private static string FormatGatewayUrl(string gatewayUrl) + { + if (gatewayUrl == null) + return null; + + return $"{gatewayUrl}?v={DiscordConfig.APIVersion}&encoding={DiscordSocketConfig.GatewayEncoding}&compress=zlib-stream"; + } + public async Task ConnectAsync() { await _stateLock.WaitAsync().ConfigureAwait(false); @@ -191,24 +211,41 @@ namespace Discord.API if (WebSocketClient != null) WebSocketClient.SetCancelToken(_connectCancelToken.Token); - if (!_isExplicitUrl) + string gatewayUrl; + if (_resumeGatewayUrl == null) + { + if (!_isExplicitUrl) + { + // TODO: 'GetGatewayAsync' -> 'GetBotGatewayAsync', but it could just be hardcoded to 'wss://gateway.discord.gg/' + var gatewayResponse = await GetGatewayAsync().ConfigureAwait(false); + gatewayUrl = _gatewayUrl = FormatGatewayUrl(gatewayResponse.Url); + } + else + { + gatewayUrl = _gatewayUrl; + } + } + else { - var gatewayResponse = await GetGatewayAsync().ConfigureAwait(false); - _gatewayUrl = $"{gatewayResponse.Url}?v={DiscordConfig.APIVersion}&encoding={DiscordSocketConfig.GatewayEncoding}&compress=zlib-stream"; + gatewayUrl = _resumeGatewayUrl; } #if DEBUG_PACKETS - Console.WriteLine("Connecting to gateway: " + _gatewayUrl); + Console.WriteLine("Connecting to gateway: " + gatewayUrl); #endif - await WebSocketClient.ConnectAsync(_gatewayUrl).ConfigureAwait(false); + await WebSocketClient.ConnectAsync(gatewayUrl).ConfigureAwait(false); ConnectionState = ConnectionState.Connected; } catch { if (!_isExplicitUrl) + { + // TODO: '_gatewayUrl = null' doesn't do anything, it's never null checked _gatewayUrl = null; //Uncache in case the gateway url changed + } + await DisconnectInternalAsync().ConfigureAwait(false); throw; } diff --git a/src/Discord.Net.WebSocket/DiscordSocketClient.cs b/src/Discord.Net.WebSocket/DiscordSocketClient.cs index f0b50aa8f..84ee3a0e4 100644 --- a/src/Discord.Net.WebSocket/DiscordSocketClient.cs +++ b/src/Discord.Net.WebSocket/DiscordSocketClient.cs @@ -322,7 +322,6 @@ namespace Discord.WebSocket } private async Task OnDisconnectingAsync(Exception ex) { - await _gatewayLogger.DebugAsync("Disconnecting ApiClient").ConfigureAwait(false); await ApiClient.DisconnectAsync(ex).ConfigureAwait(false); @@ -353,6 +352,10 @@ namespace Discord.WebSocket if (guild.IsAvailable) await GuildUnavailableAsync(guild).ConfigureAwait(false); } + + _sessionId = null; + _lastSeq = 0; + ApiClient.ResumeGatewayUrl = null; } /// @@ -832,6 +835,7 @@ namespace Discord.WebSocket _sessionId = null; _lastSeq = 0; + ApiClient.ResumeGatewayUrl = null; if (_shardedClient != null) { @@ -889,6 +893,7 @@ namespace Discord.WebSocket AddPrivateChannel(data.PrivateChannels[i], state); _sessionId = data.SessionId; + ApiClient.ResumeGatewayUrl = data.ResumeGatewayUrl; _unavailableGuildCount = unavailableGuilds; CurrentUser = currentUser; _previousSessionUser = CurrentUser;