From ec673e186317cd7449eabf2adf29dc78a5e22b90 Mon Sep 17 00:00:00 2001 From: Paulo Date: Wed, 18 Nov 2020 23:40:09 -0300 Subject: [PATCH] feature: Implement gateway ratelimit (#1537) * Implement gateway ratelimit * Remove unused code * Share WebSocketRequestQueue between clients * Add global limit and a way to change gateway limits * Refactoring variable to fit lib standards * Update xml docs * Update xml docs * Move warning to remarks * Remove specific RequestQueue for WebSocket and other changes The only account limit is for identify that is dealt in a different way (exclusive semaphore), so websocket queues can be shared with REST and don't need to be shared between clients anymore. Also added the ratelimit for presence updates. * Add summary to IdentifySemaphoreName * Fix spacing * Add max_concurrency and other fixes - Add session_start_limit to GetBotGatewayResponse - Add GetBotGatewayAsync to IDiscordClient - Add master/slave semaphores to enable concurrency - Not store semaphore name as static - Clone GatewayLimits when cloning the Config * Add missing RequestQueue parameter and wrong nullable * Add RequeueQueue paramater to Webhook * Better xml documentation * Remove GatewayLimits class and other changes - Remove GatewayLimits - Transfer a few properties to DiscordSocketConfig - Remove unnecessary usings * Remove unnecessary using and wording * Remove more unnecessary usings * Change named Semaphores to SemaphoreSlim * Remove unused using * Update branch * Fix merge conflicts and update to new ratelimit * Fixing merge, ignore limit for heartbeat, and dispose * Missed one place and better xml docs. * Wait identify before opening the connection * Only request identify ticket when needed * Move identify control to sharded client * Better description for IdentifyMaxConcurrency * Add lock to InvalidSession --- .../Entities/Gateway/BotGateway.cs | 22 ++++ .../Entities/Gateway/SessionStartLimit.cs | 38 ++++++ src/Discord.Net.Core/IDiscordClient.cs | 10 ++ src/Discord.Net.Core/RequestOptions.cs | 1 + .../API/Common/SessionStartLimit.cs | 16 +++ .../API/Rest/GetBotGatewayResponse.cs | 4 +- src/Discord.Net.Rest/BaseDiscordClient.cs | 4 + src/Discord.Net.Rest/ClientHelper.cs | 17 +++ src/Discord.Net.Rest/DiscordRestClient.cs | 8 +- .../Net/Queue/GatewayBucket.cs | 53 ++++++++ .../Net/Queue/RequestQueue.cs | 38 +++++- .../Net/Queue/RequestQueueBucket.cs | 118 ++++++++++++++++-- .../Net/Queue/Requests/WebSocketRequest.cs | 6 +- .../DiscordShardedClient.cs | 50 ++++++-- .../DiscordSocketApiClient.cs | 11 +- .../DiscordSocketClient.cs | 33 +++-- .../DiscordSocketConfig.cs | 8 ++ 17 files changed, 397 insertions(+), 40 deletions(-) create mode 100644 src/Discord.Net.Core/Entities/Gateway/BotGateway.cs create mode 100644 src/Discord.Net.Core/Entities/Gateway/SessionStartLimit.cs create mode 100644 src/Discord.Net.Rest/API/Common/SessionStartLimit.cs create mode 100644 src/Discord.Net.Rest/Net/Queue/GatewayBucket.cs diff --git a/src/Discord.Net.Core/Entities/Gateway/BotGateway.cs b/src/Discord.Net.Core/Entities/Gateway/BotGateway.cs new file mode 100644 index 000000000..c9be0ac1f --- /dev/null +++ b/src/Discord.Net.Core/Entities/Gateway/BotGateway.cs @@ -0,0 +1,22 @@ +namespace Discord +{ + /// + /// Stores the gateway information related to the current bot. + /// + public class BotGateway + { + /// + /// Gets the WSS URL that can be used for connecting to the gateway. + /// + public string Url { get; internal set; } + /// + /// Gets the recommended number of shards to use when connecting. + /// + public int Shards { get; internal set; } + /// + /// Gets the that contains the information + /// about the current session start limit. + /// + public SessionStartLimit SessionStartLimit { get; internal set; } + } +} diff --git a/src/Discord.Net.Core/Entities/Gateway/SessionStartLimit.cs b/src/Discord.Net.Core/Entities/Gateway/SessionStartLimit.cs new file mode 100644 index 000000000..74ae96af1 --- /dev/null +++ b/src/Discord.Net.Core/Entities/Gateway/SessionStartLimit.cs @@ -0,0 +1,38 @@ +namespace Discord +{ + /// + /// Stores the information related to the gateway identify request. + /// + public class SessionStartLimit + { + /// + /// Gets the total number of session starts the current user is allowed. + /// + /// + /// The maximum amount of session starts the current user is allowed. + /// + public int Total { get; internal set; } + /// + /// Gets the remaining number of session starts the current user is allowed. + /// + /// + /// The remaining amount of session starts the current user is allowed. + /// + public int Remaining { get; internal set; } + /// + /// Gets the number of milliseconds after which the limit resets. + /// + /// + /// The milliseconds until the limit resets back to the . + /// + public int ResetAfter { get; internal set; } + /// + /// Gets the maximum concurrent identify requests in a time window. + /// + /// + /// The maximum concurrent identify requests in a time window, + /// limited to the same rate limit key. + /// + public int MaxConcurrency { get; internal set; } + } +} diff --git a/src/Discord.Net.Core/IDiscordClient.cs b/src/Discord.Net.Core/IDiscordClient.cs index f972cd71d..d7d6d2856 100644 --- a/src/Discord.Net.Core/IDiscordClient.cs +++ b/src/Discord.Net.Core/IDiscordClient.cs @@ -274,5 +274,15 @@ namespace Discord /// that represents the number of shards that should be used with this account. /// Task GetRecommendedShardCountAsync(RequestOptions options = null); + + /// + /// Gets the gateway information related to the bot. + /// + /// The options to be used when sending the request. + /// + /// A task that represents the asynchronous get operation. The task result contains a + /// that represents the gateway information related to the bot. + /// + Task GetBotGatewayAsync(RequestOptions options = null); } } diff --git a/src/Discord.Net.Core/RequestOptions.cs b/src/Discord.Net.Core/RequestOptions.cs index ad0a4e33f..dbb240273 100644 --- a/src/Discord.Net.Core/RequestOptions.cs +++ b/src/Discord.Net.Core/RequestOptions.cs @@ -61,6 +61,7 @@ namespace Discord internal BucketId BucketId { get; set; } internal bool IsClientBucket { get; set; } internal bool IsReactionBucket { get; set; } + internal bool IsGatewayBucket { get; set; } internal static RequestOptions CreateOrClone(RequestOptions options) { diff --git a/src/Discord.Net.Rest/API/Common/SessionStartLimit.cs b/src/Discord.Net.Rest/API/Common/SessionStartLimit.cs new file mode 100644 index 000000000..29d5ddf85 --- /dev/null +++ b/src/Discord.Net.Rest/API/Common/SessionStartLimit.cs @@ -0,0 +1,16 @@ +using Newtonsoft.Json; + +namespace Discord.API.Rest +{ + internal class SessionStartLimit + { + [JsonProperty("total")] + public int Total { get; set; } + [JsonProperty("remaining")] + public int Remaining { get; set; } + [JsonProperty("reset_after")] + public int ResetAfter { get; set; } + [JsonProperty("max_concurrency")] + public int MaxConcurrency { get; set; } + } +} diff --git a/src/Discord.Net.Rest/API/Rest/GetBotGatewayResponse.cs b/src/Discord.Net.Rest/API/Rest/GetBotGatewayResponse.cs index 111fcf3db..d3285051b 100644 --- a/src/Discord.Net.Rest/API/Rest/GetBotGatewayResponse.cs +++ b/src/Discord.Net.Rest/API/Rest/GetBotGatewayResponse.cs @@ -1,4 +1,4 @@ -#pragma warning disable CS1591 +#pragma warning disable CS1591 using Newtonsoft.Json; namespace Discord.API.Rest @@ -9,5 +9,7 @@ namespace Discord.API.Rest public string Url { get; set; } [JsonProperty("shards")] public int Shards { get; set; } + [JsonProperty("session_start_limit")] + public SessionStartLimit SessionStartLimit { get; set; } } } diff --git a/src/Discord.Net.Rest/BaseDiscordClient.cs b/src/Discord.Net.Rest/BaseDiscordClient.cs index b641fa1c3..68589a4f1 100644 --- a/src/Discord.Net.Rest/BaseDiscordClient.cs +++ b/src/Discord.Net.Rest/BaseDiscordClient.cs @@ -152,6 +152,10 @@ namespace Discord.Rest public Task GetRecommendedShardCountAsync(RequestOptions options = null) => ClientHelper.GetRecommendShardCountAsync(this, options); + /// + public Task GetBotGatewayAsync(RequestOptions options = null) + => ClientHelper.GetBotGatewayAsync(this, options); + //IDiscordClient /// ConnectionState IDiscordClient.ConnectionState => ConnectionState.Disconnected; diff --git a/src/Discord.Net.Rest/ClientHelper.cs b/src/Discord.Net.Rest/ClientHelper.cs index 6ebdbcacb..8910e999a 100644 --- a/src/Discord.Net.Rest/ClientHelper.cs +++ b/src/Discord.Net.Rest/ClientHelper.cs @@ -184,5 +184,22 @@ namespace Discord.Rest var response = await client.ApiClient.GetBotGatewayAsync(options).ConfigureAwait(false); return response.Shards; } + + public static async Task GetBotGatewayAsync(BaseDiscordClient client, RequestOptions options) + { + var response = await client.ApiClient.GetBotGatewayAsync(options).ConfigureAwait(false); + return new BotGateway + { + Url = response.Url, + Shards = response.Shards, + SessionStartLimit = new SessionStartLimit + { + Total = response.SessionStartLimit.Total, + Remaining = response.SessionStartLimit.Remaining, + ResetAfter = response.SessionStartLimit.ResetAfter, + MaxConcurrency = response.SessionStartLimit.MaxConcurrency + } + }; + } } } diff --git a/src/Discord.Net.Rest/DiscordRestClient.cs b/src/Discord.Net.Rest/DiscordRestClient.cs index bef4e6b2a..48c40fdfa 100644 --- a/src/Discord.Net.Rest/DiscordRestClient.cs +++ b/src/Discord.Net.Rest/DiscordRestClient.cs @@ -29,10 +29,10 @@ namespace Discord.Rest internal DiscordRestClient(DiscordRestConfig config, API.DiscordRestApiClient api) : base(config, api) { } private static API.DiscordRestApiClient CreateApiClient(DiscordRestConfig config) - => new API.DiscordRestApiClient(config.RestClientProvider, - DiscordRestConfig.UserAgent, - rateLimitPrecision: config.RateLimitPrecision, - useSystemClock: config.UseSystemClock); + => new API.DiscordRestApiClient(config.RestClientProvider, + DiscordRestConfig.UserAgent, + rateLimitPrecision: config.RateLimitPrecision, + useSystemClock: config.UseSystemClock); internal override void Dispose(bool disposing) { diff --git a/src/Discord.Net.Rest/Net/Queue/GatewayBucket.cs b/src/Discord.Net.Rest/Net/Queue/GatewayBucket.cs new file mode 100644 index 000000000..aa849018a --- /dev/null +++ b/src/Discord.Net.Rest/Net/Queue/GatewayBucket.cs @@ -0,0 +1,53 @@ +using System.Collections.Immutable; + +namespace Discord.Net.Queue +{ + public enum GatewayBucketType + { + Unbucketed = 0, + Identify = 1, + PresenceUpdate = 2, + } + internal struct GatewayBucket + { + private static readonly ImmutableDictionary DefsByType; + private static readonly ImmutableDictionary DefsById; + + static GatewayBucket() + { + var buckets = new[] + { + // Limit is 120/60s, but 3 will be reserved for heartbeats (2 for possible heartbeats in the same timeframe and a possible failure) + new GatewayBucket(GatewayBucketType.Unbucketed, BucketId.Create(null, "", null), 117, 60), + new GatewayBucket(GatewayBucketType.Identify, BucketId.Create(null, "", null), 1, 5), + new GatewayBucket(GatewayBucketType.PresenceUpdate, BucketId.Create(null, "", null), 5, 60), + }; + + var builder = ImmutableDictionary.CreateBuilder(); + foreach (var bucket in buckets) + builder.Add(bucket.Type, bucket); + DefsByType = builder.ToImmutable(); + + var builder2 = ImmutableDictionary.CreateBuilder(); + foreach (var bucket in buckets) + builder2.Add(bucket.Id, bucket); + DefsById = builder2.ToImmutable(); + } + + public static GatewayBucket Get(GatewayBucketType type) => DefsByType[type]; + public static GatewayBucket Get(BucketId id) => DefsById[id]; + + public GatewayBucketType Type { get; } + public BucketId Id { get; } + public int WindowCount { get; set; } + public int WindowSeconds { get; set; } + + public GatewayBucket(GatewayBucketType type, BucketId id, int count, int seconds) + { + Type = type; + Id = id; + WindowCount = count; + WindowSeconds = seconds; + } + } +} diff --git a/src/Discord.Net.Rest/Net/Queue/RequestQueue.cs b/src/Discord.Net.Rest/Net/Queue/RequestQueue.cs index 127a48cf3..2bf8e20b0 100644 --- a/src/Discord.Net.Rest/Net/Queue/RequestQueue.cs +++ b/src/Discord.Net.Rest/Net/Queue/RequestQueue.cs @@ -89,9 +89,18 @@ namespace Discord.Net.Queue } public async Task SendAsync(WebSocketRequest request) { - //TODO: Re-impl websocket buckets - request.CancelToken = _requestCancelToken; - await request.SendAsync().ConfigureAwait(false); + CancellationTokenSource createdTokenSource = null; + if (request.Options.CancelToken.CanBeCanceled) + { + createdTokenSource = CancellationTokenSource.CreateLinkedTokenSource(_requestCancelToken, request.Options.CancelToken); + request.Options.CancelToken = createdTokenSource.Token; + } + else + request.Options.CancelToken = _requestCancelToken; + + var bucket = GetOrCreateBucket(request.Options, request); + await bucket.SendAsync(request).ConfigureAwait(false); + createdTokenSource?.Dispose(); } internal async Task EnterGlobalAsync(int id, RestRequest request) @@ -109,8 +118,23 @@ namespace Discord.Net.Queue { _waitUntil = DateTimeOffset.UtcNow.AddMilliseconds(info.RetryAfter.Value + (info.Lag?.TotalMilliseconds ?? 0.0)); } + internal async Task EnterGlobalAsync(int id, WebSocketRequest request) + { + //If this is a global request (unbucketed), it'll be dealt in EnterAsync + var requestBucket = GatewayBucket.Get(request.Options.BucketId); + if (requestBucket.Type == GatewayBucketType.Unbucketed) + return; + + //It's not a global request, so need to remove one from global (per-session) + var globalBucketType = GatewayBucket.Get(GatewayBucketType.Unbucketed); + var options = RequestOptions.CreateOrClone(request.Options); + options.BucketId = globalBucketType.Id; + var globalRequest = new WebSocketRequest(null, null, false, false, options); + var globalBucket = GetOrCreateBucket(options, globalRequest); + await globalBucket.TriggerAsync(id, globalRequest); + } - private RequestBucket GetOrCreateBucket(RequestOptions options, RestRequest request) + private RequestBucket GetOrCreateBucket(RequestOptions options, IRequest request) { var bucketId = options.BucketId; object obj = _buckets.GetOrAdd(bucketId, x => new RequestBucket(this, request, x)); @@ -137,6 +161,12 @@ namespace Discord.Net.Queue return (null, null); } + public void ClearGatewayBuckets() + { + foreach (var gwBucket in (GatewayBucketType[])Enum.GetValues(typeof(GatewayBucketType))) + _buckets.TryRemove(GatewayBucket.Get(gwBucket).Id, out _); + } + private async Task RunCleanup() { try diff --git a/src/Discord.Net.Rest/Net/Queue/RequestQueueBucket.cs b/src/Discord.Net.Rest/Net/Queue/RequestQueueBucket.cs index edd55f158..3fb45e55d 100644 --- a/src/Discord.Net.Rest/Net/Queue/RequestQueueBucket.cs +++ b/src/Discord.Net.Rest/Net/Queue/RequestQueueBucket.cs @@ -25,7 +25,7 @@ namespace Discord.Net.Queue public int WindowCount { get; private set; } public DateTimeOffset LastAttemptAt { get; private set; } - public RequestBucket(RequestQueue queue, RestRequest request, BucketId id) + public RequestBucket(RequestQueue queue, IRequest request, BucketId id) { _queue = queue; Id = id; @@ -33,7 +33,9 @@ namespace Discord.Net.Queue _lock = new object(); if (request.Options.IsClientBucket) - WindowCount = ClientBucket.Get(Id).WindowCount; + WindowCount = ClientBucket.Get(request.Options.BucketId).WindowCount; + else if (request.Options.IsGatewayBucket) + WindowCount = GatewayBucket.Get(request.Options.BucketId).WindowCount; else WindowCount = 1; //Only allow one request until we get a header back _semaphore = WindowCount; @@ -154,8 +156,68 @@ namespace Discord.Net.Queue } } } + public async Task SendAsync(WebSocketRequest request) + { + int id = Interlocked.Increment(ref nextId); +#if DEBUG_LIMITS + Debug.WriteLine($"[{id}] Start"); +#endif + LastAttemptAt = DateTimeOffset.UtcNow; + while (true) + { + await _queue.EnterGlobalAsync(id, request).ConfigureAwait(false); + await EnterAsync(id, request).ConfigureAwait(false); - private async Task EnterAsync(int id, RestRequest request) +#if DEBUG_LIMITS + Debug.WriteLine($"[{id}] Sending..."); +#endif + try + { + await request.SendAsync().ConfigureAwait(false); + return; + } + catch (TimeoutException) + { +#if DEBUG_LIMITS + Debug.WriteLine($"[{id}] Timeout"); +#endif + if ((request.Options.RetryMode & RetryMode.RetryTimeouts) == 0) + throw; + + await Task.Delay(500).ConfigureAwait(false); + continue; //Retry + } + /*catch (Exception) + { +#if DEBUG_LIMITS + Debug.WriteLine($"[{id}] Error"); +#endif + if ((request.Options.RetryMode & RetryMode.RetryErrors) == 0) + throw; + + await Task.Delay(500); + continue; //Retry + }*/ + finally + { + UpdateRateLimit(id, request, default(RateLimitInfo), false); +#if DEBUG_LIMITS + Debug.WriteLine($"[{id}] Stop"); +#endif + } + } + } + + internal async Task TriggerAsync(int id, IRequest request) + { +#if DEBUG_LIMITS + Debug.WriteLine($"[{id}] Trigger Bucket"); +#endif + await EnterAsync(id, request).ConfigureAwait(false); + UpdateRateLimit(id, request, default(RateLimitInfo), false); + } + + private async Task EnterAsync(int id, IRequest request) { int windowCount; DateTimeOffset? resetAt; @@ -186,8 +248,31 @@ namespace Discord.Net.Queue { if (!isRateLimited) { + bool ignoreRatelimit = false; isRateLimited = true; - await _queue.RaiseRateLimitTriggered(Id, null, $"{request.Method} {request.Endpoint}").ConfigureAwait(false); + switch (request) + { + case RestRequest restRequest: + await _queue.RaiseRateLimitTriggered(Id, null, $"{restRequest.Method} {restRequest.Endpoint}").ConfigureAwait(false); + break; + case WebSocketRequest webSocketRequest: + if (webSocketRequest.IgnoreLimit) + { + ignoreRatelimit = true; + break; + } + await _queue.RaiseRateLimitTriggered(Id, null, Id.Endpoint).ConfigureAwait(false); + break; + default: + throw new InvalidOperationException("Unknown request type"); + } + if (ignoreRatelimit) + { +#if DEBUG_LIMITS + Debug.WriteLine($"[{id}] Ignoring ratelimit"); +#endif + break; + } } ThrowRetryLimit(request); @@ -223,7 +308,7 @@ namespace Discord.Net.Queue } } - private void UpdateRateLimit(int id, RestRequest request, RateLimitInfo info, bool is429, bool redirected = false) + private void UpdateRateLimit(int id, IRequest request, RateLimitInfo info, bool is429, bool redirected = false) { if (WindowCount == 0) return; @@ -316,6 +401,23 @@ namespace Discord.Net.Queue Debug.WriteLine($"[{id}] Client Bucket ({ClientBucket.Get(Id).WindowSeconds * 1000} ms)"); #endif } + else if (request.Options.IsGatewayBucket && request.Options.BucketId != null) + { + resetTick = DateTimeOffset.UtcNow.AddSeconds(GatewayBucket.Get(request.Options.BucketId).WindowSeconds); +#if DEBUG_LIMITS + Debug.WriteLine($"[{id}] Gateway Bucket ({GatewayBucket.Get(request.Options.BucketId).WindowSeconds * 1000} ms)"); +#endif + if (!hasQueuedReset) + { + _resetTick = resetTick; + LastAttemptAt = resetTick.Value; +#if DEBUG_LIMITS + Debug.WriteLine($"[{id}] Reset in {(int)Math.Ceiling((resetTick - DateTimeOffset.UtcNow).Value.TotalMilliseconds)} ms"); +#endif + var _ = QueueReset(id, (int)Math.Ceiling((_resetTick.Value - DateTimeOffset.UtcNow).TotalMilliseconds), request); + } + return; + } if (resetTick == null) { @@ -336,12 +438,12 @@ namespace Discord.Net.Queue if (!hasQueuedReset) { - var _ = QueueReset(id, (int)Math.Ceiling((_resetTick.Value - DateTimeOffset.UtcNow).TotalMilliseconds)); + var _ = QueueReset(id, (int)Math.Ceiling((_resetTick.Value - DateTimeOffset.UtcNow).TotalMilliseconds), request); } } } } - private async Task QueueReset(int id, int millis) + private async Task QueueReset(int id, int millis, IRequest request) { while (true) { @@ -363,7 +465,7 @@ namespace Discord.Net.Queue } } - private void ThrowRetryLimit(RestRequest request) + private void ThrowRetryLimit(IRequest request) { if ((request.Options.RetryMode & RetryMode.RetryRatelimit) == 0) throw new RateLimitedException(request); diff --git a/src/Discord.Net.Rest/Net/Queue/Requests/WebSocketRequest.cs b/src/Discord.Net.Rest/Net/Queue/Requests/WebSocketRequest.cs index 81eb40b31..ebebd7bef 100644 --- a/src/Discord.Net.Rest/Net/Queue/Requests/WebSocketRequest.cs +++ b/src/Discord.Net.Rest/Net/Queue/Requests/WebSocketRequest.cs @@ -9,22 +9,22 @@ namespace Discord.Net.Queue public class WebSocketRequest : IRequest { public IWebSocketClient Client { get; } - public string BucketId { get; } public byte[] Data { get; } public bool IsText { get; } + public bool IgnoreLimit { get; } public DateTimeOffset? TimeoutAt { get; } public TaskCompletionSource Promise { get; } public RequestOptions Options { get; } public CancellationToken CancelToken { get; internal set; } - public WebSocketRequest(IWebSocketClient client, string bucketId, byte[] data, bool isText, RequestOptions options) + public WebSocketRequest(IWebSocketClient client, byte[] data, bool isText, bool ignoreLimit, RequestOptions options) { Preconditions.NotNull(options, nameof(options)); Client = client; - BucketId = bucketId; Data = data; IsText = isText; + IgnoreLimit = ignoreLimit; Options = options; TimeoutAt = options.Timeout.HasValue ? DateTimeOffset.UtcNow.AddMilliseconds(options.Timeout.Value) : (DateTimeOffset?)null; Promise = new TaskCompletionSource(); diff --git a/src/Discord.Net.WebSocket/DiscordShardedClient.cs b/src/Discord.Net.WebSocket/DiscordShardedClient.cs index a8780a7b0..a2c89d4e5 100644 --- a/src/Discord.Net.WebSocket/DiscordShardedClient.cs +++ b/src/Discord.Net.WebSocket/DiscordShardedClient.cs @@ -12,12 +12,14 @@ namespace Discord.WebSocket public partial class DiscordShardedClient : BaseSocketClient, IDiscordClient { private readonly DiscordSocketConfig _baseConfig; - private readonly SemaphoreSlim _connectionGroupLock; private readonly Dictionary _shardIdsToIndex; private readonly bool _automaticShards; private int[] _shardIds; private DiscordSocketClient[] _shards; private int _totalShards; + private SemaphoreSlim[] _identifySemaphores; + private object _semaphoreResetLock; + private Task _semaphoreResetTask; private bool _isDisposed; @@ -62,10 +64,10 @@ namespace Discord.WebSocket if (ids != null && config.TotalShards == null) throw new ArgumentException($"Custom ids are not supported when {nameof(config.TotalShards)} is not specified."); + _semaphoreResetLock = new object(); _shardIdsToIndex = new Dictionary(); config.DisplayInitialLog = false; _baseConfig = config; - _connectionGroupLock = new SemaphoreSlim(1, 1); if (config.TotalShards == null) _automaticShards = true; @@ -74,12 +76,15 @@ namespace Discord.WebSocket _totalShards = config.TotalShards.Value; _shardIds = ids ?? Enumerable.Range(0, _totalShards).ToArray(); _shards = new DiscordSocketClient[_shardIds.Length]; + _identifySemaphores = new SemaphoreSlim[config.IdentifyMaxConcurrency]; + for (int i = 0; i < config.IdentifyMaxConcurrency; i++) + _identifySemaphores[i] = new SemaphoreSlim(1, 1); for (int i = 0; i < _shardIds.Length; i++) { _shardIdsToIndex.Add(_shardIds[i], i); var newConfig = config.Clone(); newConfig.ShardId = _shardIds[i]; - _shards[i] = new DiscordSocketClient(newConfig, _connectionGroupLock, i != 0 ? _shards[0] : null); + _shards[i] = new DiscordSocketClient(newConfig, this, i != 0 ? _shards[0] : null); RegisterEvents(_shards[i], i == 0); } } @@ -88,21 +93,53 @@ namespace Discord.WebSocket => new API.DiscordSocketApiClient(config.RestClientProvider, config.WebSocketProvider, DiscordRestConfig.UserAgent, rateLimitPrecision: config.RateLimitPrecision); + internal async Task AcquireIdentifyLockAsync(int shardId, CancellationToken token) + { + int semaphoreIdx = shardId % _baseConfig.IdentifyMaxConcurrency; + await _identifySemaphores[semaphoreIdx].WaitAsync(token).ConfigureAwait(false); + } + + internal void ReleaseIdentifyLock() + { + lock (_semaphoreResetLock) + { + if (_semaphoreResetTask == null) + _semaphoreResetTask = ResetSemaphoresAsync(); + } + } + + private async Task ResetSemaphoresAsync() + { + await Task.Delay(5000).ConfigureAwait(false); + lock (_semaphoreResetLock) + { + foreach (var semaphore in _identifySemaphores) + if (semaphore.CurrentCount == 0) + semaphore.Release(); + _semaphoreResetTask = null; + } + } + internal override async Task OnLoginAsync(TokenType tokenType, string token) { if (_automaticShards) { - var shardCount = await GetRecommendedShardCountAsync().ConfigureAwait(false); - _shardIds = Enumerable.Range(0, shardCount).ToArray(); + var botGateway = await GetBotGatewayAsync().ConfigureAwait(false); + _shardIds = Enumerable.Range(0, botGateway.Shards).ToArray(); _totalShards = _shardIds.Length; _shards = new DiscordSocketClient[_shardIds.Length]; + int maxConcurrency = botGateway.SessionStartLimit.MaxConcurrency; + _baseConfig.IdentifyMaxConcurrency = maxConcurrency; + _identifySemaphores = new SemaphoreSlim[maxConcurrency]; + for (int i = 0; i < maxConcurrency; i++) + _identifySemaphores[i] = new SemaphoreSlim(1, 1); for (int i = 0; i < _shardIds.Length; i++) { _shardIdsToIndex.Add(_shardIds[i], i); var newConfig = _baseConfig.Clone(); newConfig.ShardId = _shardIds[i]; newConfig.TotalShards = _totalShards; - _shards[i] = new DiscordSocketClient(newConfig, _connectionGroupLock, i != 0 ? _shards[0] : null); + _shards[i] = new DiscordSocketClient(newConfig, this, i != 0 ? _shards[0] : null); RegisterEvents(_shards[i], i == 0); } } @@ -398,7 +435,6 @@ namespace Discord.WebSocket foreach (var client in _shards) client?.Dispose(); } - _connectionGroupLock?.Dispose(); } _isDisposed = true; diff --git a/src/Discord.Net.WebSocket/DiscordSocketApiClient.cs b/src/Discord.Net.WebSocket/DiscordSocketApiClient.cs index 1b21bd666..07ebc87ec 100644 --- a/src/Discord.Net.WebSocket/DiscordSocketApiClient.cs +++ b/src/Discord.Net.WebSocket/DiscordSocketApiClient.cs @@ -132,6 +132,8 @@ namespace Discord.API if (WebSocketClient == null) throw new NotSupportedException("This client is not configured with WebSocket support."); + RequestQueue.ClearGatewayBuckets(); + //Re-create streams to reset the zlib state _compressed?.Dispose(); _decompressor?.Dispose(); @@ -205,7 +207,11 @@ namespace Discord.API payload = new SocketFrame { Operation = (int)opCode, Payload = payload }; if (payload != null) bytes = Encoding.UTF8.GetBytes(SerializeJson(payload)); - await RequestQueue.SendAsync(new WebSocketRequest(WebSocketClient, null, bytes, true, options)).ConfigureAwait(false); + + options.IsGatewayBucket = true; + if (options.BucketId == null) + options.BucketId = GatewayBucket.Get(GatewayBucketType.Unbucketed).Id; + await RequestQueue.SendAsync(new WebSocketRequest(WebSocketClient, bytes, true, opCode == GatewayOpCode.Heartbeat, options)).ConfigureAwait(false); await _sentGatewayMessageEvent.InvokeAsync(opCode).ConfigureAwait(false); } @@ -225,6 +231,8 @@ namespace Discord.API if (totalShards > 1) msg.ShardingParams = new int[] { shardID, totalShards }; + options.BucketId = GatewayBucket.Get(GatewayBucketType.Identify).Id; + if (gatewayIntents.HasValue) msg.Intents = (int)gatewayIntents.Value; else @@ -258,6 +266,7 @@ namespace Discord.API IsAFK = isAFK, Game = game }; + options.BucketId = GatewayBucket.Get(GatewayBucketType.PresenceUpdate).Id; await SendGatewayAsync(GatewayOpCode.StatusUpdate, args, options: options).ConfigureAwait(false); } public async Task SendRequestMembersAsync(IEnumerable guildIds, RequestOptions options = null) diff --git a/src/Discord.Net.WebSocket/DiscordSocketClient.cs b/src/Discord.Net.WebSocket/DiscordSocketClient.cs index dfdad99fc..d53387afc 100644 --- a/src/Discord.Net.WebSocket/DiscordSocketClient.cs +++ b/src/Discord.Net.WebSocket/DiscordSocketClient.cs @@ -26,7 +26,7 @@ namespace Discord.WebSocket { private readonly ConcurrentQueue _largeGuilds; private readonly JsonSerializer _serializer; - private readonly SemaphoreSlim _connectionGroupLock; + private readonly DiscordShardedClient _shardedClient; private readonly DiscordSocketClient _parentClient; private readonly ConcurrentQueue _heartbeatTimes; private readonly ConnectionManager _connection; @@ -120,9 +120,9 @@ namespace Discord.WebSocket /// The configuration to be used with the client. #pragma warning disable IDISP004 public DiscordSocketClient(DiscordSocketConfig config) : this(config, CreateApiClient(config), null, null) { } - internal DiscordSocketClient(DiscordSocketConfig config, SemaphoreSlim groupLock, DiscordSocketClient parentClient) : this(config, CreateApiClient(config), groupLock, parentClient) { } + internal DiscordSocketClient(DiscordSocketConfig config, DiscordShardedClient shardedClient, DiscordSocketClient parentClient) : this(config, CreateApiClient(config), shardedClient, parentClient) { } #pragma warning restore IDISP004 - private DiscordSocketClient(DiscordSocketConfig config, API.DiscordSocketApiClient client, SemaphoreSlim groupLock, DiscordSocketClient parentClient) + private DiscordSocketClient(DiscordSocketConfig config, API.DiscordSocketApiClient client, DiscordShardedClient shardedClient, DiscordSocketClient parentClient) : base(config, client) { ShardId = config.ShardId ?? 0; @@ -148,7 +148,7 @@ namespace Discord.WebSocket _connection.Disconnected += (ex, recon) => TimedInvokeAsync(_disconnectedEvent, nameof(Disconnected), ex); _nextAudioId = 1; - _connectionGroupLock = groupLock; + _shardedClient = shardedClient; _parentClient = parentClient; _serializer = new JsonSerializer { ContractResolver = new DiscordContractResolver() }; @@ -229,8 +229,12 @@ namespace Discord.WebSocket private async Task OnConnectingAsync() { - if (_connectionGroupLock != null) - await _connectionGroupLock.WaitAsync(_connection.CancelToken).ConfigureAwait(false); + bool locked = false; + if (_shardedClient != null && _sessionId == null) + { + await _shardedClient.AcquireIdentifyLockAsync(ShardId, _connection.CancelToken).ConfigureAwait(false); + locked = true; + } try { await _gatewayLogger.DebugAsync("Connecting ApiClient").ConfigureAwait(false); @@ -255,11 +259,8 @@ namespace Discord.WebSocket } finally { - if (_connectionGroupLock != null) - { - await Task.Delay(5000).ConfigureAwait(false); - _connectionGroupLock.Release(); - } + if (locked) + _shardedClient.ReleaseIdentifyLock(); } } private async Task OnDisconnectingAsync(Exception ex) @@ -519,7 +520,15 @@ namespace Discord.WebSocket _sessionId = null; _lastSeq = 0; - await ApiClient.SendIdentifyAsync(shardID: ShardId, totalShards: TotalShards, guildSubscriptions: _guildSubscriptions, gatewayIntents: _gatewayIntents).ConfigureAwait(false); + await _shardedClient.AcquireIdentifyLockAsync(ShardId, _connection.CancelToken).ConfigureAwait(false); + try + { + await ApiClient.SendIdentifyAsync(shardID: ShardId, totalShards: TotalShards, guildSubscriptions: _guildSubscriptions, gatewayIntents: _gatewayIntents).ConfigureAwait(false); + } + finally + { + _shardedClient.ReleaseIdentifyLock(); + } } break; case GatewayOpCode.Reconnect: diff --git a/src/Discord.Net.WebSocket/DiscordSocketConfig.cs b/src/Discord.Net.WebSocket/DiscordSocketConfig.cs index 0e8fbe73f..6b0c5ebc4 100644 --- a/src/Discord.Net.WebSocket/DiscordSocketConfig.cs +++ b/src/Discord.Net.WebSocket/DiscordSocketConfig.cs @@ -126,6 +126,14 @@ namespace Discord.WebSocket public bool GuildSubscriptions { get; set; } = true; /// + /// Gets or sets the maximum identify concurrency. + /// + /// + /// This information is provided by Discord. + /// It is only used when using a and auto-sharding is disabled. + /// + public int IdentifyMaxConcurrency { get; set; } = 1; + /// Gets or sets the maximum wait time in milliseconds between GUILD_AVAILABLE events before firing READY. /// /// If zero, READY will fire as soon as it is received and all guilds will be unavailable.