Browse Source

Added MessageQueue

tags/1.0-rc
RogueException 9 years ago
parent
commit
66097b3fd7
17 changed files with 618 additions and 135 deletions
  1. +92
    -106
      src/Discord.Net/API/DiscordRawClient.cs
  2. +8
    -1
      src/Discord.Net/Discord.Net.csproj
  3. +1
    -0
      src/Discord.Net/Logging/ILogger.cs
  4. +25
    -15
      src/Discord.Net/Logging/LogManager.cs
  5. +15
    -0
      src/Discord.Net/Net/RateLimitException.cs
  6. +5
    -0
      src/Discord.Net/Net/Rest/DefaultRestClient.cs
  7. +0
    -7
      src/Discord.Net/Net/Rest/IMessageQueue.cs
  8. +8
    -0
      src/Discord.Net/Net/Rest/RequestQueue/BucketGroup.cs
  9. +8
    -0
      src/Discord.Net/Net/Rest/RequestQueue/GlobalBucket.cs
  10. +10
    -0
      src/Discord.Net/Net/Rest/RequestQueue/GuildBucket.cs
  11. +10
    -0
      src/Discord.Net/Net/Rest/RequestQueue/IRequestQueue.cs
  12. +163
    -0
      src/Discord.Net/Net/Rest/RequestQueue/RequestQueue.cs
  13. +223
    -0
      src/Discord.Net/Net/Rest/RequestQueue/RequestQueueBucket.cs
  14. +38
    -0
      src/Discord.Net/Net/Rest/RequestQueue/RestRequest.cs
  15. +1
    -1
      src/Discord.Net/Rest/DiscordClient.cs
  16. +4
    -4
      src/Discord.Net/Rest/Entities/Channels/TextChannel.cs
  17. +7
    -1
      src/Discord.Net/Rest/Entities/Message.cs

+ 92
- 106
src/Discord.Net/API/DiscordRawClient.cs View File

@@ -21,6 +21,7 @@ namespace Discord.API
{
internal event EventHandler<SentRequestEventArgs> SentRequest;

private readonly RequestQueue _requestQueue;
private readonly IRestClient _restClient;
private readonly CancellationToken _cancelToken;
private readonly JsonSerializer _serializer;
@@ -46,6 +47,7 @@ namespace Discord.API
_restClient = restClientProvider(DiscordConfig.ClientAPIUrl, cancelToken);
_restClient.SetHeader("authorization", authToken);
_restClient.SetHeader("user-agent", DiscordConfig.UserAgent);
_requestQueue = new RequestQueue(_restClient);

_serializer = new JsonSerializer();
_serializer.Converters.Add(new ChannelTypeConverter());
@@ -60,113 +62,73 @@ namespace Discord.API
}

