Browse Source

Added Reconnect/Resume

tags/1.0-rc
RogueException 9 years ago
parent
commit
4cc393f963
7 changed files with 179 additions and 33 deletions
  1. +15
    -1
      src/Discord.Net/API/DiscordAPIClient.cs
  2. +2
    -2
      src/Discord.Net/API/Gateway/GatewayOpCode.cs
  3. +1
    -1
      src/Discord.Net/API/Gateway/ResumeParams.cs
  4. +107
    -9
      src/Discord.Net/DiscordSocketClient.cs
  5. +16
    -0
      src/Discord.Net/Net/WebSocketException.cs
  6. +37
    -20
      src/Discord.Net/Net/WebSockets/DefaultWebsocketClient.cs
  7. +1
    -0
      src/Discord.Net/Net/WebSockets/IWebSocketClient.cs

+ 15
- 1
src/Discord.Net/API/DiscordAPIClient.cs View File

@@ -7,7 +7,6 @@ using Discord.Net.Queue;
using Discord.Net.Rest; using Discord.Net.Rest;
using Discord.Net.WebSockets; using Discord.Net.WebSockets;
using Newtonsoft.Json; using Newtonsoft.Json;
using Newtonsoft.Json.Linq;
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Collections.Immutable; using System.Collections.Immutable;
@@ -28,6 +27,7 @@ namespace Discord.API
public event Func<string, string, double, Task> SentRequest; public event Func<string, string, double, Task> SentRequest;
public event Func<int, Task> SentGatewayMessage; public event Func<int, Task> SentGatewayMessage;
public event Func<GatewayOpCode, int?, string, object, Task> ReceivedGatewayEvent; public event Func<GatewayOpCode, int?, string, object, Task> ReceivedGatewayEvent;
public event Func<Exception, Task> Disconnected;


private readonly RequestQueue _requestQueue; private readonly RequestQueue _requestQueue;
private readonly JsonSerializer _serializer; private readonly JsonSerializer _serializer;
@@ -75,6 +75,11 @@ namespace Discord.API
var msg = JsonConvert.DeserializeObject<WebSocketMessage>(text); var msg = JsonConvert.DeserializeObject<WebSocketMessage>(text);
await ReceivedGatewayEvent.RaiseAsync((GatewayOpCode)msg.Operation, msg.Sequence, msg.Type, msg.Payload).ConfigureAwait(false); await ReceivedGatewayEvent.RaiseAsync((GatewayOpCode)msg.Operation, msg.Sequence, msg.Type, msg.Payload).ConfigureAwait(false);
}; };
_gatewayClient.Closed += async ex =>
{
await DisconnectAsync().ConfigureAwait(false);
await Disconnected.RaiseAsync(ex).ConfigureAwait(false);
};
} }


