From 5b65b99eb171b70efe11c941bf355d8b54ee5c51 Mon Sep 17 00:00:00 2001 From: Paulo Date: Tue, 17 Nov 2020 22:35:23 -0300 Subject: [PATCH] Fixing merge, ignore limit for heartbeat, and dispose --- .../Net/Queue/GatewayBucket.cs | 4 +- .../Net/Queue/RequestQueue.cs | 8 ++++ .../Net/Queue/RequestQueueBucket.cs | 25 +++++++++- .../Net/Queue/Requests/WebSocketRequest.cs | 4 +- .../DiscordShardedClient.cs | 7 +-- .../DiscordSocketApiClient.cs | 4 +- .../DiscordSocketClient.cs | 48 +++++++------------ 7 files changed, 60 insertions(+), 40 deletions(-) diff --git a/src/Discord.Net.Rest/Net/Queue/GatewayBucket.cs b/src/Discord.Net.Rest/Net/Queue/GatewayBucket.cs index e7d41474c..aa849018a 100644 --- a/src/Discord.Net.Rest/Net/Queue/GatewayBucket.cs +++ b/src/Discord.Net.Rest/Net/Queue/GatewayBucket.cs @@ -1,4 +1,3 @@ -using Discord.Rest; using System.Collections.Immutable; namespace Discord.Net.Queue @@ -18,7 +17,8 @@ namespace Discord.Net.Queue { var buckets = new[] { - new GatewayBucket(GatewayBucketType.Unbucketed, BucketId.Create(null, "", null), 120, 60), + // Limit is 120/60s, but 3 will be reserved for heartbeats (2 for possible heartbeats in the same timeframe and a possible failure) + new GatewayBucket(GatewayBucketType.Unbucketed, BucketId.Create(null, "", null), 117, 60), new GatewayBucket(GatewayBucketType.Identify, BucketId.Create(null, "", null), 1, 5), new GatewayBucket(GatewayBucketType.PresenceUpdate, BucketId.Create(null, "", null), 5, 60), }; diff --git a/src/Discord.Net.Rest/Net/Queue/RequestQueue.cs b/src/Discord.Net.Rest/Net/Queue/RequestQueue.cs index 541d0bd90..0ecbfc547 100644 --- a/src/Discord.Net.Rest/Net/Queue/RequestQueue.cs +++ b/src/Discord.Net.Rest/Net/Queue/RequestQueue.cs @@ -206,6 +206,12 @@ namespace Discord.Net.Queue return (null, null); } + public void ClearGatewayBuckets() + { + foreach (var gwBucket in (GatewayBucketType[])Enum.GetValues(typeof(GatewayBucketType))) + _buckets.TryRemove(GatewayBucket.Get(gwBucket).Id, out _); + } + private async Task RunCleanup() { try @@ -236,6 +242,8 @@ namespace Discord.Net.Queue _tokenLock?.Dispose(); _clearToken?.Dispose(); _requestCancelTokenSource?.Dispose(); + _masterIdentifySemaphore?.Dispose(); + _identifySemaphore?.Dispose(); } } } diff --git a/src/Discord.Net.Rest/Net/Queue/RequestQueueBucket.cs b/src/Discord.Net.Rest/Net/Queue/RequestQueueBucket.cs index 008000668..ece18b819 100644 --- a/src/Discord.Net.Rest/Net/Queue/RequestQueueBucket.cs +++ b/src/Discord.Net.Rest/Net/Queue/RequestQueueBucket.cs @@ -248,8 +248,31 @@ namespace Discord.Net.Queue { if (!isRateLimited) { + bool ignoreRatelimit = false; isRateLimited = true; - await _queue.RaiseRateLimitTriggered(Id, null, $"{request.Method} {request.Endpoint}").ConfigureAwait(false); + switch (request) + { + case RestRequest restRequest: + await _queue.RaiseRateLimitTriggered(Id, null, $"{restRequest.Method} {restRequest.Endpoint}").ConfigureAwait(false); + break; + case WebSocketRequest webSocketRequest: + if (webSocketRequest.IgnoreLimit) + { + ignoreRatelimit = true; + break; + } + await _queue.RaiseRateLimitTriggered(Id, null, Id.Endpoint).ConfigureAwait(false); + break; + default: + throw new InvalidOperationException("Unknown request type"); + } + if (ignoreRatelimit) + { +#if DEBUG_LIMITS + Debug.WriteLine($"[{id}] Ignoring ratelimit"); +#endif + break; + } } ThrowRetryLimit(request); diff --git a/src/Discord.Net.Rest/Net/Queue/Requests/WebSocketRequest.cs b/src/Discord.Net.Rest/Net/Queue/Requests/WebSocketRequest.cs index a153f7c20..ebebd7bef 100644 --- a/src/Discord.Net.Rest/Net/Queue/Requests/WebSocketRequest.cs +++ b/src/Discord.Net.Rest/Net/Queue/Requests/WebSocketRequest.cs @@ -11,18 +11,20 @@ namespace Discord.Net.Queue public IWebSocketClient Client { get; } public byte[] Data { get; } public bool IsText { get; } + public bool IgnoreLimit { get; } public DateTimeOffset? TimeoutAt { get; } public TaskCompletionSource Promise { get; } public RequestOptions Options { get; } public CancellationToken CancelToken { get; internal set; } - public WebSocketRequest(IWebSocketClient client, byte[] data, bool isText, RequestOptions options) + public WebSocketRequest(IWebSocketClient client, byte[] data, bool isText, bool ignoreLimit, RequestOptions options) { Preconditions.NotNull(options, nameof(options)); Client = client; Data = data; IsText = isText; + IgnoreLimit = ignoreLimit; Options = options; TimeoutAt = options.Timeout.HasValue ? DateTimeOffset.UtcNow.AddMilliseconds(options.Timeout.Value) : (DateTimeOffset?)null; Promise = new TaskCompletionSource(); diff --git a/src/Discord.Net.WebSocket/DiscordShardedClient.cs b/src/Discord.Net.WebSocket/DiscordShardedClient.cs index e5ca5e428..662f67e40 100644 --- a/src/Discord.Net.WebSocket/DiscordShardedClient.cs +++ b/src/Discord.Net.WebSocket/DiscordShardedClient.cs @@ -12,7 +12,6 @@ namespace Discord.WebSocket public partial class DiscordShardedClient : BaseSocketClient, IDiscordClient { private readonly DiscordSocketConfig _baseConfig; - private readonly SemaphoreSlim _connectionGroupLock; private readonly Dictionary _shardIdsToIndex; private readonly bool _automaticShards; private int[] _shardIds; @@ -65,7 +64,6 @@ namespace Discord.WebSocket _shardIdsToIndex = new Dictionary(); config.DisplayInitialLog = false; _baseConfig = config; - _connectionGroupLock = new SemaphoreSlim(1, 1); if (config.TotalShards == null) _automaticShards = true; @@ -88,7 +86,7 @@ namespace Discord.WebSocket _shardIdsToIndex.Add(_shardIds[i], i); var newConfig = config.Clone(); newConfig.ShardId = _shardIds[i]; - _shards[i] = new DiscordSocketClient(newConfig, _connectionGroupLock, i != 0 ? _shards[0] : null, masterIdentifySemaphore, config.IdentifyMaxConcurrency > 1 ? null : identifySemaphores[i / config.IdentifyMaxConcurrency], config.IdentifyMaxConcurrency); + _shards[i] = new DiscordSocketClient(newConfig, i != 0 ? _shards[0] : null, masterIdentifySemaphore, config.IdentifyMaxConcurrency > 1 ? null : identifySemaphores[i / config.IdentifyMaxConcurrency], config.IdentifyMaxConcurrency); RegisterEvents(_shards[i], i == 0); } } @@ -122,7 +120,7 @@ namespace Discord.WebSocket var newConfig = _baseConfig.Clone(); newConfig.ShardId = _shardIds[i]; newConfig.TotalShards = _totalShards; - _shards[i] = new DiscordSocketClient(newConfig, _connectionGroupLock, i != 0 ? _shards[0] : null, masterIdentifySemaphore, maxConcurrency > 1 ? null : identifySemaphores[i / maxConcurrency], maxConcurrency); + _shards[i] = new DiscordSocketClient(newConfig, i != 0 ? _shards[0] : null, masterIdentifySemaphore, maxConcurrency > 1 ? null : identifySemaphores[i / maxConcurrency], maxConcurrency); RegisterEvents(_shards[i], i == 0); } } @@ -418,7 +416,6 @@ namespace Discord.WebSocket foreach (var client in _shards) client?.Dispose(); } - _connectionGroupLock?.Dispose(); } _isDisposed = true; diff --git a/src/Discord.Net.WebSocket/DiscordSocketApiClient.cs b/src/Discord.Net.WebSocket/DiscordSocketApiClient.cs index 4d191544b..47a7def29 100644 --- a/src/Discord.Net.WebSocket/DiscordSocketApiClient.cs +++ b/src/Discord.Net.WebSocket/DiscordSocketApiClient.cs @@ -133,6 +133,8 @@ namespace Discord.API if (WebSocketClient == null) throw new NotSupportedException("This client is not configured with WebSocket support."); + RequestQueue.ClearGatewayBuckets(); + //Re-create streams to reset the zlib state _compressed?.Dispose(); _decompressor?.Dispose(); @@ -210,7 +212,7 @@ namespace Discord.API options.IsGatewayBucket = true; if (options.BucketId == null) options.BucketId = GatewayBucket.Get(GatewayBucketType.Unbucketed).Id; - await RequestQueue.SendAsync(new WebSocketRequest(WebSocketClient, bytes, true, options)).ConfigureAwait(false); + await RequestQueue.SendAsync(new WebSocketRequest(WebSocketClient, bytes, true, opCode == GatewayOpCode.Heartbeat, options)).ConfigureAwait(false); await _sentGatewayMessageEvent.InvokeAsync(opCode).ConfigureAwait(false); } diff --git a/src/Discord.Net.WebSocket/DiscordSocketClient.cs b/src/Discord.Net.WebSocket/DiscordSocketClient.cs index 13b99016b..cd83699a7 100644 --- a/src/Discord.Net.WebSocket/DiscordSocketClient.cs +++ b/src/Discord.Net.WebSocket/DiscordSocketClient.cs @@ -26,7 +26,6 @@ namespace Discord.WebSocket { private readonly ConcurrentQueue _largeGuilds; private readonly JsonSerializer _serializer; - private readonly SemaphoreSlim _connectionGroupLock; private readonly DiscordSocketClient _parentClient; private readonly ConcurrentQueue _heartbeatTimes; private readonly ConnectionManager _connection; @@ -119,10 +118,10 @@ namespace Discord.WebSocket /// /// The configuration to be used with the client. #pragma warning disable IDISP004 - public DiscordSocketClient(DiscordSocketConfig config) : this(config, CreateApiClient(config, new SemaphoreSlim(1, 1), null, 1), null, null) { } - internal DiscordSocketClient(DiscordSocketConfig config, SemaphoreSlim groupLock, DiscordSocketClient parentClient, SemaphoreSlim identifyMasterSemaphore, SemaphoreSlim identifySemaphore, int identifyMaxConcurrency) : this(config, CreateApiClient(config, identifyMasterSemaphore, identifySemaphore, identifyMaxConcurrency), groupLock, parentClient) { } + public DiscordSocketClient(DiscordSocketConfig config) : this(config, CreateApiClient(config, new SemaphoreSlim(1, 1), null, 1), null) { } + internal DiscordSocketClient(DiscordSocketConfig config, DiscordSocketClient parentClient, SemaphoreSlim identifyMasterSemaphore, SemaphoreSlim identifySemaphore, int identifyMaxConcurrency) : this(config, CreateApiClient(config, identifyMasterSemaphore, identifySemaphore, identifyMaxConcurrency), parentClient) { } #pragma warning restore IDISP004 - private DiscordSocketClient(DiscordSocketConfig config, API.DiscordSocketApiClient client, SemaphoreSlim groupLock, DiscordSocketClient parentClient) + private DiscordSocketClient(DiscordSocketConfig config, API.DiscordSocketApiClient client, DiscordSocketClient parentClient) : base(config, client) { ShardId = config.ShardId ?? 0; @@ -148,7 +147,6 @@ namespace Discord.WebSocket _connection.Disconnected += (ex, recon) => TimedInvokeAsync(_disconnectedEvent, nameof(Disconnected), ex); _nextAudioId = 1; - _connectionGroupLock = groupLock; _parentClient = parentClient; _serializer = new JsonSerializer { ContractResolver = new DiscordContractResolver() }; @@ -230,35 +228,25 @@ namespace Discord.WebSocket private async Task OnConnectingAsync() { - if (_connectionGroupLock != null) - await _connectionGroupLock.WaitAsync(_connection.CancelToken).ConfigureAwait(false); - try - { - await _gatewayLogger.DebugAsync("Connecting ApiClient").ConfigureAwait(false); - await ApiClient.ConnectAsync().ConfigureAwait(false); - - if (_sessionId != null) - { - await _gatewayLogger.DebugAsync("Resuming").ConfigureAwait(false); - await ApiClient.SendResumeAsync(_sessionId, _lastSeq).ConfigureAwait(false); - } - else - { - await _gatewayLogger.DebugAsync("Identifying").ConfigureAwait(false); - await ApiClient.SendIdentifyAsync(shardID: ShardId, totalShards: TotalShards, guildSubscriptions: _guildSubscriptions, gatewayIntents: _gatewayIntents).ConfigureAwait(false); - } - - //Wait for READY - await _connection.WaitAsync().ConfigureAwait(false); + await _gatewayLogger.DebugAsync("Connecting ApiClient").ConfigureAwait(false); + await ApiClient.ConnectAsync().ConfigureAwait(false); - await _gatewayLogger.DebugAsync("Sending Status").ConfigureAwait(false); - await SendStatusAsync().ConfigureAwait(false); + if (_sessionId != null) + { + await _gatewayLogger.DebugAsync("Resuming").ConfigureAwait(false); + await ApiClient.SendResumeAsync(_sessionId, _lastSeq).ConfigureAwait(false); } - finally + else { - if (_connectionGroupLock != null) - _connectionGroupLock.Release(); + await _gatewayLogger.DebugAsync("Identifying").ConfigureAwait(false); + await ApiClient.SendIdentifyAsync(shardID: ShardId, totalShards: TotalShards, guildSubscriptions: _guildSubscriptions, gatewayIntents: _gatewayIntents).ConfigureAwait(false); } + + //Wait for READY + await _connection.WaitAsync().ConfigureAwait(false); + + await _gatewayLogger.DebugAsync("Sending Status").ConfigureAwait(false); + await SendStatusAsync().ConfigureAwait(false); } private async Task OnDisconnectingAsync(Exception ex) {