diff --git a/src/Discord.Net.Rest/DiscordRestApiClient.cs b/src/Discord.Net.Rest/DiscordRestApiClient.cs index 52d7e0cd5..592ad7e92 100644 --- a/src/Discord.Net.Rest/DiscordRestApiClient.cs +++ b/src/Discord.Net.Rest/DiscordRestApiClient.cs @@ -51,7 +51,7 @@ namespace Discord.API internal JsonSerializer Serializer => _serializer; /// Unknown OAuth token type. - public DiscordRestApiClient(RestClientProvider restClientProvider, string userAgent, RequestQueue requestQueue, RetryMode defaultRetryMode = RetryMode.AlwaysRetry, + public DiscordRestApiClient(RestClientProvider restClientProvider, string userAgent, RetryMode defaultRetryMode = RetryMode.AlwaysRetry, JsonSerializer serializer = null, RateLimitPrecision rateLimitPrecision = RateLimitPrecision.Second, bool useSystemClock = true) { _restClientProvider = restClientProvider; @@ -61,7 +61,7 @@ namespace Discord.API RateLimitPrecision = rateLimitPrecision; UseSystemClock = useSystemClock; - RequestQueue = requestQueue ?? new RequestQueue(); + RequestQueue = new RequestQueue(); _stateLock = new SemaphoreSlim(1, 1); SetBaseUrl(DiscordConfig.APIUrl); diff --git a/src/Discord.Net.Rest/DiscordRestClient.cs b/src/Discord.Net.Rest/DiscordRestClient.cs index 65af43a99..48c40fdfa 100644 --- a/src/Discord.Net.Rest/DiscordRestClient.cs +++ b/src/Discord.Net.Rest/DiscordRestClient.cs @@ -31,7 +31,6 @@ namespace Discord.Rest private static API.DiscordRestApiClient CreateApiClient(DiscordRestConfig config) => new API.DiscordRestApiClient(config.RestClientProvider, DiscordRestConfig.UserAgent, - null, rateLimitPrecision: config.RateLimitPrecision, useSystemClock: config.UseSystemClock); diff --git a/src/Discord.Net.Rest/Net/Queue/RequestQueue.cs b/src/Discord.Net.Rest/Net/Queue/RequestQueue.cs index 488e2c5c8..2bf8e20b0 100644 --- a/src/Discord.Net.Rest/Net/Queue/RequestQueue.cs +++ b/src/Discord.Net.Rest/Net/Queue/RequestQueue.cs @@ -23,10 +23,6 @@ namespace Discord.Net.Queue private CancellationToken _requestCancelToken; //Parent token + Clear token private DateTimeOffset _waitUntil; - private readonly SemaphoreSlim _masterIdentifySemaphore; - private readonly SemaphoreSlim _identifySemaphore; - private readonly int _identifySemaphoreMaxConcurrency; - private Task _cleanupTask; public RequestQueue() @@ -43,14 +39,6 @@ namespace Discord.Net.Queue _cleanupTask = RunCleanup(); } - public RequestQueue(SemaphoreSlim masterIdentifySemaphore, SemaphoreSlim slaveIdentifySemaphore, int slaveIdentifySemaphoreMaxConcurrency) - : this () - { - _masterIdentifySemaphore = masterIdentifySemaphore; - _identifySemaphore = slaveIdentifySemaphore; - _identifySemaphoreMaxConcurrency = slaveIdentifySemaphoreMaxConcurrency; - } - public async Task SetCancelTokenAsync(CancellationToken cancelToken) { await _tokenLock.WaitAsync().ConfigureAwait(false); @@ -145,42 +133,6 @@ namespace Discord.Net.Queue var globalBucket = GetOrCreateBucket(options, globalRequest); await globalBucket.TriggerAsync(id, globalRequest); } - internal void ReleaseIdentifySemaphore(int id) - { - if (_masterIdentifySemaphore == null) - throw new InvalidOperationException("Not a RequestQueue with WebSocket data."); - - while (_identifySemaphore?.Wait(0) == true) //exhaust all tickets before releasing master - { } - _masterIdentifySemaphore.Release(); -#if DEBUG_LIMITS - Debug.WriteLine($"[{id}] Released identify master semaphore"); -#endif - } - - public async Task AcquireIdentifyTicket(CancellationToken cancellationToken) - { - try - { - if (_masterIdentifySemaphore == null) - throw new InvalidOperationException("Not a RequestQueue with WebSocket data."); - - if (_identifySemaphore == null) - await _masterIdentifySemaphore.WaitAsync(cancellationToken); - else - { - bool master; - while (!(master = _masterIdentifySemaphore.Wait(0)) && !_identifySemaphore.Wait(0)) //To not block the thread - await Task.Delay(100, cancellationToken); - if (master && _identifySemaphoreMaxConcurrency > 1) - _identifySemaphore.Release(_identifySemaphoreMaxConcurrency - 1); - } -#if DEBUG_LIMITS - Debug.WriteLine($"[{id}] Acquired identify ticket"); -#endif - } - catch(OperationCanceledException) { } - } private RequestBucket GetOrCreateBucket(RequestOptions options, IRequest request) { @@ -245,8 +197,6 @@ namespace Discord.Net.Queue _tokenLock?.Dispose(); _clearToken?.Dispose(); _requestCancelTokenSource?.Dispose(); - _masterIdentifySemaphore?.Dispose(); - _identifySemaphore?.Dispose(); } } } diff --git a/src/Discord.Net.Rest/Net/Queue/RequestQueueBucket.cs b/src/Discord.Net.Rest/Net/Queue/RequestQueueBucket.cs index ece18b819..3fb45e55d 100644 --- a/src/Discord.Net.Rest/Net/Queue/RequestQueueBucket.cs +++ b/src/Discord.Net.Rest/Net/Queue/RequestQueueBucket.cs @@ -457,8 +457,6 @@ namespace Discord.Net.Queue #if DEBUG_LIMITS Debug.WriteLine($"[{id}] * Reset *"); #endif - if (request is WebSocketRequest webSocketRequest && webSocketRequest.Options.BucketId == GatewayBucket.Get(GatewayBucketType.Identify).Id) - _queue.ReleaseIdentifySemaphore(id); _semaphore = WindowCount; _resetTick = null; return; diff --git a/src/Discord.Net.WebSocket/BaseSocketClient.cs b/src/Discord.Net.WebSocket/BaseSocketClient.cs index cceee2557..548bb75bf 100644 --- a/src/Discord.Net.WebSocket/BaseSocketClient.cs +++ b/src/Discord.Net.WebSocket/BaseSocketClient.cs @@ -1,6 +1,5 @@ using System.Collections.Generic; using System.IO; -using System.Threading; using System.Threading.Tasks; using Discord.API; using Discord.Rest; @@ -80,9 +79,8 @@ namespace Discord.WebSocket internal BaseSocketClient(DiscordSocketConfig config, DiscordRestApiClient client) : base(config, client) => BaseConfig = config; - private static DiscordSocketApiClient CreateApiClient(DiscordSocketConfig config, SemaphoreSlim identifyMasterSemaphore, SemaphoreSlim identifySemaphore, int identifyMaxConcurrency) + private static DiscordSocketApiClient CreateApiClient(DiscordSocketConfig config) => new DiscordSocketApiClient(config.RestClientProvider, config.WebSocketProvider, DiscordRestConfig.UserAgent, - identifyMasterSemaphore, identifySemaphore, identifyMaxConcurrency, rateLimitPrecision: config.RateLimitPrecision, useSystemClock: config.UseSystemClock); diff --git a/src/Discord.Net.WebSocket/DiscordShardedClient.cs b/src/Discord.Net.WebSocket/DiscordShardedClient.cs index 6c98d9e43..a2c89d4e5 100644 --- a/src/Discord.Net.WebSocket/DiscordShardedClient.cs +++ b/src/Discord.Net.WebSocket/DiscordShardedClient.cs @@ -17,6 +17,9 @@ namespace Discord.WebSocket private int[] _shardIds; private DiscordSocketClient[] _shards; private int _totalShards; + private SemaphoreSlim[] _identifySemaphores; + private object _semaphoreResetLock; + private Task _semaphoreResetTask; private bool _isDisposed; @@ -61,6 +64,7 @@ namespace Discord.WebSocket if (ids != null && config.TotalShards == null) throw new ArgumentException($"Custom ids are not supported when {nameof(config.TotalShards)} is not specified."); + _semaphoreResetLock = new object(); _shardIdsToIndex = new Dictionary(); config.DisplayInitialLog = false; _baseConfig = config; @@ -72,28 +76,49 @@ namespace Discord.WebSocket _totalShards = config.TotalShards.Value; _shardIds = ids ?? Enumerable.Range(0, _totalShards).ToArray(); _shards = new DiscordSocketClient[_shardIds.Length]; - var masterIdentifySemaphore = new SemaphoreSlim(1, 1); - SemaphoreSlim[] identifySemaphores = null; - if (config.IdentifyMaxConcurrency > 1) - { - int maxSemaphores = (_shardIds.Length + config.IdentifyMaxConcurrency - 1) / config.IdentifyMaxConcurrency; - identifySemaphores = new SemaphoreSlim[maxSemaphores]; - for (int i = 0; i < maxSemaphores; i++) - identifySemaphores[i] = new SemaphoreSlim(0, config.IdentifyMaxConcurrency); - } + _identifySemaphores = new SemaphoreSlim[config.IdentifyMaxConcurrency]; + for (int i = 0; i < config.IdentifyMaxConcurrency; i++) + _identifySemaphores[i] = new SemaphoreSlim(1, 1); for (int i = 0; i < _shardIds.Length; i++) { _shardIdsToIndex.Add(_shardIds[i], i); var newConfig = config.Clone(); newConfig.ShardId = _shardIds[i]; - _shards[i] = new DiscordSocketClient(newConfig, i != 0 ? _shards[0] : null, masterIdentifySemaphore, config.IdentifyMaxConcurrency == 1 ? null : identifySemaphores[i / config.IdentifyMaxConcurrency], config.IdentifyMaxConcurrency); + _shards[i] = new DiscordSocketClient(newConfig, this, i != 0 ? _shards[0] : null); RegisterEvents(_shards[i], i == 0); } } } private static API.DiscordSocketApiClient CreateApiClient(DiscordSocketConfig config) => new API.DiscordSocketApiClient(config.RestClientProvider, config.WebSocketProvider, DiscordRestConfig.UserAgent, - null, null, 0, rateLimitPrecision: config.RateLimitPrecision); + rateLimitPrecision: config.RateLimitPrecision); + + internal async Task AcquireIdentifyLockAsync(int shardId, CancellationToken token) + { + int semaphoreIdx = shardId % _baseConfig.IdentifyMaxConcurrency; + await _identifySemaphores[semaphoreIdx].WaitAsync(token).ConfigureAwait(false); + } + + internal void ReleaseIdentifyLock() + { + lock (_semaphoreResetLock) + { + if (_semaphoreResetTask == null) + _semaphoreResetTask = ResetSemaphoresAsync(); + } + } + + private async Task ResetSemaphoresAsync() + { + await Task.Delay(5000).ConfigureAwait(false); + lock (_semaphoreResetLock) + { + foreach (var semaphore in _identifySemaphores) + if (semaphore.CurrentCount == 0) + semaphore.Release(); + _semaphoreResetTask = null; + } + } internal override async Task OnLoginAsync(TokenType tokenType, string token) { @@ -105,22 +130,16 @@ namespace Discord.WebSocket _shards = new DiscordSocketClient[_shardIds.Length]; int maxConcurrency = botGateway.SessionStartLimit.MaxConcurrency; _baseConfig.IdentifyMaxConcurrency = maxConcurrency; - var masterIdentifySemaphore = new SemaphoreSlim(1, 1); - SemaphoreSlim[] identifySemaphores = null; - if (maxConcurrency > 1) - { - int maxSemaphores = (_shardIds.Length + maxConcurrency - 1) / maxConcurrency; - identifySemaphores = new SemaphoreSlim[maxSemaphores]; - for (int i = 0; i < maxSemaphores; i++) - identifySemaphores[i] = new SemaphoreSlim(0, maxConcurrency); - } + _identifySemaphores = new SemaphoreSlim[maxConcurrency]; + for (int i = 0; i < maxConcurrency; i++) + _identifySemaphores[i] = new SemaphoreSlim(1, 1); for (int i = 0; i < _shardIds.Length; i++) { _shardIdsToIndex.Add(_shardIds[i], i); var newConfig = _baseConfig.Clone(); newConfig.ShardId = _shardIds[i]; newConfig.TotalShards = _totalShards; - _shards[i] = new DiscordSocketClient(newConfig, i != 0 ? _shards[0] : null, masterIdentifySemaphore, maxConcurrency == 1 ? null : identifySemaphores[i / maxConcurrency], maxConcurrency); + _shards[i] = new DiscordSocketClient(newConfig, this, i != 0 ? _shards[0] : null); RegisterEvents(_shards[i], i == 0); } } diff --git a/src/Discord.Net.WebSocket/DiscordSocketApiClient.cs b/src/Discord.Net.WebSocket/DiscordSocketApiClient.cs index 47a7def29..07ebc87ec 100644 --- a/src/Discord.Net.WebSocket/DiscordSocketApiClient.cs +++ b/src/Discord.Net.WebSocket/DiscordSocketApiClient.cs @@ -38,11 +38,10 @@ namespace Discord.API public ConnectionState ConnectionState { get; private set; } public DiscordSocketApiClient(RestClientProvider restClientProvider, WebSocketProvider webSocketProvider, string userAgent, - SemaphoreSlim identifyMasterSemaphore, SemaphoreSlim identifySemaphore, int identifyMaxConcurrency, string url = null, RetryMode defaultRetryMode = RetryMode.AlwaysRetry, JsonSerializer serializer = null, RateLimitPrecision rateLimitPrecision = RateLimitPrecision.Second, bool useSystemClock = true) - : base(restClientProvider, userAgent, new RequestQueue(identifyMasterSemaphore, identifySemaphore, identifyMaxConcurrency), defaultRetryMode, serializer, rateLimitPrecision, useSystemClock) + : base(restClientProvider, userAgent, defaultRetryMode, serializer, rateLimitPrecision, useSystemClock) { _gatewayUrl = url; if (url != null) diff --git a/src/Discord.Net.WebSocket/DiscordSocketClient.cs b/src/Discord.Net.WebSocket/DiscordSocketClient.cs index 823c2d2fa..9f448c658 100644 --- a/src/Discord.Net.WebSocket/DiscordSocketClient.cs +++ b/src/Discord.Net.WebSocket/DiscordSocketClient.cs @@ -26,6 +26,7 @@ namespace Discord.WebSocket { private readonly ConcurrentQueue _largeGuilds; private readonly JsonSerializer _serializer; + private readonly DiscordShardedClient _shardedClient; private readonly DiscordSocketClient _parentClient; private readonly ConcurrentQueue _heartbeatTimes; private readonly ConnectionManager _connection; @@ -118,10 +119,10 @@ namespace Discord.WebSocket /// /// The configuration to be used with the client. #pragma warning disable IDISP004 - public DiscordSocketClient(DiscordSocketConfig config) : this(config, CreateApiClient(config, new SemaphoreSlim(1, 1), null, 1), null) { } - internal DiscordSocketClient(DiscordSocketConfig config, DiscordSocketClient parentClient, SemaphoreSlim identifyMasterSemaphore, SemaphoreSlim identifySemaphore, int identifyMaxConcurrency) : this(config, CreateApiClient(config, identifyMasterSemaphore, identifySemaphore, identifyMaxConcurrency), parentClient) { } + public DiscordSocketClient(DiscordSocketConfig config) : this(config, CreateApiClient(config), null, null) { } + internal DiscordSocketClient(DiscordSocketConfig config, DiscordShardedClient shardedClient, DiscordSocketClient parentClient) : this(config, CreateApiClient(config), shardedClient, parentClient) { } #pragma warning restore IDISP004 - private DiscordSocketClient(DiscordSocketConfig config, API.DiscordSocketApiClient client, DiscordSocketClient parentClient) + private DiscordSocketClient(DiscordSocketConfig config, API.DiscordSocketApiClient client, DiscordShardedClient shardedClient, DiscordSocketClient parentClient) : base(config, client) { ShardId = config.ShardId ?? 0; @@ -147,6 +148,7 @@ namespace Discord.WebSocket _connection.Disconnected += (ex, recon) => TimedInvokeAsync(_disconnectedEvent, nameof(Disconnected), ex); _nextAudioId = 1; + _shardedClient = shardedClient; _parentClient = parentClient; _serializer = new JsonSerializer { ContractResolver = new DiscordContractResolver() }; @@ -177,9 +179,8 @@ namespace Discord.WebSocket _voiceRegions = ImmutableDictionary.Create(); _largeGuilds = new ConcurrentQueue(); } - private static API.DiscordSocketApiClient CreateApiClient(DiscordSocketConfig config, SemaphoreSlim identifyMasterSemaphore, SemaphoreSlim identifySemaphore, int identifyMaxConcurrency) - => new API.DiscordSocketApiClient(config.RestClientProvider, config.WebSocketProvider, DiscordRestConfig.UserAgent, - identifyMasterSemaphore, identifySemaphore, identifyMaxConcurrency, config.GatewayHost, + private static API.DiscordSocketApiClient CreateApiClient(DiscordSocketConfig config) + => new API.DiscordSocketApiClient(config.RestClientProvider, config.WebSocketProvider, DiscordRestConfig.UserAgent, config.GatewayHost, rateLimitPrecision: config.RateLimitPrecision); /// internal override void Dispose(bool disposing) @@ -228,28 +229,39 @@ namespace Discord.WebSocket private async Task OnConnectingAsync() { - if (_sessionId == null) - await ApiClient.RequestQueue.AcquireIdentifyTicket(_connection.CancelToken); - - await _gatewayLogger.DebugAsync("Connecting ApiClient").ConfigureAwait(false); - await ApiClient.ConnectAsync().ConfigureAwait(false); - - if (_sessionId != null) + bool locked = false; + if (_shardedClient != null && _sessionId == null) { - await _gatewayLogger.DebugAsync("Resuming").ConfigureAwait(false); - await ApiClient.SendResumeAsync(_sessionId, _lastSeq).ConfigureAwait(false); + await _shardedClient.AcquireIdentifyLockAsync(ShardId, _connection.CancelToken).ConfigureAwait(false); + locked = true; } - else + try { - await _gatewayLogger.DebugAsync("Identifying").ConfigureAwait(false); - await ApiClient.SendIdentifyAsync(shardID: ShardId, totalShards: TotalShards, guildSubscriptions: _guildSubscriptions, gatewayIntents: _gatewayIntents).ConfigureAwait(false); - } + await _gatewayLogger.DebugAsync("Connecting ApiClient").ConfigureAwait(false); + await ApiClient.ConnectAsync().ConfigureAwait(false); - //Wait for READY - await _connection.WaitAsync().ConfigureAwait(false); + if (_sessionId != null) + { + await _gatewayLogger.DebugAsync("Resuming").ConfigureAwait(false); + await ApiClient.SendResumeAsync(_sessionId, _lastSeq).ConfigureAwait(false); + } + else + { + await _gatewayLogger.DebugAsync("Identifying").ConfigureAwait(false); + await ApiClient.SendIdentifyAsync(shardID: ShardId, totalShards: TotalShards, guildSubscriptions: _guildSubscriptions, gatewayIntents: _gatewayIntents).ConfigureAwait(false); + } - await _gatewayLogger.DebugAsync("Sending Status").ConfigureAwait(false); - await SendStatusAsync().ConfigureAwait(false); + //Wait for READY + await _connection.WaitAsync().ConfigureAwait(false); + + await _gatewayLogger.DebugAsync("Sending Status").ConfigureAwait(false); + await SendStatusAsync().ConfigureAwait(false); + } + finally + { + if (locked) + _shardedClient.ReleaseIdentifyLock(); + } } private async Task OnDisconnectingAsync(Exception ex) { diff --git a/src/Discord.Net.Webhook/DiscordWebhookClient.cs b/src/Discord.Net.Webhook/DiscordWebhookClient.cs index c39d377c7..a6d4ef183 100644 --- a/src/Discord.Net.Webhook/DiscordWebhookClient.cs +++ b/src/Discord.Net.Webhook/DiscordWebhookClient.cs @@ -84,7 +84,7 @@ namespace Discord.Webhook ApiClient.SentRequest += async (method, endpoint, millis) => await _restLogger.VerboseAsync($"{method} {endpoint}: {millis} ms").ConfigureAwait(false); } private static API.DiscordRestApiClient CreateApiClient(DiscordRestConfig config) - => new API.DiscordRestApiClient(config.RestClientProvider, DiscordRestConfig.UserAgent, null); + => new API.DiscordRestApiClient(config.RestClientProvider, DiscordRestConfig.UserAgent); /// Sends a message to the channel for this webhook. /// Returns the ID of the created message. public Task SendMessageAsync(string text = null, bool isTTS = false, IEnumerable embeds = null,