Browse Source

Implement gateway ratelimit

pull/1537/head
Paulo 5 years ago
parent
commit
bcb8b53849
6 changed files with 137 additions and 14 deletions
  1. +1
    -0
      src/Discord.Net.Core/RequestOptions.cs
  2. +50
    -0
      src/Discord.Net.Rest/Net/Queue/GatewayBucket.cs
  3. +16
    -5
      src/Discord.Net.Rest/Net/Queue/RequestQueue.cs
  4. +65
    -5
      src/Discord.Net.Rest/Net/Queue/RequestQueueBucket.cs
  5. +1
    -3
      src/Discord.Net.Rest/Net/Queue/Requests/WebSocketRequest.cs
  6. +4
    -1
      src/Discord.Net.WebSocket/DiscordSocketApiClient.cs

+ 1
- 0
src/Discord.Net.Core/RequestOptions.cs View File

@@ -60,6 +60,7 @@ namespace Discord
internal string BucketId { get; set; } internal string BucketId { get; set; }
internal bool IsClientBucket { get; set; } internal bool IsClientBucket { get; set; }
internal bool IsReactionBucket { get; set; } internal bool IsReactionBucket { get; set; }
internal bool IsGatewayBucket { get; set; }


internal static RequestOptions CreateOrClone(RequestOptions options) internal static RequestOptions CreateOrClone(RequestOptions options)
{ {


+ 50
- 0
src/Discord.Net.Rest/Net/Queue/GatewayBucket.cs View File

@@ -0,0 +1,50 @@
using System.Collections.Immutable;

namespace Discord.Net.Queue
{
public enum GatewayBucketType
{
Unbucketed = 0,
Identify = 1
}
internal struct GatewayBucket
{
private static readonly ImmutableDictionary<GatewayBucketType, GatewayBucket> DefsByType;
private static readonly ImmutableDictionary<string, GatewayBucket> DefsById;

static GatewayBucket()
{
var buckets = new[]
{
new GatewayBucket(GatewayBucketType.Unbucketed, "<unbucketed>", 120, 60),
new GatewayBucket(GatewayBucketType.Identify, "<identify>", 1, 5)
};

var builder = ImmutableDictionary.CreateBuilder<GatewayBucketType, GatewayBucket>();
foreach (var bucket in buckets)
builder.Add(bucket.Type, bucket);
DefsByType = builder.ToImmutable();

var builder2 = ImmutableDictionary.CreateBuilder<string, GatewayBucket>();
foreach (var bucket in buckets)
builder2.Add(bucket.Id, bucket);
DefsById = builder2.ToImmutable();
}

public static GatewayBucket Get(GatewayBucketType type) => DefsByType[type];
public static GatewayBucket Get(string id) => DefsById[id];

public GatewayBucketType Type { get; }
public string Id { get; }
public int WindowCount { get; }
public int WindowSeconds { get; }

public GatewayBucket(GatewayBucketType type, string id, int count, int seconds)
{
Type = type;
Id = id;
WindowCount = count;
WindowSeconds = seconds;
}
}
}

+ 16
- 5
src/Discord.Net.Rest/Net/Queue/RequestQueue.cs View File

@@ -89,12 +89,23 @@ namespace Discord.Net.Queue
} }
public async Task SendAsync(WebSocketRequest request) public async Task SendAsync(WebSocketRequest request)
{ {
//TODO: Re-impl websocket buckets
request.CancelToken = _requestCancelToken;
await request.SendAsync().ConfigureAwait(false);
CancellationTokenSource createdTokenSource = null;
if (request.Options.CancelToken.CanBeCanceled)
{
createdTokenSource = CancellationTokenSource.CreateLinkedTokenSource(_requestCancelToken, request.Options.CancelToken);
request.Options.CancelToken = createdTokenSource.Token;
}
else
request.Options.CancelToken = _requestCancelToken;

var bucket = GetOrCreateBucket(request.Options.BucketId, request);
await bucket.SendAsync(request).ConfigureAwait(false);
createdTokenSource?.Dispose();
//request.CancelToken = _requestCancelToken;
//await request.SendAsync().ConfigureAwait(false);
} }


