@@ -37,30 +37,33 @@ namespace Discord.Net.WebSockets
protected readonly IWebSocketEngine _engine;
protected readonly DiscordClient _client;
protected readonly LogMessageSeverity _logLevel;
protected readonly ManualResetEventSlim _connectedEvent;
public string Host { get; set; }
protected ExceptionDispatchInfo _disconnectReason;
protected bool _wasDisconnectUnexpected;
protected WebSocketState _disconnectState;
protected int _loginTimeout, _heartbeatInterval;
private DateTime _lastHeartbeat;
private Task _runTask;
public WebSocketState State => (WebSocketState)_state;
protected int _state;
protected ExceptionDispatchInfo _disconnectReason;
private bool _wasDisconnectUnexpected;
public CancellationToken ParentCancelToken { get; set; }
public CancellationToken CancelToken => _cancelToken;
private CancellationTokenSource _cancelTokenSource;
protected CancellationToken _cancelToken;
public string Host { get; set; }
public WebSocketState State => (WebSocketState)_state;
protected int _state;
public WebSocket(DiscordClient client)
{
_client = client;
_logLevel = client.Config.LogLevel;
_loginTimeout = client.Config.ConnectionTimeout;
_cancelToken = new CancellationToken(true);
_connectedEvent = new ManualResetEventSlim(false);
_engine = new BuiltInWebSocketEngine(client.Config.WebSocketInterval);
_engine.ProcessMessage += async (s, e) =>
@@ -78,9 +81,7 @@ namespace Discord.Net.WebSockets
try
{
await Disconnect().ConfigureAwait(false);
_state = (int)WebSocketState.Connecting;
await Disconnect().ConfigureAwait(false);
_cancelTokenSource = new CancellationTokenSource();
if (ParentCancelToken != null)
@@ -91,50 +92,59 @@ namespace Discord.Net.WebSockets
await _engine.Connect(Host, _cancelToken).ConfigureAwait(false);
_lastHeartbeat = DateTime.UtcNow;
_state = (int)WebSocketState.Connecting;
_runTask = RunTasks();
}
catch
catch (Exception ex)
{
await Disconnect().ConfigureAwait(false);
throw;
await DisconnectInternal (ex, isUnexpected: false ).ConfigureAwait(false);
throw; //Dont handle this exception internally, send up it upwards
}
}
protected void CompleteConnect()
{
_state = (int)WebSocketState.Connected;
_connectedEvent.Set();
RaiseConnected();
}
/*public Task Reconnect(CancellationToken cancelToken)
=> Connect(_host, _cancelToken);*/
public Task Disconnect() => DisconnectInternal(new Exception("Disconnect was requested by user."), isUnexpected: false);
protected Task DisconnectInternal(Exception ex = null, bool isUnexpected = true, bool skipAwait = false)
protected async Task DisconnectInternal(Exception ex = null, bool isUnexpected = true, bool skipAwait = false)
{
int oldState;
bool hasWriterLock;
//If in either connecting or connected state, get a lock by being the first to switch to disconnecting
oldState = Interlocked.CompareExchange(ref _state, (int)WebSocketState.Disconnecting, (int)WebSocketState.Connecting);
if (oldState == (int)WebSocketState.Disconnected) return TaskHelper.CompletedTask ; //Already disconnected
if (oldState == (int)WebSocketState.Disconnected) return; //Already disconnected
hasWriterLock = oldState == (int)WebSocketState.Connecting; //Caused state change
if (!hasWriterLock)
{
oldState = Interlocked.CompareExchange(ref _state, (int)WebSocketState.Disconnecting, (int)WebSocketState.Connected);
if (oldState == (int)WebSocketState.Disconnected) return TaskHelper.CompletedTask ; //Already disconnected
if (oldState == (int)WebSocketState.Disconnected) return; //Already disconnected
hasWriterLock = oldState == (int)WebSocketState.Connected; //Caused state change
}
if (hasWriterLock)
{
_wasDisconnectUnexpected = isUnexpected;
_disconnectState = (WebSocketState)oldState;
_disconnectReason = ex != null ? ExceptionDispatchInfo.Capture(ex) : null;
if (_disconnectState == WebSocketState.Connecting) //_runTask was never made
await Cleanup();
_cancelTokenSource.Cancel();
}
if (!skipAwait)
return _runTask ?? TaskHelper.CompletedTask;
{
Task task = _runTask ?? TaskHelper.CompletedTask;
await task;
}
else
return TaskHelper.CompletedTask;
await TaskHelper.CompletedTask;
}
protected virtual async Task RunTasks()
@@ -143,19 +153,19 @@ namespace Discord.Net.WebSockets
Task firstTask = Task.WhenAny(tasks);
Task allTasks = Task.WhenAll(tasks);
//Wait until the first task ends/errors and capture the error
try { await firstTask.ConfigureAwait(false); }
catch (Exception ex) { await DisconnectInternal(ex: ex, skipAwait: true).ConfigureAwait(false); }
//When the first task ends, make sure the rest do too
//Ensure all other tasks are signaled to end.
await DisconnectInternal(skipAwait: true);
//Wait for the remaining tasks to complete
try { await allTasks.ConfigureAwait(false); }
catch { }
bool wasUnexpected = _wasDisconnectUnexpected;
_wasDisconnectUnexpected = false;
await Cleanup(wasUnexpected).ConfigureAwait(false);
_runTask = null ;
//Clean up state variables and raise disconnect event
await Cleanup().ConfigureAwait(false);
}
protected virtual Task[] Run()
{
@@ -164,12 +174,23 @@ namespace Discord.Net.WebSockets
.Concat(new Task[] { HeartbeatAsync(cancelToken) })
.ToArray();
}
protected virtual Task Cleanup(bool wasUnexpected )
protected virtual async Task Cleanup()
{
var disconnectState = _disconnectState;
_disconnectState = WebSocketState.Disconnected;
var wasDisconnectUnexpected = _wasDisconnectUnexpected;
_wasDisconnectUnexpected = false;
//Dont reset disconnectReason, we may called ThrowError() later
await _engine.Disconnect();
_cancelTokenSource = null;
_state = (int)WebSocketState.Disconnected;
RaiseDisconnected(wasUnexpected, _disconnectReason?.SourceException);
return _engine.Disconnect();
var oldState = _state;
_state = (int)WebSocketState.Disconnected;
_runTask = null;
_connectedEvent.Reset();
if (disconnectState == WebSocketState.Connected)
RaiseDisconnected(wasDisconnectUnexpected, _disconnectReason?.SourceException);
}
protected abstract Task ProcessMessage(string json);