Browse Source

add support for per-request headers

pull/2261/head
Cenngo 3 years ago
parent
commit
ec36f9d861
4 changed files with 35 additions and 12 deletions
  1. +7
    -3
      src/Discord.Net.Core/Net/Rest/IRestClient.cs
  2. +8
    -4
      src/Discord.Net.Core/RequestOptions.cs
  3. +16
    -4
      src/Discord.Net.Rest/Net/DefaultRestClient.cs
  4. +4
    -1
      src/Discord.Net.Rest/Net/Queue/Requests/RestRequest.cs

+ 7
- 3
src/Discord.Net.Core/Net/Rest/IRestClient.cs View File

@@ -30,9 +30,13 @@ namespace Discord.Net.Rest
/// <param name="cancelToken">The cancellation token used to cancel the task.</param>
/// <param name="headerOnly">Indicates whether to send the header only.</param>
/// <param name="reason">The audit log reason.</param>
/// <param name="requestHeaders">Additional headers to be sent with the request.</param>
/// <returns></returns>
Task<RestResponse> SendAsync(string method, string endpoint, CancellationToken cancelToken, bool headerOnly = false, string reason = null);
Task<RestResponse> SendAsync(string method, string endpoint, string json, CancellationToken cancelToken, bool headerOnly = false, string reason = null);
Task<RestResponse> SendAsync(string method, string endpoint, IReadOnlyDictionary<string, object> multipartParams, CancellationToken cancelToken, bool headerOnly = false, string reason = null);
Task<RestResponse> SendAsync(string method, string endpoint, CancellationToken cancelToken, bool headerOnly = false, string reason = null,
IEnumerable<KeyValuePair<string, IEnumerable<string>>> requestHeaders = null);
Task<RestResponse> SendAsync(string method, string endpoint, string json, CancellationToken cancelToken, bool headerOnly = false, string reason = null,
IEnumerable<KeyValuePair<string, IEnumerable<string>>> requestHeaders = null);
Task<RestResponse> SendAsync(string method, string endpoint, IReadOnlyDictionary<string, object> multipartParams, CancellationToken cancelToken, bool headerOnly = false, string reason = null,
IEnumerable<KeyValuePair<string, IEnumerable<string>>> requestHeaders = null);
}
}

+ 8
- 4
src/Discord.Net.Core/RequestOptions.cs View File

@@ -1,5 +1,6 @@
using Discord.Net;
using System;
using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;

@@ -19,7 +20,7 @@ namespace Discord
/// Gets or sets the maximum time to wait for this request to complete.
/// </summary>
/// <remarks>
/// Gets or set the max time, in milliseconds, to wait for this request to complete. If
/// Gets or set the max time, in milliseconds, to wait for this request to complete. If
/// <c>null</c>, a request will not time out. If a rate limit has been triggered for this request's bucket
/// and will not be unpaused in time, this request will fail immediately.
/// </remarks>
@@ -53,7 +54,7 @@ namespace Discord
/// </summary>
/// <remarks>
/// This property can also be set in <see cref="DiscordConfig"/>.
/// On a per-request basis, the system clock should only be disabled
/// On a per-request basis, the system clock should only be disabled
/// when millisecond precision is especially important, and the
/// hosting system is known to have a desynced clock.
/// </remarks>
@@ -70,8 +71,10 @@ namespace Discord
internal bool IsReactionBucket { get; set; }
internal bool IsGatewayBucket { get; set; }

internal IDictionary<string, IEnumerable<string?>> RequestHeaders { get; }

internal static RequestOptions CreateOrClone(RequestOptions options)
{
{
if (options == null)
return new RequestOptions();
else
@@ -96,8 +99,9 @@ namespace Discord
public RequestOptions()
{
Timeout = DiscordConfig.DefaultRequestTimeout;
RequestHeaders = new Dictionary<string, IEnumerable<string?>>();
}
public RequestOptions Clone() => MemberwiseClone() as RequestOptions;
}
}

+ 16
- 4
src/Discord.Net.Rest/Net/DefaultRestClient.cs View File

