diff --git a/src/Discord.Net.Core/Entities/Gateway/BotGateway.cs b/src/Discord.Net.Core/Entities/Gateway/BotGateway.cs new file mode 100644 index 000000000..a6f1b0989 --- /dev/null +++ b/src/Discord.Net.Core/Entities/Gateway/BotGateway.cs @@ -0,0 +1,18 @@ +namespace Discord +{ + public class BotGateway + { + /// + /// The WSS URL that can be used for connecting to the gateway. + /// + public string Url { get; internal set; } + /// + /// The recommended number of shards to use when connecting. + /// + public int Shards { get; internal set; } + /// + /// Information on the current session start limit. + /// + public SessionStartLimit SessionStartLimit { get; internal set; } + } +} diff --git a/src/Discord.Net.Core/Entities/Gateway/SessionStartLimit.cs b/src/Discord.Net.Core/Entities/Gateway/SessionStartLimit.cs new file mode 100644 index 000000000..40c9d6dd2 --- /dev/null +++ b/src/Discord.Net.Core/Entities/Gateway/SessionStartLimit.cs @@ -0,0 +1,22 @@ +namespace Discord +{ + public class SessionStartLimit + { + /// + /// The total number of session starts the current user is allowed. + /// + public int Total { get; internal set; } + /// + /// The remaining number of session starts the current user is allowed. + /// + public int Remaining { get; internal set; } + /// + /// The number of milliseconds after which the limit resets. + /// + public int ResetAfter { get; internal set; } + /// + /// The maximum concurrent identify requests in a time window. + /// + public int MaxConcurrency { get; internal set; } + } +} diff --git a/src/Discord.Net.Core/IDiscordClient.cs b/src/Discord.Net.Core/IDiscordClient.cs index f972cd71d..d7d6d2856 100644 --- a/src/Discord.Net.Core/IDiscordClient.cs +++ b/src/Discord.Net.Core/IDiscordClient.cs @@ -274,5 +274,15 @@ namespace Discord /// that represents the number of shards that should be used with this account. /// Task GetRecommendedShardCountAsync(RequestOptions options = null); + + /// + /// Gets the gateway information related to the bot. + /// + /// The options to be used when sending the request. + /// + /// A task that represents the asynchronous get operation. The task result contains a + /// that represents the gateway information related to the bot. + /// + Task GetBotGatewayAsync(RequestOptions options = null); } } diff --git a/src/Discord.Net.Rest/API/Common/SessionStartLimit.cs b/src/Discord.Net.Rest/API/Common/SessionStartLimit.cs new file mode 100644 index 000000000..29d5ddf85 --- /dev/null +++ b/src/Discord.Net.Rest/API/Common/SessionStartLimit.cs @@ -0,0 +1,16 @@ +using Newtonsoft.Json; + +namespace Discord.API.Rest +{ + internal class SessionStartLimit + { + [JsonProperty("total")] + public int Total { get; set; } + [JsonProperty("remaining")] + public int Remaining { get; set; } + [JsonProperty("reset_after")] + public int ResetAfter { get; set; } + [JsonProperty("max_concurrency")] + public int MaxConcurrency { get; set; } + } +} diff --git a/src/Discord.Net.Rest/API/Rest/GetBotGatewayResponse.cs b/src/Discord.Net.Rest/API/Rest/GetBotGatewayResponse.cs index 111fcf3db..d3285051b 100644 --- a/src/Discord.Net.Rest/API/Rest/GetBotGatewayResponse.cs +++ b/src/Discord.Net.Rest/API/Rest/GetBotGatewayResponse.cs @@ -1,4 +1,4 @@ -#pragma warning disable CS1591 +#pragma warning disable CS1591 using Newtonsoft.Json; namespace Discord.API.Rest @@ -9,5 +9,7 @@ namespace Discord.API.Rest public string Url { get; set; } [JsonProperty("shards")] public int Shards { get; set; } + [JsonProperty("session_start_limit")] + public SessionStartLimit SessionStartLimit { get; set; } } } diff --git a/src/Discord.Net.Rest/BaseDiscordClient.cs b/src/Discord.Net.Rest/BaseDiscordClient.cs index 1837e38c0..399299173 100644 --- a/src/Discord.Net.Rest/BaseDiscordClient.cs +++ b/src/Discord.Net.Rest/BaseDiscordClient.cs @@ -152,6 +152,10 @@ namespace Discord.Rest public Task GetRecommendedShardCountAsync(RequestOptions options = null) => ClientHelper.GetRecommendShardCountAsync(this, options); + /// + public Task GetBotGatewayAsync(RequestOptions options = null) + => ClientHelper.GetBotGatewayAsync(this, options); + //IDiscordClient /// ConnectionState IDiscordClient.ConnectionState => ConnectionState.Disconnected; diff --git a/src/Discord.Net.Rest/ClientHelper.cs b/src/Discord.Net.Rest/ClientHelper.cs index a8f6b58ef..c9f0fc368 100644 --- a/src/Discord.Net.Rest/ClientHelper.cs +++ b/src/Discord.Net.Rest/ClientHelper.cs @@ -176,5 +176,22 @@ namespace Discord.Rest var response = await client.ApiClient.GetBotGatewayAsync(options).ConfigureAwait(false); return response.Shards; } + + public static async Task GetBotGatewayAsync(BaseDiscordClient client, RequestOptions options) + { + var response = await client.ApiClient.GetBotGatewayAsync(options).ConfigureAwait(false); + return new BotGateway + { + Url = response.Url, + Shards = response.Shards, + SessionStartLimit = new SessionStartLimit + { + Total = response.SessionStartLimit.Total, + Remaining = response.SessionStartLimit.Remaining, + ResetAfter = response.SessionStartLimit.ResetAfter, + MaxConcurrency = response.SessionStartLimit.MaxConcurrency + } + }; + } } } diff --git a/src/Discord.Net.Rest/DiscordRestApiClient.cs b/src/Discord.Net.Rest/DiscordRestApiClient.cs index a726ef75d..5ef635870 100644 --- a/src/Discord.Net.Rest/DiscordRestApiClient.cs +++ b/src/Discord.Net.Rest/DiscordRestApiClient.cs @@ -51,7 +51,7 @@ namespace Discord.API internal JsonSerializer Serializer => _serializer; /// Unknown OAuth token type. - public DiscordRestApiClient(RestClientProvider restClientProvider, string userAgent, RetryMode defaultRetryMode = RetryMode.AlwaysRetry, + public DiscordRestApiClient(RestClientProvider restClientProvider, string userAgent, RequestQueue requestQueue, RetryMode defaultRetryMode = RetryMode.AlwaysRetry, JsonSerializer serializer = null, RateLimitPrecision rateLimitPrecision = RateLimitPrecision.Second, bool useSystemClock = true) { _restClientProvider = restClientProvider; @@ -61,7 +61,7 @@ namespace Discord.API RateLimitPrecision = rateLimitPrecision; UseSystemClock = useSystemClock; - RequestQueue = new RequestQueue(); + RequestQueue = requestQueue ?? new RequestQueue(); _stateLock = new SemaphoreSlim(1, 1); SetBaseUrl(DiscordConfig.APIUrl); diff --git a/src/Discord.Net.Rest/Entities/Gateway/GatewayLimits.cs b/src/Discord.Net.Rest/Entities/Gateway/GatewayLimits.cs index 61b9318d6..14819722b 100644 --- a/src/Discord.Net.Rest/Entities/Gateway/GatewayLimits.cs +++ b/src/Discord.Net.Rest/Entities/Gateway/GatewayLimits.cs @@ -7,6 +7,11 @@ namespace Discord.Rest /// public class GatewayLimits { + /// + /// Creates a new with the default values. + /// + public static GatewayLimits Default => new GatewayLimits(); + /// /// Gets or sets the global limits for the gateway rate limiter. /// @@ -15,6 +20,7 @@ namespace Discord.Rest /// and it is per websocket. /// public GatewayLimit Global { get; set; } + /// /// Gets or sets the limits of Identify requests. /// @@ -23,6 +29,7 @@ namespace Discord.Rest /// also per account. /// public GatewayLimit Identify { get; set; } + /// /// Gets or sets the limits of Presence Update requests. /// @@ -31,11 +38,35 @@ namespace Discord.Rest /// and status (online, idle, etc) /// public GatewayLimit PresenceUpdate { get; set; } + + /// + /// Gets or sets the name of the master + /// used by identify. + /// + /// + /// It is used to define what slave + /// is free to run for concurrent identify requests. + /// + public string IdentifyMasterSemaphoreName { get; set; } + /// - /// Gets or sets the name of the used by identify. + /// Gets or sets the name of the slave + /// used by identify. /// + /// + /// If the maximum concurrency is higher than one and you are using the sharded client, + /// it will be dinamilly renamed to fit the necessary needs. + /// public string IdentifySemaphoreName { get; set; } + /// + /// Gets or sets the maximum identify concurrency. + /// + /// + /// This limit is provided by Discord. + /// + public int IdentifyMaxConcurrency { get; set; } + /// /// Initializes a new with the default values. /// @@ -44,10 +75,26 @@ namespace Discord.Rest Global = new GatewayLimit(120, 60); Identify = new GatewayLimit(1, 5); PresenceUpdate = new GatewayLimit(5, 60); + IdentifyMasterSemaphoreName = Guid.NewGuid().ToString(); IdentifySemaphoreName = Guid.NewGuid().ToString(); + IdentifyMaxConcurrency = 1; } - internal static GatewayLimits GetOrCreate(GatewayLimits limits) - => limits ?? new GatewayLimits(); + internal GatewayLimits(GatewayLimits limits) + { + Global = new GatewayLimit(limits.Global.Count, limits.Global.Seconds); + Identify = new GatewayLimit(limits.Identify.Count, limits.Identify.Seconds); + PresenceUpdate = new GatewayLimit(limits.PresenceUpdate.Count, limits.PresenceUpdate.Seconds); + IdentifyMasterSemaphoreName = limits.IdentifyMasterSemaphoreName; + IdentifySemaphoreName = limits.IdentifySemaphoreName; + IdentifyMaxConcurrency = limits.IdentifyMaxConcurrency; + } + + + internal static GatewayLimits GetOrCreate(GatewayLimits? limits) + => limits ?? Default; + + public GatewayLimits Clone() + => new GatewayLimits(this); } } diff --git a/src/Discord.Net.Rest/Net/Queue/GatewayBucket.cs b/src/Discord.Net.Rest/Net/Queue/GatewayBucket.cs index b1f6aae0e..65d652656 100644 --- a/src/Discord.Net.Rest/Net/Queue/GatewayBucket.cs +++ b/src/Discord.Net.Rest/Net/Queue/GatewayBucket.cs @@ -13,7 +13,6 @@ namespace Discord.Net.Queue { private static ImmutableDictionary DefsByType; private static ImmutableDictionary DefsById; - private static string IdentifySemaphoreName; static GatewayBucket() { @@ -22,7 +21,6 @@ namespace Discord.Net.Queue public static GatewayBucket Get(GatewayBucketType type) => DefsByType[type]; public static GatewayBucket Get(string id) => DefsById[id]; - public static string GetIdentifySemaphoreName() => IdentifySemaphoreName; public static void SetLimits(GatewayLimits limits) { @@ -50,8 +48,6 @@ namespace Discord.Net.Queue foreach (var bucket in buckets) builder2.Add(bucket.Id, bucket); DefsById = builder2.ToImmutable(); - - IdentifySemaphoreName = limits.IdentifySemaphoreName; } public GatewayBucketType Type { get; } diff --git a/src/Discord.Net.Rest/Net/Queue/RequestQueue.cs b/src/Discord.Net.Rest/Net/Queue/RequestQueue.cs index 507bce80e..f011c7825 100644 --- a/src/Discord.Net.Rest/Net/Queue/RequestQueue.cs +++ b/src/Discord.Net.Rest/Net/Queue/RequestQueue.cs @@ -23,15 +23,16 @@ namespace Discord.Net.Queue private CancellationTokenSource _requestCancelTokenSource; private CancellationToken _requestCancelToken; //Parent token + Clear token private DateTimeOffset _waitUntil; - private Semaphore _identifySemaphore; + + private readonly Semaphore _masterIdentifySemaphore; + private readonly Semaphore _identifySemaphore; + private readonly int _identifySemaphoreMaxConcurrency; private Task _cleanupTask; public RequestQueue() { _tokenLock = new SemaphoreSlim(1, 1); - int semaphoreCount = GatewayBucket.Get(GatewayBucketType.Identify).WindowCount; - _identifySemaphore = new Semaphore(semaphoreCount, semaphoreCount, GatewayBucket.GetIdentifySemaphoreName()); _clearToken = new CancellationTokenSource(); _cancelTokenSource = new CancellationTokenSource(); @@ -43,6 +44,14 @@ namespace Discord.Net.Queue _cleanupTask = RunCleanup(); } + public RequestQueue(string masterIdentifySemaphoreName, string slaveIdentifySemaphoreName, int slaveIdentifySemaphoreMaxConcurrency) + : this () + { + _masterIdentifySemaphore = new Semaphore(1, 1, masterIdentifySemaphoreName); + _identifySemaphore = new Semaphore(0, GatewayBucket.Get(GatewayBucketType.Identify).WindowCount, slaveIdentifySemaphoreName); + _identifySemaphoreMaxConcurrency = slaveIdentifySemaphoreMaxConcurrency; + } + public async Task SetCancelTokenAsync(CancellationToken cancelToken) { await _tokenLock.WaitAsync().ConfigureAwait(false); @@ -132,8 +141,14 @@ namespace Discord.Net.Queue //Identify is per-account so we won't trigger global until we can actually go for it if (requestBucket.Type == GatewayBucketType.Identify) { - while (!_identifySemaphore.WaitOne(0)) //To not block the thread + if (_masterIdentifySemaphore == null || _identifySemaphore == null) + throw new InvalidOperationException("Not a RequestQueue with WebSocket data."); + + bool master; + while (!(master = _masterIdentifySemaphore.WaitOne(0)) && !_identifySemaphore.WaitOne(0)) //To not block the thread await Task.Delay(100, request.CancelToken); + if (master && _identifySemaphoreMaxConcurrency > 1) + _identifySemaphore.Release(_identifySemaphoreMaxConcurrency - 1); #if DEBUG_LIMITS Debug.WriteLine($"[{id}] Acquired identify ticket"); #endif @@ -149,7 +164,12 @@ namespace Discord.Net.Queue } internal void ReleaseIdentifySemaphore(int id) { - _identifySemaphore.Release(); + if (_masterIdentifySemaphore == null || _identifySemaphore == null) + throw new InvalidOperationException("Not a RequestQueue with WebSocket data."); + + while (_identifySemaphore.WaitOne(0)) //exhaust all tickets before releasing master + { } + _masterIdentifySemaphore.Release(); #if DEBUG_LIMITS Debug.WriteLine($"[{id}] Released identify ticket"); #endif diff --git a/src/Discord.Net.WebSocket/BaseSocketClient.cs b/src/Discord.Net.WebSocket/BaseSocketClient.cs index 548bb75bf..24d574bd4 100644 --- a/src/Discord.Net.WebSocket/BaseSocketClient.cs +++ b/src/Discord.Net.WebSocket/BaseSocketClient.cs @@ -80,7 +80,7 @@ namespace Discord.WebSocket internal BaseSocketClient(DiscordSocketConfig config, DiscordRestApiClient client) : base(config, client) => BaseConfig = config; private static DiscordSocketApiClient CreateApiClient(DiscordSocketConfig config) - => new DiscordSocketApiClient(config.RestClientProvider, config.WebSocketProvider, DiscordRestConfig.UserAgent, + => new DiscordSocketApiClient(config.RestClientProvider, config.WebSocketProvider, DiscordRestConfig.UserAgent, config.GatewayLimits, rateLimitPrecision: config.RateLimitPrecision, useSystemClock: config.UseSystemClock); diff --git a/src/Discord.Net.WebSocket/DiscordShardedClient.cs b/src/Discord.Net.WebSocket/DiscordShardedClient.cs index e5d31e5c3..3e6b2f024 100644 --- a/src/Discord.Net.WebSocket/DiscordShardedClient.cs +++ b/src/Discord.Net.WebSocket/DiscordShardedClient.cs @@ -81,29 +81,35 @@ namespace Discord.WebSocket _shardIdsToIndex.Add(_shardIds[i], i); var newConfig = config.Clone(); newConfig.ShardId = _shardIds[i]; + if (config.GatewayLimits.IdentifyMaxConcurrency != 1) + newConfig.GatewayLimits.IdentifySemaphoreName += $"_{i / config.GatewayLimits.IdentifyMaxConcurrency}"; _shards[i] = new DiscordSocketClient(newConfig, _connectionGroupLock, i != 0 ? _shards[0] : null); RegisterEvents(_shards[i], i == 0); } } } private static API.DiscordSocketApiClient CreateApiClient(DiscordSocketConfig config) - => new API.DiscordSocketApiClient(config.RestClientProvider, config.WebSocketProvider, DiscordRestConfig.UserAgent, + => new API.DiscordSocketApiClient(config.RestClientProvider, config.WebSocketProvider, DiscordRestConfig.UserAgent, config.GatewayLimits, rateLimitPrecision: config.RateLimitPrecision); internal override async Task OnLoginAsync(TokenType tokenType, string token) { if (_automaticShards) { - var shardCount = await GetRecommendedShardCountAsync().ConfigureAwait(false); - _shardIds = Enumerable.Range(0, shardCount).ToArray(); + var botGateway = await GetBotGatewayAsync().ConfigureAwait(false); + _shardIds = Enumerable.Range(0, botGateway.Shards).ToArray(); _totalShards = _shardIds.Length; _shards = new DiscordSocketClient[_shardIds.Length]; + int maxConcurrency = botGateway.SessionStartLimit.MaxConcurrency; + _baseConfig.GatewayLimits.IdentifyMaxConcurrency = maxConcurrency; for (int i = 0; i < _shardIds.Length; i++) { _shardIdsToIndex.Add(_shardIds[i], i); var newConfig = _baseConfig.Clone(); newConfig.ShardId = _shardIds[i]; newConfig.TotalShards = _totalShards; + if (maxConcurrency != 1) + newConfig.GatewayLimits.IdentifySemaphoreName += $"_{i / maxConcurrency}"; _shards[i] = new DiscordSocketClient(newConfig, _connectionGroupLock, i != 0 ? _shards[0] : null); RegisterEvents(_shards[i], i == 0); } diff --git a/src/Discord.Net.WebSocket/DiscordSocketApiClient.cs b/src/Discord.Net.WebSocket/DiscordSocketApiClient.cs index 86c297070..a04ec965e 100644 --- a/src/Discord.Net.WebSocket/DiscordSocketApiClient.cs +++ b/src/Discord.Net.WebSocket/DiscordSocketApiClient.cs @@ -3,6 +3,7 @@ using Discord.API.Gateway; using Discord.Net.Queue; using Discord.Net.Rest; using Discord.Net.WebSockets; +using Discord.Rest; using Discord.WebSocket; using Newtonsoft.Json; using System; @@ -37,11 +38,11 @@ namespace Discord.API public ConnectionState ConnectionState { get; private set; } - public DiscordSocketApiClient(RestClientProvider restClientProvider, WebSocketProvider webSocketProvider, string userAgent, + public DiscordSocketApiClient(RestClientProvider restClientProvider, WebSocketProvider webSocketProvider, string userAgent, GatewayLimits limits, string url = null, RetryMode defaultRetryMode = RetryMode.AlwaysRetry, JsonSerializer serializer = null, RateLimitPrecision rateLimitPrecision = RateLimitPrecision.Second, bool useSystemClock = true) - : base(restClientProvider, userAgent, defaultRetryMode, serializer, rateLimitPrecision, useSystemClock) + : base(restClientProvider, userAgent, new RequestQueue(limits.IdentifyMasterSemaphoreName, limits.IdentifySemaphoreName, limits.IdentifyMaxConcurrency), defaultRetryMode, serializer, rateLimitPrecision, useSystemClock) { _gatewayUrl = url; if (url != null) diff --git a/src/Discord.Net.WebSocket/DiscordSocketClient.cs b/src/Discord.Net.WebSocket/DiscordSocketClient.cs index c3979ebb4..ac9d8da8e 100644 --- a/src/Discord.Net.WebSocket/DiscordSocketClient.cs +++ b/src/Discord.Net.WebSocket/DiscordSocketClient.cs @@ -182,7 +182,7 @@ namespace Discord.WebSocket _largeGuilds = new ConcurrentQueue(); } private static API.DiscordSocketApiClient CreateApiClient(DiscordSocketConfig config) - => new API.DiscordSocketApiClient(config.RestClientProvider, config.WebSocketProvider, DiscordRestConfig.UserAgent, config.GatewayHost, + => new API.DiscordSocketApiClient(config.RestClientProvider, config.WebSocketProvider, DiscordRestConfig.UserAgent, config.GatewayLimits, config.GatewayHost, rateLimitPrecision: config.RateLimitPrecision); /// internal override void Dispose(bool disposing) diff --git a/src/Discord.Net.WebSocket/DiscordSocketConfig.cs b/src/Discord.Net.WebSocket/DiscordSocketConfig.cs index 4df080f91..180a8bf49 100644 --- a/src/Discord.Net.WebSocket/DiscordSocketConfig.cs +++ b/src/Discord.Net.WebSocket/DiscordSocketConfig.cs @@ -133,7 +133,7 @@ namespace Discord.WebSocket /// This property should only be changed for bots that have special limits provided by Discord. /// /// - public GatewayLimits GatewayLimits { get; set; } = new GatewayLimits(); + public GatewayLimits GatewayLimits { get; set; } = GatewayLimits.Default; /// /// Initializes a default configuration. @@ -144,6 +144,11 @@ namespace Discord.WebSocket UdpSocketProvider = DefaultUdpSocketProvider.Instance; } - internal DiscordSocketConfig Clone() => MemberwiseClone() as DiscordSocketConfig; + internal DiscordSocketConfig Clone() + { + var clone = MemberwiseClone() as DiscordSocketConfig; + clone.GatewayLimits = GatewayLimits.Clone(); + return clone; + } } }