@@ -17,6 +17,9 @@ namespace Discord.WebSocket
private int[] _shardIds;
private int[] _shardIds;
private DiscordSocketClient[] _shards;
private DiscordSocketClient[] _shards;
private int _totalShards;
private int _totalShards;
private SemaphoreSlim[] _identifySemaphores;
private object _semaphoreResetLock;
private Task _semaphoreResetTask;
private bool _isDisposed;
private bool _isDisposed;
@@ -61,6 +64,7 @@ namespace Discord.WebSocket
if (ids != null && config.TotalShards == null)
if (ids != null && config.TotalShards == null)
throw new ArgumentException($"Custom ids are not supported when {nameof(config.TotalShards)} is not specified.");
throw new ArgumentException($"Custom ids are not supported when {nameof(config.TotalShards)} is not specified.");
_semaphoreResetLock = new object();
_shardIdsToIndex = new Dictionary<int, int>();
_shardIdsToIndex = new Dictionary<int, int>();
config.DisplayInitialLog = false;
config.DisplayInitialLog = false;
_baseConfig = config;
_baseConfig = config;
@@ -72,28 +76,49 @@ namespace Discord.WebSocket
_totalShards = config.TotalShards.Value;
_totalShards = config.TotalShards.Value;
_shardIds = ids ?? Enumerable.Range(0, _totalShards).ToArray();
_shardIds = ids ?? Enumerable.Range(0, _totalShards).ToArray();
_shards = new DiscordSocketClient[_shardIds.Length];
_shards = new DiscordSocketClient[_shardIds.Length];
var masterIdentifySemaphore = new SemaphoreSlim(1, 1);
SemaphoreSlim[] identifySemaphores = null;
if (config.IdentifyMaxConcurrency > 1)
{
int maxSemaphores = (_shardIds.Length + config.IdentifyMaxConcurrency - 1) / config.IdentifyMaxConcurrency;
identifySemaphores = new SemaphoreSlim[maxSemaphores];
for (int i = 0; i < maxSemaphores; i++)
identifySemaphores[i] = new SemaphoreSlim(0, config.IdentifyMaxConcurrency);
}
_identifySemaphores = new SemaphoreSlim[config.IdentifyMaxConcurrency];
for (int i = 0; i < config.IdentifyMaxConcurrency; i++)
_identifySemaphores[i] = new SemaphoreSlim(1, 1);
for (int i = 0; i < _shardIds.Length; i++)
for (int i = 0; i < _shardIds.Length; i++)
{
{
_shardIdsToIndex.Add(_shardIds[i], i);
_shardIdsToIndex.Add(_shardIds[i], i);
var newConfig = config.Clone();
var newConfig = config.Clone();
newConfig.ShardId = _shardIds[i];
newConfig.ShardId = _shardIds[i];
_shards[i] = new DiscordSocketClient(newConfig, i != 0 ? _shards[0] : null, masterIdentifySemaphore, config.IdentifyMaxConcurrency == 1 ? null : identifySemaphores[i / config.IdentifyMaxConcurrency], config.IdentifyMaxConcurrency );
_shards[i] = new DiscordSocketClient(newConfig, this, i != 0 ? _shards[0] : null);
RegisterEvents(_shards[i], i == 0);
RegisterEvents(_shards[i], i == 0);
}
}
}
}
}
}
private static API.DiscordSocketApiClient CreateApiClient(DiscordSocketConfig config)
private static API.DiscordSocketApiClient CreateApiClient(DiscordSocketConfig config)
=> new API.DiscordSocketApiClient(config.RestClientProvider, config.WebSocketProvider, DiscordRestConfig.UserAgent,
=> new API.DiscordSocketApiClient(config.RestClientProvider, config.WebSocketProvider, DiscordRestConfig.UserAgent,
null, null, 0, rateLimitPrecision: config.RateLimitPrecision);
rateLimitPrecision: config.RateLimitPrecision);
internal async Task AcquireIdentifyLockAsync(int shardId, CancellationToken token)
{
int semaphoreIdx = shardId % _baseConfig.IdentifyMaxConcurrency;
await _identifySemaphores[semaphoreIdx].WaitAsync(token).ConfigureAwait(false);
}
internal void ReleaseIdentifyLock()
{
lock (_semaphoreResetLock)
{
if (_semaphoreResetTask == null)
_semaphoreResetTask = ResetSemaphoresAsync();
}
}
private async Task ResetSemaphoresAsync()
{
await Task.Delay(5000).ConfigureAwait(false);
lock (_semaphoreResetLock)
{
foreach (var semaphore in _identifySemaphores)
if (semaphore.CurrentCount == 0)
semaphore.Release();
_semaphoreResetTask = null;
}
}
internal override async Task OnLoginAsync(TokenType tokenType, string token)
internal override async Task OnLoginAsync(TokenType tokenType, string token)
{
{
@@ -105,22 +130,16 @@ namespace Discord.WebSocket
_shards = new DiscordSocketClient[_shardIds.Length];
_shards = new DiscordSocketClient[_shardIds.Length];
int maxConcurrency = botGateway.SessionStartLimit.MaxConcurrency;
int maxConcurrency = botGateway.SessionStartLimit.MaxConcurrency;
_baseConfig.IdentifyMaxConcurrency = maxConcurrency;
_baseConfig.IdentifyMaxConcurrency = maxConcurrency;
var masterIdentifySemaphore = new SemaphoreSlim(1, 1);
SemaphoreSlim[] identifySemaphores = null;
if (maxConcurrency > 1)
{
int maxSemaphores = (_shardIds.Length + maxConcurrency - 1) / maxConcurrency;
identifySemaphores = new SemaphoreSlim[maxSemaphores];
for (int i = 0; i < maxSemaphores; i++)
identifySemaphores[i] = new SemaphoreSlim(0, maxConcurrency);
}
_identifySemaphores = new SemaphoreSlim[maxConcurrency];
for (int i = 0; i < maxConcurrency; i++)
_identifySemaphores[i] = new SemaphoreSlim(1, 1);
for (int i = 0; i < _shardIds.Length; i++)
for (int i = 0; i < _shardIds.Length; i++)
{
{
_shardIdsToIndex.Add(_shardIds[i], i);
_shardIdsToIndex.Add(_shardIds[i], i);
var newConfig = _baseConfig.Clone();
var newConfig = _baseConfig.Clone();
newConfig.ShardId = _shardIds[i];
newConfig.ShardId = _shardIds[i];
newConfig.TotalShards = _totalShards;
newConfig.TotalShards = _totalShards;
_shards[i] = new DiscordSocketClient(newConfig, i != 0 ? _shards[0] : null, masterIdentifySemaphore, maxConcurrency == 1 ? null : identifySemaphores[i / maxConcurrency], maxConcurrency );
_shards[i] = new DiscordSocketClient(newConfig, this, i != 0 ? _shards[0] : null);
RegisterEvents(_shards[i], i == 0);
RegisterEvents(_shards[i], i == 0);
}
}
}
}