//Core
public async Task<TResponse> Send<TResponse>(string method, string endpoint)
public Task Send(string method, string endpoint, GlobalBucket bucket = GlobalBucket.General)
=> SendInternal(method, endpoint, null, bucket);
public Task Send(string method, string endpoint, object payload, GlobalBucket bucket = GlobalBucket.General)
=> SendInternal(method, endpoint, payload, bucket);
public Task Send(string method, string endpoint, Stream file, IReadOnlyDictionary<string, string> multipartArgs, GlobalBucket bucket = GlobalBucket.General)
=> SendInternal(method, endpoint, multipartArgs, bucket);
public async Task<TResponse> Send<TResponse>(string method, string endpoint, GlobalBucket bucket = GlobalBucket.General)
where TResponse : class
{
var stopwatch = Stopwatch.StartNew();
Stream responseStream;
try
{
responseStream = await _restClient.Send(method, endpoint, (string)null).ConfigureAwait(false);
}
catch (HttpException ex)
{
if (!HandleException(ex))
throw;
return null;
}
int bytes = (int)responseStream.Length;
stopwatch.Stop();
var response = Deserialize<TResponse>(responseStream);

double milliseconds = ToMilliseconds(stopwatch);
SentRequest(this, new SentRequestEventArgs(method, endpoint, bytes, milliseconds));
=> Deserialize<TResponse>(await SendInternal(method, endpoint, null, bucket).ConfigureAwait(false));
public async Task<TResponse> Send<TResponse>(string method, string endpoint, object payload, GlobalBucket bucket = GlobalBucket.General)
where TResponse : class
=> Deserialize<TResponse>(await SendInternal(method, endpoint, payload, bucket).ConfigureAwait(false));
public async Task<TResponse> Send<TResponse>(string method, string endpoint, Stream file, IReadOnlyDictionary<string, string> multipartArgs, GlobalBucket bucket = GlobalBucket.General)
where TResponse : class
=> Deserialize<TResponse>(await SendInternal(method, endpoint, multipartArgs, bucket).ConfigureAwait(false));

public Task Send(string method, string endpoint, GuildBucket bucket, ulong guildId)
=> SendInternal(method, endpoint, null, bucket, guildId);
public Task Send(string method, string endpoint, object payload, GuildBucket bucket, ulong guildId)
=> SendInternal(method, endpoint, payload, bucket, guildId);
public Task Send(string method, string endpoint, Stream file, IReadOnlyDictionary<string, string> multipartArgs, GuildBucket bucket, ulong guildId)
=> SendInternal(method, endpoint, multipartArgs, bucket, guildId);
public async Task<TResponse> Send<TResponse>(string method, string endpoint, GuildBucket bucket, ulong guildId)
where TResponse : class
=> Deserialize<TResponse>(await SendInternal(method, endpoint, null, bucket, guildId).ConfigureAwait(false));
public async Task<TResponse> Send<TResponse>(string method, string endpoint, object payload, GuildBucket bucket, ulong guildId)
where TResponse : class
=> Deserialize<TResponse>(await SendInternal(method, endpoint, payload, bucket, guildId).ConfigureAwait(false));
public async Task<TResponse> Send<TResponse>(string method, string endpoint, Stream file, IReadOnlyDictionary<string, string> multipartArgs, GuildBucket bucket, ulong guildId)
where TResponse : class
=> Deserialize<TResponse>(await SendInternal(method, endpoint, multipartArgs, bucket, guildId).ConfigureAwait(false));

return response;
}
public async Task Send(string method, string endpoint)
{
var stopwatch = Stopwatch.StartNew();
try
{
await _restClient.Send(method, endpoint, (string)null).ConfigureAwait(false);
}
catch (HttpException ex)
{
if (!HandleException(ex))
throw;
return;
}
stopwatch.Stop();
private Task<Stream> SendInternal(string method, string endpoint, object payload, GlobalBucket bucket)
=> SendInternal(method, endpoint, payload, BucketGroup.Global, (int)bucket, 0);
private Task<Stream> SendInternal(string method, string endpoint, object payload, GuildBucket bucket, ulong guildId)
=> SendInternal(method, endpoint, payload, BucketGroup.Guild, (int)bucket, guildId);
private Task<Stream> SendInternal(string method, string endpoint, IReadOnlyDictionary<string, object> multipartArgs, GlobalBucket bucket)
=> SendInternal(method, endpoint, multipartArgs, BucketGroup.Global, (int)bucket, 0);
private Task<Stream> SendInternal(string method, string endpoint, IReadOnlyDictionary<string, object> multipartArgs, GuildBucket bucket, ulong guildId)
=> SendInternal(method, endpoint, multipartArgs, BucketGroup.Guild, (int)bucket, guildId);

double milliseconds = ToMilliseconds(stopwatch);
SentRequest(this, new SentRequestEventArgs(method, endpoint, 0, milliseconds));
}
public async Task<TResponse> Send<TResponse>(string method, string endpoint, object payload)
where TResponse : class
private async Task<Stream> SendInternal(string method, string endpoint, object payload, BucketGroup group, int bucketId, ulong guildId)
{
string requestStream = Serialize(payload);
var stopwatch = Stopwatch.StartNew();
Stream responseStream;
try
{
responseStream = await _restClient.Send(method, endpoint, requestStream).ConfigureAwait(false);
}
catch (HttpException ex)
{
if (!HandleException(ex))
throw;
return null;
}
string json = null;
if (payload != null)
json = Serialize(payload);
var responseStream = await _requestQueue.Send(new RestRequest(method, endpoint, json), group, bucketId, guildId).ConfigureAwait(false);
int bytes = (int)responseStream.Length;
stopwatch.Stop();
var response = Deserialize<TResponse>(responseStream);

double milliseconds = ToMilliseconds(stopwatch);
SentRequest(this, new SentRequestEventArgs(method, endpoint, bytes, milliseconds));

return response;
return responseStream;
}
public async Task Send(string method, string endpoint, object payload)
private async Task<Stream> SendInternal(string method, string endpoint, IReadOnlyDictionary<string, object> multipartArgs, BucketGroup group, int bucketId, ulong guildId)
{
string requestStream = Serialize(payload);
var stopwatch = Stopwatch.StartNew();
try
{
await _restClient.Send(method, endpoint, requestStream).ConfigureAwait(false);
}
catch (HttpException ex)
{
if (!HandleException(ex))
throw;
return;
}
stopwatch.Stop();

double milliseconds = ToMilliseconds(stopwatch);
SentRequest(this, new SentRequestEventArgs(method, endpoint, 0, milliseconds));
}
public async Task<TResponse> Send<TResponse>(string method, string endpoint, Stream file, IReadOnlyDictionary<string, string> multipartArgs)
where TResponse : class
{
var stopwatch = Stopwatch.StartNew();
var responseStream = await _restClient.Send(method, endpoint).ConfigureAwait(false);
var responseStream = await _requestQueue.Send(new RestRequest(method, endpoint, multipartArgs), group, bucketId, guildId).ConfigureAwait(false);
int bytes = (int)responseStream.Length;
stopwatch.Stop();
var response = Deserialize<TResponse>(responseStream);

double milliseconds = ToMilliseconds(stopwatch);
SentRequest(this, new SentRequestEventArgs(method, endpoint, (int)responseStream.Length, milliseconds));

return response;
}
public async Task Send(string method, string endpoint, Stream file, IReadOnlyDictionary<string, string> multipartArgs)
{
var stopwatch = Stopwatch.StartNew();
await _restClient.Send(method, endpoint).ConfigureAwait(false);
stopwatch.Stop();
SentRequest(this, new SentRequestEventArgs(method, endpoint, bytes, milliseconds));

double milliseconds = ToMilliseconds(stopwatch);
SentRequest(this, new SentRequestEventArgs(method, endpoint, 0, milliseconds));
return responseStream;
}

