@@ -47,8 +47,9 @@ namespace Discord.Net.WebSockets
public WebSocketState State => (WebSocketState)_state;
public WebSocketState State => (WebSocketState)_state;
protected int _state;
protected int _state;
protected ExceptionDispatchInfo _disconnectReason;
private bool _wasDisconnectUnexpected;
protected ExceptionDispatchInfo _disconnectReason;
protected bool _wasDisconnectUnexpected;
protected WebSocketState _disconnectState;
public CancellationToken ParentCancelToken { get; set; }
public CancellationToken ParentCancelToken { get; set; }
public CancellationToken CancelToken => _cancelToken;
public CancellationToken CancelToken => _cancelToken;
@@ -78,9 +79,7 @@ namespace Discord.Net.WebSockets
try
try
{
{
await Disconnect().ConfigureAwait(false);
_state = (int)WebSocketState.Connecting;
await Disconnect().ConfigureAwait(false);
_cancelTokenSource = new CancellationTokenSource();
_cancelTokenSource = new CancellationTokenSource();
if (ParentCancelToken != null)
if (ParentCancelToken != null)
@@ -91,12 +90,13 @@ namespace Discord.Net.WebSockets
await _engine.Connect(Host, _cancelToken).ConfigureAwait(false);
await _engine.Connect(Host, _cancelToken).ConfigureAwait(false);
_lastHeartbeat = DateTime.UtcNow;
_lastHeartbeat = DateTime.UtcNow;
_state = (int)WebSocketState.Connecting;
_runTask = RunTasks();
_runTask = RunTasks();
}
}
catch
catch (Exception ex)
{
{
await Disconnect().ConfigureAwait(false);
throw;
await DisconnectInternal (ex, isUnexpected: false ).ConfigureAwait(false);
throw; //Dont handle this exception internally, send up it upwards
}
}
}
}
protected void CompleteConnect()
protected void CompleteConnect()
@@ -108,33 +108,40 @@ namespace Discord.Net.WebSockets
=> Connect(_host, _cancelToken);*/
=> Connect(_host, _cancelToken);*/
public Task Disconnect() => DisconnectInternal(new Exception("Disconnect was requested by user."), isUnexpected: false);
public Task Disconnect() => DisconnectInternal(new Exception("Disconnect was requested by user."), isUnexpected: false);
protected Task DisconnectInternal(Exception ex = null, bool isUnexpected = true, bool skipAwait = false)
protected async Task DisconnectInternal(Exception ex = null, bool isUnexpected = true, bool skipAwait = false)
{
{
int oldState;
int oldState;
bool hasWriterLock;
bool hasWriterLock;
//If in either connecting or connected state, get a lock by being the first to switch to disconnecting
//If in either connecting or connected state, get a lock by being the first to switch to disconnecting
oldState = Interlocked.CompareExchange(ref _state, (int)WebSocketState.Disconnecting, (int)WebSocketState.Connecting);
oldState = Interlocked.CompareExchange(ref _state, (int)WebSocketState.Disconnecting, (int)WebSocketState.Connecting);
if (oldState == (int)WebSocketState.Disconnected) return TaskHelper.CompletedTask ; //Already disconnected
if (oldState == (int)WebSocketState.Disconnected) return; //Already disconnected
hasWriterLock = oldState == (int)WebSocketState.Connecting; //Caused state change
hasWriterLock = oldState == (int)WebSocketState.Connecting; //Caused state change
if (!hasWriterLock)
if (!hasWriterLock)
{
{
oldState = Interlocked.CompareExchange(ref _state, (int)WebSocketState.Disconnecting, (int)WebSocketState.Connected);
oldState = Interlocked.CompareExchange(ref _state, (int)WebSocketState.Disconnecting, (int)WebSocketState.Connected);
if (oldState == (int)WebSocketState.Disconnected) return TaskHelper.CompletedTask ; //Already disconnected
if (oldState == (int)WebSocketState.Disconnected) return; //Already disconnected
hasWriterLock = oldState == (int)WebSocketState.Connected; //Caused state change
hasWriterLock = oldState == (int)WebSocketState.Connected; //Caused state change
}
}
if (hasWriterLock)
if (hasWriterLock)
{
{
_wasDisconnectUnexpected = isUnexpected;
_wasDisconnectUnexpected = isUnexpected;
_disconnectState = (WebSocketState)oldState;
_disconnectReason = ex != null ? ExceptionDispatchInfo.Capture(ex) : null;
_disconnectReason = ex != null ? ExceptionDispatchInfo.Capture(ex) : null;
if (_disconnectState == WebSocketState.Connecting) //_runTask was never made
await Cleanup();
_cancelTokenSource.Cancel();
_cancelTokenSource.Cancel();
}
}
if (!skipAwait)
if (!skipAwait)
return _runTask ?? TaskHelper.CompletedTask;
{
Task task = _runTask ?? TaskHelper.CompletedTask;
await task;
}
else
else
return TaskHelper.CompletedTask;
await TaskHelper.CompletedTask;
}
}
protected virtual async Task RunTasks()
protected virtual async Task RunTasks()
@@ -143,19 +150,19 @@ namespace Discord.Net.WebSockets
Task firstTask = Task.WhenAny(tasks);
Task firstTask = Task.WhenAny(tasks);
Task allTasks = Task.WhenAll(tasks);
Task allTasks = Task.WhenAll(tasks);
//Wait until the first task ends/errors and capture the error
try { await firstTask.ConfigureAwait(false); }
try { await firstTask.ConfigureAwait(false); }
catch (Exception ex) { await DisconnectInternal(ex: ex, skipAwait: true).ConfigureAwait(false); }
catch (Exception ex) { await DisconnectInternal(ex: ex, skipAwait: true).ConfigureAwait(false); }
//When the first task ends, make sure the rest do too
//Ensure all other tasks are signaled to end.
await DisconnectInternal(skipAwait: true);
await DisconnectInternal(skipAwait: true);
//Wait for the remaining tasks to complete
try { await allTasks.ConfigureAwait(false); }
try { await allTasks.ConfigureAwait(false); }
catch { }
catch { }
bool wasUnexpected = _wasDisconnectUnexpected;
_wasDisconnectUnexpected = false;
await Cleanup(wasUnexpected).ConfigureAwait(false);
_runTask = null ;
//Clean up state variables and raise disconnect event
await Cleanup().ConfigureAwait(false);
}
}
protected virtual Task[] Run()
protected virtual Task[] Run()
{
{
@@ -164,12 +171,22 @@ namespace Discord.Net.WebSockets
.Concat(new Task[] { HeartbeatAsync(cancelToken) })
.Concat(new Task[] { HeartbeatAsync(cancelToken) })
.ToArray();
.ToArray();
}
}
protected virtual Task Cleanup(bool wasUnexpected )
protected virtual async Task Cleanup()
{
{
var disconnectState = _disconnectState;
_disconnectState = WebSocketState.Disconnected;
var wasDisconnectUnexpected = _wasDisconnectUnexpected;
_wasDisconnectUnexpected = false;
//Dont reset disconnectReason, we may called ThrowError() later
await _engine.Disconnect();
_cancelTokenSource = null;
_cancelTokenSource = null;
_state = (int)WebSocketState.Disconnected;
RaiseDisconnected(wasUnexpected, _disconnectReason?.SourceException);
return _engine.Disconnect();
var oldState = _state;
_state = (int)WebSocketState.Disconnected;
_runTask = null;
if (disconnectState == WebSocketState.Connected)
RaiseDisconnected(wasDisconnectUnexpected, _disconnectReason?.SourceException);
}
}
protected abstract Task ProcessMessage(string json);
protected abstract Task ProcessMessage(string json);