_serializer = serializer ?? new JsonSerializer { ContractResolver = new DiscordContractResolver() }; _serializer = serializer ?? new JsonSerializer { ContractResolver = new DiscordContractResolver() };
@@ -363,6 +368,15 @@ namespace Discord.API
}; };
await SendGatewayAsync(GatewayOpCode.Identify, msg, options: options).ConfigureAwait(false); await SendGatewayAsync(GatewayOpCode.Identify, msg, options: options).ConfigureAwait(false);
} }
public async Task SendResumeAsync(string sessionId, int lastSeq, RequestOptions options = null)
{
var msg = new ResumeParams()
{
SessionId = sessionId,
Sequence = lastSeq
};
await SendGatewayAsync(GatewayOpCode.Resume, msg, options: options).ConfigureAwait(false);
}
public async Task SendHeartbeatAsync(int lastSeq, RequestOptions options = null) public async Task SendHeartbeatAsync(int lastSeq, RequestOptions options = null)
{ {
await SendGatewayAsync(GatewayOpCode.Heartbeat, lastSeq, options: options).ConfigureAwait(false); await SendGatewayAsync(GatewayOpCode.Heartbeat, lastSeq, options: options).ConfigureAwait(false);


+ 2
- 2
src/Discord.Net/API/Gateway/GatewayOpCode.cs View File

@@ -2,7 +2,7 @@
{ {
public enum GatewayOpCode : byte public enum GatewayOpCode : byte
{ {
/// <summary> C←S - Used to send most events. </summary>
/// <summary> S→C - Used to send most events. </summary>
Dispatch = 0, Dispatch = 0,
/// <summary> C↔S - Used to keep the connection alive and measure latency. </summary> /// <summary> C↔S - Used to keep the connection alive and measure latency. </summary>
Heartbeat = 1, Heartbeat = 1,
@@ -16,7 +16,7 @@
VoiceServerPing = 5, VoiceServerPing = 5,
/// <summary> C→S - Used to resume a connection after a redirect occurs. </summary> /// <summary> C→S - Used to resume a connection after a redirect occurs. </summary>
Resume = 6, Resume = 6,
/// <summary> C←S - Used to notify a client that they must reconnect to another gateway. </summary>
/// <summary> S→C - Used to notify a client that they must reconnect to another gateway. </summary>
Reconnect = 7, Reconnect = 7,
/// <summary> C→S - Used to request all members that were withheld by large_threshold </summary> /// <summary> C→S - Used to request all members that were withheld by large_threshold </summary>
RequestGuildMembers = 8, RequestGuildMembers = 8,


+ 1
- 1
src/Discord.Net/API/Gateway/ResumeParams.cs View File

@@ -7,6 +7,6 @@ namespace Discord.API.Gateway
[JsonProperty("session_id")] [JsonProperty("session_id")]
public string SessionId { get; set; } public string SessionId { get; set; }
[JsonProperty("seq")] [JsonProperty("seq")]
public uint Sequence { get; set; }
public int Sequence { get; set; }
} }
} }

+ 107
- 9
src/Discord.Net/DiscordSocketClient.cs View File

@@ -55,8 +55,9 @@ namespace Discord
private ImmutableDictionary<string, VoiceRegion> _voiceRegions; private ImmutableDictionary<string, VoiceRegion> _voiceRegions;
private TaskCompletionSource<bool> _connectTask; private TaskCompletionSource<bool> _connectTask;
private CancellationTokenSource _heartbeatCancelToken; private CancellationTokenSource _heartbeatCancelToken;
private Task _heartbeatTask;
private Task _heartbeatTask, _reconnectTask;
private long _heartbeatTime; private long _heartbeatTime;
private bool _isReconnecting;


/// <summary> Gets the shard if of this client. </summary> /// <summary> Gets the shard if of this client. </summary>
public int ShardId { get; } public int ShardId { get; }
@@ -64,9 +65,9 @@ namespace Discord
public ConnectionState ConnectionState { get; private set; } public ConnectionState ConnectionState { get; private set; }
/// <summary> Gets the estimated round-trip latency, in milliseconds, to the gateway server. </summary> /// <summary> Gets the estimated round-trip latency, in milliseconds, to the gateway server. </summary>
public int Latency { get; private set; } public int Latency { get; private set; }

internal IWebSocketClient GatewaySocket { get; private set; } internal IWebSocketClient GatewaySocket { get; private set; }
internal int MessageCacheSize { get; private set; } internal int MessageCacheSize { get; private set; }
//internal bool UsePermissionCache { get; private set; }
internal DataStore DataStore { get; private set; } internal DataStore DataStore { get; private set; }


internal CachedSelfUser CurrentUser => _currentUser as CachedSelfUser; internal CachedSelfUser CurrentUser => _currentUser as CachedSelfUser;
@@ -104,7 +105,6 @@ namespace Discord
_dataStoreProvider = config.DataStoreProvider; _dataStoreProvider = config.DataStoreProvider;


MessageCacheSize = config.MessageCacheSize; MessageCacheSize = config.MessageCacheSize;
//UsePermissionCache = config.UsePermissionsCache;
_enablePreUpdateEvents = config.EnablePreUpdateEvents; _enablePreUpdateEvents = config.EnablePreUpdateEvents;
_largeThreshold = config.LargeThreshold; _largeThreshold = config.LargeThreshold;
@@ -122,6 +122,16 @@ namespace Discord
ApiClient.SentGatewayMessage += async opCode => await _gatewayLogger.DebugAsync($"Sent {(GatewayOpCode)opCode}").ConfigureAwait(false); ApiClient.SentGatewayMessage += async opCode => await _gatewayLogger.DebugAsync($"Sent {(GatewayOpCode)opCode}").ConfigureAwait(false);
ApiClient.ReceivedGatewayEvent += ProcessMessageAsync; ApiClient.ReceivedGatewayEvent += ProcessMessageAsync;
ApiClient.Disconnected += async ex =>
{
if (ex != null)
{
await _gatewayLogger.WarningAsync($"Connection Closed: {ex.Message}").ConfigureAwait(false);
await StartReconnectAsync().ConfigureAwait(false);
}
else
await _gatewayLogger.WarningAsync($"Connection Closed").ConfigureAwait(false);
};
GatewaySocket = config.WebSocketProvider(); GatewaySocket = config.WebSocketProvider();


_voiceRegions = ImmutableDictionary.Create<string, VoiceRegion>(); _voiceRegions = ImmutableDictionary.Create<string, VoiceRegion>();
@@ -147,6 +157,7 @@ namespace Discord
await _connectionLock.WaitAsync().ConfigureAwait(false); await _connectionLock.WaitAsync().ConfigureAwait(false);
try try
{ {
_isReconnecting = false;
await ConnectInternalAsync().ConfigureAwait(false); await ConnectInternalAsync().ConfigureAwait(false);
} }
finally { _connectionLock.Release(); } finally { _connectionLock.Release(); }
@@ -157,6 +168,7 @@ namespace Discord
throw new InvalidOperationException("You must log in before connecting."); throw new InvalidOperationException("You must log in before connecting.");


ConnectionState = ConnectionState.Connecting; ConnectionState = ConnectionState.Connecting;
await _gatewayLogger.InfoAsync("Connecting");
try try
{ {
_connectTask = new TaskCompletionSource<bool>(); _connectTask = new TaskCompletionSource<bool>();
@@ -165,6 +177,7 @@ namespace Discord


await _connectTask.Task.ConfigureAwait(false); await _connectTask.Task.ConfigureAwait(false);
ConnectionState = ConnectionState.Connected; ConnectionState = ConnectionState.Connected;
await _gatewayLogger.InfoAsync("Connected");
} }
catch (Exception) catch (Exception)
{ {
@@ -180,6 +193,7 @@ namespace Discord
await _connectionLock.WaitAsync().ConfigureAwait(false); await _connectionLock.WaitAsync().ConfigureAwait(false);
try try
{ {
_isReconnecting = false;
await DisconnectInternalAsync().ConfigureAwait(false); await DisconnectInternalAsync().ConfigureAwait(false);
} }
finally { _connectionLock.Release(); } finally { _connectionLock.Release(); }
@@ -190,15 +204,62 @@ namespace Discord


if (ConnectionState == ConnectionState.Disconnected) return; if (ConnectionState == ConnectionState.Disconnected) return;
ConnectionState = ConnectionState.Disconnecting; ConnectionState = ConnectionState.Disconnecting;
await _gatewayLogger.InfoAsync("Disconnecting");


try { _heartbeatCancelToken.Cancel(); } catch { }
await ApiClient.DisconnectAsync().ConfigureAwait(false); await ApiClient.DisconnectAsync().ConfigureAwait(false);
await _heartbeatTask.ConfigureAwait(false); await _heartbeatTask.ConfigureAwait(false);
while (_largeGuilds.TryDequeue(out guildId)) { } while (_largeGuilds.TryDequeue(out guildId)) { }


ConnectionState = ConnectionState.Disconnected; ConnectionState = ConnectionState.Disconnected;
await _gatewayLogger.InfoAsync("Disconnected").ConfigureAwait(false);


await Disconnected.RaiseAsync().ConfigureAwait(false); await Disconnected.RaiseAsync().ConfigureAwait(false);
} }
private async Task StartReconnectAsync()
{
//TODO: Is this thread-safe?
while (true)
{
if (_reconnectTask != null) return;

await _connectionLock.WaitAsync().ConfigureAwait(false);
try
{
if (_reconnectTask != null) return;
_isReconnecting = true;
_reconnectTask = ReconnectInternalAsync();
}
finally { _connectionLock.Release(); }
}
}
private async Task ReconnectInternalAsync()
{
int nextReconnectDelay = 1000;
while (_isReconnecting)
{
try
{
await Task.Delay(nextReconnectDelay).ConfigureAwait(false);
nextReconnectDelay *= 2;
if (nextReconnectDelay > 30000)
nextReconnectDelay = 30000;

await _connectionLock.WaitAsync().ConfigureAwait(false);
try
{
await ConnectInternalAsync().ConfigureAwait(false);
}
finally { _connectionLock.Release(); }
return;
}
catch (Exception ex)
{
await _gatewayLogger.WarningAsync("Reconnect failed", ex).ConfigureAwait(false);
}
}
_reconnectTask = null;
}


/// <inheritdoc /> /// <inheritdoc />
public override Task<IVoiceRegion> GetVoiceRegionAsync(string id) public override Task<IVoiceRegion> GetVoiceRegionAsync(string id)
@@ -332,7 +393,10 @@ namespace Discord
await _gatewayLogger.DebugAsync("Received Hello").ConfigureAwait(false); await _gatewayLogger.DebugAsync("Received Hello").ConfigureAwait(false);
var data = (payload as JToken).ToObject<HelloEvent>(_serializer); var data = (payload as JToken).ToObject<HelloEvent>(_serializer);


await ApiClient.SendIdentifyAsync().ConfigureAwait(false);
if (_sessionId != null)
await ApiClient.SendResumeAsync(_sessionId, _lastSeq).ConfigureAwait(false);
else
await ApiClient.SendIdentifyAsync().ConfigureAwait(false);
_heartbeatTask = RunHeartbeatAsync(data.HeartbeatInterval, _heartbeatCancelToken.Token); _heartbeatTask = RunHeartbeatAsync(data.HeartbeatInterval, _heartbeatCancelToken.Token);
} }
break; break;
@@ -354,6 +418,24 @@ namespace Discord
await LatencyUpdated.RaiseAsync(latency).ConfigureAwait(false); await LatencyUpdated.RaiseAsync(latency).ConfigureAwait(false);
} }
break; break;
case GatewayOpCode.InvalidSession:
{
await _gatewayLogger.DebugAsync("Received InvalidSession").ConfigureAwait(false);
await _gatewayLogger.WarningAsync("Failed to resume previous session");

_sessionId = null;
_lastSeq = 0;
await ApiClient.SendIdentifyAsync().ConfigureAwait(false);
}
break;
case GatewayOpCode.Reconnect:
{
await _gatewayLogger.DebugAsync("Received Reconnect").ConfigureAwait(false);
await _gatewayLogger.WarningAsync("Server requested a reconnect");

await StartReconnectAsync().ConfigureAwait(false);
}
break;
case GatewayOpCode.Dispatch: case GatewayOpCode.Dispatch:
switch (type) switch (type)
{ {
@@ -380,6 +462,7 @@ namespace Discord
await Ready.RaiseAsync().ConfigureAwait(false); await Ready.RaiseAsync().ConfigureAwait(false);


_connectTask.TrySetResult(true); //Signal the .Connect() call to complete _connectTask.TrySetResult(true); //Signal the .Connect() call to complete
await _gatewayLogger.InfoAsync("Ready");
} }
break; break;