internal async Task EnterGlobalAsync(int id, RestRequest request)
internal async Task EnterGlobalAsync(int id, IRequest request)
{ {
int millis = (int)Math.Ceiling((_waitUntil - DateTimeOffset.UtcNow).TotalMilliseconds); int millis = (int)Math.Ceiling((_waitUntil - DateTimeOffset.UtcNow).TotalMilliseconds);
if (millis > 0) if (millis > 0)
@@ -110,7 +121,7 @@ namespace Discord.Net.Queue
_waitUntil = DateTimeOffset.UtcNow.AddMilliseconds(info.RetryAfter.Value + (info.Lag?.TotalMilliseconds ?? 0.0)); _waitUntil = DateTimeOffset.UtcNow.AddMilliseconds(info.RetryAfter.Value + (info.Lag?.TotalMilliseconds ?? 0.0));
} }


private RequestBucket GetOrCreateBucket(string id, RestRequest request)
private RequestBucket GetOrCreateBucket(string id, IRequest request)
{ {
return _buckets.GetOrAdd(id, x => new RequestBucket(this, request, x)); return _buckets.GetOrAdd(id, x => new RequestBucket(this, request, x));
} }


+ 65
- 5
src/Discord.Net.Rest/Net/Queue/RequestQueueBucket.cs View File

@@ -22,7 +22,7 @@ namespace Discord.Net.Queue
public int WindowCount { get; private set; } public int WindowCount { get; private set; }
public DateTimeOffset LastAttemptAt { get; private set; } public DateTimeOffset LastAttemptAt { get; private set; }


public RequestBucket(RequestQueue queue, RestRequest request, string id)
public RequestBucket(RequestQueue queue, IRequest request, string id)
{ {
_queue = queue; _queue = queue;
Id = id; Id = id;
@@ -31,13 +31,15 @@ namespace Discord.Net.Queue


if (request.Options.IsClientBucket) if (request.Options.IsClientBucket)
WindowCount = ClientBucket.Get(request.Options.BucketId).WindowCount; WindowCount = ClientBucket.Get(request.Options.BucketId).WindowCount;
else if (request.Options.IsGatewayBucket)
WindowCount = GatewayBucket.Get(request.Options.BucketId).WindowCount;
else else
WindowCount = 1; //Only allow one request until we get a header back WindowCount = 1; //Only allow one request until we get a header back
_semaphore = WindowCount; _semaphore = WindowCount;
_resetTick = null; _resetTick = null;
LastAttemptAt = DateTimeOffset.UtcNow; LastAttemptAt = DateTimeOffset.UtcNow;
} }
static int nextId = 0; static int nextId = 0;
public async Task<Stream> SendAsync(RestRequest request) public async Task<Stream> SendAsync(RestRequest request)
{ {
@@ -149,8 +151,59 @@ namespace Discord.Net.Queue
} }
} }
} }
public async Task SendAsync(WebSocketRequest request)
{
int id = Interlocked.Increment(ref nextId);
#if DEBUG_LIMITS
Debug.WriteLine($"[{id}] Start");
#endif
LastAttemptAt = DateTimeOffset.UtcNow;
while (true)
{
await _queue.EnterGlobalAsync(id, request).ConfigureAwait(false);
await EnterAsync(id, request).ConfigureAwait(false);


private async Task EnterAsync(int id, RestRequest request)
#if DEBUG_LIMITS
Debug.WriteLine($"[{id}] Sending...");
#endif
try
{
await request.SendAsync().ConfigureAwait(false);
return;
}
catch (TimeoutException)
{
#if DEBUG_LIMITS
Debug.WriteLine($"[{id}] Timeout");
#endif
if ((request.Options.RetryMode & RetryMode.RetryTimeouts) == 0)
throw;

await Task.Delay(500).ConfigureAwait(false);
continue; //Retry
}
/*catch (Exception)
{
#if DEBUG_LIMITS
Debug.WriteLine($"[{id}] Error");
#endif
if ((request.Options.RetryMode & RetryMode.RetryErrors) == 0)
throw;

await Task.Delay(500);
continue; //Retry
}*/
finally
{
UpdateRateLimit(id, request, default(RateLimitInfo), false);
#if DEBUG_LIMITS
Debug.WriteLine($"[{id}] Stop");
#endif
}
}
}

private async Task EnterAsync(int id, IRequest request)
{ {
int windowCount; int windowCount;
DateTimeOffset? resetAt; DateTimeOffset? resetAt;
@@ -213,7 +266,7 @@ namespace Discord.Net.Queue
} }
} }


