diff --git a/src/Discord.Net.Rest/Net/Queue/RequestQueue.cs b/src/Discord.Net.Rest/Net/Queue/RequestQueue.cs index ec2212741..7596890ae 100644 --- a/src/Discord.Net.Rest/Net/Queue/RequestQueue.cs +++ b/src/Discord.Net.Rest/Net/Queue/RequestQueue.cs @@ -23,8 +23,8 @@ namespace Discord.Net.Queue private CancellationToken _requestCancelToken; //Parent token + Clear token private DateTimeOffset _waitUntil; - private readonly Semaphore _masterIdentifySemaphore; - private readonly Semaphore _identifySemaphore; + private readonly SemaphoreSlim _masterIdentifySemaphore; + private readonly SemaphoreSlim _identifySemaphore; private readonly int _identifySemaphoreMaxConcurrency; private Task _cleanupTask; @@ -43,11 +43,11 @@ namespace Discord.Net.Queue _cleanupTask = RunCleanup(); } - public RequestQueue(string masterIdentifySemaphoreName, string slaveIdentifySemaphoreName, int slaveIdentifySemaphoreMaxConcurrency) + public RequestQueue(SemaphoreSlim masterIdentifySemaphore, SemaphoreSlim slaveIdentifySemaphore, int slaveIdentifySemaphoreMaxConcurrency) : this () { - _masterIdentifySemaphore = new Semaphore(1, 1, masterIdentifySemaphoreName); - _identifySemaphore = new Semaphore(0, GatewayBucket.Get(GatewayBucketType.Identify).WindowCount, slaveIdentifySemaphoreName); + _masterIdentifySemaphore = masterIdentifySemaphore; + _identifySemaphore = slaveIdentifySemaphore; _identifySemaphoreMaxConcurrency = slaveIdentifySemaphoreMaxConcurrency; } @@ -140,14 +140,19 @@ namespace Discord.Net.Queue //Identify is per-account so we won't trigger global until we can actually go for it if (requestBucket.Type == GatewayBucketType.Identify) { - if (_masterIdentifySemaphore == null || _identifySemaphore == null) + if (_masterIdentifySemaphore == null) throw new InvalidOperationException("Not a RequestQueue with WebSocket data."); - bool master; - while (!(master = _masterIdentifySemaphore.WaitOne(0)) && !_identifySemaphore.WaitOne(0)) //To not block the thread - await Task.Delay(100, request.CancelToken); - if (master && _identifySemaphoreMaxConcurrency > 1) - _identifySemaphore.Release(_identifySemaphoreMaxConcurrency - 1); + if (_identifySemaphore == null) + await _masterIdentifySemaphore.WaitAsync(request.CancelToken); + else + { + bool master; + while (!(master = _masterIdentifySemaphore.Wait(0)) && !_identifySemaphore.Wait(0)) //To not block the thread + await Task.Delay(100, request.CancelToken); + if (master && _identifySemaphoreMaxConcurrency > 1) + _identifySemaphore.Release(_identifySemaphoreMaxConcurrency - 1); + } #if DEBUG_LIMITS Debug.WriteLine($"[{id}] Acquired identify ticket"); #endif @@ -163,10 +168,10 @@ namespace Discord.Net.Queue } internal void ReleaseIdentifySemaphore(int id) { - if (_masterIdentifySemaphore == null || _identifySemaphore == null) + if (_masterIdentifySemaphore == null) throw new InvalidOperationException("Not a RequestQueue with WebSocket data."); - while (_identifySemaphore.WaitOne(0)) //exhaust all tickets before releasing master + while (_identifySemaphore?.Wait(0) == true) //exhaust all tickets before releasing master { } _masterIdentifySemaphore.Release(); #if DEBUG_LIMITS diff --git a/src/Discord.Net.WebSocket/BaseSocketClient.cs b/src/Discord.Net.WebSocket/BaseSocketClient.cs index 652654d73..cceee2557 100644 --- a/src/Discord.Net.WebSocket/BaseSocketClient.cs +++ b/src/Discord.Net.WebSocket/BaseSocketClient.cs @@ -1,5 +1,6 @@ using System.Collections.Generic; using System.IO; +using System.Threading; using System.Threading.Tasks; using Discord.API; using Discord.Rest; @@ -79,8 +80,9 @@ namespace Discord.WebSocket internal BaseSocketClient(DiscordSocketConfig config, DiscordRestApiClient client) : base(config, client) => BaseConfig = config; - private static DiscordSocketApiClient CreateApiClient(DiscordSocketConfig config) - => new DiscordSocketApiClient(config.RestClientProvider, config.WebSocketProvider, DiscordRestConfig.UserAgent, config, + private static DiscordSocketApiClient CreateApiClient(DiscordSocketConfig config, SemaphoreSlim identifyMasterSemaphore, SemaphoreSlim identifySemaphore, int identifyMaxConcurrency) + => new DiscordSocketApiClient(config.RestClientProvider, config.WebSocketProvider, DiscordRestConfig.UserAgent, + identifyMasterSemaphore, identifySemaphore, identifyMaxConcurrency, rateLimitPrecision: config.RateLimitPrecision, useSystemClock: config.UseSystemClock); diff --git a/src/Discord.Net.WebSocket/DiscordShardedClient.cs b/src/Discord.Net.WebSocket/DiscordShardedClient.cs index f759457ca..497572fd6 100644 --- a/src/Discord.Net.WebSocket/DiscordShardedClient.cs +++ b/src/Discord.Net.WebSocket/DiscordShardedClient.cs @@ -74,21 +74,28 @@ namespace Discord.WebSocket _totalShards = config.TotalShards.Value; _shardIds = ids ?? Enumerable.Range(0, _totalShards).ToArray(); _shards = new DiscordSocketClient[_shardIds.Length]; + var masterIdentifySemaphore = new SemaphoreSlim(1, 1); + SemaphoreSlim[] identifySemaphores = null; + if (config.IdentifyMaxConcurrency > 1) + { + int maxSemaphores = (_shardIds.Length + config.IdentifyMaxConcurrency - 1) / config.IdentifyMaxConcurrency; + identifySemaphores = new SemaphoreSlim[maxSemaphores]; + for (int i = 0; i < maxSemaphores; i++) + identifySemaphores[i] = new SemaphoreSlim(0, config.IdentifyMaxConcurrency); + } for (int i = 0; i < _shardIds.Length; i++) { _shardIdsToIndex.Add(_shardIds[i], i); var newConfig = config.Clone(); newConfig.ShardId = _shardIds[i]; - if (config.IdentifyMaxConcurrency != 1) - newConfig.IdentifySemaphoreName += $"_{i / config.IdentifyMaxConcurrency}"; - _shards[i] = new DiscordSocketClient(newConfig, _connectionGroupLock, i != 0 ? _shards[0] : null); + _shards[i] = new DiscordSocketClient(newConfig, _connectionGroupLock, i != 0 ? _shards[0] : null, masterIdentifySemaphore, config.IdentifyMaxConcurrency > 1 ? null : identifySemaphores[i / config.IdentifyMaxConcurrency], config.IdentifyMaxConcurrency); RegisterEvents(_shards[i], i == 0); } } } private static API.DiscordSocketApiClient CreateApiClient(DiscordSocketConfig config) - => new API.DiscordSocketApiClient(config.RestClientProvider, config.WebSocketProvider, DiscordRestConfig.UserAgent, config, - rateLimitPrecision: config.RateLimitPrecision); + => new API.DiscordSocketApiClient(config.RestClientProvider, config.WebSocketProvider, DiscordRestConfig.UserAgent, + null, null, 0, rateLimitPrecision: config.RateLimitPrecision); internal override async Task OnLoginAsync(TokenType tokenType, string token) { @@ -100,15 +107,22 @@ namespace Discord.WebSocket _shards = new DiscordSocketClient[_shardIds.Length]; int maxConcurrency = botGateway.SessionStartLimit.MaxConcurrency; _baseConfig.IdentifyMaxConcurrency = maxConcurrency; + var masterIdentifySemaphore = new SemaphoreSlim(1, 1); + SemaphoreSlim[] identifySemaphores = null; + if (maxConcurrency > 1) + { + int maxSemaphores = (_shardIds.Length + maxConcurrency - 1) / maxConcurrency; + identifySemaphores = new SemaphoreSlim[maxSemaphores]; + for (int i = 0; i < maxSemaphores; i++) + identifySemaphores[i] = new SemaphoreSlim(0, maxConcurrency); + } for (int i = 0; i < _shardIds.Length; i++) { _shardIdsToIndex.Add(_shardIds[i], i); var newConfig = _baseConfig.Clone(); newConfig.ShardId = _shardIds[i]; newConfig.TotalShards = _totalShards; - if (maxConcurrency != 1) - newConfig.IdentifySemaphoreName += $"_{i / maxConcurrency}"; - _shards[i] = new DiscordSocketClient(newConfig, _connectionGroupLock, i != 0 ? _shards[0] : null); + _shards[i] = new DiscordSocketClient(newConfig, _connectionGroupLock, i != 0 ? _shards[0] : null, masterIdentifySemaphore, maxConcurrency > 1 ? null : identifySemaphores[i / maxConcurrency], maxConcurrency); RegisterEvents(_shards[i], i == 0); } } diff --git a/src/Discord.Net.WebSocket/DiscordSocketApiClient.cs b/src/Discord.Net.WebSocket/DiscordSocketApiClient.cs index afe477c7a..01c2d7f2a 100644 --- a/src/Discord.Net.WebSocket/DiscordSocketApiClient.cs +++ b/src/Discord.Net.WebSocket/DiscordSocketApiClient.cs @@ -37,11 +37,12 @@ namespace Discord.API public ConnectionState ConnectionState { get; private set; } - public DiscordSocketApiClient(RestClientProvider restClientProvider, WebSocketProvider webSocketProvider, string userAgent, DiscordSocketConfig config, + public DiscordSocketApiClient(RestClientProvider restClientProvider, WebSocketProvider webSocketProvider, string userAgent, + SemaphoreSlim identifyMasterSemaphore, SemaphoreSlim identifySemaphore, int identifyMaxConcurrency, string url = null, RetryMode defaultRetryMode = RetryMode.AlwaysRetry, JsonSerializer serializer = null, RateLimitPrecision rateLimitPrecision = RateLimitPrecision.Second, bool useSystemClock = true) - : base(restClientProvider, userAgent, new RequestQueue(config.IdentifyMasterSemaphoreName, config.IdentifySemaphoreName, config.IdentifyMaxConcurrency), defaultRetryMode, serializer, rateLimitPrecision, useSystemClock) + : base(restClientProvider, userAgent, new RequestQueue(identifyMasterSemaphore, identifySemaphore, identifyMaxConcurrency), defaultRetryMode, serializer, rateLimitPrecision, useSystemClock) { _gatewayUrl = url; if (url != null) diff --git a/src/Discord.Net.WebSocket/DiscordSocketClient.cs b/src/Discord.Net.WebSocket/DiscordSocketClient.cs index 9c219bcb6..4dce80399 100644 --- a/src/Discord.Net.WebSocket/DiscordSocketClient.cs +++ b/src/Discord.Net.WebSocket/DiscordSocketClient.cs @@ -118,8 +118,8 @@ namespace Discord.WebSocket /// /// The configuration to be used with the client. #pragma warning disable IDISP004 - public DiscordSocketClient(DiscordSocketConfig config) : this(config, CreateApiClient(config), null, null) { } - internal DiscordSocketClient(DiscordSocketConfig config, SemaphoreSlim groupLock, DiscordSocketClient parentClient) : this(config, CreateApiClient(config), groupLock, parentClient) { } + public DiscordSocketClient(DiscordSocketConfig config) : this(config, CreateApiClient(config, new SemaphoreSlim(1, 1), null, 1), null, null) { } + internal DiscordSocketClient(DiscordSocketConfig config, SemaphoreSlim groupLock, DiscordSocketClient parentClient, SemaphoreSlim identifyMasterSemaphore, SemaphoreSlim identifySemaphore, int identifyMaxConcurrency) : this(config, CreateApiClient(config, identifyMasterSemaphore, identifySemaphore, identifyMaxConcurrency), groupLock, parentClient) { } #pragma warning restore IDISP004 private DiscordSocketClient(DiscordSocketConfig config, API.DiscordSocketApiClient client, SemaphoreSlim groupLock, DiscordSocketClient parentClient) : base(config, client) @@ -177,8 +177,9 @@ namespace Discord.WebSocket _voiceRegions = ImmutableDictionary.Create(); _largeGuilds = new ConcurrentQueue(); } - private static API.DiscordSocketApiClient CreateApiClient(DiscordSocketConfig config) - => new API.DiscordSocketApiClient(config.RestClientProvider, config.WebSocketProvider, DiscordRestConfig.UserAgent, config, config.GatewayHost, + private static API.DiscordSocketApiClient CreateApiClient(DiscordSocketConfig config, SemaphoreSlim identifyMasterSemaphore, SemaphoreSlim identifySemaphore, int identifyMaxConcurrency) + => new API.DiscordSocketApiClient(config.RestClientProvider, config.WebSocketProvider, DiscordRestConfig.UserAgent, + identifyMasterSemaphore, identifySemaphore, identifyMaxConcurrency, config.GatewayHost, rateLimitPrecision: config.RateLimitPrecision); /// internal override void Dispose(bool disposing) diff --git a/src/Discord.Net.WebSocket/DiscordSocketConfig.cs b/src/Discord.Net.WebSocket/DiscordSocketConfig.cs index f59761c26..ad8b06066 100644 --- a/src/Discord.Net.WebSocket/DiscordSocketConfig.cs +++ b/src/Discord.Net.WebSocket/DiscordSocketConfig.cs @@ -125,26 +125,6 @@ namespace Discord.WebSocket /// public bool GuildSubscriptions { get; set; } = true; - /// - /// Gets or sets the name of the master - /// used by identify. - /// - /// - /// It is used to define what slave - /// is free to run for concurrent identify requests. - /// - public string IdentifyMasterSemaphoreName { get; set; } = Guid.NewGuid().ToString(); - - /// - /// Gets or sets the name of the slave - /// used by identify. - /// - /// - /// If the maximum concurrency is higher than one and you are using the sharded client, - /// it will be dinamilly renamed to fit the necessary needs. - /// - public string IdentifySemaphoreName { get; set; } = Guid.NewGuid().ToString(); - /// /// Gets or sets the maximum identify concurrency. ///