@@ -17,29 +17,27 @@ using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using GameModel = Discord.API.Game;
using Discord.Net;
namespace Discord.WebSocket
{
public partial class DiscordSocketClient : BaseDiscordClient, IDiscordClient
{
private readonly ConcurrentQueue<ulong> _largeGuilds;
private readonly Logger _gatewayLogger;
private readonly JsonSerializer _serializer;
private readonly SemaphoreSlim _connectionGroupLock;
private readonly DiscordSocketClient _parentClient;
private readonly ConcurrentQueue<long> _heartbeatTimes;
private readonly ConnectionManager _connection;
private readonly Logger _gatewayLogger;
private readonly SemaphoreSlim _stateLock;
private string _sessionId;
private int _lastSeq;
private ImmutableDictionary<string, RestVoiceRegion> _voiceRegions;
private TaskCompletionSource<bool> _connectTask;
private CancellationTokenSource _cancelToken, _reconnectCancelToken;
private Task _heartbeatTask, _guildDownloadTask, _reconnectTask;
private Task _heartbeatTask, _guildDownloadTask;
private int _unavailableGuilds;
private long _lastGuildAvailableTime, _lastMessageTime;
private int _nextAudioId;
private bool _canReconnect;
private DateTimeOffset? _statusSince;
private RestApplication _applicationInfo;
private ConcurrentHashSet<ulong> _downloadUsersFor;
@@ -59,7 +57,6 @@ namespace Discord.WebSocket
internal int LargeThreshold { get; private set; }
internal AudioMode AudioMode { get; private set; }
internal ClientState State { get; private set; }
internal int ConnectionTimeout { get; private set; }
internal UdpSocketProvider UdpSocketProvider { get; private set; }
internal WebSocketProvider WebSocketProvider { get; private set; }
internal bool AlwaysDownloadUsers { get; private set; }
@@ -90,35 +87,28 @@ namespace Discord.WebSocket
UdpSocketProvider = config.UdpSocketProvider;
WebSocketProvider = config.WebSocketProvider;
AlwaysDownloadUsers = config.AlwaysDownloadUsers;
ConnectionTimeout = config.ConnectionTimeout;
State = new ClientState(0, 0);
_downloadUsersFor = new ConcurrentHashSet<ulong>();
_heartbeatTimes = new ConcurrentQueue<long>();
_stateLock = new SemaphoreSlim(1, 1);
_gatewayLogger = LogManager.CreateLogger(ShardId == 0 && TotalShards == 1 ? "Gateway" : $"Shard #{ShardId}");
_connection = new ConnectionManager(_stateLock, _gatewayLogger, config.ConnectionTimeout,
OnConnectingAsync, OnDisconnectingAsync, x => ApiClient.Disconnected += x);
_nextAudioId = 1;
_gatewayLogger = LogManager.CreateLogger(ShardId == 0 && TotalShards == 1 ? "Gateway" : "Shard #" + ShardId);
_connectionGroupLock = groupLock;
_parentClient = parentClient;
_serializer = new JsonSerializer { ContractResolver = new DiscordContractResolver() };
_serializer.Error += (s, e) =>
{
_gatewayLogger.WarningAsync(e.ErrorContext.Error).GetAwaiter().GetResult();
_gatewayLogger.WarningAsync("Serializer Error", e.ErrorContext.Error).GetAwaiter().GetResult();
e.ErrorContext.Handled = true;
};
ApiClient.SentGatewayMessage += async opCode => await _gatewayLogger.DebugAsync($"Sent {opCode}").ConfigureAwait(false);
ApiClient.ReceivedGatewayEvent += ProcessMessageAsync;
ApiClient.Disconnected += async ex =>
{
if (ex != null)
{
await _gatewayLogger.WarningAsync($"Connection Closed", ex).ConfigureAwait(false);
await StartReconnectAsync(ex).ConfigureAwait(false);
}
else
await _gatewayLogger.WarningAsync($"Connection Closed").ConfigureAwait(false);
};
LeftGuild += async g => await _gatewayLogger.InfoAsync($"Left {g.Name}").ConfigureAwait(false);
JoinedGuild += async g => await _gatewayLogger.InfoAsync($"Joined {g.Name}").ConfigureAwait(false);
@@ -143,8 +133,16 @@ namespace Discord.WebSocket
}
private static API.DiscordSocketApiClient CreateApiClient(DiscordSocketConfig config)
=> new API.DiscordSocketApiClient(config.RestClientProvider, config.WebSocketProvider, DiscordRestConfig.UserAgent, config.GatewayHost);
internal override void Dispose(bool disposing)
{
if (disposing)
{
StopAsync().GetAwaiter().GetResult();
ApiClient.Dispose();
}
}
protected override async Task OnLoginAsync(TokenType tokenType, string token)
internal override async Task OnLoginAsync(TokenType tokenType, string token)
{
if (_parentClient == null)
{
@@ -154,92 +152,49 @@ namespace Discord.WebSocket
else
_voiceRegions = _parentClient._voiceRegions;
}
protected override async Task OnLogoutAsync()
internal override async Task OnLogoutAsync()
{
if (ConnectionState != ConnectionState.Disconnected)
await DisconnectInternalAsync(null, false).ConfigureAwait(false);
await StopAsync().ConfigureAwait(false);
_applicationInfo = null;
_voiceRegions = ImmutableDictionary.Create<string, RestVoiceRegion>();
_downloadUsersFor.Clear();
}
public async Task StartAsync()
=> await _connection.StartAsync().ConfigureAwait(false);
public async Task StopAsync()
=> await _connection.StopAsync().ConfigureAwait(false);
/// <inheritdoc />
public async Task ConnectAsync()
private async Task OnConnectingAsync()
{
await _connectionLock.WaitAsync().ConfigureAwait(false);
try
{
await ConnectInternalAsync(false).ConfigureAwait(false);
}
finally { _connectionLock.Release(); }
}
private async Task ConnectInternalAsync(bool isReconnecting)
{
if (LoginState != LoginState.LoggedIn)
throw new InvalidOperationException("Client is not logged in.");
if (!isReconnecting && _reconnectCancelToken != null && !_reconnectCancelToken.IsCancellationRequested)
_reconnectCancelToken.Cancel();
var state = ConnectionState;
if (state == ConnectionState.Connecting || state == ConnectionState.Connected)
await DisconnectInternalAsync(null, isReconnecting).ConfigureAwait(false);
if (_connectionGroupLock != null)
await _connectionGroupLock.WaitAsync().ConfigureAwait(false);
await _connectionGroupLock.WaitAsync(_connection.CancelToken).ConfigureAwait(false);
try
{
_canReconnect = true;
ConnectionState = ConnectionState.Connecting;
await _gatewayLogger.InfoAsync("Connecting").ConfigureAwait(false);
try
{
var connectTask = new TaskCompletionSource<bool>();
_connectTask = connectTask;
_cancelToken = new CancellationTokenSource();
//Abort connection on timeout
var _ = Task.Run(async () =>
{
await Task.Delay(ConnectionTimeout).ConfigureAwait(false);
connectTask.TrySetException(new TimeoutException());
});
await _gatewayLogger.DebugAsync("Connecting ApiClient").ConfigureAwait(false);
await ApiClient.ConnectAsync().ConfigureAwait(false);
await _gatewayLogger.DebugAsync("Raising Event").ConfigureAwait(false);
await _connectedEvent.InvokeAsync().ConfigureAwait(false);
if (_sessionId != null)
{
await _gatewayLogger.DebugAsync("Resuming").ConfigureAwait(false);
await ApiClient.SendResumeAsync(_sessionId, _lastSeq).ConfigureAwait(false);
}
else
{
await _gatewayLogger.DebugAsync("Identifying").ConfigureAwait(false);
await ApiClient.SendIdentifyAsync(shardID: ShardId, totalShards: TotalShards).ConfigureAwait(false);
}
await _connectTask.Task.ConfigureAwait(false);
await _gatewayLogger.DebugAsync("Connecting ApiClient").ConfigureAwait(false);
await ApiClient.ConnectAsync().ConfigureAwait(false);
await _gatewayLogger.DebugAsync("Raising Event").ConfigureAwait(false);
await _connectedEvent.InvokeAsync().ConfigureAwait(false);
await _gatewayLogger.DebugAsync("Sending Status").ConfigureAwait(false);
await SendStatusAsync().ConfigureAwait(false);
await _gatewayLogger.DebugAsync("Raising Event").ConfigureAwait(false);
ConnectionState = ConnectionState.Connected;
await _gatewayLogger.InfoAsync("Connected").ConfigureAwait(false);
await ProcessUserDownloadsAsync(_downloadUsersFor.Select(x => GetGuild(x))
.Where(x => x != null).ToImmutableArray()).ConfigureAwait(false);
if (_sessionId != null)
{
await _gatewayLogger.DebugAsync("Resuming").ConfigureAwait(false);
await ApiClient.SendResumeAsync(_sessionId, _lastSeq).ConfigureAwait(false);
}
catch (Exception)
else
{
await DisconnectInternalAsync(null, isReconnecting ).ConfigureAwait(false);
throw;
await _gatewayLogger.DebugAsync("Identifying").ConfigureAwait(false);
await ApiClient.SendIdentifyAsync(shardID: ShardId, totalShards: TotalShards).ConfigureAwait(false);
}
//Wait for READY
await _connection.WaitAsync().ConfigureAwait(false);
await _gatewayLogger.DebugAsync("Sending Status").ConfigureAwait(false);
await SendStatusAsync().ConfigureAwait(false);
await ProcessUserDownloadsAsync(_downloadUsersFor.Select(x => GetGuild(x))
.Where(x => x != null).ToImmutableArray()).ConfigureAwait(false);
}
finally
{
@@ -250,41 +205,11 @@ namespace Discord.WebSocket
}
}
}
/// <inheritdoc />
public async Task DisconnectAsync()
{
if (_connectTask?.TrySetCanceled() ?? false) return;
await _connectionLock.WaitAsync().ConfigureAwait(false);
try
{
await DisconnectInternalAsync(null, false).ConfigureAwait(false);
}
finally { _connectionLock.Release(); }
}
private async Task DisconnectInternalAsync(Exception ex, bool isReconnecting)
private async Task OnDisconnectingAsync(Exception ex)
{
if (!isReconnecting)
{
_canReconnect = false;
_sessionId = null;
_lastSeq = 0;
if (_reconnectCancelToken != null && !_reconnectCancelToken.IsCancellationRequested)
_reconnectCancelToken.Cancel();
}
ulong guildId;
if (ConnectionState == ConnectionState.Disconnected) return;
ConnectionState = ConnectionState.Disconnecting;
await _gatewayLogger.InfoAsync("Disconnecting").ConfigureAwait(false);
await _gatewayLogger.DebugAsync("Cancelling current tasks").ConfigureAwait(false);
//Signal tasks to complete
try { _cancelToken.Cancel(); } catch { }
await _gatewayLogger.DebugAsync("Disconnecting ApiClient").ConfigureAwait(false);
//Disconnect from server
await ApiClient.DisconnectAsync().ConfigureAwait(false);
//Wait for tasks to complete
@@ -294,8 +219,8 @@ namespace Discord.WebSocket
await heartbeatTask.ConfigureAwait(false);
_heartbeatTask = null;
long times ;
while (_heartbeatTimes.TryDequeue(out times )) { }
long time;
while (_heartbeatTimes.TryDequeue(out time)) { }
_lastMessageTime = 0;
await _gatewayLogger.DebugAsync("Waiting for guild downloader").ConfigureAwait(false);
@@ -315,70 +240,6 @@ namespace Discord.WebSocket
if (guild._available)
await _guildUnavailableEvent.InvokeAsync(guild).ConfigureAwait(false);
}
ConnectionState = ConnectionState.Disconnected;
await _gatewayLogger.InfoAsync("Disconnected").ConfigureAwait(false);
await _disconnectedEvent.InvokeAsync(ex).ConfigureAwait(false);
}
private async Task StartReconnectAsync(Exception ex)
{
if ((ex as WebSocketClosedException)?.CloseCode == 4004) //Bad Token
{
_canReconnect = false;
_connectTask?.TrySetException(ex);
await LogoutAsync().ConfigureAwait(false);
return;
}
await _connectionLock.WaitAsync().ConfigureAwait(false);
try
{
if (!_canReconnect || _reconnectTask != null) return;
_reconnectCancelToken = new CancellationTokenSource();
_reconnectTask = ReconnectInternalAsync(ex, _reconnectCancelToken.Token);
}
finally { _connectionLock.Release(); }
}
private async Task ReconnectInternalAsync(Exception ex, CancellationToken cancelToken)
{
try
{
Random jitter = new Random();
int nextReconnectDelay = 1000;
while (true)
{
await Task.Delay(nextReconnectDelay, cancelToken).ConfigureAwait(false);
nextReconnectDelay = nextReconnectDelay * 2 + jitter.Next(-250, 250);
if (nextReconnectDelay > 60000)
nextReconnectDelay = 60000;
await _connectionLock.WaitAsync().ConfigureAwait(false);
try
{
if (cancelToken.IsCancellationRequested) return;
await ConnectInternalAsync(true).ConfigureAwait(false);
_reconnectTask = null;
return;
}
catch (Exception ex2)
{
await _gatewayLogger.WarningAsync("Reconnect failed", ex2).ConfigureAwait(false);
}
finally { _connectionLock.Release(); }
}
}
catch (OperationCanceledException)
{
await _connectionLock.WaitAsync().ConfigureAwait(false);
try
{
await _gatewayLogger.DebugAsync("Reconnect cancelled").ConfigureAwait(false);
_reconnectTask = null;
}
finally { _connectionLock.Release(); }
}
}
/// <inheritdoc />
@@ -555,7 +416,7 @@ namespace Discord.WebSocket
await _gatewayLogger.DebugAsync("Received Hello").ConfigureAwait(false);
var data = (payload as JToken).ToObject<HelloEvent>(_serializer);
_heartbeatTask = RunHeartbeatAsync(data.HeartbeatInterval, _cancelToken.Token, _gatewayLogger );
_heartbeatTask = RunHeartbeatAsync(data.HeartbeatInterval, _connection.C ancelToken);
}
break;
case GatewayOpCode.Heartbeat:
@@ -593,9 +454,7 @@ namespace Discord.WebSocket
case GatewayOpCode.Reconnect:
{
await _gatewayLogger.DebugAsync("Received Reconnect").ConfigureAwait(false);
await _gatewayLogger.WarningAsync("Server requested a reconnect").ConfigureAwait(false);
await StartReconnectAsync(new Exception("Server requested a reconnect")).ConfigureAwait(false);
_connection.Error(new Exception("Server requested a reconnect"));
}
break;
case GatewayOpCode.Dispatch:
@@ -633,8 +492,7 @@ namespace Discord.WebSocket
}
catch (Exception ex)
{
_canReconnect = false;
_connectTask.TrySetException(new Exception("Processing READY failed", ex));
_connection.CriticalError(new Exception("Processing READY failed", ex));
return;
}
@@ -642,11 +500,11 @@ namespace Discord.WebSocket
await SyncGuildsAsync().ConfigureAwait(false);
_lastGuildAvailableTime = Environment.TickCount;
_guildDownloadTask = WaitForGuildsAsync(_cancelToken. Token, _gatewayLogger);
_guildDownloadTask = WaitForGuildsAsync(_connection.C ancelToken, _gatewayLogger);
await _readyEvent.InvokeAsync().ConfigureAwait(false);
var _ = _connectTask.TrySetResultAsync(true); //Signal the .Connect() call to complete
var _ = _connection.CompleteAsync();
await _gatewayLogger.InfoAsync("Ready").ConfigureAwait(false);
}
break;
@@ -654,7 +512,7 @@ namespace Discord.WebSocket
{
await _gatewayLogger.DebugAsync("Received Dispatch (RESUMED)").ConfigureAwait(false);
var _ = _connectTask.TrySetResultAsync(true); //Signal the .Connect() call to complete
var _ = _connection.CompleteAsync();
//Notify the client that these guilds are available again
foreach (var guild in State.Guilds)
@@ -1356,7 +1214,6 @@ namespace Discord.WebSocket
SocketUserMessage cachedMsg = channel.GetCachedMessage(data.MessageId) as SocketUserMessage;
var user = await channel.GetUserAsync(data.UserId, CacheMode.CacheOnly);
SocketReaction reaction = SocketReaction.Create(data, channel, cachedMsg, Optional.Create(user));
if (cachedMsg != null)
{
cachedMsg.AddReaction(reaction);
@@ -1691,11 +1548,11 @@ namespace Discord.WebSocket
}
}
private async Task RunHeartbeatAsync(int intervalMillis, CancellationToken cancelToken, Logger logger )
private async Task RunHeartbeatAsync(int intervalMillis, CancellationToken cancelToken)
{
try
{
await l ogger.DebugAsync("Heartbeat Started").ConfigureAwait(false);
await _gatewayL ogger.DebugAsync("Heartbeat Started").ConfigureAwait(false);
while (!cancelToken.IsCancellationRequested)
{
var now = Environment.TickCount;
@@ -1705,8 +1562,7 @@ namespace Discord.WebSocket
{
if (ConnectionState == ConnectionState.Connected && (_guildDownloadTask?.IsCompleted ?? true))
{
await _gatewayLogger.WarningAsync("Server missed last heartbeat").ConfigureAwait(false);
await StartReconnectAsync(new Exception("Server missed last heartbeat")).ConfigureAwait(false);
_connection.Error(new Exception("Server missed last heartbeat"));
return;
}
}
@@ -1718,20 +1574,20 @@ namespace Discord.WebSocket
}
catch (Exception ex)
{
await l ogger.WarningAsync("Heartbeat Errored", ex).ConfigureAwait(false);
await _gatewayL ogger.WarningAsync("Heartbeat Errored", ex).ConfigureAwait(false);
}
await Task.Delay(intervalMillis, cancelToken).ConfigureAwait(false);
}
await l ogger.DebugAsync("Heartbeat Stopped").ConfigureAwait(false);
await _gatewayL ogger.DebugAsync("Heartbeat Stopped").ConfigureAwait(false);
}
catch (OperationCanceledException)
{
await l ogger.DebugAsync("Heartbeat Stopped").ConfigureAwait(false);
await _gatewayL ogger.DebugAsync("Heartbeat Stopped").ConfigureAwait(false);
}
catch (Exception ex)
{
await l ogger.ErrorAsync("Heartbeat Errored", ex).ConfigureAwait(false);
await _gatewayL ogger.ErrorAsync("Heartbeat Errored", ex).ConfigureAwait(false);
}
}
public async Task WaitForGuildsAsync()
@@ -1805,8 +1661,7 @@ namespace Discord.WebSocket
}
//IDiscordClient
Task IDiscordClient.ConnectAsync()
=> ConnectAsync();
ConnectionState IDiscordClient.ConnectionState => _connection.State;
async Task<IApplication> IDiscordClient.GetApplicationInfoAsync()
=> await GetApplicationInfoAsync().ConfigureAwait(false);
@@ -1842,5 +1697,10 @@ namespace Discord.WebSocket
=> Task.FromResult<IReadOnlyCollection<IVoiceRegion>>(VoiceRegions);
Task<IVoiceRegion> IDiscordClient.GetVoiceRegionAsync(string id)
=> Task.FromResult<IVoiceRegion>(GetVoiceRegion(id));
async Task IDiscordClient.StartAsync()
=> await StartAsync().ConfigureAwait(false);
async Task IDiscordClient.StopAsync()
=> await StopAsync().ConfigureAwait(false);
}
}