diff --git a/src/Discord.Net.Rest/Entities/Gateway/GatewayLimit.cs b/src/Discord.Net.Rest/Entities/Gateway/GatewayLimit.cs new file mode 100644 index 000000000..a687be514 --- /dev/null +++ b/src/Discord.Net.Rest/Entities/Gateway/GatewayLimit.cs @@ -0,0 +1,23 @@ +namespace Discord.Rest +{ + /// + /// Represents the limits for a gateway request. + /// + public struct GatewayLimit + { + /// + /// The maximum amount of this type of request in a time window, that is set by . + /// + public int Count { get; set; } + /// + /// The amount of seconds until the rate limiter resets the remaining requests . + /// + public int Seconds { get; set; } + + internal GatewayLimit(int count, int seconds) + { + Count = count; + Seconds = seconds; + } + } +} diff --git a/src/Discord.Net.Rest/Entities/Gateway/GatewayLimits.cs b/src/Discord.Net.Rest/Entities/Gateway/GatewayLimits.cs new file mode 100644 index 000000000..481831e1c --- /dev/null +++ b/src/Discord.Net.Rest/Entities/Gateway/GatewayLimits.cs @@ -0,0 +1,29 @@ +namespace Discord.Rest +{ + /// + /// Contains the rate limits for the gateway. + /// + public class GatewayLimits + { + /// + /// Gets or sets the global limits for the gateway rate limiter. + /// + /// + /// It includes all the other limits, like Identify. + /// + public GatewayLimit Global { get; set; } + /// + /// Gets or sets the limits of Identify requests. + /// + public GatewayLimit Identify { get; set; } + + public GatewayLimits() + { + Global = new GatewayLimit(120, 60); + Identify = new GatewayLimit(1, 5); + } + + internal static GatewayLimits GetOrCreate(GatewayLimits limits) + => limits ?? new GatewayLimits(); + } +} diff --git a/src/Discord.Net.Rest/Net/Queue/GatewayBucket.cs b/src/Discord.Net.Rest/Net/Queue/GatewayBucket.cs index a1b2b9a7a..2177031df 100644 --- a/src/Discord.Net.Rest/Net/Queue/GatewayBucket.cs +++ b/src/Discord.Net.Rest/Net/Queue/GatewayBucket.cs @@ -1,3 +1,4 @@ +using Discord.Rest; using System.Collections.Immutable; namespace Discord.Net.Queue @@ -9,15 +10,29 @@ namespace Discord.Net.Queue } internal struct GatewayBucket { - private static readonly ImmutableDictionary DefsByType; - private static readonly ImmutableDictionary DefsById; + private static ImmutableDictionary DefsByType; + private static ImmutableDictionary DefsById; static GatewayBucket() { + SetLimits(GatewayLimits.GetOrCreate(null)); + } + + public static GatewayBucket Get(GatewayBucketType type) => DefsByType[type]; + public static GatewayBucket Get(string id) => DefsById[id]; + + public static void SetLimits(GatewayLimits limits) + { + limits = GatewayLimits.GetOrCreate(limits); + Preconditions.GreaterThan(limits.Global.Count, 0, nameof(limits.Global.Count), "Global count must be greater than zero."); + Preconditions.GreaterThan(limits.Global.Seconds, 0, nameof(limits.Global.Seconds), "Global seconds must be greater than zero."); + Preconditions.GreaterThan(limits.Identify.Count, 0, nameof(limits.Identify.Count), "Identify count must be greater than zero."); + Preconditions.GreaterThan(limits.Identify.Seconds, 0, nameof(limits.Identify.Seconds), "Identify seconds must be greater than zero."); + var buckets = new[] { - new GatewayBucket(GatewayBucketType.Unbucketed, "", 120, 60), - new GatewayBucket(GatewayBucketType.Identify, "", 1, 5) + new GatewayBucket(GatewayBucketType.Unbucketed, "", limits.Global.Count, limits.Global.Seconds), + new GatewayBucket(GatewayBucketType.Identify, "", limits.Identify.Count, limits.Identify.Seconds) }; var builder = ImmutableDictionary.CreateBuilder(); @@ -31,13 +46,10 @@ namespace Discord.Net.Queue DefsById = builder2.ToImmutable(); } - public static GatewayBucket Get(GatewayBucketType type) => DefsByType[type]; - public static GatewayBucket Get(string id) => DefsById[id]; - public GatewayBucketType Type { get; } public string Id { get; } - public int WindowCount { get; } - public int WindowSeconds { get; } + public int WindowCount { get; set; } + public int WindowSeconds { get; set; } public GatewayBucket(GatewayBucketType type, string id, int count, int seconds) { diff --git a/src/Discord.Net.Rest/Net/Queue/RequestQueue.cs b/src/Discord.Net.Rest/Net/Queue/RequestQueue.cs index 143ec3e22..be7dd8b38 100644 --- a/src/Discord.Net.Rest/Net/Queue/RequestQueue.cs +++ b/src/Discord.Net.Rest/Net/Queue/RequestQueue.cs @@ -103,7 +103,7 @@ namespace Discord.Net.Queue createdTokenSource?.Dispose(); } - internal async Task EnterGlobalAsync(int id, IRequest request) + internal async Task EnterGlobalAsync(int id, RestRequest request) { int millis = (int)Math.Ceiling((_waitUntil - DateTimeOffset.UtcNow).TotalMilliseconds); if (millis > 0) @@ -118,6 +118,19 @@ namespace Discord.Net.Queue { _waitUntil = DateTimeOffset.UtcNow.AddMilliseconds(info.RetryAfter.Value + (info.Lag?.TotalMilliseconds ?? 0.0)); } + internal async Task EnterGlobalAsync(int id, WebSocketRequest request) + { + var requestBucket = GatewayBucket.Get(request.Options.BucketId); + if (requestBucket.Type == GatewayBucketType.Unbucketed) + return; + + var globalBucketType = GatewayBucket.Get(GatewayBucketType.Unbucketed); + var options = RequestOptions.CreateOrClone(request.Options); + options.BucketId = globalBucketType.Id; + var globalRequest = new WebSocketRequest(null, null, false, options); + var globalBucket = GetOrCreateBucket(globalBucketType.Id, globalRequest); + await globalBucket.TriggerAsync(id, globalRequest); + } private RequestBucket GetOrCreateBucket(string id, IRequest request) { diff --git a/src/Discord.Net.Rest/Net/Queue/RequestQueueBucket.cs b/src/Discord.Net.Rest/Net/Queue/RequestQueueBucket.cs index 758ba6783..ef5b247fd 100644 --- a/src/Discord.Net.Rest/Net/Queue/RequestQueueBucket.cs +++ b/src/Discord.Net.Rest/Net/Queue/RequestQueueBucket.cs @@ -203,6 +203,15 @@ namespace Discord.Net.Queue } } + internal async Task TriggerAsync(int id, IRequest request) + { +#if DEBUG_LIMITS + Debug.WriteLine($"[{id}] Trigger Bucket"); +#endif + await EnterAsync(id, request).ConfigureAwait(false); + UpdateRateLimit(id, request, default(RateLimitInfo), false); + } + private async Task EnterAsync(int id, IRequest request) { int windowCount; diff --git a/src/Discord.Net.WebSocket/DiscordShardedClient.cs b/src/Discord.Net.WebSocket/DiscordShardedClient.cs index 4eb539d0c..edbc15fef 100644 --- a/src/Discord.Net.WebSocket/DiscordShardedClient.cs +++ b/src/Discord.Net.WebSocket/DiscordShardedClient.cs @@ -67,6 +67,7 @@ namespace Discord.WebSocket config.DisplayInitialLog = false; _baseConfig = config; _connectionGroupLock = new SemaphoreSlim(1, 1); + GatewayBucket.SetLimits(GatewayLimits.GetOrCreate(config.GatewayLimits)); if (config.TotalShards == null) _automaticShards = true; diff --git a/src/Discord.Net.WebSocket/DiscordSocketClient.cs b/src/Discord.Net.WebSocket/DiscordSocketClient.cs index 61d15cf19..b046ce03c 100644 --- a/src/Discord.Net.WebSocket/DiscordSocketClient.cs +++ b/src/Discord.Net.WebSocket/DiscordSocketClient.cs @@ -2,6 +2,7 @@ using Discord.API; using Discord.API.Gateway; using Discord.Logging; using Discord.Net.Converters; +using Discord.Net.Queue; using Discord.Net.Udp; using Discord.Net.WebSockets; using Discord.Rest; @@ -120,6 +121,8 @@ namespace Discord.WebSocket #pragma warning disable IDISP004 public DiscordSocketClient(DiscordSocketConfig config) : this(config, CreateApiClient(config), null, null) { + GatewayBucket.SetLimits(GatewayLimits.GetOrCreate(config.GatewayLimits)); + ApiClient.WebSocketRequestQueue.RateLimitTriggered += async (id, info) => { if (info == null) diff --git a/src/Discord.Net.WebSocket/DiscordSocketConfig.cs b/src/Discord.Net.WebSocket/DiscordSocketConfig.cs index 91b597bbf..ed745d7de 100644 --- a/src/Discord.Net.WebSocket/DiscordSocketConfig.cs +++ b/src/Discord.Net.WebSocket/DiscordSocketConfig.cs @@ -125,6 +125,14 @@ namespace Discord.WebSocket /// public bool GuildSubscriptions { get; set; } = true; + /// + /// Gets or sets the gateway limits. + /// + /// It should only be changed for bots that have special limits provided by Discord. + /// + /// + public GatewayLimits GatewayLimits { get; set; } = new GatewayLimits(); + internal RequestQueue _websocketRequestQueue; ///