@@ -410,7 +493,11 @@ namespace Discord
} }
} }


await GuildAvailable.RaiseAsync(guild).ConfigureAwait(false);
if (data.Unavailable != true)
{
await _gatewayLogger.InfoAsync($"Connected to {data.Name}").ConfigureAwait(false);
await GuildAvailable.RaiseAsync(guild).ConfigureAwait(false);
}
} }
break; break;
case "GUILD_UPDATE": case "GUILD_UPDATE":
@@ -442,11 +529,17 @@ namespace Discord
var guild = RemoveGuild(data.Id); var guild = RemoveGuild(data.Id);
if (guild != null) if (guild != null)
{ {
foreach (var member in guild.Members)
member.User.RemoveRef();

await GuildUnavailable.RaiseAsync(guild).ConfigureAwait(false); await GuildUnavailable.RaiseAsync(guild).ConfigureAwait(false);
await _gatewayLogger.InfoAsync($"Disconnected from {data.Name}").ConfigureAwait(false);
if (data.Unavailable != true) if (data.Unavailable != true)
{
await LeftGuild.RaiseAsync(guild).ConfigureAwait(false); await LeftGuild.RaiseAsync(guild).ConfigureAwait(false);
foreach (var member in guild.Members)
member.User.RemoveRef();
await _gatewayLogger.InfoAsync($"Left {data.Name}").ConfigureAwait(false);
}

} }
else else
{ {
@@ -987,11 +1080,16 @@ namespace Discord
var state = ConnectionState; var state = ConnectionState;
while (state == ConnectionState.Connecting || state == ConnectionState.Connected) while (state == ConnectionState.Connecting || state == ConnectionState.Connected)
{ {
//if (_heartbeatTime != 0) //TODO: Connection lost, reconnect
await Task.Delay(intervalMillis, cancelToken).ConfigureAwait(false);


if (_heartbeatTime != 0) //Server never responded to our last heartbeat
{
await _gatewayLogger.WarningAsync("Server missed last heartbeat").ConfigureAwait(false);
await StartReconnectAsync().ConfigureAwait(false);
return;
}
_heartbeatTime = Environment.TickCount; _heartbeatTime = Environment.TickCount;
await ApiClient.SendHeartbeatAsync(_lastSeq).ConfigureAwait(false); await ApiClient.SendHeartbeatAsync(_lastSeq).ConfigureAwait(false);
await Task.Delay(intervalMillis, cancelToken).ConfigureAwait(false);
} }
} }
catch (OperationCanceledException) { } catch (OperationCanceledException) { }


