From 606d4e7ba95568b14ce8e68f2c229f8ff3df5c1c Mon Sep 17 00:00:00 2001 From: RogueException Date: Wed, 12 Oct 2016 18:12:45 -0300 Subject: [PATCH] Cleaned up net code, readded user token rate limits --- .../API/DiscordRestApiClient.cs | 109 ++++++++++++------ .../Net/Queue/ClientBucket.cs | 6 +- .../Net/Queue/RequestQueue.cs | 8 +- .../Net/Queue/RequestQueueBucket.cs | 42 ++++--- .../Net/Queue/Requests/IRequest.cs | 12 -- .../Net/Queue/Requests/JsonRestRequest.cs | 6 +- .../Queue/Requests/MultipartRestRequest.cs | 6 +- .../Net/Queue/Requests/RestRequest.cs | 9 +- .../Net/Queue/Requests/WebSocketRequest.cs | 2 +- .../Net/Rest/DefaultRestClient.cs | 18 +-- src/Discord.Net.Core/Net/Rest/IRestClient.cs | 7 +- src/Discord.Net.Core/RequestOptions.cs | 4 +- 12 files changed, 124 insertions(+), 105 deletions(-) delete mode 100644 src/Discord.Net.Core/Net/Queue/Requests/IRequest.cs diff --git a/src/Discord.Net.Core/API/DiscordRestApiClient.cs b/src/Discord.Net.Core/API/DiscordRestApiClient.cs index 93dbabef4..6611f233a 100644 --- a/src/Discord.Net.Core/API/DiscordRestApiClient.cs +++ b/src/Discord.Net.Core/API/DiscordRestApiClient.cs @@ -117,7 +117,6 @@ namespace Discord.API _restClient.SetCancelToken(_loginCancelToken.Token); AuthTokenType = tokenType; - RequestQueue.TokenType = tokenType; _authToken = token; _restClient.SetHeader("authorization", GetPrefixedToken(AuthTokenType, _authToken)); @@ -165,61 +164,95 @@ namespace Discord.API internal virtual Task DisconnectInternalAsync() => Task.CompletedTask; //Core - public async Task SendAsync(string method, string endpoint, string bucketId, RequestOptions options) + internal Task SendAsync(string method, Expression> endpointExpr, BucketIds ids, + string clientBucketId = null, RequestOptions options = null, [CallerMemberName] string funcName = null) + => SendAsync(method, GetEndpoint(endpointExpr), GetBucketId(ids, endpointExpr, AuthTokenType, funcName), clientBucketId, options); + public async Task SendAsync(string method, string endpoint, + string bucketId = null, string clientBucketId = null, RequestOptions options = null) { + options = options ?? new RequestOptions(); options.HeaderOnly = true; - var request = new RestRequest(_restClient, method, endpoint, bucketId, options); + options.BucketId = bucketId; + options.ClientBucketId = AuthTokenType == TokenType.User ? clientBucketId : null; + + var request = new RestRequest(_restClient, method, endpoint, options); await SendInternalAsync(method, endpoint, request).ConfigureAwait(false); } - public async Task SendJsonAsync(string method, string endpoint, string bucketId, object payload, RequestOptions options) + + internal Task SendJsonAsync(string method, Expression> endpointExpr, object payload, BucketIds ids, + string clientBucketId = null, RequestOptions options = null, [CallerMemberName] string funcName = null) + => SendJsonAsync(method, GetEndpoint(endpointExpr), payload, GetBucketId(ids, endpointExpr, AuthTokenType, funcName), clientBucketId, options); + public async Task SendJsonAsync(string method, string endpoint, object payload, + string bucketId = null, string clientBucketId = null, RequestOptions options = null) { + options = options ?? new RequestOptions(); options.HeaderOnly = true; + options.BucketId = bucketId; + options.ClientBucketId = AuthTokenType == TokenType.User ? clientBucketId : null; + var json = payload != null ? SerializeJson(payload) : null; - var request = new JsonRestRequest(_restClient, method, endpoint, bucketId, json, options); + var request = new JsonRestRequest(_restClient, method, endpoint, json, options); await SendInternalAsync(method, endpoint, request).ConfigureAwait(false); } - public async Task SendMultipartAsync(string method, string endpoint, string bucketId, IReadOnlyDictionary multipartArgs, RequestOptions options) + + internal Task SendMultipartAsync(string method, Expression> endpointExpr, IReadOnlyDictionary multipartArgs, BucketIds ids, + string clientBucketId = null, RequestOptions options = null, [CallerMemberName] string funcName = null) + => SendMultipartAsync(method, GetEndpoint(endpointExpr), multipartArgs, GetBucketId(ids, endpointExpr, AuthTokenType, funcName), clientBucketId, options); + public async Task SendMultipartAsync(string method, string endpoint, IReadOnlyDictionary multipartArgs, + string bucketId = null, string clientBucketId = null, RequestOptions options = null) { + options = options ?? new RequestOptions(); options.HeaderOnly = true; - var request = new MultipartRestRequest(_restClient, method, endpoint, bucketId, multipartArgs, options); + options.BucketId = bucketId; + options.ClientBucketId = AuthTokenType == TokenType.User ? clientBucketId : null; + + var request = new MultipartRestRequest(_restClient, method, endpoint, multipartArgs, options); await SendInternalAsync(method, endpoint, request).ConfigureAwait(false); } - public async Task SendAsync(string method, string endpoint, string bucketId, RequestOptions options) where TResponse : class + + internal Task SendAsync(string method, Expression> endpointExpr, BucketIds ids, + string clientBucketId = null, RequestOptions options = null, [CallerMemberName] string funcName = null) where TResponse : class + => SendAsync(method, GetEndpoint(endpointExpr), GetBucketId(ids, endpointExpr, AuthTokenType, funcName), clientBucketId, options); + public async Task SendAsync(string method, string endpoint, + string bucketId = null, string clientBucketId = null, RequestOptions options = null) where TResponse : class { - var request = new RestRequest(_restClient, method, endpoint, bucketId, options); + options = options ?? new RequestOptions(); + options.BucketId = bucketId; + options.ClientBucketId = AuthTokenType == TokenType.User ? clientBucketId : null; + + var request = new RestRequest(_restClient, method, endpoint, options); return DeserializeJson(await SendInternalAsync(method, endpoint, request).ConfigureAwait(false)); } - public async Task SendJsonAsync(string method, string endpoint, string bucketId, object payload, RequestOptions options) where TResponse : class + + internal Task SendJsonAsync(string method, Expression> endpointExpr, object payload, BucketIds ids, + string clientBucketId = null, RequestOptions options = null, [CallerMemberName] string funcName = null) where TResponse : class + => SendJsonAsync(method, GetEndpoint(endpointExpr), payload, GetBucketId(ids, endpointExpr, AuthTokenType, funcName), clientBucketId, options); + public async Task SendJsonAsync(string method, string endpoint, object payload, + string bucketId = null, string clientBucketId = null, RequestOptions options = null) where TResponse : class { + options = options ?? new RequestOptions(); + options.BucketId = bucketId; + options.ClientBucketId = AuthTokenType == TokenType.User ? clientBucketId : null; + var json = payload != null ? SerializeJson(payload) : null; - var request = new JsonRestRequest(_restClient, method, endpoint, bucketId, json, options); + var request = new JsonRestRequest(_restClient, method, endpoint, json, options); return DeserializeJson(await SendInternalAsync(method, endpoint, request).ConfigureAwait(false)); } - public async Task SendMultipartAsync(string method, string endpoint, string bucketId, IReadOnlyDictionary multipartArgs, RequestOptions options) + + internal Task SendMultipartAsync(string method, Expression> endpointExpr, IReadOnlyDictionary multipartArgs, BucketIds ids, + string clientBucketId = null, RequestOptions options = null, [CallerMemberName] string funcName = null) + => SendMultipartAsync(method, GetEndpoint(endpointExpr), multipartArgs, GetBucketId(ids, endpointExpr, AuthTokenType, funcName), clientBucketId, options); + public async Task SendMultipartAsync(string method, string endpoint, IReadOnlyDictionary multipartArgs, + string bucketId = null, string clientBucketId = null, RequestOptions options = null) { - var request = new MultipartRestRequest(_restClient, method, endpoint, bucketId, multipartArgs, options); + options = options ?? new RequestOptions(); + options.BucketId = bucketId; + options.ClientBucketId = AuthTokenType == TokenType.User ? clientBucketId : null; + + var request = new MultipartRestRequest(_restClient, method, endpoint, multipartArgs, options); return DeserializeJson(await SendInternalAsync(method, endpoint, request).ConfigureAwait(false)); } - internal Task SendAsync(string method, Expression> endpointExpr, BucketIds ids, - RequestOptions options, [CallerMemberName] string funcName = null) - => SendAsync(method, GetEndpoint(endpointExpr), GetBucketId(ids, endpointExpr, funcName), options); - internal Task SendJsonAsync(string method, Expression> endpointExpr, object payload, BucketIds ids, - RequestOptions options, [CallerMemberName] string funcName = null) - => SendJsonAsync(method, GetEndpoint(endpointExpr), GetBucketId(ids, endpointExpr, funcName), payload, options); - internal Task SendMultipartAsync(string method, Expression> endpointExpr, IReadOnlyDictionary multipartArgs, BucketIds ids, - RequestOptions options, [CallerMemberName] string funcName = null) - => SendMultipartAsync(method, GetEndpoint(endpointExpr), GetBucketId(ids, endpointExpr, funcName), multipartArgs, options); - internal Task SendAsync(string method, Expression> endpointExpr, BucketIds ids, - RequestOptions options, [CallerMemberName] string funcName = null) where TResponse : class - => SendAsync(method, GetEndpoint(endpointExpr), GetBucketId(ids, endpointExpr, funcName), options); - internal Task SendJsonAsync(string method, Expression> endpointExpr, object payload, BucketIds ids, - RequestOptions options, [CallerMemberName] string funcName = null) where TResponse : class - => SendJsonAsync(method, GetEndpoint(endpointExpr), GetBucketId(ids, endpointExpr, funcName), payload, options); - internal Task SendMultipartAsync(string method, Expression> endpointExpr, IReadOnlyDictionary multipartArgs, BucketIds ids, - RequestOptions options, [CallerMemberName] string funcName = null) - => SendMultipartAsync(method, GetEndpoint(endpointExpr), GetBucketId(ids, endpointExpr, funcName), multipartArgs, options); - private async Task SendInternalAsync(string method, string endpoint, RestRequest request) { if (!request.Options.IgnoreState) @@ -412,7 +445,7 @@ namespace Discord.API options = RequestOptions.CreateOrClone(options); var ids = new BucketIds(channelId: channelId); - return await SendJsonAsync("POST", () => $"channels/{channelId}/messages", args, ids, options: options).ConfigureAwait(false); + return await SendJsonAsync("POST", () => $"channels/{channelId}/messages", args, ids, clientBucketId: ClientBucket.SendEditId, options: options).ConfigureAwait(false); } public async Task UploadFileAsync(ulong channelId, UploadFileParams args, RequestOptions options = null) { @@ -431,7 +464,7 @@ namespace Discord.API } var ids = new BucketIds(channelId: channelId); - return await SendMultipartAsync("POST", () => $"channels/{channelId}/messages", args.ToDictionary(), ids, options: options).ConfigureAwait(false); + return await SendMultipartAsync("POST", () => $"channels/{channelId}/messages", args.ToDictionary(), ids, clientBucketId: ClientBucket.SendEditId, options: options).ConfigureAwait(false); } public async Task DeleteMessageAsync(ulong channelId, ulong messageId, RequestOptions options = null) { @@ -477,7 +510,7 @@ namespace Discord.API options = RequestOptions.CreateOrClone(options); var ids = new BucketIds(channelId: channelId); - return await SendJsonAsync("PATCH", () => $"channels/{channelId}/messages/{messageId}", args, ids, options: options).ConfigureAwait(false); + return await SendJsonAsync("PATCH", () => $"channels/{channelId}/messages/{messageId}", args, ids, clientBucketId: ClientBucket.SendEditId, options: options).ConfigureAwait(false); } public async Task AckMessageAsync(ulong channelId, ulong messageId, RequestOptions options = null) { @@ -1042,8 +1075,8 @@ namespace Discord.API internal class BucketIds { - public ulong GuildId { get; } - public ulong ChannelId { get; } + public ulong GuildId { get; internal set; } + public ulong ChannelId { get; internal set; } internal BucketIds(ulong guildId = 0, ulong channelId = 0) { @@ -1069,7 +1102,7 @@ namespace Discord.API { return endpointExpr.Compile()(); } - private static string GetBucketId(BucketIds ids, Expression> endpointExpr, string callingMethod) + private static string GetBucketId(BucketIds ids, Expression> endpointExpr, TokenType tokenType, string callingMethod) { return _bucketIdGenerators.GetOrAdd(callingMethod, x => CreateBucketId(endpointExpr))(ids); } diff --git a/src/Discord.Net.Core/Net/Queue/ClientBucket.cs b/src/Discord.Net.Core/Net/Queue/ClientBucket.cs index 93e5cfd23..14d3c3207 100644 --- a/src/Discord.Net.Core/Net/Queue/ClientBucket.cs +++ b/src/Discord.Net.Core/Net/Queue/ClientBucket.cs @@ -4,15 +4,17 @@ namespace Discord.Net.Queue { public struct ClientBucket { + public const string SendEditId = ""; + private static readonly ImmutableDictionary _defs; static ClientBucket() { var builder = ImmutableDictionary.CreateBuilder(); - builder.Add("", new ClientBucket(5, 5)); + builder.Add(SendEditId, new ClientBucket(10, 10)); _defs = builder.ToImmutable(); } - public static ClientBucket Get(string id) => _defs[id]; + public static ClientBucket Get(string id) =>_defs[id]; public int WindowCount { get; } public int WindowSeconds { get; } diff --git a/src/Discord.Net.Core/Net/Queue/RequestQueue.cs b/src/Discord.Net.Core/Net/Queue/RequestQueue.cs index 28caca1c2..ab20d7c18 100644 --- a/src/Discord.Net.Core/Net/Queue/RequestQueue.cs +++ b/src/Discord.Net.Core/Net/Queue/RequestQueue.cs @@ -12,8 +12,6 @@ namespace Discord.Net.Queue { public event Func RateLimitTriggered; - internal TokenType TokenType { get; set; } - private readonly ConcurrentDictionary _buckets; private readonly SemaphoreSlim _tokenLock; private CancellationTokenSource _clearToken; @@ -66,7 +64,7 @@ namespace Discord.Net.Queue public async Task SendAsync(RestRequest request) { request.CancelToken = _requestCancelToken; - var bucket = GetOrCreateBucket(request.BucketId); + var bucket = GetOrCreateBucket(request.Options.BucketId, request); return await bucket.SendAsync(request).ConfigureAwait(false); } public async Task SendAsync(WebSocketRequest request) @@ -90,9 +88,9 @@ namespace Discord.Net.Queue _waitUntil = DateTimeOffset.UtcNow.AddMilliseconds(info.RetryAfter.Value + lag.TotalMilliseconds); } - private RequestBucket GetOrCreateBucket(string id) + private RequestBucket GetOrCreateBucket(string id, RestRequest request) { - return _buckets.GetOrAdd(id, x => new RequestBucket(this, x)); + return _buckets.GetOrAdd(id, x => new RequestBucket(this, request, x)); } internal async Task RaiseRateLimitTriggered(string bucketId, RateLimitInfo? info) { diff --git a/src/Discord.Net.Core/Net/Queue/RequestQueueBucket.cs b/src/Discord.Net.Core/Net/Queue/RequestQueueBucket.cs index 211a68eab..8af87ee06 100644 --- a/src/Discord.Net.Core/Net/Queue/RequestQueueBucket.cs +++ b/src/Discord.Net.Core/Net/Queue/RequestQueueBucket.cs @@ -20,15 +20,15 @@ namespace Discord.Net.Queue public int WindowCount { get; private set; } public DateTimeOffset LastAttemptAt { get; private set; } - public RequestBucket(RequestQueue queue, string id) + public RequestBucket(RequestQueue queue, RestRequest request, string id) { _queue = queue; Id = id; _lock = new object(); - if (queue.TokenType == TokenType.User) - WindowCount = ClientBucket.Get(Id).WindowCount; + if (request.Options.ClientBucketId != null) + WindowCount = ClientBucket.Get(request.Options.ClientBucketId).WindowCount; else WindowCount = 1; //Only allow one request until we get a header back _semaphore = WindowCount; @@ -65,7 +65,7 @@ namespace Discord.Net.Queue else { Debug.WriteLine($"[{id}] (!) 429"); - Update(id, info, lag); + UpdateRateLimit(id, request, info, lag, true); } await _queue.RaiseRateLimitTriggered(Id, info).ConfigureAwait(false); continue; //Retry @@ -93,7 +93,7 @@ namespace Discord.Net.Queue else { Debug.WriteLine($"[{id}] Success"); - Update(id, info, lag); + UpdateRateLimit(id, request, info, lag, false); Debug.WriteLine($"[{id}] Stop"); return response.Stream; } @@ -151,26 +151,23 @@ namespace Discord.Net.Queue } } - private void Update(int id, RateLimitInfo info, TimeSpan lag) + private void UpdateRateLimit(int id, RestRequest request, RateLimitInfo info, TimeSpan lag, bool is429) { + if (WindowCount == 0) + return; + lock (_lock) { - if (!info.Limit.HasValue && _queue.TokenType != TokenType.User) - { - WindowCount = 0; - return; - } - bool hasQueuedReset = _resetTick != null; if (info.Limit.HasValue && WindowCount != info.Limit.Value) { WindowCount = info.Limit.Value; _semaphore = info.Remaining.Value; - Debug.WriteLine($"[{id}] Upgraded Semaphore to {info.Remaining.Value}/{WindowCount} "); + Debug.WriteLine($"[{id}] Upgraded Semaphore to {info.Remaining.Value}/{WindowCount}"); } var now = DateTimeOffset.UtcNow.ToUnixTimeSeconds(); - DateTimeOffset resetTick; + DateTimeOffset? resetTick = null; //Using X-RateLimit-Remaining causes a race condition /*if (info.Remaining.HasValue) @@ -187,26 +184,27 @@ namespace Discord.Net.Queue else if (info.Reset.HasValue) { resetTick = info.Reset.Value.AddSeconds(/*1.0 +*/ lag.TotalSeconds); - int diff = (int)(resetTick - DateTimeOffset.UtcNow).TotalMilliseconds; + int diff = (int)(resetTick.Value - DateTimeOffset.UtcNow).TotalMilliseconds; Debug.WriteLine($"[{id}] X-RateLimit-Reset: {info.Reset.Value.ToUnixTimeSeconds()} ({diff} ms, {lag.TotalMilliseconds} ms lag)"); } - else if (_queue.TokenType == TokenType.User) + else if (request.Options.ClientBucketId != null) { - resetTick = DateTimeOffset.UtcNow.AddSeconds(ClientBucket.Get(Id).WindowSeconds); - Debug.WriteLine($"[{id}] Client Bucket: " + ClientBucket.Get(Id).WindowSeconds); + resetTick = DateTimeOffset.UtcNow.AddSeconds(ClientBucket.Get(request.Options.ClientBucketId).WindowSeconds); + Debug.WriteLine($"[{id}] Client Bucket ({ClientBucket.Get(request.Options.ClientBucketId).WindowSeconds * 1000} ms)"); } if (resetTick == null) { - resetTick = DateTimeOffset.UtcNow.AddSeconds(1.0); //Forcibly reset in a second - Debug.WriteLine($"[{id}] Unknown Retry Time!"); + WindowCount = 0; //No rate limit info, disable limits on this bucket (should only ever happen with a user token) + Debug.WriteLine($"[{id}] Disabled Semaphore"); + return; } if (!hasQueuedReset || resetTick > _resetTick) { _resetTick = resetTick; - LastAttemptAt = resetTick; //Make sure we dont destroy this until after its been reset - Debug.WriteLine($"[{id}] Reset in {(int)Math.Ceiling((resetTick - DateTimeOffset.UtcNow).TotalMilliseconds)} ms"); + LastAttemptAt = resetTick.Value; //Make sure we dont destroy this until after its been reset + Debug.WriteLine($"[{id}] Reset in {(int)Math.Ceiling((resetTick - DateTimeOffset.UtcNow).Value.TotalMilliseconds)} ms"); if (!hasQueuedReset) { diff --git a/src/Discord.Net.Core/Net/Queue/Requests/IRequest.cs b/src/Discord.Net.Core/Net/Queue/Requests/IRequest.cs deleted file mode 100644 index c8d861a11..000000000 --- a/src/Discord.Net.Core/Net/Queue/Requests/IRequest.cs +++ /dev/null @@ -1,12 +0,0 @@ -using System; -using System.Threading; - -namespace Discord.Net.Queue -{ - public interface IRequest - { - CancellationToken CancelToken { get; } - DateTimeOffset? TimeoutAt { get; } - string BucketId { get; } - } -} diff --git a/src/Discord.Net.Core/Net/Queue/Requests/JsonRestRequest.cs b/src/Discord.Net.Core/Net/Queue/Requests/JsonRestRequest.cs index d328a3e26..75869d52a 100644 --- a/src/Discord.Net.Core/Net/Queue/Requests/JsonRestRequest.cs +++ b/src/Discord.Net.Core/Net/Queue/Requests/JsonRestRequest.cs @@ -7,15 +7,15 @@ namespace Discord.Net.Queue { public string Json { get; } - public JsonRestRequest(IRestClient client, string method, string endpoint, string bucket, string json, RequestOptions options) - : base(client, method, endpoint, bucket, options) + public JsonRestRequest(IRestClient client, string method, string endpoint, string json, RequestOptions options) + : base(client, method, endpoint, options) { Json = json; } public override async Task SendAsync() { - return await Client.SendAsync(Method, Endpoint, Json, Options).ConfigureAwait(false); + return await Client.SendAsync(Method, Endpoint, Json, CancelToken, Options.HeaderOnly).ConfigureAwait(false); } } } diff --git a/src/Discord.Net.Core/Net/Queue/Requests/MultipartRestRequest.cs b/src/Discord.Net.Core/Net/Queue/Requests/MultipartRestRequest.cs index e27bb92a0..d132ef395 100644 --- a/src/Discord.Net.Core/Net/Queue/Requests/MultipartRestRequest.cs +++ b/src/Discord.Net.Core/Net/Queue/Requests/MultipartRestRequest.cs @@ -8,15 +8,15 @@ namespace Discord.Net.Queue { public IReadOnlyDictionary MultipartParams { get; } - public MultipartRestRequest(IRestClient client, string method, string endpoint, string bucket, IReadOnlyDictionary multipartParams, RequestOptions options) - : base(client, method, endpoint, bucket, options) + public MultipartRestRequest(IRestClient client, string method, string endpoint, IReadOnlyDictionary multipartParams, RequestOptions options) + : base(client, method, endpoint, options) { MultipartParams = multipartParams; } public override async Task SendAsync() { - return await Client.SendAsync(Method, Endpoint, MultipartParams, Options).ConfigureAwait(false); + return await Client.SendAsync(Method, Endpoint, MultipartParams, CancelToken, Options.HeaderOnly).ConfigureAwait(false); } } } diff --git a/src/Discord.Net.Core/Net/Queue/Requests/RestRequest.cs b/src/Discord.Net.Core/Net/Queue/Requests/RestRequest.cs index 8382003c8..5d5bc1e59 100644 --- a/src/Discord.Net.Core/Net/Queue/Requests/RestRequest.cs +++ b/src/Discord.Net.Core/Net/Queue/Requests/RestRequest.cs @@ -6,33 +6,32 @@ using System.Threading.Tasks; namespace Discord.Net.Queue { - public class RestRequest : IRequest + public class RestRequest { public IRestClient Client { get; } public string Method { get; } public string Endpoint { get; } - public string BucketId { get; } public DateTimeOffset? TimeoutAt { get; } public TaskCompletionSource Promise { get; } public RequestOptions Options { get; } public CancellationToken CancelToken { get; internal set; } - public RestRequest(IRestClient client, string method, string endpoint, string bucketId, RequestOptions options) + public RestRequest(IRestClient client, string method, string endpoint, RequestOptions options) { Preconditions.NotNull(options, nameof(options)); Client = client; Method = method; Endpoint = endpoint; - BucketId = bucketId; Options = options; + CancelToken = CancellationToken.None; TimeoutAt = options.Timeout.HasValue ? DateTimeOffset.UtcNow.AddMilliseconds(options.Timeout.Value) : (DateTimeOffset?)null; Promise = new TaskCompletionSource(); } public virtual async Task SendAsync() { - return await Client.SendAsync(Method, Endpoint, Options).ConfigureAwait(false); + return await Client.SendAsync(Method, Endpoint, CancelToken, Options.HeaderOnly).ConfigureAwait(false); } } } diff --git a/src/Discord.Net.Core/Net/Queue/Requests/WebSocketRequest.cs b/src/Discord.Net.Core/Net/Queue/Requests/WebSocketRequest.cs index 08cdb192c..478289b59 100644 --- a/src/Discord.Net.Core/Net/Queue/Requests/WebSocketRequest.cs +++ b/src/Discord.Net.Core/Net/Queue/Requests/WebSocketRequest.cs @@ -6,7 +6,7 @@ using System.Threading.Tasks; namespace Discord.Net.Queue { - public class WebSocketRequest : IRequest + public class WebSocketRequest { public IWebSocketClient Client { get; } public string BucketId { get; } diff --git a/src/Discord.Net.Core/Net/Rest/DefaultRestClient.cs b/src/Discord.Net.Core/Net/Rest/DefaultRestClient.cs index 02c356efd..5ec30c750 100644 --- a/src/Discord.Net.Core/Net/Rest/DefaultRestClient.cs +++ b/src/Discord.Net.Core/Net/Rest/DefaultRestClient.cs @@ -66,22 +66,22 @@ namespace Discord.Net.Rest _cancelToken = CancellationTokenSource.CreateLinkedTokenSource(_parentToken, _cancelTokenSource.Token).Token; } - public async Task SendAsync(string method, string endpoint, RequestOptions options) + public async Task SendAsync(string method, string endpoint, CancellationToken cancelToken, bool headerOnly) { string uri = Path.Combine(_baseUrl, endpoint); using (var restRequest = new HttpRequestMessage(GetMethod(method), uri)) - return await SendInternalAsync(restRequest, options).ConfigureAwait(false); + return await SendInternalAsync(restRequest, cancelToken, headerOnly).ConfigureAwait(false); } - public async Task SendAsync(string method, string endpoint, string json, RequestOptions options) + public async Task SendAsync(string method, string endpoint, string json, CancellationToken cancelToken, bool headerOnly) { string uri = Path.Combine(_baseUrl, endpoint); using (var restRequest = new HttpRequestMessage(GetMethod(method), uri)) { restRequest.Content = new StringContent(json, Encoding.UTF8, "application/json"); - return await SendInternalAsync(restRequest, options).ConfigureAwait(false); + return await SendInternalAsync(restRequest, cancelToken, headerOnly).ConfigureAwait(false); } } - public async Task SendAsync(string method, string endpoint, IReadOnlyDictionary multipartParams, RequestOptions options) + public async Task SendAsync(string method, string endpoint, IReadOnlyDictionary multipartParams, CancellationToken cancelToken, bool headerOnly) { string uri = Path.Combine(_baseUrl, endpoint); using (var restRequest = new HttpRequestMessage(GetMethod(method), uri)) @@ -109,19 +109,19 @@ namespace Discord.Net.Rest } } restRequest.Content = content; - return await SendInternalAsync(restRequest, options).ConfigureAwait(false); + return await SendInternalAsync(restRequest, cancelToken, headerOnly).ConfigureAwait(false); } } - private async Task SendInternalAsync(HttpRequestMessage request, RequestOptions options) + private async Task SendInternalAsync(HttpRequestMessage request, CancellationToken cancelToken, bool headerOnly) { while (true) { - var cancelToken = _cancelToken; //It's okay if another thread changes this, causes a retry to abort + cancelToken = CancellationTokenSource.CreateLinkedTokenSource(_cancelToken, cancelToken).Token; HttpResponseMessage response = await _client.SendAsync(request, cancelToken).ConfigureAwait(false); var headers = response.Headers.ToDictionary(x => x.Key, x => x.Value.FirstOrDefault()); - var stream = !options.HeaderOnly ? await response.Content.ReadAsStreamAsync().ConfigureAwait(false) : null; + var stream = !headerOnly ? await response.Content.ReadAsStreamAsync().ConfigureAwait(false) : null; return new RestResponse(response.StatusCode, headers, stream); } diff --git a/src/Discord.Net.Core/Net/Rest/IRestClient.cs b/src/Discord.Net.Core/Net/Rest/IRestClient.cs index 16cfbe62d..b5f136cb0 100644 --- a/src/Discord.Net.Core/Net/Rest/IRestClient.cs +++ b/src/Discord.Net.Core/Net/Rest/IRestClient.cs @@ -1,4 +1,3 @@ -using Discord.Net.Queue; using System.Collections.Generic; using System.Threading; using System.Threading.Tasks; @@ -10,8 +9,8 @@ namespace Discord.Net.Rest void SetHeader(string key, string value); void SetCancelToken(CancellationToken cancelToken); - Task SendAsync(string method, string endpoint, RequestOptions options); - Task SendAsync(string method, string endpoint, string json, RequestOptions options); - Task SendAsync(string method, string endpoint, IReadOnlyDictionary multipartParams, RequestOptions options); + Task SendAsync(string method, string endpoint, CancellationToken cancelToken, bool headerOnly = false); + Task SendAsync(string method, string endpoint, string json, CancellationToken cancelToken, bool headerOnly = false); + Task SendAsync(string method, string endpoint, IReadOnlyDictionary multipartParams, CancellationToken cancelToken, bool headerOnly = false); } } diff --git a/src/Discord.Net.Core/RequestOptions.cs b/src/Discord.Net.Core/RequestOptions.cs index 1d362fad1..3af6c929d 100644 --- a/src/Discord.Net.Core/RequestOptions.cs +++ b/src/Discord.Net.Core/RequestOptions.cs @@ -9,7 +9,9 @@ public bool HeaderOnly { get; internal set; } internal bool IgnoreState { get; set; } - + internal string BucketId { get; set; } + internal string ClientBucketId { get; set; } + internal static RequestOptions CreateOrClone(RequestOptions options) { if (options == null)