diff --git a/src/Discord.Net.WebSocket/API/Gateway/ReadyEvent.cs b/src/Discord.Net.WebSocket/API/Gateway/ReadyEvent.cs
index e0b5fc0b5..fb6670a90 100644
--- a/src/Discord.Net.WebSocket/API/Gateway/ReadyEvent.cs
+++ b/src/Discord.Net.WebSocket/API/Gateway/ReadyEvent.cs
@@ -20,6 +20,8 @@ namespace Discord.API.Gateway
public User User { get; set; }
[JsonProperty("session_id")]
public string SessionId { get; set; }
+ [JsonProperty("resume_gateway_url")]
+ public string ResumeGatewayUrl { get; set; }
[JsonProperty("read_state")]
public ReadState[] ReadStates { get; set; }
[JsonProperty("guilds")]
diff --git a/src/Discord.Net.WebSocket/DiscordShardedClient.cs b/src/Discord.Net.WebSocket/DiscordShardedClient.cs
index 9fc717762..dcee36736 100644
--- a/src/Discord.Net.WebSocket/DiscordShardedClient.cs
+++ b/src/Discord.Net.WebSocket/DiscordShardedClient.cs
@@ -139,9 +139,9 @@ namespace Discord.WebSocket
internal override async Task OnLoginAsync(TokenType tokenType, string token)
{
+ var botGateway = await GetBotGatewayAsync().ConfigureAwait(false);
if (_automaticShards)
{
- var botGateway = await GetBotGatewayAsync().ConfigureAwait(false);
_shardIds = Enumerable.Range(0, botGateway.Shards).ToArray();
_totalShards = _shardIds.Length;
_shards = new DiscordSocketClient[_shardIds.Length];
@@ -163,7 +163,12 @@ namespace Discord.WebSocket
//Assume thread safe: already in a connection lock
for (int i = 0; i < _shards.Length; i++)
+ {
+ // Set the gateway URL to the one returned by Discord, if a custom one isn't set.
+ _shards[i].ApiClient.GatewayUrl = botGateway.Url;
+
await _shards[i].LoginAsync(tokenType, token);
+ }
if(_defaultStickers.Length == 0 && _baseConfig.AlwaysDownloadDefaultStickers)
await DownloadDefaultStickersAsync().ConfigureAwait(false);
@@ -175,7 +180,12 @@ namespace Discord.WebSocket
if (_shards != null)
{
for (int i = 0; i < _shards.Length; i++)
+ {
+ // Reset the gateway URL set for the shard.
+ _shards[i].ApiClient.GatewayUrl = null;
+
await _shards[i].LogoutAsync();
+ }
}
if (_automaticShards)
diff --git a/src/Discord.Net.WebSocket/DiscordSocketApiClient.cs b/src/Discord.Net.WebSocket/DiscordSocketApiClient.cs
index cca2de203..465c47a1d 100644
--- a/src/Discord.Net.WebSocket/DiscordSocketApiClient.cs
+++ b/src/Discord.Net.WebSocket/DiscordSocketApiClient.cs
@@ -28,6 +28,7 @@ namespace Discord.API
private readonly bool _isExplicitUrl;
private CancellationTokenSource _connectCancelToken;
private string _gatewayUrl;
+ private string _resumeGatewayUrl;
//Store our decompression streams for zlib shared state
private MemoryStream _compressed;
@@ -37,6 +38,32 @@ namespace Discord.API
public ConnectionState ConnectionState { get; private set; }
+ ///
+ /// Sets the gateway URL used for identifies.
+ ///
+ ///
+ /// If a custom URL is set, setting this property does nothing.
+ ///
+ public string GatewayUrl
+ {
+ set
+ {
+ // Makes the sharded client not override the custom value.
+ if (_isExplicitUrl)
+ return;
+
+ _gatewayUrl = FormatGatewayUrl(value);
+ }
+ }
+
+ ///
+ /// Sets the gateway URL used for resumes.
+ ///
+ public string ResumeGatewayUrl
+ {
+ set => _resumeGatewayUrl = FormatGatewayUrl(value);
+ }
+
public DiscordSocketApiClient(RestClientProvider restClientProvider, WebSocketProvider webSocketProvider, string userAgent,
string url = null, RetryMode defaultRetryMode = RetryMode.AlwaysRetry, JsonSerializer serializer = null,
bool useSystemClock = true, Func defaultRatelimitCallback = null)
@@ -157,6 +184,17 @@ namespace Discord.API
#endif
}
+ ///
+ /// Appends necessary query parameters to the specified gateway URL.
+ ///
+ private static string FormatGatewayUrl(string gatewayUrl)
+ {
+ if (gatewayUrl == null)
+ return null;
+
+ return $"{gatewayUrl}?v={DiscordConfig.APIVersion}&encoding={DiscordSocketConfig.GatewayEncoding}&compress=zlib-stream";
+ }
+
public async Task ConnectAsync()
{
await _stateLock.WaitAsync().ConfigureAwait(false);
@@ -191,24 +229,32 @@ namespace Discord.API
if (WebSocketClient != null)
WebSocketClient.SetCancelToken(_connectCancelToken.Token);
- if (!_isExplicitUrl)
+ string gatewayUrl;
+ if (_resumeGatewayUrl == null)
+ {
+ if (!_isExplicitUrl && _gatewayUrl == null)
+ {
+ var gatewayResponse = await GetBotGatewayAsync().ConfigureAwait(false);
+ _gatewayUrl = FormatGatewayUrl(gatewayResponse.Url);
+ }
+
+ gatewayUrl = _gatewayUrl;
+ }
+ else
{
- var gatewayResponse = await GetGatewayAsync().ConfigureAwait(false);
- _gatewayUrl = $"{gatewayResponse.Url}?v={DiscordConfig.APIVersion}&encoding={DiscordSocketConfig.GatewayEncoding}&compress=zlib-stream";
+ gatewayUrl = _resumeGatewayUrl;
}
#if DEBUG_PACKETS
- Console.WriteLine("Connecting to gateway: " + _gatewayUrl);
+ Console.WriteLine("Connecting to gateway: " + gatewayUrl);
#endif
- await WebSocketClient.ConnectAsync(_gatewayUrl).ConfigureAwait(false);
+ await WebSocketClient.ConnectAsync(gatewayUrl).ConfigureAwait(false);
ConnectionState = ConnectionState.Connected;
}
catch
{
- if (!_isExplicitUrl)
- _gatewayUrl = null; //Uncache in case the gateway url changed
await DisconnectInternalAsync().ConfigureAwait(false);
throw;
}
diff --git a/src/Discord.Net.WebSocket/DiscordSocketClient.cs b/src/Discord.Net.WebSocket/DiscordSocketClient.cs
index 670ed4567..1cc35f761 100644
--- a/src/Discord.Net.WebSocket/DiscordSocketClient.cs
+++ b/src/Discord.Net.WebSocket/DiscordSocketClient.cs
@@ -322,7 +322,6 @@ namespace Discord.WebSocket
}
private async Task OnDisconnectingAsync(Exception ex)
{
-
await _gatewayLogger.DebugAsync("Disconnecting ApiClient").ConfigureAwait(false);
await ApiClient.DisconnectAsync(ex).ConfigureAwait(false);
@@ -353,6 +352,10 @@ namespace Discord.WebSocket
if (guild.IsAvailable)
await GuildUnavailableAsync(guild).ConfigureAwait(false);
}
+
+ _sessionId = null;
+ _lastSeq = 0;
+ ApiClient.ResumeGatewayUrl = null;
}
///
@@ -834,6 +837,7 @@ namespace Discord.WebSocket
_sessionId = null;
_lastSeq = 0;
+ ApiClient.ResumeGatewayUrl = null;
if (_shardedClient != null)
{
@@ -891,6 +895,7 @@ namespace Discord.WebSocket
AddPrivateChannel(data.PrivateChannels[i], state);
_sessionId = data.SessionId;
+ ApiClient.ResumeGatewayUrl = data.ResumeGatewayUrl;
_unavailableGuildCount = unavailableGuilds;
CurrentUser = currentUser;
_previousSessionUser = CurrentUser;