+ 16
- 0
src/Discord.Net/Net/WebSocketException.cs View File

@@ -0,0 +1,16 @@
using System;
namespace Discord.Net
{
public class WebSocketClosedException : Exception
{
public int CloseCode { get; }
public string Reason { get; }

public WebSocketClosedException(int closeCode, string reason = null)
: base($"The server sent close {closeCode}{(reason != null ? $": \"{reason}\"" : "")}")
{
CloseCode = closeCode;
Reason = reason;
}
}
}

+ 37
- 20
src/Discord.Net/Net/WebSockets/DefaultWebsocketClient.cs View File

@@ -1,5 +1,6 @@
using Discord.Extensions; using Discord.Extensions;
using System; using System;
using System.Collections.Generic;
using System.ComponentModel; using System.ComponentModel;
using System.IO; using System.IO;
using System.Net.WebSockets; using System.Net.WebSockets;
@@ -17,9 +18,11 @@ namespace Discord.Net.WebSockets


public event Func<byte[], int, int, Task> BinaryMessage; public event Func<byte[], int, int, Task> BinaryMessage;
public event Func<string, Task> TextMessage; public event Func<string, Task> TextMessage;
private readonly ClientWebSocket _client;
public event Func<Exception, Task> Closed;
private readonly SemaphoreSlim _sendLock; private readonly SemaphoreSlim _sendLock;
private readonly Dictionary<string, string> _headers;
private ClientWebSocket _client;
private Task _task; private Task _task;
private CancellationTokenSource _cancelTokenSource; private CancellationTokenSource _cancelTokenSource;
private CancellationToken _cancelToken, _parentToken; private CancellationToken _cancelToken, _parentToken;
@@ -27,14 +30,11 @@ namespace Discord.Net.WebSockets