private void UpdateRateLimit(int id, RestRequest request, RateLimitInfo info, bool is429)
private void UpdateRateLimit(int id, IRequest request, RateLimitInfo info, bool is429)
{ {
if (WindowCount == 0) if (WindowCount == 0)
return; return;
@@ -273,6 +326,13 @@ namespace Discord.Net.Queue
Debug.WriteLine($"[{id}] Client Bucket ({ClientBucket.Get(request.Options.BucketId).WindowSeconds * 1000} ms)"); Debug.WriteLine($"[{id}] Client Bucket ({ClientBucket.Get(request.Options.BucketId).WindowSeconds * 1000} ms)");
#endif #endif
} }
else if (request.Options.IsGatewayBucket && request.Options.BucketId != null)
{
resetTick = DateTimeOffset.UtcNow.AddSeconds(GatewayBucket.Get(request.Options.BucketId).WindowSeconds);
#if DEBUG_LIMITS
Debug.WriteLine($"[{id}] Gateway Bucket ({GatewayBucket.Get(request.Options.BucketId).WindowSeconds * 1000} ms)");
#endif
}


if (resetTick == null) if (resetTick == null)
{ {
@@ -320,7 +380,7 @@ namespace Discord.Net.Queue
} }
} }


private void ThrowRetryLimit(RestRequest request)
private void ThrowRetryLimit(IRequest request)
{ {
if ((request.Options.RetryMode & RetryMode.RetryRatelimit) == 0) if ((request.Options.RetryMode & RetryMode.RetryRatelimit) == 0)
throw new RateLimitedException(request); throw new RateLimitedException(request);


+ 1
- 3
src/Discord.Net.Rest/Net/Queue/Requests/WebSocketRequest.cs View File

@@ -9,7 +9,6 @@ namespace Discord.Net.Queue
public class WebSocketRequest : IRequest public class WebSocketRequest : IRequest
{ {
public IWebSocketClient Client { get; } public IWebSocketClient Client { get; }
public string BucketId { get; }
public byte[] Data { get; } public byte[] Data { get; }
public bool IsText { get; } public bool IsText { get; }
public DateTimeOffset? TimeoutAt { get; } public DateTimeOffset? TimeoutAt { get; }
@@ -17,12 +16,11 @@ namespace Discord.Net.Queue
public RequestOptions Options { get; } public RequestOptions Options { get; }
public CancellationToken CancelToken { get; internal set; } public CancellationToken CancelToken { get; internal set; }


public WebSocketRequest(IWebSocketClient client, string bucketId, byte[] data, bool isText, RequestOptions options)
public WebSocketRequest(IWebSocketClient client, byte[] data, bool isText, RequestOptions options)
{ {
Preconditions.NotNull(options, nameof(options)); Preconditions.NotNull(options, nameof(options));


Client = client; Client = client;
BucketId = bucketId;
Data = data; Data = data;
IsText = isText; IsText = isText;
Options = options; Options = options;


+ 4
- 1
src/Discord.Net.WebSocket/DiscordSocketApiClient.cs View File

@@ -205,7 +205,10 @@ namespace Discord.API
payload = new SocketFrame { Operation = (int)opCode, Payload = payload }; payload = new SocketFrame { Operation = (int)opCode, Payload = payload };
if (payload != null) if (payload != null)
bytes = Encoding.UTF8.GetBytes(SerializeJson(payload)); bytes = Encoding.UTF8.GetBytes(SerializeJson(payload));
await RequestQueue.SendAsync(new WebSocketRequest(WebSocketClient, null, bytes, true, options)).ConfigureAwait(false);

options.IsGatewayBucket = true;
options.BucketId = GatewayBucket.Get(opCode == GatewayOpCode.Identify ? GatewayBucketType.Identify : GatewayBucketType.Unbucketed).Id;
await RequestQueue.SendAsync(new WebSocketRequest(WebSocketClient, bytes, true, options)).ConfigureAwait(false);
await _sentGatewayMessageEvent.InvokeAsync(opCode).ConfigureAwait(false); await _sentGatewayMessageEvent.InvokeAsync(opCode).ConfigureAwait(false);
} }




Loading…
Cancel
Save