//Gateway
@@ -623,29 +585,50 @@ namespace Discord.API
else
return result[0];
}
public async Task<Message> CreateMessage(ulong channelId, CreateMessageParams args)
public Task<Message> CreateMessage(ulong channelId, CreateMessageParams args)
=> CreateMessage(0, channelId, args);
public async Task<Message> CreateMessage(ulong guildId, ulong channelId, CreateMessageParams args)
{
if (args == null) throw new ArgumentNullException(nameof(args));
if (channelId == 0) throw new ArgumentOutOfRangeException(nameof(channelId));

return await Send<Message>("POST", $"channels/{channelId}/messages", args).ConfigureAwait(false);
if (guildId != 0)
return await Send<Message>("POST", $"channels/{channelId}/messages", args, GuildBucket.SendEditMessage, guildId).ConfigureAwait(false);
else
return await Send<Message>("POST", $"channels/{channelId}/messages", args, GlobalBucket.DirectMessage).ConfigureAwait(false);
}
public async Task<Message> UploadFile(ulong channelId, Stream file, UploadFileParams args)
public Task<Message> UploadFile(ulong channelId, Stream file, UploadFileParams args)
=> UploadFile(0, channelId, file, args);
public async Task<Message> UploadFile(ulong guildId, ulong channelId, Stream file, UploadFileParams args)
{
if (args == null) throw new ArgumentNullException(nameof(args));
//if (guildId == 0) throw new ArgumentOutOfRangeException(nameof(guildId));
if (channelId == 0) throw new ArgumentOutOfRangeException(nameof(channelId));

return await Send<Message>("POST", $"channels/{channelId}/messages", file, args.ToDictionary()).ConfigureAwait(false);
if (guildId != 0)
return await Send<Message>("POST", $"channels/{channelId}/messages", file, args.ToDictionary(), GuildBucket.SendEditMessage, guildId).ConfigureAwait(false);
else
return await Send<Message>("POST", $"channels/{channelId}/messages", file, args.ToDictionary()).ConfigureAwait(false);
}
public async Task DeleteMessage(ulong channelId, ulong messageId)
public Task DeleteMessage(ulong channelId, ulong messageId)
=> DeleteMessage(0, channelId, messageId);
public async Task DeleteMessage(ulong guildId, ulong channelId, ulong messageId)
{
//if (guildId == 0) throw new ArgumentOutOfRangeException(nameof(guildId));
if (channelId == 0) throw new ArgumentOutOfRangeException(nameof(channelId));
if (messageId == 0) throw new ArgumentOutOfRangeException(nameof(messageId));

await Send("DELETE", $"channels/{channelId}/messages/{messageId}").ConfigureAwait(false);
if (guildId != 0)
await Send("DELETE", $"channels/{channelId}/messages/{messageId}", GuildBucket.DeleteMessage, guildId).ConfigureAwait(false);
else
await Send("DELETE", $"channels/{channelId}/messages/{messageId}").ConfigureAwait(false);
}
public async Task DeleteMessages(ulong channelId, DeleteMessagesParam args)
public Task DeleteMessages(ulong channelId, DeleteMessagesParam args)
=> DeleteMessages(0, channelId, args);
public async Task DeleteMessages(ulong guildId, ulong channelId, DeleteMessagesParam args)
{
//if (guildId == 0) throw new ArgumentOutOfRangeException(nameof(guildId));
if (channelId == 0) throw new ArgumentOutOfRangeException(nameof(channelId));
if (args == null) throw new ArgumentNullException(nameof(args));
if (args.MessageIds == null) throw new ArgumentNullException(nameof(args.MessageIds));

@@ -655,20 +638,29 @@ namespace Discord.API
case 0:
throw new ArgumentOutOfRangeException(nameof(args.MessageIds));
case 1:
await DeleteMessage(channelId, messageIds[0]).ConfigureAwait(false);
await DeleteMessage(guildId, channelId, messageIds[0]).ConfigureAwait(false);
break;
default:
await Send("POST", $"channels/{channelId}/messages/bulk_delete", args).ConfigureAwait(false);
if (guildId != 0)
await Send("POST", $"channels/{channelId}/messages/bulk_delete", args, GuildBucket.DeleteMessages, guildId).ConfigureAwait(false);
else
await Send("POST", $"channels/{channelId}/messages/bulk_delete", args).ConfigureAwait(false);
break;
}
}
public async Task<Message> ModifyMessage(ulong channelId, ulong messageId, ModifyMessageParams args)
public Task<Message> ModifyMessage(ulong channelId, ulong messageId, ModifyMessageParams args)
=> ModifyMessage(0, channelId, messageId, args);
public async Task<Message> ModifyMessage(ulong guildId, ulong channelId, ulong messageId, ModifyMessageParams args)
{
if (args == null) throw new ArgumentNullException(nameof(args));
//if (guildId == 0) throw new ArgumentOutOfRangeException(nameof(guildId));
if (channelId == 0) throw new ArgumentOutOfRangeException(nameof(channelId));
if (messageId == 0) throw new ArgumentOutOfRangeException(nameof(messageId));

return await Send<Message>("PATCH", $"channels/{channelId}/messages/{messageId}", args).ConfigureAwait(false);
if (guildId != 0)
return await Send<Message>("PATCH", $"channels/{channelId}/messages/{messageId}", args, GuildBucket.SendEditMessage, guildId).ConfigureAwait(false);
else
return await Send<Message>("PATCH", $"channels/{channelId}/messages/{messageId}", args).ConfigureAwait(false);
}
public async Task AckMessage(ulong channelId, ulong messageId)
{
@@ -775,11 +767,5 @@ namespace Discord.API
using (JsonReader reader = new JsonTextReader(text))
return _serializer.Deserialize<T>(reader);
}

