diff --git a/src/Discord.Net.Rest/DiscordRestApiClient.cs b/src/Discord.Net.Rest/DiscordRestApiClient.cs
index 52d7e0cd5..592ad7e92 100644
--- a/src/Discord.Net.Rest/DiscordRestApiClient.cs
+++ b/src/Discord.Net.Rest/DiscordRestApiClient.cs
@@ -51,7 +51,7 @@ namespace Discord.API
internal JsonSerializer Serializer => _serializer;
/// Unknown OAuth token type.
- public DiscordRestApiClient(RestClientProvider restClientProvider, string userAgent, RequestQueue requestQueue, RetryMode defaultRetryMode = RetryMode.AlwaysRetry,
+ public DiscordRestApiClient(RestClientProvider restClientProvider, string userAgent, RetryMode defaultRetryMode = RetryMode.AlwaysRetry,
JsonSerializer serializer = null, RateLimitPrecision rateLimitPrecision = RateLimitPrecision.Second, bool useSystemClock = true)
{
_restClientProvider = restClientProvider;
@@ -61,7 +61,7 @@ namespace Discord.API
RateLimitPrecision = rateLimitPrecision;
UseSystemClock = useSystemClock;
- RequestQueue = requestQueue ?? new RequestQueue();
+ RequestQueue = new RequestQueue();
_stateLock = new SemaphoreSlim(1, 1);
SetBaseUrl(DiscordConfig.APIUrl);
diff --git a/src/Discord.Net.Rest/DiscordRestClient.cs b/src/Discord.Net.Rest/DiscordRestClient.cs
index 65af43a99..48c40fdfa 100644
--- a/src/Discord.Net.Rest/DiscordRestClient.cs
+++ b/src/Discord.Net.Rest/DiscordRestClient.cs
@@ -31,7 +31,6 @@ namespace Discord.Rest
private static API.DiscordRestApiClient CreateApiClient(DiscordRestConfig config)
=> new API.DiscordRestApiClient(config.RestClientProvider,
DiscordRestConfig.UserAgent,
- null,
rateLimitPrecision: config.RateLimitPrecision,
useSystemClock: config.UseSystemClock);
diff --git a/src/Discord.Net.Rest/Net/Queue/RequestQueue.cs b/src/Discord.Net.Rest/Net/Queue/RequestQueue.cs
index 488e2c5c8..2bf8e20b0 100644
--- a/src/Discord.Net.Rest/Net/Queue/RequestQueue.cs
+++ b/src/Discord.Net.Rest/Net/Queue/RequestQueue.cs
@@ -23,10 +23,6 @@ namespace Discord.Net.Queue
private CancellationToken _requestCancelToken; //Parent token + Clear token
private DateTimeOffset _waitUntil;
- private readonly SemaphoreSlim _masterIdentifySemaphore;
- private readonly SemaphoreSlim _identifySemaphore;
- private readonly int _identifySemaphoreMaxConcurrency;
-
private Task _cleanupTask;
public RequestQueue()
@@ -43,14 +39,6 @@ namespace Discord.Net.Queue
_cleanupTask = RunCleanup();
}
- public RequestQueue(SemaphoreSlim masterIdentifySemaphore, SemaphoreSlim slaveIdentifySemaphore, int slaveIdentifySemaphoreMaxConcurrency)
- : this ()
- {
- _masterIdentifySemaphore = masterIdentifySemaphore;
- _identifySemaphore = slaveIdentifySemaphore;
- _identifySemaphoreMaxConcurrency = slaveIdentifySemaphoreMaxConcurrency;
- }
-
public async Task SetCancelTokenAsync(CancellationToken cancelToken)
{
await _tokenLock.WaitAsync().ConfigureAwait(false);
@@ -145,42 +133,6 @@ namespace Discord.Net.Queue
var globalBucket = GetOrCreateBucket(options, globalRequest);
await globalBucket.TriggerAsync(id, globalRequest);
}
- internal void ReleaseIdentifySemaphore(int id)
- {
- if (_masterIdentifySemaphore == null)
- throw new InvalidOperationException("Not a RequestQueue with WebSocket data.");
-
- while (_identifySemaphore?.Wait(0) == true) //exhaust all tickets before releasing master
- { }
- _masterIdentifySemaphore.Release();
-#if DEBUG_LIMITS
- Debug.WriteLine($"[{id}] Released identify master semaphore");
-#endif
- }
-
- public async Task AcquireIdentifyTicket(CancellationToken cancellationToken)
- {
- try
- {
- if (_masterIdentifySemaphore == null)
- throw new InvalidOperationException("Not a RequestQueue with WebSocket data.");
-
- if (_identifySemaphore == null)
- await _masterIdentifySemaphore.WaitAsync(cancellationToken);
- else
- {
- bool master;
- while (!(master = _masterIdentifySemaphore.Wait(0)) && !_identifySemaphore.Wait(0)) //To not block the thread
- await Task.Delay(100, cancellationToken);
- if (master && _identifySemaphoreMaxConcurrency > 1)
- _identifySemaphore.Release(_identifySemaphoreMaxConcurrency - 1);
- }
-#if DEBUG_LIMITS
- Debug.WriteLine($"[{id}] Acquired identify ticket");
-#endif
- }
- catch(OperationCanceledException) { }
- }
private RequestBucket GetOrCreateBucket(RequestOptions options, IRequest request)
{
@@ -245,8 +197,6 @@ namespace Discord.Net.Queue
_tokenLock?.Dispose();
_clearToken?.Dispose();
_requestCancelTokenSource?.Dispose();
- _masterIdentifySemaphore?.Dispose();
- _identifySemaphore?.Dispose();
}
}
}
diff --git a/src/Discord.Net.Rest/Net/Queue/RequestQueueBucket.cs b/src/Discord.Net.Rest/Net/Queue/RequestQueueBucket.cs
index ece18b819..3fb45e55d 100644
--- a/src/Discord.Net.Rest/Net/Queue/RequestQueueBucket.cs
+++ b/src/Discord.Net.Rest/Net/Queue/RequestQueueBucket.cs
@@ -457,8 +457,6 @@ namespace Discord.Net.Queue
#if DEBUG_LIMITS
Debug.WriteLine($"[{id}] * Reset *");
#endif
- if (request is WebSocketRequest webSocketRequest && webSocketRequest.Options.BucketId == GatewayBucket.Get(GatewayBucketType.Identify).Id)
- _queue.ReleaseIdentifySemaphore(id);
_semaphore = WindowCount;
_resetTick = null;
return;
diff --git a/src/Discord.Net.WebSocket/BaseSocketClient.cs b/src/Discord.Net.WebSocket/BaseSocketClient.cs
index cceee2557..548bb75bf 100644
--- a/src/Discord.Net.WebSocket/BaseSocketClient.cs
+++ b/src/Discord.Net.WebSocket/BaseSocketClient.cs
@@ -1,6 +1,5 @@
using System.Collections.Generic;
using System.IO;
-using System.Threading;
using System.Threading.Tasks;
using Discord.API;
using Discord.Rest;
@@ -80,9 +79,8 @@ namespace Discord.WebSocket
internal BaseSocketClient(DiscordSocketConfig config, DiscordRestApiClient client)
: base(config, client) => BaseConfig = config;
- private static DiscordSocketApiClient CreateApiClient(DiscordSocketConfig config, SemaphoreSlim identifyMasterSemaphore, SemaphoreSlim identifySemaphore, int identifyMaxConcurrency)
+ private static DiscordSocketApiClient CreateApiClient(DiscordSocketConfig config)
=> new DiscordSocketApiClient(config.RestClientProvider, config.WebSocketProvider, DiscordRestConfig.UserAgent,
- identifyMasterSemaphore, identifySemaphore, identifyMaxConcurrency,
rateLimitPrecision: config.RateLimitPrecision,
useSystemClock: config.UseSystemClock);
diff --git a/src/Discord.Net.WebSocket/DiscordShardedClient.cs b/src/Discord.Net.WebSocket/DiscordShardedClient.cs
index 6c98d9e43..a2c89d4e5 100644
--- a/src/Discord.Net.WebSocket/DiscordShardedClient.cs
+++ b/src/Discord.Net.WebSocket/DiscordShardedClient.cs
@@ -17,6 +17,9 @@ namespace Discord.WebSocket
private int[] _shardIds;
private DiscordSocketClient[] _shards;
private int _totalShards;
+ private SemaphoreSlim[] _identifySemaphores;
+ private object _semaphoreResetLock;
+ private Task _semaphoreResetTask;
private bool _isDisposed;
@@ -61,6 +64,7 @@ namespace Discord.WebSocket
if (ids != null && config.TotalShards == null)
throw new ArgumentException($"Custom ids are not supported when {nameof(config.TotalShards)} is not specified.");
+ _semaphoreResetLock = new object();
_shardIdsToIndex = new Dictionary();
config.DisplayInitialLog = false;
_baseConfig = config;
@@ -72,28 +76,49 @@ namespace Discord.WebSocket
_totalShards = config.TotalShards.Value;
_shardIds = ids ?? Enumerable.Range(0, _totalShards).ToArray();
_shards = new DiscordSocketClient[_shardIds.Length];
- var masterIdentifySemaphore = new SemaphoreSlim(1, 1);
- SemaphoreSlim[] identifySemaphores = null;
- if (config.IdentifyMaxConcurrency > 1)
- {
- int maxSemaphores = (_shardIds.Length + config.IdentifyMaxConcurrency - 1) / config.IdentifyMaxConcurrency;
- identifySemaphores = new SemaphoreSlim[maxSemaphores];
- for (int i = 0; i < maxSemaphores; i++)
- identifySemaphores[i] = new SemaphoreSlim(0, config.IdentifyMaxConcurrency);
- }
+ _identifySemaphores = new SemaphoreSlim[config.IdentifyMaxConcurrency];
+ for (int i = 0; i < config.IdentifyMaxConcurrency; i++)
+ _identifySemaphores[i] = new SemaphoreSlim(1, 1);
for (int i = 0; i < _shardIds.Length; i++)
{
_shardIdsToIndex.Add(_shardIds[i], i);
var newConfig = config.Clone();
newConfig.ShardId = _shardIds[i];
- _shards[i] = new DiscordSocketClient(newConfig, i != 0 ? _shards[0] : null, masterIdentifySemaphore, config.IdentifyMaxConcurrency == 1 ? null : identifySemaphores[i / config.IdentifyMaxConcurrency], config.IdentifyMaxConcurrency);
+ _shards[i] = new DiscordSocketClient(newConfig, this, i != 0 ? _shards[0] : null);
RegisterEvents(_shards[i], i == 0);
}
}
}
private static API.DiscordSocketApiClient CreateApiClient(DiscordSocketConfig config)
=> new API.DiscordSocketApiClient(config.RestClientProvider, config.WebSocketProvider, DiscordRestConfig.UserAgent,
- null, null, 0, rateLimitPrecision: config.RateLimitPrecision);
+ rateLimitPrecision: config.RateLimitPrecision);
+
+ internal async Task AcquireIdentifyLockAsync(int shardId, CancellationToken token)
+ {
+ int semaphoreIdx = shardId % _baseConfig.IdentifyMaxConcurrency;
+ await _identifySemaphores[semaphoreIdx].WaitAsync(token).ConfigureAwait(false);
+ }
+
+ internal void ReleaseIdentifyLock()
+ {
+ lock (_semaphoreResetLock)
+ {
+ if (_semaphoreResetTask == null)
+ _semaphoreResetTask = ResetSemaphoresAsync();
+ }
+ }
+
+ private async Task ResetSemaphoresAsync()
+ {
+ await Task.Delay(5000).ConfigureAwait(false);
+ lock (_semaphoreResetLock)
+ {
+ foreach (var semaphore in _identifySemaphores)
+ if (semaphore.CurrentCount == 0)
+ semaphore.Release();
+ _semaphoreResetTask = null;
+ }
+ }
internal override async Task OnLoginAsync(TokenType tokenType, string token)
{
@@ -105,22 +130,16 @@ namespace Discord.WebSocket
_shards = new DiscordSocketClient[_shardIds.Length];
int maxConcurrency = botGateway.SessionStartLimit.MaxConcurrency;
_baseConfig.IdentifyMaxConcurrency = maxConcurrency;
- var masterIdentifySemaphore = new SemaphoreSlim(1, 1);
- SemaphoreSlim[] identifySemaphores = null;
- if (maxConcurrency > 1)
- {
- int maxSemaphores = (_shardIds.Length + maxConcurrency - 1) / maxConcurrency;
- identifySemaphores = new SemaphoreSlim[maxSemaphores];
- for (int i = 0; i < maxSemaphores; i++)
- identifySemaphores[i] = new SemaphoreSlim(0, maxConcurrency);
- }
+ _identifySemaphores = new SemaphoreSlim[maxConcurrency];
+ for (int i = 0; i < maxConcurrency; i++)
+ _identifySemaphores[i] = new SemaphoreSlim(1, 1);
for (int i = 0; i < _shardIds.Length; i++)
{
_shardIdsToIndex.Add(_shardIds[i], i);
var newConfig = _baseConfig.Clone();
newConfig.ShardId = _shardIds[i];
newConfig.TotalShards = _totalShards;
- _shards[i] = new DiscordSocketClient(newConfig, i != 0 ? _shards[0] : null, masterIdentifySemaphore, maxConcurrency == 1 ? null : identifySemaphores[i / maxConcurrency], maxConcurrency);
+ _shards[i] = new DiscordSocketClient(newConfig, this, i != 0 ? _shards[0] : null);
RegisterEvents(_shards[i], i == 0);
}
}
diff --git a/src/Discord.Net.WebSocket/DiscordSocketApiClient.cs b/src/Discord.Net.WebSocket/DiscordSocketApiClient.cs
index 47a7def29..07ebc87ec 100644
--- a/src/Discord.Net.WebSocket/DiscordSocketApiClient.cs
+++ b/src/Discord.Net.WebSocket/DiscordSocketApiClient.cs
@@ -38,11 +38,10 @@ namespace Discord.API
public ConnectionState ConnectionState { get; private set; }
public DiscordSocketApiClient(RestClientProvider restClientProvider, WebSocketProvider webSocketProvider, string userAgent,
- SemaphoreSlim identifyMasterSemaphore, SemaphoreSlim identifySemaphore, int identifyMaxConcurrency,
string url = null, RetryMode defaultRetryMode = RetryMode.AlwaysRetry, JsonSerializer serializer = null,
RateLimitPrecision rateLimitPrecision = RateLimitPrecision.Second,
bool useSystemClock = true)
- : base(restClientProvider, userAgent, new RequestQueue(identifyMasterSemaphore, identifySemaphore, identifyMaxConcurrency), defaultRetryMode, serializer, rateLimitPrecision, useSystemClock)
+ : base(restClientProvider, userAgent, defaultRetryMode, serializer, rateLimitPrecision, useSystemClock)
{
_gatewayUrl = url;
if (url != null)
diff --git a/src/Discord.Net.WebSocket/DiscordSocketClient.cs b/src/Discord.Net.WebSocket/DiscordSocketClient.cs
index 823c2d2fa..9f448c658 100644
--- a/src/Discord.Net.WebSocket/DiscordSocketClient.cs
+++ b/src/Discord.Net.WebSocket/DiscordSocketClient.cs
@@ -26,6 +26,7 @@ namespace Discord.WebSocket
{
private readonly ConcurrentQueue _largeGuilds;
private readonly JsonSerializer _serializer;
+ private readonly DiscordShardedClient _shardedClient;
private readonly DiscordSocketClient _parentClient;
private readonly ConcurrentQueue _heartbeatTimes;
private readonly ConnectionManager _connection;
@@ -118,10 +119,10 @@ namespace Discord.WebSocket
///
/// The configuration to be used with the client.
#pragma warning disable IDISP004
- public DiscordSocketClient(DiscordSocketConfig config) : this(config, CreateApiClient(config, new SemaphoreSlim(1, 1), null, 1), null) { }
- internal DiscordSocketClient(DiscordSocketConfig config, DiscordSocketClient parentClient, SemaphoreSlim identifyMasterSemaphore, SemaphoreSlim identifySemaphore, int identifyMaxConcurrency) : this(config, CreateApiClient(config, identifyMasterSemaphore, identifySemaphore, identifyMaxConcurrency), parentClient) { }
+ public DiscordSocketClient(DiscordSocketConfig config) : this(config, CreateApiClient(config), null, null) { }
+ internal DiscordSocketClient(DiscordSocketConfig config, DiscordShardedClient shardedClient, DiscordSocketClient parentClient) : this(config, CreateApiClient(config), shardedClient, parentClient) { }
#pragma warning restore IDISP004
- private DiscordSocketClient(DiscordSocketConfig config, API.DiscordSocketApiClient client, DiscordSocketClient parentClient)
+ private DiscordSocketClient(DiscordSocketConfig config, API.DiscordSocketApiClient client, DiscordShardedClient shardedClient, DiscordSocketClient parentClient)
: base(config, client)
{
ShardId = config.ShardId ?? 0;
@@ -147,6 +148,7 @@ namespace Discord.WebSocket
_connection.Disconnected += (ex, recon) => TimedInvokeAsync(_disconnectedEvent, nameof(Disconnected), ex);
_nextAudioId = 1;
+ _shardedClient = shardedClient;
_parentClient = parentClient;
_serializer = new JsonSerializer { ContractResolver = new DiscordContractResolver() };
@@ -177,9 +179,8 @@ namespace Discord.WebSocket
_voiceRegions = ImmutableDictionary.Create();
_largeGuilds = new ConcurrentQueue();
}
- private static API.DiscordSocketApiClient CreateApiClient(DiscordSocketConfig config, SemaphoreSlim identifyMasterSemaphore, SemaphoreSlim identifySemaphore, int identifyMaxConcurrency)
- => new API.DiscordSocketApiClient(config.RestClientProvider, config.WebSocketProvider, DiscordRestConfig.UserAgent,
- identifyMasterSemaphore, identifySemaphore, identifyMaxConcurrency, config.GatewayHost,
+ private static API.DiscordSocketApiClient CreateApiClient(DiscordSocketConfig config)
+ => new API.DiscordSocketApiClient(config.RestClientProvider, config.WebSocketProvider, DiscordRestConfig.UserAgent, config.GatewayHost,
rateLimitPrecision: config.RateLimitPrecision);
///
internal override void Dispose(bool disposing)
@@ -228,28 +229,39 @@ namespace Discord.WebSocket
private async Task OnConnectingAsync()
{
- if (_sessionId == null)
- await ApiClient.RequestQueue.AcquireIdentifyTicket(_connection.CancelToken);
-
- await _gatewayLogger.DebugAsync("Connecting ApiClient").ConfigureAwait(false);
- await ApiClient.ConnectAsync().ConfigureAwait(false);
-
- if (_sessionId != null)
+ bool locked = false;
+ if (_shardedClient != null && _sessionId == null)
{
- await _gatewayLogger.DebugAsync("Resuming").ConfigureAwait(false);
- await ApiClient.SendResumeAsync(_sessionId, _lastSeq).ConfigureAwait(false);
+ await _shardedClient.AcquireIdentifyLockAsync(ShardId, _connection.CancelToken).ConfigureAwait(false);
+ locked = true;
}
- else
+ try
{
- await _gatewayLogger.DebugAsync("Identifying").ConfigureAwait(false);
- await ApiClient.SendIdentifyAsync(shardID: ShardId, totalShards: TotalShards, guildSubscriptions: _guildSubscriptions, gatewayIntents: _gatewayIntents).ConfigureAwait(false);
- }
+ await _gatewayLogger.DebugAsync("Connecting ApiClient").ConfigureAwait(false);
+ await ApiClient.ConnectAsync().ConfigureAwait(false);
- //Wait for READY
- await _connection.WaitAsync().ConfigureAwait(false);
+ if (_sessionId != null)
+ {
+ await _gatewayLogger.DebugAsync("Resuming").ConfigureAwait(false);
+ await ApiClient.SendResumeAsync(_sessionId, _lastSeq).ConfigureAwait(false);
+ }
+ else
+ {
+ await _gatewayLogger.DebugAsync("Identifying").ConfigureAwait(false);
+ await ApiClient.SendIdentifyAsync(shardID: ShardId, totalShards: TotalShards, guildSubscriptions: _guildSubscriptions, gatewayIntents: _gatewayIntents).ConfigureAwait(false);
+ }
- await _gatewayLogger.DebugAsync("Sending Status").ConfigureAwait(false);
- await SendStatusAsync().ConfigureAwait(false);
+ //Wait for READY
+ await _connection.WaitAsync().ConfigureAwait(false);
+
+ await _gatewayLogger.DebugAsync("Sending Status").ConfigureAwait(false);
+ await SendStatusAsync().ConfigureAwait(false);
+ }
+ finally
+ {
+ if (locked)
+ _shardedClient.ReleaseIdentifyLock();
+ }
}
private async Task OnDisconnectingAsync(Exception ex)
{
diff --git a/src/Discord.Net.Webhook/DiscordWebhookClient.cs b/src/Discord.Net.Webhook/DiscordWebhookClient.cs
index c39d377c7..a6d4ef183 100644
--- a/src/Discord.Net.Webhook/DiscordWebhookClient.cs
+++ b/src/Discord.Net.Webhook/DiscordWebhookClient.cs
@@ -84,7 +84,7 @@ namespace Discord.Webhook
ApiClient.SentRequest += async (method, endpoint, millis) => await _restLogger.VerboseAsync($"{method} {endpoint}: {millis} ms").ConfigureAwait(false);
}
private static API.DiscordRestApiClient CreateApiClient(DiscordRestConfig config)
- => new API.DiscordRestApiClient(config.RestClientProvider, DiscordRestConfig.UserAgent, null);
+ => new API.DiscordRestApiClient(config.RestClientProvider, DiscordRestConfig.UserAgent);
/// Sends a message to the channel for this webhook.
/// Returns the ID of the created message.
public Task SendMessageAsync(string text = null, bool isTTS = false, IEnumerable