Browse Source

Cleaned up ShardedClient, delayed connections

tags/1.0-rc
RogueException 8 years ago
parent
commit
203265cb65
3 changed files with 95 additions and 58 deletions
  1. +0
    -0
      src/Discord.Net.Core/Net/WebSocketClosedException.cs
  2. +34
    -13
      src/Discord.Net.WebSocket/DiscordShardedClient.cs
  3. +61
    -45
      src/Discord.Net.WebSocket/DiscordSocketClient.cs

src/Discord.Net.Core/Net/WebSocketException.cs → src/Discord.Net.Core/Net/WebSocketClosedException.cs View File


+ 34
- 13
src/Discord.Net.WebSocket/DiscordShardedClient.cs View File

@@ -5,12 +5,14 @@ using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Threading.Tasks;
using System.Threading;

namespace Discord.WebSocket
{
public partial class DiscordShardedClient : BaseDiscordClient, IDiscordClient
{
private readonly DiscordSocketConfig _baseConfig;
private readonly SemaphoreSlim _connectionGroupLock;
private int[] _shardIds;
private Dictionary<int, int> _shardIdsToIndex;
private DiscordSocketClient[] _shards;
@@ -18,9 +20,9 @@ namespace Discord.WebSocket
private bool _automaticShards;
/// <summary> Gets the estimated round-trip latency, in milliseconds, to the gateway server. </summary>
public int Latency { get; private set; }
internal UserStatus Status => _shards[0].Status;
internal Game? Game => _shards[0].Game;
public int Latency => GetLatency();
public UserStatus Status => _shards[0].Status;
public Game? Game => _shards[0].Game;

internal new DiscordSocketApiClient ApiClient => base.ApiClient as DiscordSocketApiClient;
public new SocketSelfUser CurrentUser { get { return base.CurrentUser as SocketSelfUser; } private set { base.CurrentUser = value; } }
@@ -48,6 +50,7 @@ namespace Discord.WebSocket
_shardIdsToIndex = new Dictionary<int, int>();
config.DisplayInitialLog = false;
_baseConfig = config;
_connectionGroupLock = new SemaphoreSlim(1, 1);

if (config.TotalShards == null)
_automaticShards = true;
@@ -61,7 +64,7 @@ namespace Discord.WebSocket
_shardIdsToIndex.Add(_shardIds[i], i);
var newConfig = config.Clone();
newConfig.ShardId = _shardIds[i];
_shards[i] = new DiscordSocketClient(newConfig);
_shards[i] = new DiscordSocketClient(newConfig, _connectionGroupLock);
RegisterEvents(_shards[i]);
}
}
@@ -83,7 +86,7 @@ namespace Discord.WebSocket
var newConfig = _baseConfig.Clone();
newConfig.ShardId = _shardIds[i];
newConfig.TotalShards = _totalShards;
_shards[i] = new DiscordSocketClient(newConfig);
_shards[i] = new DiscordSocketClient(newConfig, _connectionGroupLock);
RegisterEvents(_shards[i]);
}
}
@@ -125,12 +128,11 @@ namespace Discord.WebSocket
}
private async Task ConnectInternalAsync(bool waitForGuilds)
{
for (int i = 0; i < _shards.Length; i++)
{
await _shards[i].ConnectAsync(waitForGuilds).ConfigureAwait(false);
if (i == 0)
CurrentUser = _shards[i].CurrentUser;
}
await Task.WhenAll(
_shards.Select(x => x.ConnectAsync(waitForGuilds))
).ConfigureAwait(false);

CurrentUser = _shards[0].CurrentUser;
}
/// <inheritdoc />
public async Task DisconnectAsync()
@@ -156,11 +158,11 @@ namespace Discord.WebSocket
}
private int GetShardIdFor(ulong guildId)
=> (int)((guildId >> 22) % (uint)_totalShards);
private int GetShardIdFor(IGuild guild)
public int GetShardIdFor(IGuild guild)
=> GetShardIdFor(guild.Id);
private DiscordSocketClient GetShardFor(ulong guildId)
=> GetShard(GetShardIdFor(guildId));
private DiscordSocketClient GetShardFor(IGuild guild)
public DiscordSocketClient GetShardFor(IGuild guild)
=> GetShardFor(guild.Id);

