diff --git a/src/Discord.Net/Rpc/DiscordRpcClient.cs b/src/Discord.Net/Rpc/DiscordRpcClient.cs index b54a6f235..c062537ec 100644 --- a/src/Discord.Net/Rpc/DiscordRpcClient.cs +++ b/src/Discord.Net/Rpc/DiscordRpcClient.cs @@ -6,7 +6,6 @@ using Discord.Rest; using Newtonsoft.Json; using Newtonsoft.Json.Linq; using System; -using System.Runtime.InteropServices; using System.Threading; using System.Threading.Tasks; @@ -21,9 +20,13 @@ namespace Discord.Rpc private CancellationTokenSource _cancelToken, _reconnectCancelToken; private Task _reconnectTask; private bool _canReconnect; + private int _connectionTimeout; public ConnectionState ConnectionState { get; private set; } + //From DiscordRpcConfig + internal int ConnectionTimeout { get; private set; } + public new API.DiscordRpcApiClient ApiClient => base.ApiClient as API.DiscordRpcApiClient; /// Creates a new RPC discord client. @@ -32,6 +35,7 @@ namespace Discord.Rpc public DiscordRpcClient(DiscordRpcConfig config) : base(config, CreateApiClient(config)) { + ConnectionTimeout = config.ConnectionTimeout; _rpcLogger = LogManager.CreateLogger("RPC"); _serializer = new JsonSerializer { ContractResolver = new DiscordContractResolver() }; @@ -95,8 +99,17 @@ namespace Discord.Rpc await _rpcLogger.InfoAsync("Connecting").ConfigureAwait(false); try { - _connectTask = new TaskCompletionSource(); + var connectTask = new TaskCompletionSource(); + _connectTask = connectTask; _cancelToken = new CancellationTokenSource(); + + //Abort connection on timeout + Task.Run(async () => + { + await Task.Delay(_connectionTimeout); + connectTask.TrySetException(new TimeoutException()); + }); + await ApiClient.ConnectAsync().ConfigureAwait(false); await _connectedEvent.InvokeAsync().ConfigureAwait(false); diff --git a/src/Discord.Net/Rpc/DiscordRpcConfig.cs b/src/Discord.Net/Rpc/DiscordRpcConfig.cs index 32c63fc01..ac54551ed 100644 --- a/src/Discord.Net/Rpc/DiscordRpcConfig.cs +++ b/src/Discord.Net/Rpc/DiscordRpcConfig.cs @@ -17,9 +17,12 @@ namespace Discord.Rpc } /// Gets or sets the Discord client/application id used for this RPC connection. - public string ClientId { get; set; } + public string ClientId { get; } /// Gets or sets the origin used for this RPC connection. - public string Origin { get; set; } + public string Origin { get; } + + /// Gets or sets the time, in milliseconds, to wait for a connection to complete before aborting. + public int ConnectionTimeout { get; set; } = 30000; /// Gets or sets the provider used to generate new websocket connections. public WebSocketProvider WebSocketProvider { get; set; } = () => new DefaultWebSocketClient(); diff --git a/src/Discord.Net/WebSocket/DiscordSocketClient.cs b/src/Discord.Net/WebSocket/DiscordSocketClient.cs index d0bea7faf..d9dbb1af5 100644 --- a/src/Discord.Net/WebSocket/DiscordSocketClient.cs +++ b/src/Discord.Net/WebSocket/DiscordSocketClient.cs @@ -51,6 +51,7 @@ namespace Discord.WebSocket internal int LargeThreshold { get; private set; } internal AudioMode AudioMode { get; private set; } internal DataStore DataStore { get; private set; } + internal int ConnectionTimeout { get; private set; } internal WebSocketProvider WebSocketProvider { get; private set; } public new API.DiscordSocketApiClient ApiClient => base.ApiClient as API.DiscordSocketApiClient; @@ -70,6 +71,7 @@ namespace Discord.WebSocket LargeThreshold = config.LargeThreshold; AudioMode = config.AudioMode; WebSocketProvider = config.WebSocketProvider; + ConnectionTimeout = config.ConnectionTimeout; DataStore = new DataStore(0, 0); _nextAudioId = 1; @@ -158,8 +160,17 @@ namespace Discord.WebSocket try { - _connectTask = new TaskCompletionSource(); + var connectTask = new TaskCompletionSource(); + _connectTask = connectTask; _cancelToken = new CancellationTokenSource(); + + //Abort connection on timeout + Task.Run(async () => + { + await Task.Delay(ConnectionTimeout); + connectTask.TrySetException(new TimeoutException()); + }); + await ApiClient.ConnectAsync().ConfigureAwait(false); await _connectedEvent.InvokeAsync().ConfigureAwait(false); @@ -249,6 +260,15 @@ namespace Discord.WebSocket private async Task StartReconnectAsync(Exception ex) { + if (ex == null) + { + if (_connectTask?.TrySetCanceled() ?? false) return; + } + else + { + if (_connectTask?.TrySetException(ex) ?? false) return; + } + await _connectionLock.WaitAsync().ConfigureAwait(false); try { @@ -260,15 +280,6 @@ namespace Discord.WebSocket } private async Task ReconnectInternalAsync(Exception ex, CancellationToken cancelToken) { - if (ex == null) - { - if (_connectTask?.TrySetCanceled() ?? false) return; - } - else - { - if (_connectTask?.TrySetException(ex) ?? false) return; - } - try { Random jitter = new Random(); diff --git a/src/Discord.Net/WebSocket/DiscordSocketConfig.cs b/src/Discord.Net/WebSocket/DiscordSocketConfig.cs index f1f29b22b..dc0347a1c 100644 --- a/src/Discord.Net/WebSocket/DiscordSocketConfig.cs +++ b/src/Discord.Net/WebSocket/DiscordSocketConfig.cs @@ -8,6 +8,9 @@ namespace Discord.WebSocket { public const string GatewayEncoding = "json"; + /// Gets or sets the time, in milliseconds, to wait for a connection to complete before aborting. + public int ConnectionTimeout { get; set; } = 30000; + /// Gets or sets the id for this shard. Must be less than TotalShards. public int ShardId { get; set; } = 0; /// Gets or sets the total number of shards for this application. @@ -16,8 +19,7 @@ namespace Discord.WebSocket /// Gets or sets the number of messages per channel that should be kept in cache. Setting this to zero disables the message cache entirely. public int MessageCacheSize { get; set; } = 0; /// - /// Gets or sets the max number of users a guild may have for offline users to be included in the READY packet. Max is 250. - /// Decreasing this may reduce CPU usage while increasing login time and network usage. + /// Gets or sets the max number of users a guild may have for offline users to be included in the READY packet. Max is 250. /// public int LargeThreshold { get; set; } = 250;