@@ -66,33 +66,45 @@ namespace Discord.Net.Rest
_cancelToken = cancelToken;
}

public async Task<RestResponse> SendAsync(string method, string endpoint, CancellationToken cancelToken, bool headerOnly, string reason = null)
public async Task<RestResponse> SendAsync(string method, string endpoint, CancellationToken cancelToken, bool headerOnly, string reason = null,
IEnumerable<KeyValuePair<string, IEnumerable<string>>> requestHeaders = null)
{
string uri = Path.Combine(_baseUrl, endpoint);
using (var restRequest = new HttpRequestMessage(GetMethod(method), uri))
{
if (reason != null) restRequest.Headers.Add("X-Audit-Log-Reason", Uri.EscapeDataString(reason));
if (requestHeaders != null)
foreach (var header in requestHeaders)
restRequest.Headers.Add(header.Key, header.Value);
return await SendInternalAsync(restRequest, cancelToken, headerOnly).ConfigureAwait(false);
}
}
public async Task<RestResponse> SendAsync(string method, string endpoint, string json, CancellationToken cancelToken, bool headerOnly, string reason = null)
public async Task<RestResponse> SendAsync(string method, string endpoint, string json, CancellationToken cancelToken, bool headerOnly, string reason = null,
IEnumerable<KeyValuePair<string, IEnumerable<string>>> requestHeaders = null)
{
string uri = Path.Combine(_baseUrl, endpoint);
using (var restRequest = new HttpRequestMessage(GetMethod(method), uri))
{
if (reason != null) restRequest.Headers.Add("X-Audit-Log-Reason", Uri.EscapeDataString(reason));
if (requestHeaders != null)
foreach (var header in requestHeaders)
restRequest.Headers.Add(header.Key, header.Value);
restRequest.Content = new StringContent(json, Encoding.UTF8, "application/json");
return await SendInternalAsync(restRequest, cancelToken, headerOnly).ConfigureAwait(false);
}
}

/// <exception cref="InvalidOperationException">Unsupported param type.</exception>
public async Task<RestResponse> SendAsync(string method, string endpoint, IReadOnlyDictionary<string, object> multipartParams, CancellationToken cancelToken, bool headerOnly, string reason = null)
public async Task<RestResponse> SendAsync(string method, string endpoint, IReadOnlyDictionary<string, object> multipartParams, CancellationToken cancelToken, bool headerOnly, string reason = null,
IEnumerable<KeyValuePair<string, IEnumerable<string>>> requestHeaders = null)
{
string uri = Path.Combine(_baseUrl, endpoint);
using (var restRequest = new HttpRequestMessage(GetMethod(method), uri))
{
if (reason != null) restRequest.Headers.Add("X-Audit-Log-Reason", Uri.EscapeDataString(reason));
if (requestHeaders != null)
foreach (var header in requestHeaders)
restRequest.Headers.Add(header.Key, header.Value);
var content = new MultipartFormDataContent("Upload----" + DateTime.Now.ToString(CultureInfo.InvariantCulture));
MemoryStream memoryStream = null;
if (multipartParams != null)
@@ -126,7 +138,7 @@ namespace Discord.Net.Rest

content.Add(streamContent, p.Key, fileValue.Filename);
#pragma warning restore IDISP004
continue;
}
default:


+ 4
- 1
src/Discord.Net.Rest/Net/Queue/Requests/RestRequest.cs View File

@@ -1,5 +1,8 @@
using Discord.Net.Rest;
using System;
using System.Collections;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.IO;
using System.Threading.Tasks;

@@ -28,7 +31,7 @@ namespace Discord.Net.Queue

public virtual async Task<RestResponse> SendAsync()
{
return await Client.SendAsync(Method, Endpoint, Options.CancelToken, Options.HeaderOnly, Options.AuditLogReason).ConfigureAwait(false);
return await Client.SendAsync(Method, Endpoint, Options.CancelToken, Options.HeaderOnly, Options.AuditLogReason, Options.RequestHeaders).ConfigureAwait(false);
}
}
}

Loading…
Cancel
Save