private bool HandleException(Exception ex)
{
//TODO: Implement... maybe via SentRequest? Need to bubble this up to DiscordClient or a MessageQueue
return false;
}
}
}

+ 8
- 1
src/Discord.Net/Discord.Net.csproj View File

@@ -95,6 +95,13 @@
<Compile Include="API\Rest\ModifyVoiceChannelParams.cs" />
<Compile Include="DiscordConfig.cs" />
<Compile Include="API\DiscordRawClient.cs" />
<Compile Include="Net\RateLimitException.cs" />
<Compile Include="Net\Rest\RequestQueue\BucketGroup.cs" />
<Compile Include="Net\Rest\RequestQueue\GlobalBucket.cs" />
<Compile Include="Net\Rest\RequestQueue\GuildBucket.cs" />
<Compile Include="Net\Rest\RequestQueue\RequestQueue.cs" />
<Compile Include="Net\Rest\RequestQueue\RequestQueueBucket.cs" />
<Compile Include="Net\Rest\RequestQueue\RestRequest.cs" />
<Compile Include="Rest\DiscordClient.cs" />
<Compile Include="Common\Entities\Guilds\IGuildEmbed.cs" />
<Compile Include="Common\Entities\Guilds\IIntegrationAccount.cs" />
@@ -181,7 +188,7 @@
<Compile Include="Net\Converters\UserStatusConverter.cs" />
<Compile Include="Net\HttpException.cs" />
<Compile Include="Net\Rest\DefaultRestClient.cs" />
<Compile Include="Net\Rest\IMessageQueue.cs" />
<Compile Include="Net\Rest\RequestQueue\IRequestQueue.cs" />
<Compile Include="Net\Rest\IRestClient.cs" />
<Compile Include="Net\Rest\MultipartFile.cs" />
<Compile Include="Net\Rest\RestClientProvider.cs" />


+ 1
- 0
src/Discord.Net/Logging/ILogger.cs View File

@@ -8,6 +8,7 @@ namespace Discord.Logging

void Log(LogSeverity severity, string message, Exception exception = null);
void Log(LogSeverity severity, FormattableString message, Exception exception = null);
void Log(LogSeverity severity, Exception exception);

void Error(string message, Exception exception = null);
void Error(FormattableString message, Exception exception = null);


+ 25
- 15
src/Discord.Net/Logging/LogManager.cs View File

@@ -23,6 +23,11 @@ namespace Discord.Logging
if (severity <= Level)
Message(this, new LogMessageEventArgs(severity, source, message.ToString(), ex));
}
public void Log(LogSeverity severity, string source, Exception ex)
{
if (severity <= Level)
Message(this, new LogMessageEventArgs(severity, source, null, ex));
}
void ILogger.Log(LogSeverity severity, string message, Exception ex)
{
if (severity <= Level)
@@ -33,71 +38,76 @@ namespace Discord.Logging
if (severity <= Level)
Message(this, new LogMessageEventArgs(severity, "Discord", message.ToString(), ex));
}
void ILogger.Log(LogSeverity severity, Exception ex)
{
if (severity <= Level)
Message(this, new LogMessageEventArgs(severity, "Discord", null, ex));
}

public void Error(string source, string message, Exception ex = null)
=> Log(LogSeverity.Error, source, message, ex);
public void Error(string source, FormattableString message, Exception ex = null)
=> Log(LogSeverity.Error, source, message, ex);
public void Error(string source, Exception ex = null)
=> Log(LogSeverity.Error, source, (string)null, ex);
public void Error(string source, Exception ex)
=> Log(LogSeverity.Error, source, ex);
void ILogger.Error(string message, Exception ex)
=> Log(LogSeverity.Error, "Discord", message, ex);
void ILogger.Error(FormattableString message, Exception ex)
=> Log(LogSeverity.Error, "Discord", message, ex);
void ILogger.Error(Exception ex)
=> Log(LogSeverity.Error, "Discord", (string)null, ex);
=> Log(LogSeverity.Error, "Discord", ex);

public void Warning(string source, string message, Exception ex = null)
=> Log(LogSeverity.Warning, source, message, ex);
public void Warning(string source, FormattableString message, Exception ex = null)
=> Log(LogSeverity.Warning, source, message, ex);
public void Warning(string source, Exception ex = null)
=> Log(LogSeverity.Warning, source, (string)null, ex);
public void Warning(string source, Exception ex)
=> Log(LogSeverity.Warning, source, ex);
void ILogger.Warning(string message, Exception ex)
=> Log(LogSeverity.Warning, "Discord", message, ex);
void ILogger.Warning(FormattableString message, Exception ex)
=> Log(LogSeverity.Warning, "Discord", message, ex);
void ILogger.Warning(Exception ex)
=> Log(LogSeverity.Warning, "Discord", (string)null, ex);
=> Log(LogSeverity.Warning, "Discord", ex);

