@@ -17,6 +17,7 @@ using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using GameModel = Discord.API.Game;
using Discord.Net;
namespace Discord.WebSocket
{
@@ -25,6 +26,7 @@ namespace Discord.WebSocket
private readonly ConcurrentQueue<ulong> _largeGuilds;
private readonly Logger _gatewayLogger;
private readonly JsonSerializer _serializer;
private readonly SemaphoreSlim _connectionGroupLock;
private string _sessionId;
private int _lastSeq;
@@ -69,8 +71,9 @@ namespace Discord.WebSocket
/// <summary> Creates a new REST/WebSocket discord client. </summary>
public DiscordSocketClient() : this(new DiscordSocketConfig()) { }
/// <summary> Creates a new REST/WebSocket discord client. </summary>
public DiscordSocketClient(DiscordSocketConfig config) : this(config, CreateApiClient(config)) { }
private DiscordSocketClient(DiscordSocketConfig config, API.DiscordSocketApiClient client)
public DiscordSocketClient(DiscordSocketConfig config) : this(config, CreateApiClient(config), null) { }
internal DiscordSocketClient(DiscordSocketConfig config, SemaphoreSlim groupLock) : this(config, CreateApiClient(config), groupLock) { }
private DiscordSocketClient(DiscordSocketConfig config, API.DiscordSocketApiClient client, SemaphoreSlim groupLock)
: base(config, client)
{
ShardId = config.ShardId ?? 0;
@@ -86,6 +89,7 @@ namespace Discord.WebSocket
_nextAudioId = 1;
_gatewayLogger = LogManager.CreateLogger(ShardId == 0 && TotalShards == 1 ? "Gateway" : "Shard #" + ShardId);
_connectionGroupLock = groupLock;
_serializer = new JsonSerializer { ContractResolver = new DiscordContractResolver() };
_serializer.Error += (s, e) =>
@@ -171,53 +175,65 @@ namespace Discord.WebSocket
if (state == ConnectionState.Connecting || state == ConnectionState.Connected)
await DisconnectInternalAsync(null, isReconnecting).ConfigureAwait(false);
ConnectionState = ConnectionState.Connecting;
await _gatewayLogger.InfoAsync("Connecting").ConfigureAwait(false);
if (_connectionGroupLock != null)
await _connectionGroupLock.WaitAsync().ConfigureAwait(false);
try
{
var connectTask = new TaskCompletionSource<bool>();
_connectTask = connectTask;
_cancelToken = new CancellationTokenSource();
//Abort connection on timeout
var _ = Task.Run(async () =>
_canReconnect = true;
ConnectionState = ConnectionState.Connecting;
await _gatewayLogger.InfoAsync("Connecting").ConfigureAwait(false);
try
{
await Task.Delay(ConnectionTimeout).ConfigureAwait(false );
connectTask.TrySetException(new TimeoutException()) ;
} );
var connectTask = new TaskCompletionSource<bool>( );
_connectTask = connectTask ;
_cancelToken = new CancellationTokenSource( );
await _gatewayLogger.DebugAsync("Connecting ApiClient").ConfigureAwait(false);
await ApiClient.ConnectAsync().ConfigureAwait(false);
await _gatewayLogger.DebugAsync("Raising Event").ConfigureAwait(false);
await _connectedEvent.InvokeAsync().ConfigureAwait(false);
//Abort connection on timeout
var _ = Task.Run(async () =>
{
await Task.Delay(ConnectionTimeout).ConfigureAwait(false);
connectTask.TrySetException(new TimeoutException());
});
if (_sessionId != null)
{
await _gatewayLogger.DebugAsync("Resuming").ConfigureAwait(false);
await ApiClient.SendResumeAsync(_sessionId, _lastSeq).ConfigureAwait(false);
}
else
{
await _gatewayLogger.DebugAsync("Identifying").ConfigureAwait(false);
await ApiClient.SendIdentifyAsync(shardID: ShardId, totalShards: TotalShards).ConfigureAwait(false);
}
await _gatewayLogger.DebugAsync("Connecting ApiClient").ConfigureAwait(false);
await ApiClient.ConnectAsync().ConfigureAwait(false);
await _gatewayLogger.DebugAsync("Raising Event").ConfigureAwait(false);
await _connectedEvent.InvokeAsync().ConfigureAwait(false);
await _connectTask.Task.ConfigureAwait(false);
if (_sessionId != null)
{
await _gatewayLogger.DebugAsync("Resuming").ConfigureAwait(false);
await ApiClient.SendResumeAsync(_sessionId, _lastSeq).ConfigureAwait(false);
}
else
{
await _gatewayLogger.DebugAsync("Identifying").ConfigureAwait(false);
await ApiClient.SendIdentifyAsync(shardID: ShardId, totalShards: TotalShards).ConfigureAwait(false);
}
await _gatewayLogger.DebugAsync("Sending Status").ConfigureAwait(false);
await SendStatusAsync().ConfigureAwait(false);
await _connectTask.Task.ConfigureAwait(false);
await _gatewayLogger.DebugAsync("Raising Event").ConfigureAwait(false);
if (!isReconnecting)
_canReconnect = true;
ConnectionState = ConnectionState.Connected;
await _gatewayLogger.InfoAsync("Connected").ConfigureAwait(false);
await _gatewayLogger.DebugAsync("Sending Status").ConfigureAwait(false);
await SendStatusAsync().ConfigureAwait(false);
await _gatewayLogger.DebugAsync("Raising Event").ConfigureAwait(false);
ConnectionState = ConnectionState.Connected;
await _gatewayLogger.InfoAsync("Connected").ConfigureAwait(false);
}
catch (Exception)
{
await DisconnectInternalAsync(null, isReconnecting).ConfigureAwait(false);
throw;
}
}
catch (Exception)
finally
{
await DisconnectInternalAsync(null, isReconnecting).ConfigureAwait(false);
throw;
if (_connectionGroupLock != null)
{
await Task.Delay(5000).ConfigureAwait(false);
_connectionGroupLock.Release();
}
}
}
/// <inheritdoc />
@@ -290,13 +306,12 @@ namespace Discord.WebSocket
private async Task StartReconnectAsync(Exception ex)
{
if (ex == null)
{
if (_connectTask?.TrySetCanceled() ?? false) return;
}
else
if ((ex as WebSocketClosedException).CloseCode == 4004) //Bad Token
{
if (_connectTask?.TrySetException(ex) ?? false) return;
_canReconnect = false;
_connectTask?.TrySetException(ex);
await LogoutAsync().ConfigureAwait(false);
return;
}
await _connectionLock.WaitAsync().ConfigureAwait(false);
@@ -608,6 +623,7 @@ namespace Discord.WebSocket
}
catch (Exception ex)
{
_canReconnect = false;
_connectTask.TrySetException(new Exception("Processing READY failed", ex));
return;
}