public DefaultWebSocketClient() public DefaultWebSocketClient()
{ {
_client = new ClientWebSocket();
_client.Options.Proxy = null;
_client.Options.KeepAliveInterval = TimeSpan.Zero;

_sendLock = new SemaphoreSlim(1, 1); _sendLock = new SemaphoreSlim(1, 1);
_cancelTokenSource = new CancellationTokenSource(); _cancelTokenSource = new CancellationTokenSource();
_cancelToken = CancellationToken.None; _cancelToken = CancellationToken.None;
_parentToken = CancellationToken.None; _parentToken = CancellationToken.None;
_headers = new Dictionary<string, string>();
} }
private void Dispose(bool disposing) private void Dispose(bool disposing)
{ {
@@ -58,6 +58,15 @@ namespace Discord.Net.WebSockets
_cancelTokenSource = new CancellationTokenSource(); _cancelTokenSource = new CancellationTokenSource();
_cancelToken = CancellationTokenSource.CreateLinkedTokenSource(_parentToken, _cancelTokenSource.Token).Token; _cancelToken = CancellationTokenSource.CreateLinkedTokenSource(_parentToken, _cancelTokenSource.Token).Token;


_client = new ClientWebSocket();
_client.Options.Proxy = null;
_client.Options.KeepAliveInterval = TimeSpan.Zero;
foreach (var header in _headers)
{
if (header.Value != null)
_client.Options.SetRequestHeader(header.Key, header.Value);
}

await _client.ConnectAsync(new Uri(host), _cancelToken).ConfigureAwait(false); await _client.ConnectAsync(new Uri(host), _cancelToken).ConfigureAwait(false);
_task = RunAsync(_cancelToken); _task = RunAsync(_cancelToken);
} }
@@ -66,7 +75,7 @@ namespace Discord.Net.WebSockets
//Assume locked //Assume locked
_cancelTokenSource.Cancel(); _cancelTokenSource.Cancel();
if (_client.State == WebSocketState.Open)
if (_client != null && _client.State == WebSocketState.Open)
{ {
try try
{ {
@@ -82,7 +91,7 @@ namespace Discord.Net.WebSockets


public void SetHeader(string key, string value) public void SetHeader(string key, string value)
{ {
_client.Options.SetRequestHeader(key, value);
_headers[key] = value;
} }
public void SetCancelToken(CancellationToken cancelToken) public void SetCancelToken(CancellationToken cancelToken)
{ {
@@ -148,28 +157,36 @@ namespace Discord.Net.WebSockets
throw new Exception("Connection timed out."); throw new Exception("Connection timed out.");
} }


if (result.MessageType == WebSocketMessageType.Close)
throw new WebSocketException((int)result.CloseStatus.Value, result.CloseStatusDescription);
else
if (result.Count > 0)
stream.Write(buffer.Array, 0, result.Count); stream.Write(buffer.Array, 0, result.Count);

} }
while (result == null || !result.EndOfMessage); while (result == null || !result.EndOfMessage);


var array = stream.ToArray(); var array = stream.ToArray();
if (result.MessageType == WebSocketMessageType.Binary)
await BinaryMessage.RaiseAsync(array, 0, array.Length).ConfigureAwait(false);
else if (result.MessageType == WebSocketMessageType.Text)
{
string text = Encoding.UTF8.GetString(array, 0, array.Length);
await TextMessage.RaiseAsync(text).ConfigureAwait(false);
}

stream.Position = 0; stream.Position = 0;
stream.SetLength(0); stream.SetLength(0);

switch (result.MessageType)
{
case WebSocketMessageType.Binary:
await BinaryMessage(array, 0, array.Length).ConfigureAwait(false);
break;
case WebSocketMessageType.Text:
string text = Encoding.UTF8.GetString(array, 0, array.Length);
await TextMessage(text).ConfigureAwait(false);
break;
case WebSocketMessageType.Close:
var _ = Closed(new WebSocketClosedException((int)result.CloseStatus, result.CloseStatusDescription));
return;
}
} }
} }
catch (OperationCanceledException) { } catch (OperationCanceledException) { }
catch (Exception ex)
{
//This cannot be awaited otherwise we'll deadlock when DiscordApiClient waits for this task to complete.
var _ = Closed(ex);
}
} }
} }
} }

+ 1
- 0
src/Discord.Net/Net/WebSockets/IWebSocketClient.cs View File

@@ -8,6 +8,7 @@ namespace Discord.Net.WebSockets
{ {
event Func<byte[], int, int, Task> BinaryMessage; event Func<byte[], int, int, Task> BinaryMessage;
event Func<string, Task> TextMessage; event Func<string, Task> TextMessage;
event Func<Exception, Task> Closed;


void SetHeader(string key, string value); void SetHeader(string key, string value);
void SetCancelToken(CancellationToken cancelToken); void SetCancelToken(CancellationToken cancelToken);


Loading…
Cancel
Save