From 66097b3fd77a163bf99df62fcc9984214d598bc7 Mon Sep 17 00:00:00 2001 From: RogueException Date: Thu, 12 May 2016 09:25:38 -0300 Subject: [PATCH] Added MessageQueue --- src/Discord.Net/API/DiscordRawClient.cs | 198 ++++++++-------- src/Discord.Net/Discord.Net.csproj | 9 +- src/Discord.Net/Logging/ILogger.cs | 1 + src/Discord.Net/Logging/LogManager.cs | 40 ++-- src/Discord.Net/Net/RateLimitException.cs | 15 ++ src/Discord.Net/Net/Rest/DefaultRestClient.cs | 5 + src/Discord.Net/Net/Rest/IMessageQueue.cs | 7 - .../Net/Rest/RequestQueue/BucketGroup.cs | 8 + .../Net/Rest/RequestQueue/GlobalBucket.cs | 8 + .../Net/Rest/RequestQueue/GuildBucket.cs | 10 + .../Net/Rest/RequestQueue/IRequestQueue.cs | 10 + .../Net/Rest/RequestQueue/RequestQueue.cs | 163 +++++++++++++ .../Rest/RequestQueue/RequestQueueBucket.cs | 223 ++++++++++++++++++ .../Net/Rest/RequestQueue/RestRequest.cs | 38 +++ src/Discord.Net/Rest/DiscordClient.cs | 2 +- .../Rest/Entities/Channels/TextChannel.cs | 8 +- src/Discord.Net/Rest/Entities/Message.cs | 8 +- 17 files changed, 618 insertions(+), 135 deletions(-) create mode 100644 src/Discord.Net/Net/RateLimitException.cs delete mode 100644 src/Discord.Net/Net/Rest/IMessageQueue.cs create mode 100644 src/Discord.Net/Net/Rest/RequestQueue/BucketGroup.cs create mode 100644 src/Discord.Net/Net/Rest/RequestQueue/GlobalBucket.cs create mode 100644 src/Discord.Net/Net/Rest/RequestQueue/GuildBucket.cs create mode 100644 src/Discord.Net/Net/Rest/RequestQueue/IRequestQueue.cs create mode 100644 src/Discord.Net/Net/Rest/RequestQueue/RequestQueue.cs create mode 100644 src/Discord.Net/Net/Rest/RequestQueue/RequestQueueBucket.cs create mode 100644 src/Discord.Net/Net/Rest/RequestQueue/RestRequest.cs diff --git a/src/Discord.Net/API/DiscordRawClient.cs b/src/Discord.Net/API/DiscordRawClient.cs index 575ece0c8..49cc122f4 100644 --- a/src/Discord.Net/API/DiscordRawClient.cs +++ b/src/Discord.Net/API/DiscordRawClient.cs @@ -21,6 +21,7 @@ namespace Discord.API { internal event EventHandler SentRequest; + private readonly RequestQueue _requestQueue; private readonly IRestClient _restClient; private readonly CancellationToken _cancelToken; private readonly JsonSerializer _serializer; @@ -46,6 +47,7 @@ namespace Discord.API _restClient = restClientProvider(DiscordConfig.ClientAPIUrl, cancelToken); _restClient.SetHeader("authorization", authToken); _restClient.SetHeader("user-agent", DiscordConfig.UserAgent); + _requestQueue = new RequestQueue(_restClient); _serializer = new JsonSerializer(); _serializer.Converters.Add(new ChannelTypeConverter()); @@ -60,113 +62,73 @@ namespace Discord.API } //Core - public async Task Send(string method, string endpoint) + public Task Send(string method, string endpoint, GlobalBucket bucket = GlobalBucket.General) + => SendInternal(method, endpoint, null, bucket); + public Task Send(string method, string endpoint, object payload, GlobalBucket bucket = GlobalBucket.General) + => SendInternal(method, endpoint, payload, bucket); + public Task Send(string method, string endpoint, Stream file, IReadOnlyDictionary multipartArgs, GlobalBucket bucket = GlobalBucket.General) + => SendInternal(method, endpoint, multipartArgs, bucket); + public async Task Send(string method, string endpoint, GlobalBucket bucket = GlobalBucket.General) where TResponse : class - { - var stopwatch = Stopwatch.StartNew(); - Stream responseStream; - try - { - responseStream = await _restClient.Send(method, endpoint, (string)null).ConfigureAwait(false); - } - catch (HttpException ex) - { - if (!HandleException(ex)) - throw; - return null; - } - int bytes = (int)responseStream.Length; - stopwatch.Stop(); - var response = Deserialize(responseStream); - - double milliseconds = ToMilliseconds(stopwatch); - SentRequest(this, new SentRequestEventArgs(method, endpoint, bytes, milliseconds)); + => Deserialize(await SendInternal(method, endpoint, null, bucket).ConfigureAwait(false)); + public async Task Send(string method, string endpoint, object payload, GlobalBucket bucket = GlobalBucket.General) + where TResponse : class + => Deserialize(await SendInternal(method, endpoint, payload, bucket).ConfigureAwait(false)); + public async Task Send(string method, string endpoint, Stream file, IReadOnlyDictionary multipartArgs, GlobalBucket bucket = GlobalBucket.General) + where TResponse : class + => Deserialize(await SendInternal(method, endpoint, multipartArgs, bucket).ConfigureAwait(false)); + + public Task Send(string method, string endpoint, GuildBucket bucket, ulong guildId) + => SendInternal(method, endpoint, null, bucket, guildId); + public Task Send(string method, string endpoint, object payload, GuildBucket bucket, ulong guildId) + => SendInternal(method, endpoint, payload, bucket, guildId); + public Task Send(string method, string endpoint, Stream file, IReadOnlyDictionary multipartArgs, GuildBucket bucket, ulong guildId) + => SendInternal(method, endpoint, multipartArgs, bucket, guildId); + public async Task Send(string method, string endpoint, GuildBucket bucket, ulong guildId) + where TResponse : class + => Deserialize(await SendInternal(method, endpoint, null, bucket, guildId).ConfigureAwait(false)); + public async Task Send(string method, string endpoint, object payload, GuildBucket bucket, ulong guildId) + where TResponse : class + => Deserialize(await SendInternal(method, endpoint, payload, bucket, guildId).ConfigureAwait(false)); + public async Task Send(string method, string endpoint, Stream file, IReadOnlyDictionary multipartArgs, GuildBucket bucket, ulong guildId) + where TResponse : class + => Deserialize(await SendInternal(method, endpoint, multipartArgs, bucket, guildId).ConfigureAwait(false)); - return response; - } - public async Task Send(string method, string endpoint) - { - var stopwatch = Stopwatch.StartNew(); - try - { - await _restClient.Send(method, endpoint, (string)null).ConfigureAwait(false); - } - catch (HttpException ex) - { - if (!HandleException(ex)) - throw; - return; - } - stopwatch.Stop(); + private Task SendInternal(string method, string endpoint, object payload, GlobalBucket bucket) + => SendInternal(method, endpoint, payload, BucketGroup.Global, (int)bucket, 0); + private Task SendInternal(string method, string endpoint, object payload, GuildBucket bucket, ulong guildId) + => SendInternal(method, endpoint, payload, BucketGroup.Guild, (int)bucket, guildId); + private Task SendInternal(string method, string endpoint, IReadOnlyDictionary multipartArgs, GlobalBucket bucket) + => SendInternal(method, endpoint, multipartArgs, BucketGroup.Global, (int)bucket, 0); + private Task SendInternal(string method, string endpoint, IReadOnlyDictionary multipartArgs, GuildBucket bucket, ulong guildId) + => SendInternal(method, endpoint, multipartArgs, BucketGroup.Guild, (int)bucket, guildId); - double milliseconds = ToMilliseconds(stopwatch); - SentRequest(this, new SentRequestEventArgs(method, endpoint, 0, milliseconds)); - } - public async Task Send(string method, string endpoint, object payload) - where TResponse : class + private async Task SendInternal(string method, string endpoint, object payload, BucketGroup group, int bucketId, ulong guildId) { - string requestStream = Serialize(payload); var stopwatch = Stopwatch.StartNew(); - Stream responseStream; - try - { - responseStream = await _restClient.Send(method, endpoint, requestStream).ConfigureAwait(false); - } - catch (HttpException ex) - { - if (!HandleException(ex)) - throw; - return null; - } + string json = null; + if (payload != null) + json = Serialize(payload); + var responseStream = await _requestQueue.Send(new RestRequest(method, endpoint, json), group, bucketId, guildId).ConfigureAwait(false); int bytes = (int)responseStream.Length; stopwatch.Stop(); - var response = Deserialize(responseStream); double milliseconds = ToMilliseconds(stopwatch); SentRequest(this, new SentRequestEventArgs(method, endpoint, bytes, milliseconds)); - return response; + return responseStream; } - public async Task Send(string method, string endpoint, object payload) + private async Task SendInternal(string method, string endpoint, IReadOnlyDictionary multipartArgs, BucketGroup group, int bucketId, ulong guildId) { - string requestStream = Serialize(payload); var stopwatch = Stopwatch.StartNew(); - try - { - await _restClient.Send(method, endpoint, requestStream).ConfigureAwait(false); - } - catch (HttpException ex) - { - if (!HandleException(ex)) - throw; - return; - } - stopwatch.Stop(); - - double milliseconds = ToMilliseconds(stopwatch); - SentRequest(this, new SentRequestEventArgs(method, endpoint, 0, milliseconds)); - } - public async Task Send(string method, string endpoint, Stream file, IReadOnlyDictionary multipartArgs) - where TResponse : class - { - var stopwatch = Stopwatch.StartNew(); - var responseStream = await _restClient.Send(method, endpoint).ConfigureAwait(false); + var responseStream = await _requestQueue.Send(new RestRequest(method, endpoint, multipartArgs), group, bucketId, guildId).ConfigureAwait(false); + int bytes = (int)responseStream.Length; stopwatch.Stop(); - var response = Deserialize(responseStream); double milliseconds = ToMilliseconds(stopwatch); - SentRequest(this, new SentRequestEventArgs(method, endpoint, (int)responseStream.Length, milliseconds)); - - return response; - } - public async Task Send(string method, string endpoint, Stream file, IReadOnlyDictionary multipartArgs) - { - var stopwatch = Stopwatch.StartNew(); - await _restClient.Send(method, endpoint).ConfigureAwait(false); - stopwatch.Stop(); + SentRequest(this, new SentRequestEventArgs(method, endpoint, bytes, milliseconds)); - double milliseconds = ToMilliseconds(stopwatch); - SentRequest(this, new SentRequestEventArgs(method, endpoint, 0, milliseconds)); + return responseStream; } //Gateway @@ -623,29 +585,50 @@ namespace Discord.API else return result[0]; } - public async Task CreateMessage(ulong channelId, CreateMessageParams args) + public Task CreateMessage(ulong channelId, CreateMessageParams args) + => CreateMessage(0, channelId, args); + public async Task CreateMessage(ulong guildId, ulong channelId, CreateMessageParams args) { if (args == null) throw new ArgumentNullException(nameof(args)); if (channelId == 0) throw new ArgumentOutOfRangeException(nameof(channelId)); - return await Send("POST", $"channels/{channelId}/messages", args).ConfigureAwait(false); + if (guildId != 0) + return await Send("POST", $"channels/{channelId}/messages", args, GuildBucket.SendEditMessage, guildId).ConfigureAwait(false); + else + return await Send("POST", $"channels/{channelId}/messages", args, GlobalBucket.DirectMessage).ConfigureAwait(false); } - public async Task UploadFile(ulong channelId, Stream file, UploadFileParams args) + public Task UploadFile(ulong channelId, Stream file, UploadFileParams args) + => UploadFile(0, channelId, file, args); + public async Task UploadFile(ulong guildId, ulong channelId, Stream file, UploadFileParams args) { if (args == null) throw new ArgumentNullException(nameof(args)); + //if (guildId == 0) throw new ArgumentOutOfRangeException(nameof(guildId)); if (channelId == 0) throw new ArgumentOutOfRangeException(nameof(channelId)); - return await Send("POST", $"channels/{channelId}/messages", file, args.ToDictionary()).ConfigureAwait(false); + if (guildId != 0) + return await Send("POST", $"channels/{channelId}/messages", file, args.ToDictionary(), GuildBucket.SendEditMessage, guildId).ConfigureAwait(false); + else + return await Send("POST", $"channels/{channelId}/messages", file, args.ToDictionary()).ConfigureAwait(false); } - public async Task DeleteMessage(ulong channelId, ulong messageId) + public Task DeleteMessage(ulong channelId, ulong messageId) + => DeleteMessage(0, channelId, messageId); + public async Task DeleteMessage(ulong guildId, ulong channelId, ulong messageId) { + //if (guildId == 0) throw new ArgumentOutOfRangeException(nameof(guildId)); if (channelId == 0) throw new ArgumentOutOfRangeException(nameof(channelId)); if (messageId == 0) throw new ArgumentOutOfRangeException(nameof(messageId)); - await Send("DELETE", $"channels/{channelId}/messages/{messageId}").ConfigureAwait(false); + if (guildId != 0) + await Send("DELETE", $"channels/{channelId}/messages/{messageId}", GuildBucket.DeleteMessage, guildId).ConfigureAwait(false); + else + await Send("DELETE", $"channels/{channelId}/messages/{messageId}").ConfigureAwait(false); } - public async Task DeleteMessages(ulong channelId, DeleteMessagesParam args) + public Task DeleteMessages(ulong channelId, DeleteMessagesParam args) + => DeleteMessages(0, channelId, args); + public async Task DeleteMessages(ulong guildId, ulong channelId, DeleteMessagesParam args) { + //if (guildId == 0) throw new ArgumentOutOfRangeException(nameof(guildId)); + if (channelId == 0) throw new ArgumentOutOfRangeException(nameof(channelId)); if (args == null) throw new ArgumentNullException(nameof(args)); if (args.MessageIds == null) throw new ArgumentNullException(nameof(args.MessageIds)); @@ -655,20 +638,29 @@ namespace Discord.API case 0: throw new ArgumentOutOfRangeException(nameof(args.MessageIds)); case 1: - await DeleteMessage(channelId, messageIds[0]).ConfigureAwait(false); + await DeleteMessage(guildId, channelId, messageIds[0]).ConfigureAwait(false); break; default: - await Send("POST", $"channels/{channelId}/messages/bulk_delete", args).ConfigureAwait(false); + if (guildId != 0) + await Send("POST", $"channels/{channelId}/messages/bulk_delete", args, GuildBucket.DeleteMessages, guildId).ConfigureAwait(false); + else + await Send("POST", $"channels/{channelId}/messages/bulk_delete", args).ConfigureAwait(false); break; } } - public async Task ModifyMessage(ulong channelId, ulong messageId, ModifyMessageParams args) + public Task ModifyMessage(ulong channelId, ulong messageId, ModifyMessageParams args) + => ModifyMessage(0, channelId, messageId, args); + public async Task ModifyMessage(ulong guildId, ulong channelId, ulong messageId, ModifyMessageParams args) { if (args == null) throw new ArgumentNullException(nameof(args)); + //if (guildId == 0) throw new ArgumentOutOfRangeException(nameof(guildId)); if (channelId == 0) throw new ArgumentOutOfRangeException(nameof(channelId)); if (messageId == 0) throw new ArgumentOutOfRangeException(nameof(messageId)); - return await Send("PATCH", $"channels/{channelId}/messages/{messageId}", args).ConfigureAwait(false); + if (guildId != 0) + return await Send("PATCH", $"channels/{channelId}/messages/{messageId}", args, GuildBucket.SendEditMessage, guildId).ConfigureAwait(false); + else + return await Send("PATCH", $"channels/{channelId}/messages/{messageId}", args).ConfigureAwait(false); } public async Task AckMessage(ulong channelId, ulong messageId) { @@ -775,11 +767,5 @@ namespace Discord.API using (JsonReader reader = new JsonTextReader(text)) return _serializer.Deserialize(reader); } - - private bool HandleException(Exception ex) - { - //TODO: Implement... maybe via SentRequest? Need to bubble this up to DiscordClient or a MessageQueue - return false; - } } } diff --git a/src/Discord.Net/Discord.Net.csproj b/src/Discord.Net/Discord.Net.csproj index a178e7f8e..f7c9e7a8c 100644 --- a/src/Discord.Net/Discord.Net.csproj +++ b/src/Discord.Net/Discord.Net.csproj @@ -95,6 +95,13 @@ + + + + + + + @@ -181,7 +188,7 @@ - + diff --git a/src/Discord.Net/Logging/ILogger.cs b/src/Discord.Net/Logging/ILogger.cs index 787965786..f8679d0ec 100644 --- a/src/Discord.Net/Logging/ILogger.cs +++ b/src/Discord.Net/Logging/ILogger.cs @@ -8,6 +8,7 @@ namespace Discord.Logging void Log(LogSeverity severity, string message, Exception exception = null); void Log(LogSeverity severity, FormattableString message, Exception exception = null); + void Log(LogSeverity severity, Exception exception); void Error(string message, Exception exception = null); void Error(FormattableString message, Exception exception = null); diff --git a/src/Discord.Net/Logging/LogManager.cs b/src/Discord.Net/Logging/LogManager.cs index 5b1e4d14b..0c183071d 100644 --- a/src/Discord.Net/Logging/LogManager.cs +++ b/src/Discord.Net/Logging/LogManager.cs @@ -23,6 +23,11 @@ namespace Discord.Logging if (severity <= Level) Message(this, new LogMessageEventArgs(severity, source, message.ToString(), ex)); } + public void Log(LogSeverity severity, string source, Exception ex) + { + if (severity <= Level) + Message(this, new LogMessageEventArgs(severity, source, null, ex)); + } void ILogger.Log(LogSeverity severity, string message, Exception ex) { if (severity <= Level) @@ -33,71 +38,76 @@ namespace Discord.Logging if (severity <= Level) Message(this, new LogMessageEventArgs(severity, "Discord", message.ToString(), ex)); } + void ILogger.Log(LogSeverity severity, Exception ex) + { + if (severity <= Level) + Message(this, new LogMessageEventArgs(severity, "Discord", null, ex)); + } public void Error(string source, string message, Exception ex = null) => Log(LogSeverity.Error, source, message, ex); public void Error(string source, FormattableString message, Exception ex = null) => Log(LogSeverity.Error, source, message, ex); - public void Error(string source, Exception ex = null) - => Log(LogSeverity.Error, source, (string)null, ex); + public void Error(string source, Exception ex) + => Log(LogSeverity.Error, source, ex); void ILogger.Error(string message, Exception ex) => Log(LogSeverity.Error, "Discord", message, ex); void ILogger.Error(FormattableString message, Exception ex) => Log(LogSeverity.Error, "Discord", message, ex); void ILogger.Error(Exception ex) - => Log(LogSeverity.Error, "Discord", (string)null, ex); + => Log(LogSeverity.Error, "Discord", ex); public void Warning(string source, string message, Exception ex = null) => Log(LogSeverity.Warning, source, message, ex); public void Warning(string source, FormattableString message, Exception ex = null) => Log(LogSeverity.Warning, source, message, ex); - public void Warning(string source, Exception ex = null) - => Log(LogSeverity.Warning, source, (string)null, ex); + public void Warning(string source, Exception ex) + => Log(LogSeverity.Warning, source, ex); void ILogger.Warning(string message, Exception ex) => Log(LogSeverity.Warning, "Discord", message, ex); void ILogger.Warning(FormattableString message, Exception ex) => Log(LogSeverity.Warning, "Discord", message, ex); void ILogger.Warning(Exception ex) - => Log(LogSeverity.Warning, "Discord", (string)null, ex); + => Log(LogSeverity.Warning, "Discord", ex); public void Info(string source, string message, Exception ex = null) => Log(LogSeverity.Info, source, message, ex); public void Info(string source, FormattableString message, Exception ex = null) => Log(LogSeverity.Info, source, message, ex); - public void Info(string source, Exception ex = null) - => Log(LogSeverity.Info, source, (string)null, ex); + public void Info(string source, Exception ex) + => Log(LogSeverity.Info, source, ex); void ILogger.Info(string message, Exception ex) => Log(LogSeverity.Info, "Discord", message, ex); void ILogger.Info(FormattableString message, Exception ex) => Log(LogSeverity.Info, "Discord", message, ex); void ILogger.Info(Exception ex) - => Log(LogSeverity.Info, "Discord", (string)null, ex); + => Log(LogSeverity.Info, "Discord", ex); public void Verbose(string source, string message, Exception ex = null) => Log(LogSeverity.Verbose, source, message, ex); public void Verbose(string source, FormattableString message, Exception ex = null) => Log(LogSeverity.Verbose, source, message, ex); - public void Verbose(string source, Exception ex = null) - => Log(LogSeverity.Verbose, source, (string)null, ex); + public void Verbose(string source, Exception ex) + => Log(LogSeverity.Verbose, source, ex); void ILogger.Verbose(string message, Exception ex) => Log(LogSeverity.Verbose, "Discord", message, ex); void ILogger.Verbose(FormattableString message, Exception ex) => Log(LogSeverity.Verbose, "Discord", message, ex); void ILogger.Verbose(Exception ex) - => Log(LogSeverity.Verbose, "Discord", (string)null, ex); + => Log(LogSeverity.Verbose, "Discord", ex); public void Debug(string source, string message, Exception ex = null) => Log(LogSeverity.Debug, source, message, ex); public void Debug(string source, FormattableString message, Exception ex = null) => Log(LogSeverity.Debug, source, message, ex); - public void Debug(string source, Exception ex = null) - => Log(LogSeverity.Debug, source, (string)null, ex); + public void Debug(string source, Exception ex) + => Log(LogSeverity.Debug, source, ex); void ILogger.Debug(string message, Exception ex) => Log(LogSeverity.Debug, "Discord", message, ex); void ILogger.Debug(FormattableString message, Exception ex) => Log(LogSeverity.Debug, "Discord", message, ex); void ILogger.Debug(Exception ex) - => Log(LogSeverity.Debug, "Discord", (string)null, ex); + => Log(LogSeverity.Debug, "Discord", ex); internal Logger CreateLogger(string name) => new Logger(this, name); } diff --git a/src/Discord.Net/Net/RateLimitException.cs b/src/Discord.Net/Net/RateLimitException.cs new file mode 100644 index 000000000..a07e90760 --- /dev/null +++ b/src/Discord.Net/Net/RateLimitException.cs @@ -0,0 +1,15 @@ +using System.Net; + +namespace Discord.Net +{ + public class HttpRateLimitException : HttpException + { + public int RetryAfterMilliseconds { get; } + + public HttpRateLimitException(int retryAfterMilliseconds) + : base((HttpStatusCode)429) + { + RetryAfterMilliseconds = retryAfterMilliseconds; + } + } +} diff --git a/src/Discord.Net/Net/Rest/DefaultRestClient.cs b/src/Discord.Net/Net/Rest/DefaultRestClient.cs index 3c72f7258..0aec73597 100644 --- a/src/Discord.Net/Net/Rest/DefaultRestClient.cs +++ b/src/Discord.Net/Net/Rest/DefaultRestClient.cs @@ -3,6 +3,7 @@ using System; using System.Collections.Generic; using System.Globalization; using System.IO; +using System.Linq; using System.Net; using System.Net.Http; using System.Text; @@ -118,7 +119,11 @@ namespace Discord.Net.Rest int statusCode = (int)response.StatusCode; if (statusCode < 200 || statusCode >= 300) //2xx = Success + { + if (statusCode == 429) + throw new HttpRateLimitException(int.Parse(response.Headers.GetValues("retry-after").First())); throw new HttpException(response.StatusCode); + } return await response.Content.ReadAsStreamAsync().ConfigureAwait(false); } diff --git a/src/Discord.Net/Net/Rest/IMessageQueue.cs b/src/Discord.Net/Net/Rest/IMessageQueue.cs deleted file mode 100644 index a61131ed8..000000000 --- a/src/Discord.Net/Net/Rest/IMessageQueue.cs +++ /dev/null @@ -1,7 +0,0 @@ -namespace Discord.Net.Rest -{ - public interface IMessageQueue - { - int Count { get; } - } -} diff --git a/src/Discord.Net/Net/Rest/RequestQueue/BucketGroup.cs b/src/Discord.Net/Net/Rest/RequestQueue/BucketGroup.cs new file mode 100644 index 000000000..54c3e717d --- /dev/null +++ b/src/Discord.Net/Net/Rest/RequestQueue/BucketGroup.cs @@ -0,0 +1,8 @@ +namespace Discord.Net.Rest +{ + internal enum BucketGroup + { + Global, + Guild + } +} diff --git a/src/Discord.Net/Net/Rest/RequestQueue/GlobalBucket.cs b/src/Discord.Net/Net/Rest/RequestQueue/GlobalBucket.cs new file mode 100644 index 000000000..4e7126f5e --- /dev/null +++ b/src/Discord.Net/Net/Rest/RequestQueue/GlobalBucket.cs @@ -0,0 +1,8 @@ +namespace Discord.Net.Rest +{ + public enum GlobalBucket + { + General, + DirectMessage + } +} diff --git a/src/Discord.Net/Net/Rest/RequestQueue/GuildBucket.cs b/src/Discord.Net/Net/Rest/RequestQueue/GuildBucket.cs new file mode 100644 index 000000000..ccb3fa994 --- /dev/null +++ b/src/Discord.Net/Net/Rest/RequestQueue/GuildBucket.cs @@ -0,0 +1,10 @@ +namespace Discord.Net.Rest +{ + public enum GuildBucket + { + SendEditMessage, + DeleteMessage, + DeleteMessages, + Nickname + } +} diff --git a/src/Discord.Net/Net/Rest/RequestQueue/IRequestQueue.cs b/src/Discord.Net/Net/Rest/RequestQueue/IRequestQueue.cs new file mode 100644 index 000000000..27231a334 --- /dev/null +++ b/src/Discord.Net/Net/Rest/RequestQueue/IRequestQueue.cs @@ -0,0 +1,10 @@ +using System.Threading.Tasks; + +namespace Discord.Net.Rest +{ + public interface IRequestQueue + { + Task Clear(GlobalBucket type); + Task Clear(GuildBucket type, ulong guildId); + } +} diff --git a/src/Discord.Net/Net/Rest/RequestQueue/RequestQueue.cs b/src/Discord.Net/Net/Rest/RequestQueue/RequestQueue.cs new file mode 100644 index 000000000..155b683e7 --- /dev/null +++ b/src/Discord.Net/Net/Rest/RequestQueue/RequestQueue.cs @@ -0,0 +1,163 @@ +using System; +using System.Collections.Generic; +using System.IO; +using System.Threading; +using System.Threading.Tasks; + +namespace Discord.Net.Rest +{ + public class RequestQueue : IRequestQueue + { + private SemaphoreSlim _lock; + private RequestQueueBucket[] _globalBuckets; + private Dictionary[] _guildBuckets; + + public IRestClient RestClient { get; } + + public RequestQueue(IRestClient restClient) + { + RestClient = restClient; + + _lock = new SemaphoreSlim(1, 1); + _globalBuckets = new RequestQueueBucket[Enum.GetValues(typeof(GlobalBucket)).Length]; + _guildBuckets = new Dictionary[Enum.GetValues(typeof(GuildBucket)).Length]; + } + + internal async Task Send(RestRequest request, BucketGroup group, int bucketId, ulong guildId) + { + RequestQueueBucket bucket; + + await Lock().ConfigureAwait(false); + try + { + bucket = GetBucket(group, bucketId, guildId); + bucket.Queue(request); + } + finally { Unlock(); } + + //There is a chance the bucket will send this request on its own, but this will simply become a noop then. + var _ = bucket.ProcessQueue(acquireLock: true).ConfigureAwait(false); + + return await request.Promise.Task.ConfigureAwait(false); + } + + private RequestQueueBucket CreateBucket(GlobalBucket bucket) + { + switch (bucket) + { + //Globals + case GlobalBucket.General: return new RequestQueueBucket(this, bucket, int.MaxValue, 0); //Catch-all + case GlobalBucket.DirectMessage: return new RequestQueueBucket(this, bucket, 5, 5); + + default: throw new ArgumentException($"Unknown global bucket: {bucket}", nameof(bucket)); + } + } + private RequestQueueBucket CreateBucket(GuildBucket bucket, ulong guildId) + { + switch (bucket) + { + //Per Guild + case GuildBucket.SendEditMessage: return new RequestQueueBucket(this, bucket, guildId, 5, 5); + case GuildBucket.DeleteMessage: return new RequestQueueBucket(this, bucket, guildId, 5, 1); + case GuildBucket.DeleteMessages: return new RequestQueueBucket(this, bucket, guildId, 1, 1); + case GuildBucket.Nickname: return new RequestQueueBucket(this, bucket, guildId, 1, 1); + + default: throw new ArgumentException($"Unknown guild bucket: {bucket}", nameof(bucket)); + } + } + + private RequestQueueBucket GetBucket(BucketGroup group, int bucketId, ulong guildId) + { + switch (group) + { + case BucketGroup.Global: + return GetGlobalBucket((GlobalBucket)bucketId); + case BucketGroup.Guild: + return GetGuildBucket((GuildBucket)bucketId, guildId); + default: + throw new ArgumentException($"Unknown bucket group: {group}", nameof(group)); + } + } + private RequestQueueBucket GetGlobalBucket(GlobalBucket type) + { + var bucket = _globalBuckets[(int)type]; + if (bucket == null) + { + bucket = CreateBucket(type); + _globalBuckets[(int)type] = bucket; + } + return bucket; + } + private RequestQueueBucket GetGuildBucket(GuildBucket type, ulong guildId) + { + var bucketGroup = _guildBuckets[(int)type]; + if (bucketGroup == null) + { + bucketGroup = new Dictionary(); + _guildBuckets[(int)type] = bucketGroup; + } + RequestQueueBucket bucket; + if (!bucketGroup.TryGetValue(guildId, out bucket)) + { + bucket = CreateBucket(type, guildId); + bucketGroup[guildId] = bucket; + } + return bucket; + } + + internal void DestroyGlobalBucket(GlobalBucket type) + { + //Assume this object is locked + + _globalBuckets[(int)type] = null; + } + internal void DestroyGuildBucket(GuildBucket type, ulong guildId) + { + //Assume this object is locked + + var bucketGroup = _guildBuckets[(int)type]; + if (bucketGroup != null) + bucketGroup.Remove(guildId); + } + + public async Task Lock() + { + await _lock.WaitAsync(); + } + public void Unlock() + { + _lock.Release(); + } + + public async Task Clear(GlobalBucket type) + { + var bucket = _globalBuckets[(int)type]; + if (bucket != null) + { + try + { + await bucket.Lock(); + bucket.Clear(); + } + finally { bucket.Unlock(); } + } + } + public async Task Clear(GuildBucket type, ulong guildId) + { + var bucketGroup = _guildBuckets[(int)type]; + if (bucketGroup != null) + { + RequestQueueBucket bucket; + if (bucketGroup.TryGetValue(guildId, out bucket)) + { + try + { + await bucket.Lock(); + bucket.Clear(); + } + finally { bucket.Unlock(); } + } + } + } + } +} diff --git a/src/Discord.Net/Net/Rest/RequestQueue/RequestQueueBucket.cs b/src/Discord.Net/Net/Rest/RequestQueue/RequestQueueBucket.cs new file mode 100644 index 000000000..a3788a5b0 --- /dev/null +++ b/src/Discord.Net/Net/Rest/RequestQueue/RequestQueueBucket.cs @@ -0,0 +1,223 @@ +using System; +using System.Collections.Concurrent; +using System.IO; +using System.Net; +using System.Threading; +using System.Threading.Tasks; + +namespace Discord.Net.Rest +{ + internal class RequestQueueBucket + { + private readonly RequestQueue _parent; + private readonly BucketGroup _bucketGroup; + private readonly int _bucketId; + private readonly ulong _guildId; + private readonly ConcurrentQueue _queue; + private readonly SemaphoreSlim _lock; + private Task _resetTask; + private DateTime? _retryAfter; + private bool _waitingToProcess; + + public int WindowMaxCount { get; } + public int WindowSeconds { get; } + public int WindowCount { get; private set; } + + public RequestQueueBucket(RequestQueue parent, GlobalBucket bucket, int windowMaxCount, int windowSeconds) + : this(parent, windowMaxCount, windowSeconds) + { + _bucketGroup = BucketGroup.Global; + _bucketId = (int)bucket; + _guildId = 0; + } + public RequestQueueBucket(RequestQueue parent, GuildBucket bucket, ulong guildId, int windowMaxCount, int windowSeconds) + : this(parent, windowMaxCount, windowSeconds) + { + _bucketGroup = BucketGroup.Guild; + _bucketId = (int)bucket; + _guildId = guildId; + } + private RequestQueueBucket(RequestQueue parent, int windowMaxCount, int windowSeconds) + { + _parent = parent; + WindowMaxCount = windowMaxCount; + WindowSeconds = windowSeconds; + _queue = new ConcurrentQueue(); + _lock = new SemaphoreSlim(1, 1); + } + + public void Queue(RestRequest request) + { + //Assume this obj's parent is under lock + + _queue.Enqueue(request); + Debug($"Request queued ({WindowCount}/{WindowMaxCount} + {_queue.Count})"); + } + public async Task ProcessQueue(bool acquireLock = false) + { + //Assume this obj is under lock + + int nextRetry = 1000; + + //If we have another ProcessQueue waiting to run, dont bother with this one + if (_waitingToProcess) return; + _waitingToProcess = true; + + if (acquireLock) + await Lock().ConfigureAwait(false); + try + { + _waitingToProcess = false; + while (true) + { + RestRequest request; + + //If we're waiting to reset (due to a rate limit exception, or preemptive check), abort + if (WindowCount == WindowMaxCount) return; + //Get next request, return if queue is empty + if (!_queue.TryPeek(out request)) return; + + try + { + Stream stream; + if (request.IsMultipart) + stream = await _parent.RestClient.Send(request.Method, request.Endpoint, request.MultipartParams).ConfigureAwait(false); + else + stream = await _parent.RestClient.Send(request.Method, request.Endpoint, request.Json).ConfigureAwait(false); + request.Promise.SetResult(stream); + } + catch (HttpRateLimitException ex) //Preemptive check failed, use Discord's time instead of our own + { + if (_resetTask == null) + { + //No reset has been queued yet, lets create one as if this *was* preemptive + _resetTask = ResetAfter(ex.RetryAfterMilliseconds); + Debug($"External rate limit: Reset in {ex.RetryAfterMilliseconds} ms"); + } + else + { + //A preemptive reset is already queued, set RetryAfter to extend it + _retryAfter = DateTime.UtcNow.AddMilliseconds(ex.RetryAfterMilliseconds); + Debug($"External rate limit: Extended to {ex.RetryAfterMilliseconds} ms"); + } + return; + } + catch (HttpException ex) + { + if (ex.StatusCode == HttpStatusCode.BadGateway) //Gateway unavailable, retry + { + await Task.Delay(nextRetry).ConfigureAwait(false); + nextRetry *= 2; + if (nextRetry > 30000) + nextRetry = 30000; + continue; + } + else + { + //We dont need to throw this here, pass the exception via the promise + request.Promise.SetException(ex); + } + } + + //Request completed or had an error other than 429 + _queue.TryDequeue(out request); + WindowCount++; + nextRetry = 1000; + Debug($"Request succeeded ({WindowCount}/{WindowMaxCount} + {_queue.Count})"); + + if (WindowCount == 1 && WindowSeconds > 0) + { + //First request for this window, schedule a reset + _resetTask = ResetAfter(WindowSeconds * 1000); + Debug($"Internal rate limit: Reset in {WindowSeconds * 1000} ms"); + } + } + } + finally + { + if (acquireLock) + Unlock(); + } + } + public void Clear() + { + //Assume this obj is under lock + RestRequest request; + + while (_queue.TryDequeue(out request)) { } + } + + private async Task ResetAfter(int milliseconds) + { + if (milliseconds > 0) + await Task.Delay(milliseconds).ConfigureAwait(false); + try + { + await Lock().ConfigureAwait(false); + + //If an extension has been planned, start a new wait task + if (_retryAfter != null) + { + _resetTask = ResetAfter((int)(_retryAfter.Value - DateTime.UtcNow).TotalMilliseconds); + _retryAfter = null; + return; + } + + Debug($"Reset"); + //Reset the current window count and set our state back to normal + WindowCount = 0; + _resetTask = null; + + //Wait is over, work through the current queue + await ProcessQueue().ConfigureAwait(false); + + //If queue is empty and non-global, remove this bucket + if (_bucketGroup == BucketGroup.Guild && _queue.IsEmpty) + { + try + { + await _parent.Lock().ConfigureAwait(false); + if (_queue.IsEmpty) //Double check, in case a request was queued before we got both locks + _parent.DestroyGuildBucket((GuildBucket)_bucketId, _guildId); + } + finally + { + _parent.Unlock(); + } + } + } + finally + { + Unlock(); + } + } + + public async Task Lock() + { + await _lock.WaitAsync(); + } + public void Unlock() + { + _lock.Release(); + } + + //TODO: Remove + private void Debug(string text) + { + string name; + switch (_bucketGroup) + { + case BucketGroup.Global: + name = ((GlobalBucket)_bucketId).ToString(); + break; + case BucketGroup.Guild: + name = ((GuildBucket)_bucketId).ToString(); + break; + default: + name = "Unknown"; + break; + } + System.Diagnostics.Debug.WriteLine($"[{name}] {text}"); + } + } +} diff --git a/src/Discord.Net/Net/Rest/RequestQueue/RestRequest.cs b/src/Discord.Net/Net/Rest/RequestQueue/RestRequest.cs new file mode 100644 index 000000000..86c7ca962 --- /dev/null +++ b/src/Discord.Net/Net/Rest/RequestQueue/RestRequest.cs @@ -0,0 +1,38 @@ +using System.Collections.Generic; +using System.IO; +using System.Threading.Tasks; + +namespace Discord.Net.Rest +{ + internal struct RestRequest + { + public string Method { get; } + public string Endpoint { get; } + public string Json { get; } + public IReadOnlyDictionary MultipartParams { get; } + public TaskCompletionSource Promise { get; } + + public bool IsMultipart => MultipartParams != null; + + public RestRequest(string method, string endpoint, string json) + : this(method, endpoint) + { + Json = json; + } + + public RestRequest(string method, string endpoint, IReadOnlyDictionary multipartParams) + : this(method, endpoint) + { + MultipartParams = multipartParams; + } + + private RestRequest(string method, string endpoint) + { + Method = method; + Endpoint = endpoint; + Json = null; + MultipartParams = null; + Promise = new TaskCompletionSource(); + } + } +} diff --git a/src/Discord.Net/Rest/DiscordClient.cs b/src/Discord.Net/Rest/DiscordClient.cs index 5b0e0d119..ac1bc0762 100644 --- a/src/Discord.Net/Rest/DiscordClient.cs +++ b/src/Discord.Net/Rest/DiscordClient.cs @@ -60,7 +60,7 @@ namespace Discord.Rest var cancelTokenSource = new CancellationTokenSource(); BaseClient = new API.DiscordRawClient(_restClientProvider, cancelTokenSource.Token, tokenType, token); - BaseClient.SentRequest += (s, e) => _log.Verbose($"{e.Method} {e.Endpoint}: {e.Milliseconds} ms"); + BaseClient.SentRequest += (s, e) => _log.Verbose("Rest", $"{e.Method} {e.Endpoint}: {e.Milliseconds} ms"); //MessageQueue = new MessageQueue(RestClient, _restLogger); //await MessageQueue.Start(_cancelTokenSource.Token).ConfigureAwait(false); diff --git a/src/Discord.Net/Rest/Entities/Channels/TextChannel.cs b/src/Discord.Net/Rest/Entities/Channels/TextChannel.cs index 24ce2d7e9..e15d7578a 100644 --- a/src/Discord.Net/Rest/Entities/Channels/TextChannel.cs +++ b/src/Discord.Net/Rest/Entities/Channels/TextChannel.cs @@ -64,7 +64,7 @@ namespace Discord.Rest public async Task SendMessage(string text, bool isTTS = false) { var args = new CreateMessageParams { Content = text, IsTTS = isTTS }; - var model = await Discord.BaseClient.CreateMessage(Id, args).ConfigureAwait(false); + var model = await Discord.BaseClient.CreateMessage(Guild.Id, Id, args).ConfigureAwait(false); return new Message(this, model); } /// @@ -74,7 +74,7 @@ namespace Discord.Rest using (var file = File.OpenRead(filePath)) { var args = new UploadFileParams { Filename = filename, Content = text, IsTTS = isTTS }; - var model = await Discord.BaseClient.UploadFile(Id, file, args).ConfigureAwait(false); + var model = await Discord.BaseClient.UploadFile(Guild.Id, Id, file, args).ConfigureAwait(false); return new Message(this, model); } } @@ -82,14 +82,14 @@ namespace Discord.Rest public async Task SendFile(Stream stream, string filename, string text = null, bool isTTS = false) { var args = new UploadFileParams { Filename = filename, Content = text, IsTTS = isTTS }; - var model = await Discord.BaseClient.UploadFile(Id, stream, args).ConfigureAwait(false); + var model = await Discord.BaseClient.UploadFile(Guild.Id, Id, stream, args).ConfigureAwait(false); return new Message(this, model); } /// public async Task DeleteMessages(IEnumerable messages) { - await Discord.BaseClient.DeleteMessages(Id, new DeleteMessagesParam { MessageIds = messages.Select(x => x.Id) }).ConfigureAwait(false); + await Discord.BaseClient.DeleteMessages(Guild.Id, Id, new DeleteMessagesParam { MessageIds = messages.Select(x => x.Id) }).ConfigureAwait(false); } /// diff --git a/src/Discord.Net/Rest/Entities/Message.cs b/src/Discord.Net/Rest/Entities/Message.cs index 43abbebc7..4f232bed6 100644 --- a/src/Discord.Net/Rest/Entities/Message.cs +++ b/src/Discord.Net/Rest/Entities/Message.cs @@ -121,7 +121,13 @@ namespace Discord.Rest var args = new ModifyMessageParams(); func(args); - var model = await Discord.BaseClient.ModifyMessage(Channel.Id, Id, args).ConfigureAwait(false); + var guildChannel = Channel as GuildChannel; + + Model model; + if (guildChannel != null) + model = await Discord.BaseClient.ModifyMessage(guildChannel.Guild.Id, Channel.Id, Id, args).ConfigureAwait(false); + else + model = await Discord.BaseClient.ModifyMessage(Channel.Id, Id, args).ConfigureAwait(false); Update(model); }