public void Info(string source, string message, Exception ex = null)
=> Log(LogSeverity.Info, source, message, ex);
public void Info(string source, FormattableString message, Exception ex = null)
=> Log(LogSeverity.Info, source, message, ex);
public void Info(string source, Exception ex = null)
=> Log(LogSeverity.Info, source, (string)null, ex);
public void Info(string source, Exception ex)
=> Log(LogSeverity.Info, source, ex);
void ILogger.Info(string message, Exception ex)
=> Log(LogSeverity.Info, "Discord", message, ex);
void ILogger.Info(FormattableString message, Exception ex)
=> Log(LogSeverity.Info, "Discord", message, ex);
void ILogger.Info(Exception ex)
=> Log(LogSeverity.Info, "Discord", (string)null, ex);
=> Log(LogSeverity.Info, "Discord", ex);

public void Verbose(string source, string message, Exception ex = null)
=> Log(LogSeverity.Verbose, source, message, ex);
public void Verbose(string source, FormattableString message, Exception ex = null)
=> Log(LogSeverity.Verbose, source, message, ex);
public void Verbose(string source, Exception ex = null)
=> Log(LogSeverity.Verbose, source, (string)null, ex);
public void Verbose(string source, Exception ex)
=> Log(LogSeverity.Verbose, source, ex);
void ILogger.Verbose(string message, Exception ex)
=> Log(LogSeverity.Verbose, "Discord", message, ex);
void ILogger.Verbose(FormattableString message, Exception ex)
=> Log(LogSeverity.Verbose, "Discord", message, ex);
void ILogger.Verbose(Exception ex)
=> Log(LogSeverity.Verbose, "Discord", (string)null, ex);
=> Log(LogSeverity.Verbose, "Discord", ex);

public void Debug(string source, string message, Exception ex = null)
=> Log(LogSeverity.Debug, source, message, ex);
public void Debug(string source, FormattableString message, Exception ex = null)
=> Log(LogSeverity.Debug, source, message, ex);
public void Debug(string source, Exception ex = null)
=> Log(LogSeverity.Debug, source, (string)null, ex);
public void Debug(string source, Exception ex)
=> Log(LogSeverity.Debug, source, ex);
void ILogger.Debug(string message, Exception ex)
=> Log(LogSeverity.Debug, "Discord", message, ex);
void ILogger.Debug(FormattableString message, Exception ex)
=> Log(LogSeverity.Debug, "Discord", message, ex);
void ILogger.Debug(Exception ex)
=> Log(LogSeverity.Debug, "Discord", (string)null, ex);
=> Log(LogSeverity.Debug, "Discord", ex);

internal Logger CreateLogger(string name) => new Logger(this, name);
}


+ 15
- 0
src/Discord.Net/Net/RateLimitException.cs View File

@@ -0,0 +1,15 @@
using System.Net;

namespace Discord.Net
{
public class HttpRateLimitException : HttpException
{
public int RetryAfterMilliseconds { get; }

public HttpRateLimitException(int retryAfterMilliseconds)
: base((HttpStatusCode)429)
{
RetryAfterMilliseconds = retryAfterMilliseconds;
}
}
}

+ 5
- 0
src/Discord.Net/Net/Rest/DefaultRestClient.cs View File

@@ -3,6 +3,7 @@ using System;
using System.Collections.Generic;
using System.Globalization;
using System.IO;
using System.Linq;
using System.Net;
using System.Net.Http;
using System.Text;
@@ -118,7 +119,11 @@ namespace Discord.Net.Rest

int statusCode = (int)response.StatusCode;
if (statusCode < 200 || statusCode >= 300) //2xx = Success
{
if (statusCode == 429)
throw new HttpRateLimitException(int.Parse(response.Headers.GetValues("retry-after").First()));
throw new HttpException(response.StatusCode);
}

return await response.Content.ReadAsStreamAsync().ConfigureAwait(false);
}


+ 0
- 7
src/Discord.Net/Net/Rest/IMessageQueue.cs View File

@@ -1,7 +0,0 @@
namespace Discord.Net.Rest
{
public interface IMessageQueue
{
int Count { get; }
}
}

+ 8
- 0
src/Discord.Net/Net/Rest/RequestQueue/BucketGroup.cs View File

@@ -0,0 +1,8 @@
namespace Discord.Net.Rest
{
internal enum BucketGroup
{
Global,
Guild
}
}

+ 8
- 0
src/Discord.Net/Net/Rest/RequestQueue/GlobalBucket.cs View File

@@ -0,0 +1,8 @@
namespace Discord.Net.Rest
{
public enum GlobalBucket
{
General,
DirectMessage
}
}

+ 10
- 0
src/Discord.Net/Net/Rest/RequestQueue/GuildBucket.cs View File

@@ -0,0 +1,10 @@
namespace Discord.Net.Rest
{
public enum GuildBucket
{
SendEditMessage,
DeleteMessage,
DeleteMessages,
Nickname
}
}

+ 10
- 0
src/Discord.Net/Net/Rest/RequestQueue/IRequestQueue.cs View File

@@ -0,0 +1,10 @@
using System.Threading.Tasks;