/// <inheritdoc />
@@ -269,6 +271,14 @@ namespace Discord.WebSocket
}
}

private int GetLatency()
{
int total = 0;
for (int i = 0; i < _shards.Length; i++)
total += _shards[i].Latency;
return (int)Math.Round(total / (double)_shards.Length);
}

public async Task SetStatusAsync(UserStatus status)
{
for (int i = 0; i < _shards.Length; i++)
@@ -283,6 +293,17 @@ namespace Discord.WebSocket
private void RegisterEvents(DiscordSocketClient client)
{
client.Log += (msg) => _logEvent.InvokeAsync(msg);
client.LoggedOut += () =>
{
var state = LoginState;
if (state == LoginState.LoggedIn || state == LoginState.LoggingIn)
{
//Should only happen if token is changed
var _ = LogoutAsync(); //Signal the logout, fire and forget
}
return Task.Delay(0);
};

client.ChannelCreated += (channel) => _channelCreatedEvent.InvokeAsync(channel);
client.ChannelDestroyed += (channel) => _channelDestroyedEvent.InvokeAsync(channel);
client.ChannelUpdated += (oldChannel, newChannel) => _channelUpdatedEvent.InvokeAsync(oldChannel, newChannel);


+ 61
- 45
src/Discord.Net.WebSocket/DiscordSocketClient.cs View File

@@ -17,6 +17,7 @@ using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using GameModel = Discord.API.Game;
using Discord.Net;

namespace Discord.WebSocket
{
@@ -25,6 +26,7 @@ namespace Discord.WebSocket
private readonly ConcurrentQueue<ulong> _largeGuilds;
private readonly Logger _gatewayLogger;
private readonly JsonSerializer _serializer;
private readonly SemaphoreSlim _connectionGroupLock;

private string _sessionId;
private int _lastSeq;
@@ -69,8 +71,9 @@ namespace Discord.WebSocket
/// <summary> Creates a new REST/WebSocket discord client. </summary>
public DiscordSocketClient() : this(new DiscordSocketConfig()) { }
/// <summary> Creates a new REST/WebSocket discord client. </summary>
public DiscordSocketClient(DiscordSocketConfig config) : this(config, CreateApiClient(config)) { }
private DiscordSocketClient(DiscordSocketConfig config, API.DiscordSocketApiClient client)
public DiscordSocketClient(DiscordSocketConfig config) : this(config, CreateApiClient(config), null) { }
internal DiscordSocketClient(DiscordSocketConfig config, SemaphoreSlim groupLock) : this(config, CreateApiClient(config), groupLock) { }
private DiscordSocketClient(DiscordSocketConfig config, API.DiscordSocketApiClient client, SemaphoreSlim groupLock)
: base(config, client)
{
ShardId = config.ShardId ?? 0;
@@ -86,6 +89,7 @@ namespace Discord.WebSocket
_nextAudioId = 1;
_gatewayLogger = LogManager.CreateLogger(ShardId == 0 && TotalShards == 1 ? "Gateway" : "Shard #" + ShardId);
_connectionGroupLock = groupLock;

_serializer = new JsonSerializer { ContractResolver = new DiscordContractResolver() };
_serializer.Error += (s, e) =>
@@ -171,53 +175,65 @@ namespace Discord.WebSocket
if (state == ConnectionState.Connecting || state == ConnectionState.Connected)
await DisconnectInternalAsync(null, isReconnecting).ConfigureAwait(false);

ConnectionState = ConnectionState.Connecting;
await _gatewayLogger.InfoAsync("Connecting").ConfigureAwait(false);
if (_connectionGroupLock != null)
await _connectionGroupLock.WaitAsync().ConfigureAwait(false);
try
{
var connectTask = new TaskCompletionSource<bool>();
_connectTask = connectTask;
_cancelToken = new CancellationTokenSource();

//Abort connection on timeout
var _ = Task.Run(async () =>
_canReconnect = true;
ConnectionState = ConnectionState.Connecting;
await _gatewayLogger.InfoAsync("Connecting").ConfigureAwait(false);
try
{
await Task.Delay(ConnectionTimeout).ConfigureAwait(false);
connectTask.TrySetException(new TimeoutException());
});
var connectTask = new TaskCompletionSource<bool>();
_connectTask = connectTask;
_cancelToken = new CancellationTokenSource();

await _gatewayLogger.DebugAsync("Connecting ApiClient").ConfigureAwait(false);
await ApiClient.ConnectAsync().ConfigureAwait(false);
await _gatewayLogger.DebugAsync("Raising Event").ConfigureAwait(false);
await _connectedEvent.InvokeAsync().ConfigureAwait(false);
//Abort connection on timeout
var _ = Task.Run(async () =>
{
await Task.Delay(ConnectionTimeout).ConfigureAwait(false);
connectTask.TrySetException(new TimeoutException());
});

if (_sessionId != null)
{
await _gatewayLogger.DebugAsync("Resuming").ConfigureAwait(false);
await ApiClient.SendResumeAsync(_sessionId, _lastSeq).ConfigureAwait(false);
}
else
{
await _gatewayLogger.DebugAsync("Identifying").ConfigureAwait(false);
await ApiClient.SendIdentifyAsync(shardID: ShardId, totalShards: TotalShards).ConfigureAwait(false);
}
await _gatewayLogger.DebugAsync("Connecting ApiClient").ConfigureAwait(false);
await ApiClient.ConnectAsync().ConfigureAwait(false);
await _gatewayLogger.DebugAsync("Raising Event").ConfigureAwait(false);
await _connectedEvent.InvokeAsync().ConfigureAwait(false);

await _connectTask.Task.ConfigureAwait(false);
if (_sessionId != null)
{
await _gatewayLogger.DebugAsync("Resuming").ConfigureAwait(false);
await ApiClient.SendResumeAsync(_sessionId, _lastSeq).ConfigureAwait(false);
}
else
{
await _gatewayLogger.DebugAsync("Identifying").ConfigureAwait(false);
await ApiClient.SendIdentifyAsync(shardID: ShardId, totalShards: TotalShards).ConfigureAwait(false);
}

await _gatewayLogger.DebugAsync("Sending Status").ConfigureAwait(false);
await SendStatusAsync().ConfigureAwait(false);
await _connectTask.Task.ConfigureAwait(false);

await _gatewayLogger.DebugAsync("Raising Event").ConfigureAwait(false);
if (!isReconnecting)
_canReconnect = true;
ConnectionState = ConnectionState.Connected;
await _gatewayLogger.InfoAsync("Connected").ConfigureAwait(false);
await _gatewayLogger.DebugAsync("Sending Status").ConfigureAwait(false);
await SendStatusAsync().ConfigureAwait(false);

await _gatewayLogger.DebugAsync("Raising Event").ConfigureAwait(false);
ConnectionState = ConnectionState.Connected;
await _gatewayLogger.InfoAsync("Connected").ConfigureAwait(false);
}
catch (Exception)
{
await DisconnectInternalAsync(null, isReconnecting).ConfigureAwait(false);
throw;
}
}
catch (Exception)
finally
{
await DisconnectInternalAsync(null, isReconnecting).ConfigureAwait(false);
throw;
if (_connectionGroupLock != null)
{
await Task.Delay(5000).ConfigureAwait(false);
_connectionGroupLock.Release();
}
}
}
/// <inheritdoc />
@@ -290,13 +306,12 @@ namespace Discord.WebSocket

private async Task StartReconnectAsync(Exception ex)
{
if (ex == null)
{
if (_connectTask?.TrySetCanceled() ?? false) return;
}
else
if ((ex as WebSocketClosedException).CloseCode == 4004) //Bad Token
{
if (_connectTask?.TrySetException(ex) ?? false) return;
_canReconnect = false;
_connectTask?.TrySetException(ex);
await LogoutAsync().ConfigureAwait(false);
return;
}

await _connectionLock.WaitAsync().ConfigureAwait(false);
@@ -608,6 +623,7 @@ namespace Discord.WebSocket
}
catch (Exception ex)
{
_canReconnect = false;
_connectTask.TrySetException(new Exception("Processing READY failed", ex));
return;
}


Loading…
Cancel
Save