* Implement gateway ratelimit * Remove unused code * Share WebSocketRequestQueue between clients * Add global limit and a way to change gateway limits * Refactoring variable to fit lib standards * Update xml docs * Update xml docs * Move warning to remarks * Remove specific RequestQueue for WebSocket and other changes The only account limit is for identify that is dealt in a different way (exclusive semaphore), so websocket queues can be shared with REST and don't need to be shared between clients anymore. Also added the ratelimit for presence updates. * Add summary to IdentifySemaphoreName * Fix spacing * Add max_concurrency and other fixes - Add session_start_limit to GetBotGatewayResponse - Add GetBotGatewayAsync to IDiscordClient - Add master/slave semaphores to enable concurrency - Not store semaphore name as static - Clone GatewayLimits when cloning the Config * Add missing RequestQueue parameter and wrong nullable * Add RequeueQueue paramater to Webhook * Better xml documentation * Remove GatewayLimits class and other changes - Remove GatewayLimits - Transfer a few properties to DiscordSocketConfig - Remove unnecessary usings * Remove unnecessary using and wording * Remove more unnecessary usings * Change named Semaphores to SemaphoreSlim * Remove unused using * Update branch * Fix merge conflicts and update to new ratelimit * Fixing merge, ignore limit for heartbeat, and dispose * Missed one place and better xml docs. * Wait identify before opening the connection * Only request identify ticket when needed * Move identify control to sharded client * Better description for IdentifyMaxConcurrency * Add lock to InvalidSessiontags/2.3.0
| @@ -0,0 +1,22 @@ | |||||
| namespace Discord | |||||
| { | |||||
| /// <summary> | |||||
| /// Stores the gateway information related to the current bot. | |||||
| /// </summary> | |||||
| public class BotGateway | |||||
| { | |||||
| /// <summary> | |||||
| /// Gets the WSS URL that can be used for connecting to the gateway. | |||||
| /// </summary> | |||||
| public string Url { get; internal set; } | |||||
| /// <summary> | |||||
| /// Gets the recommended number of shards to use when connecting. | |||||
| /// </summary> | |||||
| public int Shards { get; internal set; } | |||||
| /// <summary> | |||||
| /// Gets the <see cref="SessionStartLimit"/> that contains the information | |||||
| /// about the current session start limit. | |||||
| /// </summary> | |||||
| public SessionStartLimit SessionStartLimit { get; internal set; } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,38 @@ | |||||
| namespace Discord | |||||
| { | |||||
| /// <summary> | |||||
| /// Stores the information related to the gateway identify request. | |||||
| /// </summary> | |||||
| public class SessionStartLimit | |||||
| { | |||||
| /// <summary> | |||||
| /// Gets the total number of session starts the current user is allowed. | |||||
| /// </summary> | |||||
| /// <returns> | |||||
| /// The maximum amount of session starts the current user is allowed. | |||||
| /// </returns> | |||||
| public int Total { get; internal set; } | |||||
| /// <summary> | |||||
| /// Gets the remaining number of session starts the current user is allowed. | |||||
| /// </summary> | |||||
| /// <returns> | |||||
| /// The remaining amount of session starts the current user is allowed. | |||||
| /// </returns> | |||||
| public int Remaining { get; internal set; } | |||||
| /// <summary> | |||||
| /// Gets the number of milliseconds after which the limit resets. | |||||
| /// </summary> | |||||
| /// <returns> | |||||
| /// The milliseconds until the limit resets back to the <see cref="Total"/>. | |||||
| /// </returns> | |||||
| public int ResetAfter { get; internal set; } | |||||
| /// <summary> | |||||
| /// Gets the maximum concurrent identify requests in a time window. | |||||
| /// </summary> | |||||
| /// <returns> | |||||
| /// The maximum concurrent identify requests in a time window, | |||||
| /// limited to the same rate limit key. | |||||
| /// </returns> | |||||
| public int MaxConcurrency { get; internal set; } | |||||
| } | |||||
| } | |||||
| @@ -274,5 +274,15 @@ namespace Discord | |||||
| /// that represents the number of shards that should be used with this account. | /// that represents the number of shards that should be used with this account. | ||||
| /// </returns> | /// </returns> | ||||
| Task<int> GetRecommendedShardCountAsync(RequestOptions options = null); | Task<int> GetRecommendedShardCountAsync(RequestOptions options = null); | ||||
| /// <summary> | |||||
| /// Gets the gateway information related to the bot. | |||||
| /// </summary> | |||||
| /// <param name="options">The options to be used when sending the request.</param> | |||||
| /// <returns> | |||||
| /// A task that represents the asynchronous get operation. The task result contains a <see cref="BotGateway"/> | |||||
| /// that represents the gateway information related to the bot. | |||||
| /// </returns> | |||||
| Task<BotGateway> GetBotGatewayAsync(RequestOptions options = null); | |||||
| } | } | ||||
| } | } | ||||
| @@ -61,6 +61,7 @@ namespace Discord | |||||
| internal BucketId BucketId { get; set; } | internal BucketId BucketId { get; set; } | ||||
| internal bool IsClientBucket { get; set; } | internal bool IsClientBucket { get; set; } | ||||
| internal bool IsReactionBucket { get; set; } | internal bool IsReactionBucket { get; set; } | ||||
| internal bool IsGatewayBucket { get; set; } | |||||
| internal static RequestOptions CreateOrClone(RequestOptions options) | internal static RequestOptions CreateOrClone(RequestOptions options) | ||||
| { | { | ||||
| @@ -0,0 +1,16 @@ | |||||
| using Newtonsoft.Json; | |||||
| namespace Discord.API.Rest | |||||
| { | |||||
| internal class SessionStartLimit | |||||
| { | |||||
| [JsonProperty("total")] | |||||
| public int Total { get; set; } | |||||
| [JsonProperty("remaining")] | |||||
| public int Remaining { get; set; } | |||||
| [JsonProperty("reset_after")] | |||||
| public int ResetAfter { get; set; } | |||||
| [JsonProperty("max_concurrency")] | |||||
| public int MaxConcurrency { get; set; } | |||||
| } | |||||
| } | |||||
| @@ -1,4 +1,4 @@ | |||||
| #pragma warning disable CS1591 | |||||
| #pragma warning disable CS1591 | |||||
| using Newtonsoft.Json; | using Newtonsoft.Json; | ||||
| namespace Discord.API.Rest | namespace Discord.API.Rest | ||||
| @@ -9,5 +9,7 @@ namespace Discord.API.Rest | |||||
| public string Url { get; set; } | public string Url { get; set; } | ||||
| [JsonProperty("shards")] | [JsonProperty("shards")] | ||||
| public int Shards { get; set; } | public int Shards { get; set; } | ||||
| [JsonProperty("session_start_limit")] | |||||
| public SessionStartLimit SessionStartLimit { get; set; } | |||||
| } | } | ||||
| } | } | ||||
| @@ -152,6 +152,10 @@ namespace Discord.Rest | |||||
| public Task<int> GetRecommendedShardCountAsync(RequestOptions options = null) | public Task<int> GetRecommendedShardCountAsync(RequestOptions options = null) | ||||
| => ClientHelper.GetRecommendShardCountAsync(this, options); | => ClientHelper.GetRecommendShardCountAsync(this, options); | ||||
| /// <inheritdoc /> | |||||
| public Task<BotGateway> GetBotGatewayAsync(RequestOptions options = null) | |||||
| => ClientHelper.GetBotGatewayAsync(this, options); | |||||
| //IDiscordClient | //IDiscordClient | ||||
| /// <inheritdoc /> | /// <inheritdoc /> | ||||
| ConnectionState IDiscordClient.ConnectionState => ConnectionState.Disconnected; | ConnectionState IDiscordClient.ConnectionState => ConnectionState.Disconnected; | ||||
| @@ -184,5 +184,22 @@ namespace Discord.Rest | |||||
| var response = await client.ApiClient.GetBotGatewayAsync(options).ConfigureAwait(false); | var response = await client.ApiClient.GetBotGatewayAsync(options).ConfigureAwait(false); | ||||
| return response.Shards; | return response.Shards; | ||||
| } | } | ||||
| public static async Task<BotGateway> GetBotGatewayAsync(BaseDiscordClient client, RequestOptions options) | |||||
| { | |||||
| var response = await client.ApiClient.GetBotGatewayAsync(options).ConfigureAwait(false); | |||||
| return new BotGateway | |||||
| { | |||||
| Url = response.Url, | |||||
| Shards = response.Shards, | |||||
| SessionStartLimit = new SessionStartLimit | |||||
| { | |||||
| Total = response.SessionStartLimit.Total, | |||||
| Remaining = response.SessionStartLimit.Remaining, | |||||
| ResetAfter = response.SessionStartLimit.ResetAfter, | |||||
| MaxConcurrency = response.SessionStartLimit.MaxConcurrency | |||||
| } | |||||
| }; | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -29,10 +29,10 @@ namespace Discord.Rest | |||||
| internal DiscordRestClient(DiscordRestConfig config, API.DiscordRestApiClient api) : base(config, api) { } | internal DiscordRestClient(DiscordRestConfig config, API.DiscordRestApiClient api) : base(config, api) { } | ||||
| private static API.DiscordRestApiClient CreateApiClient(DiscordRestConfig config) | private static API.DiscordRestApiClient CreateApiClient(DiscordRestConfig config) | ||||
| => new API.DiscordRestApiClient(config.RestClientProvider, | |||||
| DiscordRestConfig.UserAgent, | |||||
| rateLimitPrecision: config.RateLimitPrecision, | |||||
| useSystemClock: config.UseSystemClock); | |||||
| => new API.DiscordRestApiClient(config.RestClientProvider, | |||||
| DiscordRestConfig.UserAgent, | |||||
| rateLimitPrecision: config.RateLimitPrecision, | |||||
| useSystemClock: config.UseSystemClock); | |||||
| internal override void Dispose(bool disposing) | internal override void Dispose(bool disposing) | ||||
| { | { | ||||
| @@ -0,0 +1,53 @@ | |||||
| using System.Collections.Immutable; | |||||
| namespace Discord.Net.Queue | |||||
| { | |||||
| public enum GatewayBucketType | |||||
| { | |||||
| Unbucketed = 0, | |||||
| Identify = 1, | |||||
| PresenceUpdate = 2, | |||||
| } | |||||
| internal struct GatewayBucket | |||||
| { | |||||
| private static readonly ImmutableDictionary<GatewayBucketType, GatewayBucket> DefsByType; | |||||
| private static readonly ImmutableDictionary<BucketId, GatewayBucket> DefsById; | |||||
| static GatewayBucket() | |||||
| { | |||||
| var buckets = new[] | |||||
| { | |||||
| // Limit is 120/60s, but 3 will be reserved for heartbeats (2 for possible heartbeats in the same timeframe and a possible failure) | |||||
| new GatewayBucket(GatewayBucketType.Unbucketed, BucketId.Create(null, "<gateway-unbucketed>", null), 117, 60), | |||||
| new GatewayBucket(GatewayBucketType.Identify, BucketId.Create(null, "<gateway-identify>", null), 1, 5), | |||||
| new GatewayBucket(GatewayBucketType.PresenceUpdate, BucketId.Create(null, "<gateway-presenceupdate>", null), 5, 60), | |||||
| }; | |||||
| var builder = ImmutableDictionary.CreateBuilder<GatewayBucketType, GatewayBucket>(); | |||||
| foreach (var bucket in buckets) | |||||
| builder.Add(bucket.Type, bucket); | |||||
| DefsByType = builder.ToImmutable(); | |||||
| var builder2 = ImmutableDictionary.CreateBuilder<BucketId, GatewayBucket>(); | |||||
| foreach (var bucket in buckets) | |||||
| builder2.Add(bucket.Id, bucket); | |||||
| DefsById = builder2.ToImmutable(); | |||||
| } | |||||
| public static GatewayBucket Get(GatewayBucketType type) => DefsByType[type]; | |||||
| public static GatewayBucket Get(BucketId id) => DefsById[id]; | |||||
| public GatewayBucketType Type { get; } | |||||
| public BucketId Id { get; } | |||||
| public int WindowCount { get; set; } | |||||
| public int WindowSeconds { get; set; } | |||||
| public GatewayBucket(GatewayBucketType type, BucketId id, int count, int seconds) | |||||
| { | |||||
| Type = type; | |||||
| Id = id; | |||||
| WindowCount = count; | |||||
| WindowSeconds = seconds; | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -89,9 +89,18 @@ namespace Discord.Net.Queue | |||||
| } | } | ||||
| public async Task SendAsync(WebSocketRequest request) | public async Task SendAsync(WebSocketRequest request) | ||||
| { | { | ||||
| //TODO: Re-impl websocket buckets | |||||
| request.CancelToken = _requestCancelToken; | |||||
| await request.SendAsync().ConfigureAwait(false); | |||||
| CancellationTokenSource createdTokenSource = null; | |||||
| if (request.Options.CancelToken.CanBeCanceled) | |||||
| { | |||||
| createdTokenSource = CancellationTokenSource.CreateLinkedTokenSource(_requestCancelToken, request.Options.CancelToken); | |||||
| request.Options.CancelToken = createdTokenSource.Token; | |||||
| } | |||||
| else | |||||
| request.Options.CancelToken = _requestCancelToken; | |||||
| var bucket = GetOrCreateBucket(request.Options, request); | |||||
| await bucket.SendAsync(request).ConfigureAwait(false); | |||||
| createdTokenSource?.Dispose(); | |||||
| } | } | ||||
| internal async Task EnterGlobalAsync(int id, RestRequest request) | internal async Task EnterGlobalAsync(int id, RestRequest request) | ||||
| @@ -109,8 +118,23 @@ namespace Discord.Net.Queue | |||||
| { | { | ||||
| _waitUntil = DateTimeOffset.UtcNow.AddMilliseconds(info.RetryAfter.Value + (info.Lag?.TotalMilliseconds ?? 0.0)); | _waitUntil = DateTimeOffset.UtcNow.AddMilliseconds(info.RetryAfter.Value + (info.Lag?.TotalMilliseconds ?? 0.0)); | ||||
| } | } | ||||
| internal async Task EnterGlobalAsync(int id, WebSocketRequest request) | |||||
| { | |||||
| //If this is a global request (unbucketed), it'll be dealt in EnterAsync | |||||
| var requestBucket = GatewayBucket.Get(request.Options.BucketId); | |||||
| if (requestBucket.Type == GatewayBucketType.Unbucketed) | |||||
| return; | |||||
| //It's not a global request, so need to remove one from global (per-session) | |||||
| var globalBucketType = GatewayBucket.Get(GatewayBucketType.Unbucketed); | |||||
| var options = RequestOptions.CreateOrClone(request.Options); | |||||
| options.BucketId = globalBucketType.Id; | |||||
| var globalRequest = new WebSocketRequest(null, null, false, false, options); | |||||
| var globalBucket = GetOrCreateBucket(options, globalRequest); | |||||
| await globalBucket.TriggerAsync(id, globalRequest); | |||||
| } | |||||
| private RequestBucket GetOrCreateBucket(RequestOptions options, RestRequest request) | |||||
| private RequestBucket GetOrCreateBucket(RequestOptions options, IRequest request) | |||||
| { | { | ||||
| var bucketId = options.BucketId; | var bucketId = options.BucketId; | ||||
| object obj = _buckets.GetOrAdd(bucketId, x => new RequestBucket(this, request, x)); | object obj = _buckets.GetOrAdd(bucketId, x => new RequestBucket(this, request, x)); | ||||
| @@ -137,6 +161,12 @@ namespace Discord.Net.Queue | |||||
| return (null, null); | return (null, null); | ||||
| } | } | ||||
| public void ClearGatewayBuckets() | |||||
| { | |||||
| foreach (var gwBucket in (GatewayBucketType[])Enum.GetValues(typeof(GatewayBucketType))) | |||||
| _buckets.TryRemove(GatewayBucket.Get(gwBucket).Id, out _); | |||||
| } | |||||
| private async Task RunCleanup() | private async Task RunCleanup() | ||||
| { | { | ||||
| try | try | ||||
| @@ -25,7 +25,7 @@ namespace Discord.Net.Queue | |||||
| public int WindowCount { get; private set; } | public int WindowCount { get; private set; } | ||||
| public DateTimeOffset LastAttemptAt { get; private set; } | public DateTimeOffset LastAttemptAt { get; private set; } | ||||
| public RequestBucket(RequestQueue queue, RestRequest request, BucketId id) | |||||
| public RequestBucket(RequestQueue queue, IRequest request, BucketId id) | |||||
| { | { | ||||
| _queue = queue; | _queue = queue; | ||||
| Id = id; | Id = id; | ||||
| @@ -33,7 +33,9 @@ namespace Discord.Net.Queue | |||||
| _lock = new object(); | _lock = new object(); | ||||
| if (request.Options.IsClientBucket) | if (request.Options.IsClientBucket) | ||||
| WindowCount = ClientBucket.Get(Id).WindowCount; | |||||
| WindowCount = ClientBucket.Get(request.Options.BucketId).WindowCount; | |||||
| else if (request.Options.IsGatewayBucket) | |||||
| WindowCount = GatewayBucket.Get(request.Options.BucketId).WindowCount; | |||||
| else | else | ||||
| WindowCount = 1; //Only allow one request until we get a header back | WindowCount = 1; //Only allow one request until we get a header back | ||||
| _semaphore = WindowCount; | _semaphore = WindowCount; | ||||
| @@ -154,8 +156,68 @@ namespace Discord.Net.Queue | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| public async Task SendAsync(WebSocketRequest request) | |||||
| { | |||||
| int id = Interlocked.Increment(ref nextId); | |||||
| #if DEBUG_LIMITS | |||||
| Debug.WriteLine($"[{id}] Start"); | |||||
| #endif | |||||
| LastAttemptAt = DateTimeOffset.UtcNow; | |||||
| while (true) | |||||
| { | |||||
| await _queue.EnterGlobalAsync(id, request).ConfigureAwait(false); | |||||
| await EnterAsync(id, request).ConfigureAwait(false); | |||||
| private async Task EnterAsync(int id, RestRequest request) | |||||
| #if DEBUG_LIMITS | |||||
| Debug.WriteLine($"[{id}] Sending..."); | |||||
| #endif | |||||
| try | |||||
| { | |||||
| await request.SendAsync().ConfigureAwait(false); | |||||
| return; | |||||
| } | |||||
| catch (TimeoutException) | |||||
| { | |||||
| #if DEBUG_LIMITS | |||||
| Debug.WriteLine($"[{id}] Timeout"); | |||||
| #endif | |||||
| if ((request.Options.RetryMode & RetryMode.RetryTimeouts) == 0) | |||||
| throw; | |||||
| await Task.Delay(500).ConfigureAwait(false); | |||||
| continue; //Retry | |||||
| } | |||||
| /*catch (Exception) | |||||
| { | |||||
| #if DEBUG_LIMITS | |||||
| Debug.WriteLine($"[{id}] Error"); | |||||
| #endif | |||||
| if ((request.Options.RetryMode & RetryMode.RetryErrors) == 0) | |||||
| throw; | |||||
| await Task.Delay(500); | |||||
| continue; //Retry | |||||
| }*/ | |||||
| finally | |||||
| { | |||||
| UpdateRateLimit(id, request, default(RateLimitInfo), false); | |||||
| #if DEBUG_LIMITS | |||||
| Debug.WriteLine($"[{id}] Stop"); | |||||
| #endif | |||||
| } | |||||
| } | |||||
| } | |||||
| internal async Task TriggerAsync(int id, IRequest request) | |||||
| { | |||||
| #if DEBUG_LIMITS | |||||
| Debug.WriteLine($"[{id}] Trigger Bucket"); | |||||
| #endif | |||||
| await EnterAsync(id, request).ConfigureAwait(false); | |||||
| UpdateRateLimit(id, request, default(RateLimitInfo), false); | |||||
| } | |||||
| private async Task EnterAsync(int id, IRequest request) | |||||
| { | { | ||||
| int windowCount; | int windowCount; | ||||
| DateTimeOffset? resetAt; | DateTimeOffset? resetAt; | ||||
| @@ -186,8 +248,31 @@ namespace Discord.Net.Queue | |||||
| { | { | ||||
| if (!isRateLimited) | if (!isRateLimited) | ||||
| { | { | ||||
| bool ignoreRatelimit = false; | |||||
| isRateLimited = true; | isRateLimited = true; | ||||
| await _queue.RaiseRateLimitTriggered(Id, null, $"{request.Method} {request.Endpoint}").ConfigureAwait(false); | |||||
| switch (request) | |||||
| { | |||||
| case RestRequest restRequest: | |||||
| await _queue.RaiseRateLimitTriggered(Id, null, $"{restRequest.Method} {restRequest.Endpoint}").ConfigureAwait(false); | |||||
| break; | |||||
| case WebSocketRequest webSocketRequest: | |||||
| if (webSocketRequest.IgnoreLimit) | |||||
| { | |||||
| ignoreRatelimit = true; | |||||
| break; | |||||
| } | |||||
| await _queue.RaiseRateLimitTriggered(Id, null, Id.Endpoint).ConfigureAwait(false); | |||||
| break; | |||||
| default: | |||||
| throw new InvalidOperationException("Unknown request type"); | |||||
| } | |||||
| if (ignoreRatelimit) | |||||
| { | |||||
| #if DEBUG_LIMITS | |||||
| Debug.WriteLine($"[{id}] Ignoring ratelimit"); | |||||
| #endif | |||||
| break; | |||||
| } | |||||
| } | } | ||||
| ThrowRetryLimit(request); | ThrowRetryLimit(request); | ||||
| @@ -223,7 +308,7 @@ namespace Discord.Net.Queue | |||||
| } | } | ||||
| } | } | ||||
| private void UpdateRateLimit(int id, RestRequest request, RateLimitInfo info, bool is429, bool redirected = false) | |||||
| private void UpdateRateLimit(int id, IRequest request, RateLimitInfo info, bool is429, bool redirected = false) | |||||
| { | { | ||||
| if (WindowCount == 0) | if (WindowCount == 0) | ||||
| return; | return; | ||||
| @@ -316,6 +401,23 @@ namespace Discord.Net.Queue | |||||
| Debug.WriteLine($"[{id}] Client Bucket ({ClientBucket.Get(Id).WindowSeconds * 1000} ms)"); | Debug.WriteLine($"[{id}] Client Bucket ({ClientBucket.Get(Id).WindowSeconds * 1000} ms)"); | ||||
| #endif | #endif | ||||
| } | } | ||||
| else if (request.Options.IsGatewayBucket && request.Options.BucketId != null) | |||||
| { | |||||
| resetTick = DateTimeOffset.UtcNow.AddSeconds(GatewayBucket.Get(request.Options.BucketId).WindowSeconds); | |||||
| #if DEBUG_LIMITS | |||||
| Debug.WriteLine($"[{id}] Gateway Bucket ({GatewayBucket.Get(request.Options.BucketId).WindowSeconds * 1000} ms)"); | |||||
| #endif | |||||
| if (!hasQueuedReset) | |||||
| { | |||||
| _resetTick = resetTick; | |||||
| LastAttemptAt = resetTick.Value; | |||||
| #if DEBUG_LIMITS | |||||
| Debug.WriteLine($"[{id}] Reset in {(int)Math.Ceiling((resetTick - DateTimeOffset.UtcNow).Value.TotalMilliseconds)} ms"); | |||||
| #endif | |||||
| var _ = QueueReset(id, (int)Math.Ceiling((_resetTick.Value - DateTimeOffset.UtcNow).TotalMilliseconds), request); | |||||
| } | |||||
| return; | |||||
| } | |||||
| if (resetTick == null) | if (resetTick == null) | ||||
| { | { | ||||
| @@ -336,12 +438,12 @@ namespace Discord.Net.Queue | |||||
| if (!hasQueuedReset) | if (!hasQueuedReset) | ||||
| { | { | ||||
| var _ = QueueReset(id, (int)Math.Ceiling((_resetTick.Value - DateTimeOffset.UtcNow).TotalMilliseconds)); | |||||
| var _ = QueueReset(id, (int)Math.Ceiling((_resetTick.Value - DateTimeOffset.UtcNow).TotalMilliseconds), request); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| private async Task QueueReset(int id, int millis) | |||||
| private async Task QueueReset(int id, int millis, IRequest request) | |||||
| { | { | ||||
| while (true) | while (true) | ||||
| { | { | ||||
| @@ -363,7 +465,7 @@ namespace Discord.Net.Queue | |||||
| } | } | ||||
| } | } | ||||
| private void ThrowRetryLimit(RestRequest request) | |||||
| private void ThrowRetryLimit(IRequest request) | |||||
| { | { | ||||
| if ((request.Options.RetryMode & RetryMode.RetryRatelimit) == 0) | if ((request.Options.RetryMode & RetryMode.RetryRatelimit) == 0) | ||||
| throw new RateLimitedException(request); | throw new RateLimitedException(request); | ||||
| @@ -9,22 +9,22 @@ namespace Discord.Net.Queue | |||||
| public class WebSocketRequest : IRequest | public class WebSocketRequest : IRequest | ||||
| { | { | ||||
| public IWebSocketClient Client { get; } | public IWebSocketClient Client { get; } | ||||
| public string BucketId { get; } | |||||
| public byte[] Data { get; } | public byte[] Data { get; } | ||||
| public bool IsText { get; } | public bool IsText { get; } | ||||
| public bool IgnoreLimit { get; } | |||||
| public DateTimeOffset? TimeoutAt { get; } | public DateTimeOffset? TimeoutAt { get; } | ||||
| public TaskCompletionSource<Stream> Promise { get; } | public TaskCompletionSource<Stream> Promise { get; } | ||||
| public RequestOptions Options { get; } | public RequestOptions Options { get; } | ||||
| public CancellationToken CancelToken { get; internal set; } | public CancellationToken CancelToken { get; internal set; } | ||||
| public WebSocketRequest(IWebSocketClient client, string bucketId, byte[] data, bool isText, RequestOptions options) | |||||
| public WebSocketRequest(IWebSocketClient client, byte[] data, bool isText, bool ignoreLimit, RequestOptions options) | |||||
| { | { | ||||
| Preconditions.NotNull(options, nameof(options)); | Preconditions.NotNull(options, nameof(options)); | ||||
| Client = client; | Client = client; | ||||
| BucketId = bucketId; | |||||
| Data = data; | Data = data; | ||||
| IsText = isText; | IsText = isText; | ||||
| IgnoreLimit = ignoreLimit; | |||||
| Options = options; | Options = options; | ||||
| TimeoutAt = options.Timeout.HasValue ? DateTimeOffset.UtcNow.AddMilliseconds(options.Timeout.Value) : (DateTimeOffset?)null; | TimeoutAt = options.Timeout.HasValue ? DateTimeOffset.UtcNow.AddMilliseconds(options.Timeout.Value) : (DateTimeOffset?)null; | ||||
| Promise = new TaskCompletionSource<Stream>(); | Promise = new TaskCompletionSource<Stream>(); | ||||
| @@ -12,12 +12,14 @@ namespace Discord.WebSocket | |||||
| public partial class DiscordShardedClient : BaseSocketClient, IDiscordClient | public partial class DiscordShardedClient : BaseSocketClient, IDiscordClient | ||||
| { | { | ||||
| private readonly DiscordSocketConfig _baseConfig; | private readonly DiscordSocketConfig _baseConfig; | ||||
| private readonly SemaphoreSlim _connectionGroupLock; | |||||
| private readonly Dictionary<int, int> _shardIdsToIndex; | private readonly Dictionary<int, int> _shardIdsToIndex; | ||||
| private readonly bool _automaticShards; | private readonly bool _automaticShards; | ||||
| private int[] _shardIds; | private int[] _shardIds; | ||||
| private DiscordSocketClient[] _shards; | private DiscordSocketClient[] _shards; | ||||
| private int _totalShards; | private int _totalShards; | ||||
| private SemaphoreSlim[] _identifySemaphores; | |||||
| private object _semaphoreResetLock; | |||||
| private Task _semaphoreResetTask; | |||||
| private bool _isDisposed; | private bool _isDisposed; | ||||
| @@ -62,10 +64,10 @@ namespace Discord.WebSocket | |||||
| if (ids != null && config.TotalShards == null) | if (ids != null && config.TotalShards == null) | ||||
| throw new ArgumentException($"Custom ids are not supported when {nameof(config.TotalShards)} is not specified."); | throw new ArgumentException($"Custom ids are not supported when {nameof(config.TotalShards)} is not specified."); | ||||
| _semaphoreResetLock = new object(); | |||||
| _shardIdsToIndex = new Dictionary<int, int>(); | _shardIdsToIndex = new Dictionary<int, int>(); | ||||
| config.DisplayInitialLog = false; | config.DisplayInitialLog = false; | ||||
| _baseConfig = config; | _baseConfig = config; | ||||
| _connectionGroupLock = new SemaphoreSlim(1, 1); | |||||
| if (config.TotalShards == null) | if (config.TotalShards == null) | ||||
| _automaticShards = true; | _automaticShards = true; | ||||
| @@ -74,12 +76,15 @@ namespace Discord.WebSocket | |||||
| _totalShards = config.TotalShards.Value; | _totalShards = config.TotalShards.Value; | ||||
| _shardIds = ids ?? Enumerable.Range(0, _totalShards).ToArray(); | _shardIds = ids ?? Enumerable.Range(0, _totalShards).ToArray(); | ||||
| _shards = new DiscordSocketClient[_shardIds.Length]; | _shards = new DiscordSocketClient[_shardIds.Length]; | ||||
| _identifySemaphores = new SemaphoreSlim[config.IdentifyMaxConcurrency]; | |||||
| for (int i = 0; i < config.IdentifyMaxConcurrency; i++) | |||||
| _identifySemaphores[i] = new SemaphoreSlim(1, 1); | |||||
| for (int i = 0; i < _shardIds.Length; i++) | for (int i = 0; i < _shardIds.Length; i++) | ||||
| { | { | ||||
| _shardIdsToIndex.Add(_shardIds[i], i); | _shardIdsToIndex.Add(_shardIds[i], i); | ||||
| var newConfig = config.Clone(); | var newConfig = config.Clone(); | ||||
| newConfig.ShardId = _shardIds[i]; | newConfig.ShardId = _shardIds[i]; | ||||
| _shards[i] = new DiscordSocketClient(newConfig, _connectionGroupLock, i != 0 ? _shards[0] : null); | |||||
| _shards[i] = new DiscordSocketClient(newConfig, this, i != 0 ? _shards[0] : null); | |||||
| RegisterEvents(_shards[i], i == 0); | RegisterEvents(_shards[i], i == 0); | ||||
| } | } | ||||
| } | } | ||||
| @@ -88,21 +93,53 @@ namespace Discord.WebSocket | |||||
| => new API.DiscordSocketApiClient(config.RestClientProvider, config.WebSocketProvider, DiscordRestConfig.UserAgent, | => new API.DiscordSocketApiClient(config.RestClientProvider, config.WebSocketProvider, DiscordRestConfig.UserAgent, | ||||
| rateLimitPrecision: config.RateLimitPrecision); | rateLimitPrecision: config.RateLimitPrecision); | ||||
| internal async Task AcquireIdentifyLockAsync(int shardId, CancellationToken token) | |||||
| { | |||||
| int semaphoreIdx = shardId % _baseConfig.IdentifyMaxConcurrency; | |||||
| await _identifySemaphores[semaphoreIdx].WaitAsync(token).ConfigureAwait(false); | |||||
| } | |||||
| internal void ReleaseIdentifyLock() | |||||
| { | |||||
| lock (_semaphoreResetLock) | |||||
| { | |||||
| if (_semaphoreResetTask == null) | |||||
| _semaphoreResetTask = ResetSemaphoresAsync(); | |||||
| } | |||||
| } | |||||
| private async Task ResetSemaphoresAsync() | |||||
| { | |||||
| await Task.Delay(5000).ConfigureAwait(false); | |||||
| lock (_semaphoreResetLock) | |||||
| { | |||||
| foreach (var semaphore in _identifySemaphores) | |||||
| if (semaphore.CurrentCount == 0) | |||||
| semaphore.Release(); | |||||
| _semaphoreResetTask = null; | |||||
| } | |||||
| } | |||||
| internal override async Task OnLoginAsync(TokenType tokenType, string token) | internal override async Task OnLoginAsync(TokenType tokenType, string token) | ||||
| { | { | ||||
| if (_automaticShards) | if (_automaticShards) | ||||
| { | { | ||||
| var shardCount = await GetRecommendedShardCountAsync().ConfigureAwait(false); | |||||
| _shardIds = Enumerable.Range(0, shardCount).ToArray(); | |||||
| var botGateway = await GetBotGatewayAsync().ConfigureAwait(false); | |||||
| _shardIds = Enumerable.Range(0, botGateway.Shards).ToArray(); | |||||
| _totalShards = _shardIds.Length; | _totalShards = _shardIds.Length; | ||||
| _shards = new DiscordSocketClient[_shardIds.Length]; | _shards = new DiscordSocketClient[_shardIds.Length]; | ||||
| int maxConcurrency = botGateway.SessionStartLimit.MaxConcurrency; | |||||
| _baseConfig.IdentifyMaxConcurrency = maxConcurrency; | |||||
| _identifySemaphores = new SemaphoreSlim[maxConcurrency]; | |||||
| for (int i = 0; i < maxConcurrency; i++) | |||||
| _identifySemaphores[i] = new SemaphoreSlim(1, 1); | |||||
| for (int i = 0; i < _shardIds.Length; i++) | for (int i = 0; i < _shardIds.Length; i++) | ||||
| { | { | ||||
| _shardIdsToIndex.Add(_shardIds[i], i); | _shardIdsToIndex.Add(_shardIds[i], i); | ||||
| var newConfig = _baseConfig.Clone(); | var newConfig = _baseConfig.Clone(); | ||||
| newConfig.ShardId = _shardIds[i]; | newConfig.ShardId = _shardIds[i]; | ||||
| newConfig.TotalShards = _totalShards; | newConfig.TotalShards = _totalShards; | ||||
| _shards[i] = new DiscordSocketClient(newConfig, _connectionGroupLock, i != 0 ? _shards[0] : null); | |||||
| _shards[i] = new DiscordSocketClient(newConfig, this, i != 0 ? _shards[0] : null); | |||||
| RegisterEvents(_shards[i], i == 0); | RegisterEvents(_shards[i], i == 0); | ||||
| } | } | ||||
| } | } | ||||
| @@ -398,7 +435,6 @@ namespace Discord.WebSocket | |||||
| foreach (var client in _shards) | foreach (var client in _shards) | ||||
| client?.Dispose(); | client?.Dispose(); | ||||
| } | } | ||||
| _connectionGroupLock?.Dispose(); | |||||
| } | } | ||||
| _isDisposed = true; | _isDisposed = true; | ||||
| @@ -132,6 +132,8 @@ namespace Discord.API | |||||
| if (WebSocketClient == null) | if (WebSocketClient == null) | ||||
| throw new NotSupportedException("This client is not configured with WebSocket support."); | throw new NotSupportedException("This client is not configured with WebSocket support."); | ||||
| RequestQueue.ClearGatewayBuckets(); | |||||
| //Re-create streams to reset the zlib state | //Re-create streams to reset the zlib state | ||||
| _compressed?.Dispose(); | _compressed?.Dispose(); | ||||
| _decompressor?.Dispose(); | _decompressor?.Dispose(); | ||||
| @@ -205,7 +207,11 @@ namespace Discord.API | |||||
| payload = new SocketFrame { Operation = (int)opCode, Payload = payload }; | payload = new SocketFrame { Operation = (int)opCode, Payload = payload }; | ||||
| if (payload != null) | if (payload != null) | ||||
| bytes = Encoding.UTF8.GetBytes(SerializeJson(payload)); | bytes = Encoding.UTF8.GetBytes(SerializeJson(payload)); | ||||
| await RequestQueue.SendAsync(new WebSocketRequest(WebSocketClient, null, bytes, true, options)).ConfigureAwait(false); | |||||
| options.IsGatewayBucket = true; | |||||
| if (options.BucketId == null) | |||||
| options.BucketId = GatewayBucket.Get(GatewayBucketType.Unbucketed).Id; | |||||
| await RequestQueue.SendAsync(new WebSocketRequest(WebSocketClient, bytes, true, opCode == GatewayOpCode.Heartbeat, options)).ConfigureAwait(false); | |||||
| await _sentGatewayMessageEvent.InvokeAsync(opCode).ConfigureAwait(false); | await _sentGatewayMessageEvent.InvokeAsync(opCode).ConfigureAwait(false); | ||||
| } | } | ||||
| @@ -225,6 +231,8 @@ namespace Discord.API | |||||
| if (totalShards > 1) | if (totalShards > 1) | ||||
| msg.ShardingParams = new int[] { shardID, totalShards }; | msg.ShardingParams = new int[] { shardID, totalShards }; | ||||
| options.BucketId = GatewayBucket.Get(GatewayBucketType.Identify).Id; | |||||
| if (gatewayIntents.HasValue) | if (gatewayIntents.HasValue) | ||||
| msg.Intents = (int)gatewayIntents.Value; | msg.Intents = (int)gatewayIntents.Value; | ||||
| else | else | ||||
| @@ -258,6 +266,7 @@ namespace Discord.API | |||||
| IsAFK = isAFK, | IsAFK = isAFK, | ||||
| Game = game | Game = game | ||||
| }; | }; | ||||
| options.BucketId = GatewayBucket.Get(GatewayBucketType.PresenceUpdate).Id; | |||||
| await SendGatewayAsync(GatewayOpCode.StatusUpdate, args, options: options).ConfigureAwait(false); | await SendGatewayAsync(GatewayOpCode.StatusUpdate, args, options: options).ConfigureAwait(false); | ||||
| } | } | ||||
| public async Task SendRequestMembersAsync(IEnumerable<ulong> guildIds, RequestOptions options = null) | public async Task SendRequestMembersAsync(IEnumerable<ulong> guildIds, RequestOptions options = null) | ||||
| @@ -26,7 +26,7 @@ namespace Discord.WebSocket | |||||
| { | { | ||||
| private readonly ConcurrentQueue<ulong> _largeGuilds; | private readonly ConcurrentQueue<ulong> _largeGuilds; | ||||
| private readonly JsonSerializer _serializer; | private readonly JsonSerializer _serializer; | ||||
| private readonly SemaphoreSlim _connectionGroupLock; | |||||
| private readonly DiscordShardedClient _shardedClient; | |||||
| private readonly DiscordSocketClient _parentClient; | private readonly DiscordSocketClient _parentClient; | ||||
| private readonly ConcurrentQueue<long> _heartbeatTimes; | private readonly ConcurrentQueue<long> _heartbeatTimes; | ||||
| private readonly ConnectionManager _connection; | private readonly ConnectionManager _connection; | ||||
| @@ -120,9 +120,9 @@ namespace Discord.WebSocket | |||||
| /// <param name="config">The configuration to be used with the client.</param> | /// <param name="config">The configuration to be used with the client.</param> | ||||
| #pragma warning disable IDISP004 | #pragma warning disable IDISP004 | ||||
| public DiscordSocketClient(DiscordSocketConfig config) : this(config, CreateApiClient(config), null, null) { } | public DiscordSocketClient(DiscordSocketConfig config) : this(config, CreateApiClient(config), null, null) { } | ||||
| internal DiscordSocketClient(DiscordSocketConfig config, SemaphoreSlim groupLock, DiscordSocketClient parentClient) : this(config, CreateApiClient(config), groupLock, parentClient) { } | |||||
| internal DiscordSocketClient(DiscordSocketConfig config, DiscordShardedClient shardedClient, DiscordSocketClient parentClient) : this(config, CreateApiClient(config), shardedClient, parentClient) { } | |||||
| #pragma warning restore IDISP004 | #pragma warning restore IDISP004 | ||||
| private DiscordSocketClient(DiscordSocketConfig config, API.DiscordSocketApiClient client, SemaphoreSlim groupLock, DiscordSocketClient parentClient) | |||||
| private DiscordSocketClient(DiscordSocketConfig config, API.DiscordSocketApiClient client, DiscordShardedClient shardedClient, DiscordSocketClient parentClient) | |||||
| : base(config, client) | : base(config, client) | ||||
| { | { | ||||
| ShardId = config.ShardId ?? 0; | ShardId = config.ShardId ?? 0; | ||||
| @@ -148,7 +148,7 @@ namespace Discord.WebSocket | |||||
| _connection.Disconnected += (ex, recon) => TimedInvokeAsync(_disconnectedEvent, nameof(Disconnected), ex); | _connection.Disconnected += (ex, recon) => TimedInvokeAsync(_disconnectedEvent, nameof(Disconnected), ex); | ||||
| _nextAudioId = 1; | _nextAudioId = 1; | ||||
| _connectionGroupLock = groupLock; | |||||
| _shardedClient = shardedClient; | |||||
| _parentClient = parentClient; | _parentClient = parentClient; | ||||
| _serializer = new JsonSerializer { ContractResolver = new DiscordContractResolver() }; | _serializer = new JsonSerializer { ContractResolver = new DiscordContractResolver() }; | ||||
| @@ -229,8 +229,12 @@ namespace Discord.WebSocket | |||||
| private async Task OnConnectingAsync() | private async Task OnConnectingAsync() | ||||
| { | { | ||||
| if (_connectionGroupLock != null) | |||||
| await _connectionGroupLock.WaitAsync(_connection.CancelToken).ConfigureAwait(false); | |||||
| bool locked = false; | |||||
| if (_shardedClient != null && _sessionId == null) | |||||
| { | |||||
| await _shardedClient.AcquireIdentifyLockAsync(ShardId, _connection.CancelToken).ConfigureAwait(false); | |||||
| locked = true; | |||||
| } | |||||
| try | try | ||||
| { | { | ||||
| await _gatewayLogger.DebugAsync("Connecting ApiClient").ConfigureAwait(false); | await _gatewayLogger.DebugAsync("Connecting ApiClient").ConfigureAwait(false); | ||||
| @@ -255,11 +259,8 @@ namespace Discord.WebSocket | |||||
| } | } | ||||
| finally | finally | ||||
| { | { | ||||
| if (_connectionGroupLock != null) | |||||
| { | |||||
| await Task.Delay(5000).ConfigureAwait(false); | |||||
| _connectionGroupLock.Release(); | |||||
| } | |||||
| if (locked) | |||||
| _shardedClient.ReleaseIdentifyLock(); | |||||
| } | } | ||||
| } | } | ||||
| private async Task OnDisconnectingAsync(Exception ex) | private async Task OnDisconnectingAsync(Exception ex) | ||||
| @@ -519,7 +520,15 @@ namespace Discord.WebSocket | |||||
| _sessionId = null; | _sessionId = null; | ||||
| _lastSeq = 0; | _lastSeq = 0; | ||||
| await ApiClient.SendIdentifyAsync(shardID: ShardId, totalShards: TotalShards, guildSubscriptions: _guildSubscriptions, gatewayIntents: _gatewayIntents).ConfigureAwait(false); | |||||
| await _shardedClient.AcquireIdentifyLockAsync(ShardId, _connection.CancelToken).ConfigureAwait(false); | |||||
| try | |||||
| { | |||||
| await ApiClient.SendIdentifyAsync(shardID: ShardId, totalShards: TotalShards, guildSubscriptions: _guildSubscriptions, gatewayIntents: _gatewayIntents).ConfigureAwait(false); | |||||
| } | |||||
| finally | |||||
| { | |||||
| _shardedClient.ReleaseIdentifyLock(); | |||||
| } | |||||
| } | } | ||||
| break; | break; | ||||
| case GatewayOpCode.Reconnect: | case GatewayOpCode.Reconnect: | ||||
| @@ -126,6 +126,14 @@ namespace Discord.WebSocket | |||||
| public bool GuildSubscriptions { get; set; } = true; | public bool GuildSubscriptions { get; set; } = true; | ||||
| /// <summary> | /// <summary> | ||||
| /// Gets or sets the maximum identify concurrency. | |||||
| /// </summary> | |||||
| /// <remarks> | |||||
| /// This information is provided by Discord. | |||||
| /// It is only used when using a <see cref="DiscordShardedClient"/> and auto-sharding is disabled. | |||||
| /// </remarks> | |||||
| public int IdentifyMaxConcurrency { get; set; } = 1; | |||||
| /// Gets or sets the maximum wait time in milliseconds between GUILD_AVAILABLE events before firing READY. | /// Gets or sets the maximum wait time in milliseconds between GUILD_AVAILABLE events before firing READY. | ||||
| /// | /// | ||||
| /// If zero, READY will fire as soon as it is received and all guilds will be unavailable. | /// If zero, READY will fire as soon as it is received and all guilds will be unavailable. | ||||