namespace Discord.Net.Rest
{
public interface IRequestQueue
{
Task Clear(GlobalBucket type);
Task Clear(GuildBucket type, ulong guildId);
}
}

+ 163
- 0
src/Discord.Net/Net/Rest/RequestQueue/RequestQueue.cs View File

@@ -0,0 +1,163 @@
using System;
using System.Collections.Generic;
using System.IO;
using System.Threading;
using System.Threading.Tasks;

namespace Discord.Net.Rest
{
public class RequestQueue : IRequestQueue
{
private SemaphoreSlim _lock;
private RequestQueueBucket[] _globalBuckets;
private Dictionary<ulong, RequestQueueBucket>[] _guildBuckets;

public IRestClient RestClient { get; }

public RequestQueue(IRestClient restClient)
{
RestClient = restClient;

_lock = new SemaphoreSlim(1, 1);
_globalBuckets = new RequestQueueBucket[Enum.GetValues(typeof(GlobalBucket)).Length];
_guildBuckets = new Dictionary<ulong, RequestQueueBucket>[Enum.GetValues(typeof(GuildBucket)).Length];
}
internal async Task<Stream> Send(RestRequest request, BucketGroup group, int bucketId, ulong guildId)
{
RequestQueueBucket bucket;

await Lock().ConfigureAwait(false);
try
{
bucket = GetBucket(group, bucketId, guildId);
bucket.Queue(request);
}
finally { Unlock(); }

//There is a chance the bucket will send this request on its own, but this will simply become a noop then.
var _ = bucket.ProcessQueue(acquireLock: true).ConfigureAwait(false);

return await request.Promise.Task.ConfigureAwait(false);
}

private RequestQueueBucket CreateBucket(GlobalBucket bucket)
{
switch (bucket)
{
//Globals
case GlobalBucket.General: return new RequestQueueBucket(this, bucket, int.MaxValue, 0); //Catch-all
case GlobalBucket.DirectMessage: return new RequestQueueBucket(this, bucket, 5, 5);

default: throw new ArgumentException($"Unknown global bucket: {bucket}", nameof(bucket));
}
}
private RequestQueueBucket CreateBucket(GuildBucket bucket, ulong guildId)
{
switch (bucket)
{
//Per Guild
case GuildBucket.SendEditMessage: return new RequestQueueBucket(this, bucket, guildId, 5, 5);
case GuildBucket.DeleteMessage: return new RequestQueueBucket(this, bucket, guildId, 5, 1);
case GuildBucket.DeleteMessages: return new RequestQueueBucket(this, bucket, guildId, 1, 1);
case GuildBucket.Nickname: return new RequestQueueBucket(this, bucket, guildId, 1, 1);

default: throw new ArgumentException($"Unknown guild bucket: {bucket}", nameof(bucket));
}
}

private RequestQueueBucket GetBucket(BucketGroup group, int bucketId, ulong guildId)
{
switch (group)
{
case BucketGroup.Global:
return GetGlobalBucket((GlobalBucket)bucketId);
case BucketGroup.Guild:
return GetGuildBucket((GuildBucket)bucketId, guildId);
default:
throw new ArgumentException($"Unknown bucket group: {group}", nameof(group));
}
}
private RequestQueueBucket GetGlobalBucket(GlobalBucket type)
{
var bucket = _globalBuckets[(int)type];
if (bucket == null)
{
bucket = CreateBucket(type);
_globalBuckets[(int)type] = bucket;
}
return bucket;
}
private RequestQueueBucket GetGuildBucket(GuildBucket type, ulong guildId)
{
var bucketGroup = _guildBuckets[(int)type];
if (bucketGroup == null)
{
bucketGroup = new Dictionary<ulong, RequestQueueBucket>();
_guildBuckets[(int)type] = bucketGroup;
}
RequestQueueBucket bucket;
if (!bucketGroup.TryGetValue(guildId, out bucket))
{
bucket = CreateBucket(type, guildId);
bucketGroup[guildId] = bucket;
}
return bucket;
}

internal void DestroyGlobalBucket(GlobalBucket type)
{
//Assume this object is locked

_globalBuckets[(int)type] = null;
}
internal void DestroyGuildBucket(GuildBucket type, ulong guildId)
{
//Assume this object is locked

var bucketGroup = _guildBuckets[(int)type];
if (bucketGroup != null)
bucketGroup.Remove(guildId);
}

public async Task Lock()
{
await _lock.WaitAsync();
}
public void Unlock()
{
_lock.Release();
}

public async Task Clear(GlobalBucket type)
{
var bucket = _globalBuckets[(int)type];
if (bucket != null)
{
try
{
await bucket.Lock();
bucket.Clear();
}
finally { bucket.Unlock(); }
}
}
public async Task Clear(GuildBucket type, ulong guildId)
{
var bucketGroup = _guildBuckets[(int)type];
if (bucketGroup != null)
{
RequestQueueBucket bucket;
if (bucketGroup.TryGetValue(guildId, out bucket))
{
try
{
await bucket.Lock();
bucket.Clear();
}
finally { bucket.Unlock(); }
}
}
}
}
}

+ 223
- 0
src/Discord.Net/Net/Rest/RequestQueue/RequestQueueBucket.cs View File

@@ -0,0 +1,223 @@
using System;
using System.Collections.Concurrent;
using System.IO;
using System.Net;
using System.Threading;
using System.Threading.Tasks;

namespace Discord.Net.Rest
{
internal class RequestQueueBucket
{
private readonly RequestQueue _parent;
private readonly BucketGroup _bucketGroup;
private readonly int _bucketId;
private readonly ulong _guildId;
private readonly ConcurrentQueue<RestRequest> _queue;
private readonly SemaphoreSlim _lock;
private Task _resetTask;
private DateTime? _retryAfter;
private bool _waitingToProcess;

public int WindowMaxCount { get; }
public int WindowSeconds { get; }
public int WindowCount { get; private set; }

public RequestQueueBucket(RequestQueue parent, GlobalBucket bucket, int windowMaxCount, int windowSeconds)
: this(parent, windowMaxCount, windowSeconds)
{
_bucketGroup = BucketGroup.Global;
_bucketId = (int)bucket;
_guildId = 0;
}
public RequestQueueBucket(RequestQueue parent, GuildBucket bucket, ulong guildId, int windowMaxCount, int windowSeconds)
: this(parent, windowMaxCount, windowSeconds)
{
_bucketGroup = BucketGroup.Guild;
_bucketId = (int)bucket;
_guildId = guildId;
}
private RequestQueueBucket(RequestQueue parent, int windowMaxCount, int windowSeconds)
{
_parent = parent;
WindowMaxCount = windowMaxCount;
WindowSeconds = windowSeconds;
_queue = new ConcurrentQueue<RestRequest>();
_lock = new SemaphoreSlim(1, 1);
}

public void Queue(RestRequest request)
{
//Assume this obj's parent is under lock

_queue.Enqueue(request);
Debug($"Request queued ({WindowCount}/{WindowMaxCount} + {_queue.Count})");
}
public async Task ProcessQueue(bool acquireLock = false)
{
//Assume this obj is under lock

int nextRetry = 1000;

//If we have another ProcessQueue waiting to run, dont bother with this one
if (_waitingToProcess) return;
_waitingToProcess = true;

if (acquireLock)
await Lock().ConfigureAwait(false);
try
{
_waitingToProcess = false;
while (true)
{
RestRequest request;

//If we're waiting to reset (due to a rate limit exception, or preemptive check), abort
if (WindowCount == WindowMaxCount) return;
//Get next request, return if queue is empty
if (!_queue.TryPeek(out request)) return;

try
{
Stream stream;
if (request.IsMultipart)
stream = await _parent.RestClient.Send(request.Method, request.Endpoint, request.MultipartParams).ConfigureAwait(false);
else
stream = await _parent.RestClient.Send(request.Method, request.Endpoint, request.Json).ConfigureAwait(false);
request.Promise.SetResult(stream);
}
catch (HttpRateLimitException ex) //Preemptive check failed, use Discord's time instead of our own
{
if (_resetTask == null)
{
//No reset has been queued yet, lets create one as if this *was* preemptive
_resetTask = ResetAfter(ex.RetryAfterMilliseconds);
Debug($"External rate limit: Reset in {ex.RetryAfterMilliseconds} ms");
}
else
{
//A preemptive reset is already queued, set RetryAfter to extend it
_retryAfter = DateTime.UtcNow.AddMilliseconds(ex.RetryAfterMilliseconds);
Debug($"External rate limit: Extended to {ex.RetryAfterMilliseconds} ms");
}
return;
}
catch (HttpException ex)
{
if (ex.StatusCode == HttpStatusCode.BadGateway) //Gateway unavailable, retry
{
await Task.Delay(nextRetry).ConfigureAwait(false);
nextRetry *= 2;
if (nextRetry > 30000)
nextRetry = 30000;
continue;
}
else
{
//We dont need to throw this here, pass the exception via the promise
request.Promise.SetException(ex);
}
}

//Request completed or had an error other than 429
_queue.TryDequeue(out request);
WindowCount++;
nextRetry = 1000;
Debug($"Request succeeded ({WindowCount}/{WindowMaxCount} + {_queue.Count})");

if (WindowCount == 1 && WindowSeconds > 0)
{
//First request for this window, schedule a reset
_resetTask = ResetAfter(WindowSeconds * 1000);
Debug($"Internal rate limit: Reset in {WindowSeconds * 1000} ms");
}
}
}
finally
{
if (acquireLock)
Unlock();
}
}
public void Clear()
{
//Assume this obj is under lock
RestRequest request;

while (_queue.TryDequeue(out request)) { }
}

private async Task ResetAfter(int milliseconds)
{
if (milliseconds > 0)
await Task.Delay(milliseconds).ConfigureAwait(false);
try
{
await Lock().ConfigureAwait(false);

//If an extension has been planned, start a new wait task
if (_retryAfter != null)
{
_resetTask = ResetAfter((int)(_retryAfter.Value - DateTime.UtcNow).TotalMilliseconds);
_retryAfter = null;
return;
}

Debug($"Reset");
//Reset the current window count and set our state back to normal
WindowCount = 0;
_resetTask = null;

//Wait is over, work through the current queue
await ProcessQueue().ConfigureAwait(false);
//If queue is empty and non-global, remove this bucket
if (_bucketGroup == BucketGroup.Guild && _queue.IsEmpty)
{
try
{
await _parent.Lock().ConfigureAwait(false);
if (_queue.IsEmpty) //Double check, in case a request was queued before we got both locks
_parent.DestroyGuildBucket((GuildBucket)_bucketId, _guildId);
}
finally
{
_parent.Unlock();
}
}
}
finally
{
Unlock();
}
}

public async Task Lock()
{
await _lock.WaitAsync();
}
public void Unlock()
{
_lock.Release();
}

//TODO: Remove
private void Debug(string text)
{
string name;
switch (_bucketGroup)
{
case BucketGroup.Global:
name = ((GlobalBucket)_bucketId).ToString();
break;
case BucketGroup.Guild:
name = ((GuildBucket)_bucketId).ToString();
break;
default:
name = "Unknown";
break;
}
System.Diagnostics.Debug.WriteLine($"[{name}] {text}");
}
}
}

+ 38
- 0
src/Discord.Net/Net/Rest/RequestQueue/RestRequest.cs View File

@@ -0,0 +1,38 @@
using System.Collections.Generic;
using System.IO;
using System.Threading.Tasks;

namespace Discord.Net.Rest
{
internal struct RestRequest
{
public string Method { get; }
public string Endpoint { get; }
public string Json { get; }
public IReadOnlyDictionary<string, object> MultipartParams { get; }
public TaskCompletionSource<Stream> Promise { get; }

public bool IsMultipart => MultipartParams != null;

public RestRequest(string method, string endpoint, string json)
: this(method, endpoint)
{
Json = json;
}

public RestRequest(string method, string endpoint, IReadOnlyDictionary<string, object> multipartParams)
: this(method, endpoint)
{
MultipartParams = multipartParams;
}

private RestRequest(string method, string endpoint)
{
Method = method;
Endpoint = endpoint;
Json = null;
MultipartParams = null;
Promise = new TaskCompletionSource<Stream>();
}
}
}

+ 1
- 1
src/Discord.Net/Rest/DiscordClient.cs View File

@@ -60,7 +60,7 @@ namespace Discord.Rest
var cancelTokenSource = new CancellationTokenSource();
BaseClient = new API.DiscordRawClient(_restClientProvider, cancelTokenSource.Token, tokenType, token);
BaseClient.SentRequest += (s, e) => _log.Verbose($"{e.Method} {e.Endpoint}: {e.Milliseconds} ms");
BaseClient.SentRequest += (s, e) => _log.Verbose("Rest", $"{e.Method} {e.Endpoint}: {e.Milliseconds} ms");

//MessageQueue = new MessageQueue(RestClient, _restLogger);
//await MessageQueue.Start(_cancelTokenSource.Token).ConfigureAwait(false);


+ 4
- 4
src/Discord.Net/Rest/Entities/Channels/TextChannel.cs View File

@@ -64,7 +64,7 @@ namespace Discord.Rest
public async Task<Message> SendMessage(string text, bool isTTS = false)
{
var args = new CreateMessageParams { Content = text, IsTTS = isTTS };
var model = await Discord.BaseClient.CreateMessage(Id, args).ConfigureAwait(false);
var model = await Discord.BaseClient.CreateMessage(Guild.Id, Id, args).ConfigureAwait(false);
return new Message(this, model);
}
/// <inheritdoc />
@@ -74,7 +74,7 @@ namespace Discord.Rest
using (var file = File.OpenRead(filePath))
{
var args = new UploadFileParams { Filename = filename, Content = text, IsTTS = isTTS };
var model = await Discord.BaseClient.UploadFile(Id, file, args).ConfigureAwait(false);
var model = await Discord.BaseClient.UploadFile(Guild.Id, Id, file, args).ConfigureAwait(false);
return new Message(this, model);
}
}
@@ -82,14 +82,14 @@ namespace Discord.Rest
public async Task<Message> SendFile(Stream stream, string filename, string text = null, bool isTTS = false)
{
var args = new UploadFileParams { Filename = filename, Content = text, IsTTS = isTTS };
var model = await Discord.BaseClient.UploadFile(Id, stream, args).ConfigureAwait(false);
var model = await Discord.BaseClient.UploadFile(Guild.Id, Id, stream, args).ConfigureAwait(false);
return new Message(this, model);
}

/// <inheritdoc />
public async Task DeleteMessages(IEnumerable<IMessage> messages)
{
await Discord.BaseClient.DeleteMessages(Id, new DeleteMessagesParam { MessageIds = messages.Select(x => x.Id) }).ConfigureAwait(false);
await Discord.BaseClient.DeleteMessages(Guild.Id, Id, new DeleteMessagesParam { MessageIds = messages.Select(x => x.Id) }).ConfigureAwait(false);
}

/// <inheritdoc />


+ 7
- 1
src/Discord.Net/Rest/Entities/Message.cs View File

@@ -121,7 +121,13 @@ namespace Discord.Rest

var args = new ModifyMessageParams();
func(args);
var model = await Discord.BaseClient.ModifyMessage(Channel.Id, Id, args).ConfigureAwait(false);
var guildChannel = Channel as GuildChannel;

Model model;
if (guildChannel != null)
model = await Discord.BaseClient.ModifyMessage(guildChannel.Guild.Id, Channel.Id, Id, args).ConfigureAwait(false);
else
model = await Discord.BaseClient.ModifyMessage(Channel.Id, Id, args).ConfigureAwait(false);
Update(model);
}



Loading…
Cancel
Save