Browse Source

Moved (re)connection handling to ConnectionManager

tags/1.0-rc
RogueException 8 years ago
parent
commit
3190d7e26d
15 changed files with 440 additions and 581 deletions
  1. +1
    -1
      src/Discord.Net.Core/Audio/IAudioClient.cs
  2. +2
    -2
      src/Discord.Net.Core/IDiscordClient.cs
  3. +6
    -3
      src/Discord.Net.Core/Logging/LogMessage.cs
  4. +1
    -1
      src/Discord.Net.DebugTools/Discord.Net.DebugTools.csproj
  5. +16
    -17
      src/Discord.Net.Rest/BaseDiscordClient.cs
  6. +199
    -0
      src/Discord.Net.Rest/ConnectionManager.cs
  7. +7
    -2
      src/Discord.Net.Rest/DiscordRestClient.cs
  8. +2
    -2
      src/Discord.Net.Rpc/DiscordRpcClient.Events.cs
  9. +40
    -161
      src/Discord.Net.Rpc/DiscordRpcClient.cs
  10. +72
    -106
      src/Discord.Net.WebSocket/Audio/AudioClient.cs
  11. +6
    -37
      src/Discord.Net.WebSocket/DiscordShardedClient.cs
  12. +10
    -4
      src/Discord.Net.WebSocket/DiscordSocketApiClient.cs
  13. +70
    -210
      src/Discord.Net.WebSocket/DiscordSocketClient.cs
  14. +1
    -1
      src/Discord.Net.WebSocket/DiscordSocketConfig.cs
  15. +7
    -34
      src/Discord.Net.WebSocket/Entities/Guilds/SocketGuild.cs

+ 1
- 1
src/Discord.Net.Core/Audio/IAudioClient.cs View File

@@ -14,7 +14,7 @@ namespace Discord.Audio
/// <summary> Gets the estimated round-trip latency, in milliseconds, to the gateway server. </summary> /// <summary> Gets the estimated round-trip latency, in milliseconds, to the gateway server. </summary>
int Latency { get; } int Latency { get; }


Task DisconnectAsync();
Task StopAsync();


/// <summary> /// <summary>
/// Creates a new outgoing stream accepting Opus-encoded data. /// Creates a new outgoing stream accepting Opus-encoded data.


+ 2
- 2
src/Discord.Net.Core/IDiscordClient.cs View File

@@ -10,8 +10,8 @@ namespace Discord
ConnectionState ConnectionState { get; } ConnectionState ConnectionState { get; }
ISelfUser CurrentUser { get; } ISelfUser CurrentUser { get; }


Task ConnectAsync();
Task DisconnectAsync();
Task StartAsync();
Task StopAsync();


Task<IApplication> GetApplicationInfoAsync(); Task<IApplication> GetApplicationInfoAsync();




+ 6
- 3
src/Discord.Net.Core/Logging/LogMessage.cs View File

@@ -19,7 +19,7 @@ namespace Discord
} }


public override string ToString() => ToString(null); public override string ToString() => ToString(null);
public string ToString(StringBuilder builder = null, bool fullException = true, bool prependTimestamp = true, DateTimeKind timestampKind = DateTimeKind.Local, int? padSource = 9)
public string ToString(StringBuilder builder = null, bool fullException = true, bool prependTimestamp = true, DateTimeKind timestampKind = DateTimeKind.Local, int? padSource = 11)
{ {
string sourceName = Source; string sourceName = Source;
string message = Message; string message = Message;
@@ -87,8 +87,11 @@ namespace Discord
} }
if (exMessage != null) if (exMessage != null)
{ {
builder.Append(':');
builder.AppendLine();
if (!string.IsNullOrEmpty(Message))
{
builder.Append(':');
builder.AppendLine();
}
builder.Append(exMessage); builder.Append(exMessage);
} }




+ 1
- 1
src/Discord.Net.DebugTools/Discord.Net.DebugTools.csproj View File

@@ -6,7 +6,7 @@
<TargetFrameworks>netstandard1.6</TargetFrameworks> <TargetFrameworks>netstandard1.6</TargetFrameworks>
<AssemblyName>Discord.Net.DebugTools</AssemblyName> <AssemblyName>Discord.Net.DebugTools</AssemblyName>
<Authors>RogueException</Authors> <Authors>RogueException</Authors>
<Description>A Discord.Net extension adding random helper classes for diagnosing issues.</Description>
<Description>A Discord.Net extension adding some helper classes for diagnosing issues.</Description>
<PackageTags>discord;discordapp</PackageTags> <PackageTags>discord;discordapp</PackageTags>
<PackageProjectUrl>https://github.com/RogueException/Discord.Net</PackageProjectUrl> <PackageProjectUrl>https://github.com/RogueException/Discord.Net</PackageProjectUrl>
<PackageLicenseUrl>http://opensource.org/licenses/MIT</PackageLicenseUrl> <PackageLicenseUrl>http://opensource.org/licenses/MIT</PackageLicenseUrl>


+ 16
- 17
src/Discord.Net.Rest/BaseDiscordClient.cs View File

@@ -18,10 +18,9 @@ namespace Discord.Rest
public event Func<Task> LoggedOut { add { _loggedOutEvent.Add(value); } remove { _loggedOutEvent.Remove(value); } } public event Func<Task> LoggedOut { add { _loggedOutEvent.Add(value); } remove { _loggedOutEvent.Remove(value); } }
private readonly AsyncEvent<Func<Task>> _loggedOutEvent = new AsyncEvent<Func<Task>>(); private readonly AsyncEvent<Func<Task>> _loggedOutEvent = new AsyncEvent<Func<Task>>();


internal readonly Logger _restLogger, _queueLogger;
internal readonly SemaphoreSlim _connectionLock;
private bool _isFirstLogin;
private bool _isDisposed;
internal readonly Logger _restLogger;
private readonly SemaphoreSlim _stateLock;
private bool _isFirstLogin, _isDisposed;


internal API.DiscordRestApiClient ApiClient { get; } internal API.DiscordRestApiClient ApiClient { get; }
internal LogManager LogManager { get; } internal LogManager LogManager { get; }
@@ -35,17 +34,16 @@ namespace Discord.Rest
LogManager = new LogManager(config.LogLevel); LogManager = new LogManager(config.LogLevel);
LogManager.Message += async msg => await _logEvent.InvokeAsync(msg).ConfigureAwait(false); LogManager.Message += async msg => await _logEvent.InvokeAsync(msg).ConfigureAwait(false);


_connectionLock = new SemaphoreSlim(1, 1);
_stateLock = new SemaphoreSlim(1, 1);
_restLogger = LogManager.CreateLogger("Rest"); _restLogger = LogManager.CreateLogger("Rest");
_queueLogger = LogManager.CreateLogger("Queue");
_isFirstLogin = config.DisplayInitialLog; _isFirstLogin = config.DisplayInitialLog;


ApiClient.RequestQueue.RateLimitTriggered += async (id, info) => ApiClient.RequestQueue.RateLimitTriggered += async (id, info) =>
{ {
if (info == null) if (info == null)
await _queueLogger.WarningAsync($"Preemptive Rate limit triggered: {id ?? "null"}").ConfigureAwait(false);
await _restLogger.WarningAsync($"Preemptive Rate limit triggered: {id ?? "null"}").ConfigureAwait(false);
else else
await _queueLogger.WarningAsync($"Rate limit triggered: {id ?? "null"}").ConfigureAwait(false);
await _restLogger.WarningAsync($"Rate limit triggered: {id ?? "null"}").ConfigureAwait(false);
}; };
ApiClient.SentRequest += async (method, endpoint, millis) => await _restLogger.VerboseAsync($"{method} {endpoint}: {millis} ms").ConfigureAwait(false); ApiClient.SentRequest += async (method, endpoint, millis) => await _restLogger.VerboseAsync($"{method} {endpoint}: {millis} ms").ConfigureAwait(false);
} }
@@ -53,12 +51,12 @@ namespace Discord.Rest
/// <inheritdoc /> /// <inheritdoc />
public async Task LoginAsync(TokenType tokenType, string token, bool validateToken = true) public async Task LoginAsync(TokenType tokenType, string token, bool validateToken = true)
{ {
await _connectionLock.WaitAsync().ConfigureAwait(false);
await _stateLock.WaitAsync().ConfigureAwait(false);
try try
{ {
await LoginInternalAsync(tokenType, token).ConfigureAwait(false); await LoginInternalAsync(tokenType, token).ConfigureAwait(false);
} }
finally { _connectionLock.Release(); }
finally { _stateLock.Release(); }
} }
private async Task LoginInternalAsync(TokenType tokenType, string token) private async Task LoginInternalAsync(TokenType tokenType, string token)
{ {
@@ -86,17 +84,17 @@ namespace Discord.Rest


await _loggedInEvent.InvokeAsync().ConfigureAwait(false); await _loggedInEvent.InvokeAsync().ConfigureAwait(false);
} }
protected virtual Task OnLoginAsync(TokenType tokenType, string token) { return Task.Delay(0); }
internal virtual Task OnLoginAsync(TokenType tokenType, string token) { return Task.Delay(0); }


/// <inheritdoc /> /// <inheritdoc />
public async Task LogoutAsync() public async Task LogoutAsync()
{ {
await _connectionLock.WaitAsync().ConfigureAwait(false);
await _stateLock.WaitAsync().ConfigureAwait(false);
try try
{ {
await LogoutInternalAsync().ConfigureAwait(false); await LogoutInternalAsync().ConfigureAwait(false);
} }
finally { _connectionLock.Release(); }
finally { _stateLock.Release(); }
} }
private async Task LogoutInternalAsync() private async Task LogoutInternalAsync()
{ {
@@ -111,7 +109,7 @@ namespace Discord.Rest


await _loggedOutEvent.InvokeAsync().ConfigureAwait(false); await _loggedOutEvent.InvokeAsync().ConfigureAwait(false);
} }
protected virtual Task OnLogoutAsync() { return Task.Delay(0); }
internal virtual Task OnLogoutAsync() { return Task.Delay(0); }


internal virtual void Dispose(bool disposing) internal virtual void Dispose(bool disposing)
{ {
@@ -161,8 +159,9 @@ namespace Discord.Rest
Task<IVoiceRegion> IDiscordClient.GetVoiceRegionAsync(string id) Task<IVoiceRegion> IDiscordClient.GetVoiceRegionAsync(string id)
=> Task.FromResult<IVoiceRegion>(null); => Task.FromResult<IVoiceRegion>(null);


Task IDiscordClient.ConnectAsync() { throw new NotSupportedException(); }
Task IDiscordClient.DisconnectAsync() { throw new NotSupportedException(); }

Task IDiscordClient.StartAsync()
=> Task.Delay(0);
Task IDiscordClient.StopAsync()
=> Task.Delay(0);
} }
} }

+ 199
- 0
src/Discord.Net.Rest/ConnectionManager.cs View File

@@ -0,0 +1,199 @@
using Discord.Logging;
using System;
using System.Threading;
using System.Threading.Tasks;

namespace Discord
{
internal class ConnectionManager
{
public event Func<Task> Connected { add { _connectedEvent.Add(value); } remove { _connectedEvent.Remove(value); } }
private readonly AsyncEvent<Func<Task>> _connectedEvent = new AsyncEvent<Func<Task>>();
public event Func<Exception, bool, Task> Disconnected { add { _disconnectedEvent.Add(value); } remove { _disconnectedEvent.Remove(value); } }
private readonly AsyncEvent<Func<Exception, bool, Task>> _disconnectedEvent = new AsyncEvent<Func<Exception, bool, Task>>();

private readonly SemaphoreSlim _stateLock;
private readonly Logger _logger;
private readonly int _connectionTimeout;
private readonly Func<Task> _onConnecting;
private readonly Func<Exception, Task> _onDisconnecting;

private TaskCompletionSource<bool> _connectionPromise, _readyPromise;
private CancellationTokenSource _combinedCancelToken, _reconnectCancelToken, _connectionCancelToken;
private Task _task;

public ConnectionState State { get; private set; }
public CancellationToken CancelToken { get; private set; }

public bool IsCompleted => _readyPromise.Task.IsCompleted;

internal ConnectionManager(SemaphoreSlim stateLock, Logger logger, int connectionTimeout,
Func<Task> onConnecting, Func<Exception, Task> onDisconnecting, Action<Func<Exception, Task>> clientDisconnectHandler)
{
_stateLock = stateLock;
_logger = logger;
_connectionTimeout = connectionTimeout;
_onConnecting = onConnecting;
_onDisconnecting = onDisconnecting;

clientDisconnectHandler(ex =>
{
if (ex != null)
Error(new Exception("WebSocket connection was closed", ex));
else
Error(new Exception("WebSocket connection was closed"));
return Task.Delay(0);
});
}

public virtual async Task StartAsync()
{
await AcquireConnectionLock().ConfigureAwait(false);
var reconnectCancelToken = new CancellationTokenSource();
_reconnectCancelToken = new CancellationTokenSource();
_task = Task.Run(async () =>
{
try
{
Random jitter = new Random();
int nextReconnectDelay = 1000;
while (!reconnectCancelToken.IsCancellationRequested)
{
try
{
await ConnectAsync(reconnectCancelToken).ConfigureAwait(false);
nextReconnectDelay = 1000; //Reset delay
await _connectionPromise.Task.ConfigureAwait(false);
}
catch (OperationCanceledException ex)
{
Cancel(); //In case this exception didn't come from another Error call
await DisconnectAsync(ex, !reconnectCancelToken.IsCancellationRequested).ConfigureAwait(false);
}
catch (Exception ex)
{
Error(ex); //In case this exception didn't come from another Error call
if (!reconnectCancelToken.IsCancellationRequested)
{
await _logger.WarningAsync(ex).ConfigureAwait(false);
await DisconnectAsync(ex, true).ConfigureAwait(false);
}
else
{
await _logger.ErrorAsync(ex).ConfigureAwait(false);
await DisconnectAsync(ex, false).ConfigureAwait(false);
}
}

if (!reconnectCancelToken.IsCancellationRequested)
{
//Wait before reconnecting
await Task.Delay(nextReconnectDelay, reconnectCancelToken.Token).ConfigureAwait(false);
nextReconnectDelay = (nextReconnectDelay * 2) + jitter.Next(-250, 250);
if (nextReconnectDelay > 60000)
nextReconnectDelay = 60000;
}
}
}
finally { _stateLock.Release(); }
});
}
public virtual async Task StopAsync()
{
Cancel();
var task = _task;
if (task != null)
await task.ConfigureAwait(false);
}

private async Task ConnectAsync(CancellationTokenSource reconnectCancelToken)
{
_connectionCancelToken = new CancellationTokenSource();
_combinedCancelToken = CancellationTokenSource.CreateLinkedTokenSource(_connectionCancelToken.Token, reconnectCancelToken.Token);
CancelToken = _combinedCancelToken.Token;

_connectionPromise = new TaskCompletionSource<bool>();
State = ConnectionState.Connecting;
await _logger.InfoAsync("Connecting").ConfigureAwait(false);
try
{
var readyPromise = new TaskCompletionSource<bool>();
_readyPromise = readyPromise;

//Abort connection on timeout
var cancelToken = CancelToken;
var _ = Task.Run(async () =>
{
try
{
await Task.Delay(_connectionTimeout, cancelToken).ConfigureAwait(false);
readyPromise.TrySetException(new TimeoutException());
}
catch (OperationCanceledException) { }
});

await _onConnecting().ConfigureAwait(false);

await _logger.InfoAsync("Connected").ConfigureAwait(false);
State = ConnectionState.Connected;
await _logger.DebugAsync("Raising Event").ConfigureAwait(false);
await _connectedEvent.InvokeAsync().ConfigureAwait(false);
}
catch (Exception ex)
{
Error(ex);
throw;
}
}
private async Task DisconnectAsync(Exception ex, bool isReconnecting)
{
if (State == ConnectionState.Disconnected) return;
State = ConnectionState.Disconnecting;
await _logger.InfoAsync("Disconnecting").ConfigureAwait(false);

await _onDisconnecting(ex).ConfigureAwait(false);

await _logger.InfoAsync("Disconnected").ConfigureAwait(false);
State = ConnectionState.Disconnected;
await _disconnectedEvent.InvokeAsync(ex, isReconnecting).ConfigureAwait(false);
}

public async Task CompleteAsync()
{
await _readyPromise.TrySetResultAsync(true).ConfigureAwait(false);
}
public async Task WaitAsync()
{
await _readyPromise.Task.ConfigureAwait(false);
}

public void Cancel()
{
_readyPromise?.TrySetCanceled();
_connectionPromise?.TrySetCanceled();
_reconnectCancelToken?.Cancel();
_connectionCancelToken?.Cancel();
}
public void Error(Exception ex)
{
_readyPromise.TrySetException(ex);
_connectionPromise.TrySetException(ex);
_connectionCancelToken?.Cancel();
}
public void CriticalError(Exception ex)
{
_reconnectCancelToken?.Cancel();
Error(ex);
}
private async Task AcquireConnectionLock()
{
while (true)
{
await StopAsync().ConfigureAwait(false);
if (await _stateLock.WaitAsync(0).ConfigureAwait(false))
break;
}
}
}
}

+ 7
- 2
src/Discord.Net.Rest/DiscordRestClient.cs View File

@@ -16,14 +16,19 @@ namespace Discord.Rest


private static API.DiscordRestApiClient CreateApiClient(DiscordRestConfig config) private static API.DiscordRestApiClient CreateApiClient(DiscordRestConfig config)
=> new API.DiscordRestApiClient(config.RestClientProvider, DiscordRestConfig.UserAgent); => new API.DiscordRestApiClient(config.RestClientProvider, DiscordRestConfig.UserAgent);
internal override void Dispose(bool disposing)
{
if (disposing)
ApiClient.Dispose();
}


protected override async Task OnLoginAsync(TokenType tokenType, string token)
internal override async Task OnLoginAsync(TokenType tokenType, string token)
{ {
var user = await ApiClient.GetMyUserAsync(new RequestOptions { RetryMode = RetryMode.AlwaysRetry }).ConfigureAwait(false); var user = await ApiClient.GetMyUserAsync(new RequestOptions { RetryMode = RetryMode.AlwaysRetry }).ConfigureAwait(false);
ApiClient.CurrentUserId = user.Id; ApiClient.CurrentUserId = user.Id;
base.CurrentUser = RestSelfUser.Create(this, user); base.CurrentUser = RestSelfUser.Create(this, user);
} }
protected override Task OnLogoutAsync()
internal override Task OnLogoutAsync()
{ {
_applicationInfo = null; _applicationInfo = null;
return Task.Delay(0); return Task.Delay(0);


+ 2
- 2
src/Discord.Net.Rpc/DiscordRpcClient.Events.cs View File

@@ -12,12 +12,12 @@ namespace Discord.Rpc
remove { _connectedEvent.Remove(value); } remove { _connectedEvent.Remove(value); }
} }
private readonly AsyncEvent<Func<Task>> _connectedEvent = new AsyncEvent<Func<Task>>(); private readonly AsyncEvent<Func<Task>> _connectedEvent = new AsyncEvent<Func<Task>>();
public event Func<Exception, Task> Disconnected
public event Func<Exception, bool, Task> Disconnected
{ {
add { _disconnectedEvent.Add(value); } add { _disconnectedEvent.Add(value); }
remove { _disconnectedEvent.Remove(value); } remove { _disconnectedEvent.Remove(value); }
} }
private readonly AsyncEvent<Func<Exception, Task>> _disconnectedEvent = new AsyncEvent<Func<Exception, Task>>();
private readonly AsyncEvent<Func<Exception, bool, Task>> _disconnectedEvent = new AsyncEvent<Func<Exception, bool, Task>>();
public event Func<Task> Ready public event Func<Task> Ready
{ {
add { _readyEvent.Add(value); } add { _readyEvent.Add(value); }


+ 40
- 161
src/Discord.Net.Rpc/DiscordRpcClient.cs View File

@@ -8,28 +8,22 @@ using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Collections.Immutable; using System.Collections.Immutable;
using System.Linq; using System.Linq;
using System.Threading;
using System.Threading.Tasks; using System.Threading.Tasks;
using System.Threading;


namespace Discord.Rpc namespace Discord.Rpc
{ {
public partial class DiscordRpcClient : BaseDiscordClient, IDiscordClient public partial class DiscordRpcClient : BaseDiscordClient, IDiscordClient
{ {
private readonly Logger _rpcLogger;
private readonly JsonSerializer _serializer; private readonly JsonSerializer _serializer;

private TaskCompletionSource<bool> _connectTask;
private CancellationTokenSource _cancelToken, _reconnectCancelToken;
private Task _reconnectTask;
private bool _canReconnect;
private readonly ConnectionManager _connection;
private readonly Logger _rpcLogger;
private readonly SemaphoreSlim _stateLock, _authorizeLock;


public ConnectionState ConnectionState { get; private set; } public ConnectionState ConnectionState { get; private set; }
public IReadOnlyCollection<string> Scopes { get; private set; } public IReadOnlyCollection<string> Scopes { get; private set; }
public DateTimeOffset TokenExpiresAt { get; private set; } public DateTimeOffset TokenExpiresAt { get; private set; }


//From DiscordRpcConfig
internal int ConnectionTimeout { get; private set; }

internal new API.DiscordRpcApiClient ApiClient => base.ApiClient as API.DiscordRpcApiClient; internal new API.DiscordRpcApiClient ApiClient => base.ApiClient as API.DiscordRpcApiClient;
public new RestSelfUser CurrentUser { get { return base.CurrentUser as RestSelfUser; } private set { base.CurrentUser = value; } } public new RestSelfUser CurrentUser { get { return base.CurrentUser as RestSelfUser; } private set { base.CurrentUser = value; } }
public RestApplication ApplicationInfo { get; private set; } public RestApplication ApplicationInfo { get; private set; }
@@ -41,8 +35,11 @@ namespace Discord.Rpc
public DiscordRpcClient(string clientId, string origin, DiscordRpcConfig config) public DiscordRpcClient(string clientId, string origin, DiscordRpcConfig config)
: base(config, CreateApiClient(clientId, origin, config)) : base(config, CreateApiClient(clientId, origin, config))
{ {
ConnectionTimeout = config.ConnectionTimeout;
_stateLock = new SemaphoreSlim(1, 1);
_authorizeLock = new SemaphoreSlim(1, 1);
_rpcLogger = LogManager.CreateLogger("RPC"); _rpcLogger = LogManager.CreateLogger("RPC");
_connection = new ConnectionManager(_stateLock, _rpcLogger, config.ConnectionTimeout,
OnConnectingAsync, OnDisconnectingAsync, x => ApiClient.Disconnected += x);


_serializer = new JsonSerializer { ContractResolver = new DiscordContractResolver() }; _serializer = new JsonSerializer { ContractResolver = new DiscordContractResolver() };
_serializer.Error += (s, e) => _serializer.Error += (s, e) =>
@@ -53,177 +50,52 @@ namespace Discord.Rpc
ApiClient.SentRpcMessage += async opCode => await _rpcLogger.DebugAsync($"Sent {opCode}").ConfigureAwait(false); ApiClient.SentRpcMessage += async opCode => await _rpcLogger.DebugAsync($"Sent {opCode}").ConfigureAwait(false);
ApiClient.ReceivedRpcEvent += ProcessMessageAsync; ApiClient.ReceivedRpcEvent += ProcessMessageAsync;
ApiClient.Disconnected += async ex =>
{
if (ex != null)
{
await _rpcLogger.WarningAsync($"Connection Closed", ex).ConfigureAwait(false);
await StartReconnectAsync(ex).ConfigureAwait(false);
}
else
await _rpcLogger.WarningAsync($"Connection Closed").ConfigureAwait(false);
};
} }


private static API.DiscordRpcApiClient CreateApiClient(string clientId, string origin, DiscordRpcConfig config) private static API.DiscordRpcApiClient CreateApiClient(string clientId, string origin, DiscordRpcConfig config)
=> new API.DiscordRpcApiClient(clientId, DiscordRestConfig.UserAgent, origin, config.RestClientProvider, config.WebSocketProvider); => new API.DiscordRpcApiClient(clientId, DiscordRestConfig.UserAgent, origin, config.RestClientProvider, config.WebSocketProvider);

/// <inheritdoc />
public async Task ConnectAsync()
internal override void Dispose(bool disposing)
{ {
await _connectionLock.WaitAsync().ConfigureAwait(false);
try
if (disposing)
{ {
await ConnectInternalAsync(false).ConfigureAwait(false);
StopAsync().GetAwaiter().GetResult();
ApiClient.Dispose();
} }
finally { _connectionLock.Release(); }
} }
private async Task ConnectInternalAsync(bool isReconnecting)
{
if (!isReconnecting && _reconnectCancelToken != null && !_reconnectCancelToken.IsCancellationRequested)
_reconnectCancelToken.Cancel();


var state = ConnectionState;
if (state == ConnectionState.Connecting || state == ConnectionState.Connected)
await DisconnectInternalAsync(null, isReconnecting).ConfigureAwait(false);
public Task StartAsync() => _connection.StartAsync();
public Task StopAsync() => _connection.StopAsync();


ConnectionState = ConnectionState.Connecting;
await _rpcLogger.InfoAsync("Connecting").ConfigureAwait(false);
try
{
var connectTask = new TaskCompletionSource<bool>();
_connectTask = connectTask;
_cancelToken = new CancellationTokenSource();

//Abort connection on timeout
var _ = Task.Run(async () =>
{
await Task.Delay(ConnectionTimeout).ConfigureAwait(false);
connectTask.TrySetException(new TimeoutException());
});

await ApiClient.ConnectAsync().ConfigureAwait(false);
await _connectedEvent.InvokeAsync().ConfigureAwait(false);

await _connectTask.Task.ConfigureAwait(false);
if (!isReconnecting)
_canReconnect = true;
ConnectionState = ConnectionState.Connected;
await _rpcLogger.InfoAsync("Connected").ConfigureAwait(false);
}
catch (Exception)
{
await DisconnectInternalAsync(null, isReconnecting).ConfigureAwait(false);
throw;
}
}
/// <inheritdoc />
public async Task DisconnectAsync()
private async Task OnConnectingAsync()
{ {
if (_connectTask?.TrySetCanceled() ?? false) return;
await _connectionLock.WaitAsync().ConfigureAwait(false);
try
{
await DisconnectInternalAsync(null, false).ConfigureAwait(false);
}
finally { _connectionLock.Release(); }
await _rpcLogger.DebugAsync("Connecting ApiClient").ConfigureAwait(false);
await ApiClient.ConnectAsync().ConfigureAwait(false);

await _connection.WaitAsync().ConfigureAwait(false);
} }
private async Task DisconnectInternalAsync(Exception ex, bool isReconnecting)
private async Task OnDisconnectingAsync(Exception ex)
{ {
if (!isReconnecting)
{
_canReconnect = false;

if (_reconnectCancelToken != null && !_reconnectCancelToken.IsCancellationRequested)
_reconnectCancelToken.Cancel();
}

if (ConnectionState == ConnectionState.Disconnected) return;
ConnectionState = ConnectionState.Disconnecting;
await _rpcLogger.InfoAsync("Disconnecting").ConfigureAwait(false);

await _rpcLogger.DebugAsync("Disconnecting - CancelToken").ConfigureAwait(false);
//Signal tasks to complete
try { _cancelToken.Cancel(); } catch { }

await _rpcLogger.DebugAsync("Disconnecting - ApiClient").ConfigureAwait(false);
//Disconnect from server
await _rpcLogger.DebugAsync("Disconnecting ApiClient").ConfigureAwait(false);
await ApiClient.DisconnectAsync().ConfigureAwait(false); await ApiClient.DisconnectAsync().ConfigureAwait(false);
ConnectionState = ConnectionState.Disconnected;
await _rpcLogger.InfoAsync("Disconnected").ConfigureAwait(false);

await _disconnectedEvent.InvokeAsync(ex).ConfigureAwait(false);
} }


private async Task StartReconnectAsync(Exception ex)
public async Task<string> AuthorizeAsync(string[] scopes, string rpcToken = null, RequestOptions options = null)
{ {
await _connectionLock.WaitAsync().ConfigureAwait(false);
await _authorizeLock.WaitAsync().ConfigureAwait(false);
try try
{ {
if (!_canReconnect || _reconnectTask != null) return;
_reconnectCancelToken = new CancellationTokenSource();
_reconnectTask = ReconnectInternalAsync(ex, _reconnectCancelToken.Token);
}
finally { _connectionLock.Release(); }
}
private async Task ReconnectInternalAsync(Exception ex, CancellationToken cancelToken)
{
if (ex == null)
{
if (_connectTask?.TrySetCanceled() ?? false) return;
await _connection.StartAsync().ConfigureAwait(false);
await _connection.WaitAsync().ConfigureAwait(false);
var result = await ApiClient.SendAuthorizeAsync(scopes, rpcToken, options).ConfigureAwait(false);
await _connection.StopAsync().ConfigureAwait(false);
return result.Code;
} }
else
finally
{ {
if (_connectTask?.TrySetException(ex) ?? false) return;
}

try
{
Random jitter = new Random();
int nextReconnectDelay = 1000;
while (true)
{
await Task.Delay(nextReconnectDelay, cancelToken).ConfigureAwait(false);
nextReconnectDelay = nextReconnectDelay * 2 + jitter.Next(-250, 250);
if (nextReconnectDelay > 60000)
nextReconnectDelay = 60000;

await _connectionLock.WaitAsync().ConfigureAwait(false);
try
{
if (cancelToken.IsCancellationRequested) return;
await ConnectInternalAsync(true).ConfigureAwait(false);
_reconnectTask = null;
return;
}
catch (Exception ex2)
{
await _rpcLogger.WarningAsync("Reconnect failed", ex2).ConfigureAwait(false);
}
finally { _connectionLock.Release(); }
}
}
catch (OperationCanceledException)
{
await _connectionLock.WaitAsync().ConfigureAwait(false);
try
{
await _rpcLogger.DebugAsync("Reconnect cancelled").ConfigureAwait(false);
_reconnectTask = null;
}
finally { _connectionLock.Release(); }
_authorizeLock.Release();
} }
} }


public async Task<string> AuthorizeAsync(string[] scopes, string rpcToken = null, RequestOptions options = null)
{
await ConnectAsync().ConfigureAwait(false);
var result = await ApiClient.SendAuthorizeAsync(scopes, rpcToken, options).ConfigureAwait(false);
await DisconnectAsync().ConfigureAwait(false);
return result.Code;
}

public async Task SubscribeGlobal(RpcGlobalEvent evnt, RequestOptions options = null) public async Task SubscribeGlobal(RpcGlobalEvent evnt, RequestOptions options = null)
{ {
await ApiClient.SendGlobalSubscribeAsync(GetEventName(evnt), options).ConfigureAwait(false); await ApiClient.SendGlobalSubscribeAsync(GetEventName(evnt), options).ConfigureAwait(false);
@@ -439,8 +311,8 @@ namespace Discord.Rpc
ApplicationInfo = RestApplication.Create(this, response.Application); ApplicationInfo = RestApplication.Create(this, response.Application);
Scopes = response.Scopes; Scopes = response.Scopes;
TokenExpiresAt = response.Expires; TokenExpiresAt = response.Expires;
var __ = _connectTask.TrySetResultAsync(true); //Signal the .Connect() call to complete
var __ = _connection.CompleteAsync();
await _rpcLogger.InfoAsync("Ready").ConfigureAwait(false); await _rpcLogger.InfoAsync("Ready").ConfigureAwait(false);
} }
catch (Exception ex) catch (Exception ex)
@@ -452,7 +324,7 @@ namespace Discord.Rpc
} }
else else
{ {
var _ = _connectTask.TrySetResultAsync(true); //Signal the .Connect() call to complete
var _ = _connection.CompleteAsync();
await _rpcLogger.InfoAsync("Ready").ConfigureAwait(false); await _rpcLogger.InfoAsync("Ready").ConfigureAwait(false);
} }
} }
@@ -592,6 +464,13 @@ namespace Discord.Rpc
} }


//IDiscordClient //IDiscordClient
ConnectionState IDiscordClient.ConnectionState => _connection.State;

Task<IApplication> IDiscordClient.GetApplicationInfoAsync() => Task.FromResult<IApplication>(ApplicationInfo); Task<IApplication> IDiscordClient.GetApplicationInfoAsync() => Task.FromResult<IApplication>(ApplicationInfo);

async Task IDiscordClient.StartAsync()
=> await StartAsync().ConfigureAwait(false);
async Task IDiscordClient.StopAsync()
=> await StopAsync().ConfigureAwait(false);
} }
} }

+ 72
- 106
src/Discord.Net.WebSocket/Audio/AudioClient.cs View File

@@ -5,6 +5,7 @@ using Discord.WebSocket;
using Newtonsoft.Json; using Newtonsoft.Json;
using Newtonsoft.Json.Linq; using Newtonsoft.Json.Linq;
using System; using System;
using System.Collections.Concurrent;
using System.Linq; using System.Linq;
using System.Text; using System.Text;
using System.Threading; using System.Threading;
@@ -12,6 +13,7 @@ using System.Threading.Tasks;


namespace Discord.Audio namespace Discord.Audio
{ {
//TODO: Add audio reconnecting
internal class AudioClient : IAudioClient, IDisposable internal class AudioClient : IAudioClient, IDisposable
{ {
public event Func<Task> Connected public event Func<Task> Connected
@@ -34,34 +36,37 @@ namespace Discord.Audio
private readonly AsyncEvent<Func<int, int, Task>> _latencyUpdatedEvent = new AsyncEvent<Func<int, int, Task>>(); private readonly AsyncEvent<Func<int, int, Task>> _latencyUpdatedEvent = new AsyncEvent<Func<int, int, Task>>();


private readonly Logger _audioLogger; private readonly Logger _audioLogger;
internal readonly SemaphoreSlim _connectionLock;
private readonly JsonSerializer _serializer; private readonly JsonSerializer _serializer;
private readonly ConnectionManager _connection;
private readonly SemaphoreSlim _stateLock;
private readonly ConcurrentQueue<long> _heartbeatTimes;


private TaskCompletionSource<bool> _connectTask;
private CancellationTokenSource _cancelTokenSource;
private Task _heartbeatTask; private Task _heartbeatTask;
private long _heartbeatTime;
private string _url;
private long _lastMessageTime;
private string _url, _sessionId, _token;
private ulong _userId;
private uint _ssrc; private uint _ssrc;
private byte[] _secretKey; private byte[] _secretKey;
private bool _isDisposed;


public SocketGuild Guild { get; } public SocketGuild Guild { get; }
public DiscordVoiceAPIClient ApiClient { get; private set; } public DiscordVoiceAPIClient ApiClient { get; private set; }
public ConnectionState ConnectionState { get; private set; }
public int Latency { get; private set; } public int Latency { get; private set; }


private DiscordSocketClient Discord => Guild.Discord; private DiscordSocketClient Discord => Guild.Discord;
public ConnectionState ConnectionState => _connection.State;


/// <summary> Creates a new REST/WebSocket discord client. </summary> /// <summary> Creates a new REST/WebSocket discord client. </summary>
internal AudioClient(SocketGuild guild, int id) internal AudioClient(SocketGuild guild, int id)
{ {
Guild = guild; Guild = guild;


_audioLogger = Discord.LogManager.CreateLogger($"Audio #{id}");

_connectionLock = new SemaphoreSlim(1, 1);
_stateLock = new SemaphoreSlim(1, 1);
_connection = new ConnectionManager(_stateLock, _audioLogger, 30000,
OnConnectingAsync, OnDisconnectingAsync, x => ApiClient.Disconnected += x);
_heartbeatTimes = new ConcurrentQueue<long>();


_audioLogger = Discord.LogManager.CreateLogger($"Audio #{id}");
_serializer = new JsonSerializer { ContractResolver = new DiscordContractResolver() }; _serializer = new JsonSerializer { ContractResolver = new DiscordContractResolver() };
_serializer.Error += (s, e) => _serializer.Error += (s, e) =>
{ {
@@ -76,83 +81,28 @@ namespace Discord.Audio
//ApiClient.SentData += async bytes => await _audioLogger.DebugAsync($"Sent {bytes} Bytes").ConfigureAwait(false); //ApiClient.SentData += async bytes => await _audioLogger.DebugAsync($"Sent {bytes} Bytes").ConfigureAwait(false);
ApiClient.ReceivedEvent += ProcessMessageAsync; ApiClient.ReceivedEvent += ProcessMessageAsync;
ApiClient.ReceivedPacket += ProcessPacketAsync; ApiClient.ReceivedPacket += ProcessPacketAsync;
ApiClient.Disconnected += async ex =>
{
if (ex != null)
await _audioLogger.WarningAsync($"Connection Closed", ex).ConfigureAwait(false);
else
await _audioLogger.WarningAsync($"Connection Closed").ConfigureAwait(false);
};


LatencyUpdated += async (old, val) => await _audioLogger.VerboseAsync($"Latency = {val} ms").ConfigureAwait(false); LatencyUpdated += async (old, val) => await _audioLogger.VerboseAsync($"Latency = {val} ms").ConfigureAwait(false);
} }


/// <inheritdoc />
internal async Task ConnectAsync(string url, ulong userId, string sessionId, string token)
internal async Task StartAsync(string url, ulong userId, string sessionId, string token)
{ {
await _connectionLock.WaitAsync().ConfigureAwait(false);
try
{
await ConnectInternalAsync(url, userId, sessionId, token).ConfigureAwait(false);
}
finally { _connectionLock.Release(); }
_url = url;
_userId = userId;
_sessionId = sessionId;
_token = token;
await _connection.StartAsync().ConfigureAwait(false);
} }
private async Task ConnectInternalAsync(string url, ulong userId, string sessionId, string token)
{
var state = ConnectionState;
if (state == ConnectionState.Connecting || state == ConnectionState.Connected)
await DisconnectInternalAsync(null).ConfigureAwait(false);

ConnectionState = ConnectionState.Connecting;
await _audioLogger.InfoAsync("Connecting").ConfigureAwait(false);
try
{
_url = url;
_connectTask = new TaskCompletionSource<bool>();
_cancelTokenSource = new CancellationTokenSource();

await ApiClient.ConnectAsync("wss://" + url).ConfigureAwait(false);
await ApiClient.SendIdentityAsync(userId, sessionId, token).ConfigureAwait(false);
await _connectTask.Task.ConfigureAwait(false);
public async Task StopAsync()
=> await _connection.StopAsync().ConfigureAwait(false);


await _connectedEvent.InvokeAsync().ConfigureAwait(false);
ConnectionState = ConnectionState.Connected;
await _audioLogger.InfoAsync("Connected").ConfigureAwait(false);
}
catch (Exception)
{
await DisconnectInternalAsync(null).ConfigureAwait(false);
throw;
}
}
/// <inheritdoc />
public async Task DisconnectAsync()
private async Task OnConnectingAsync()
{ {
await _connectionLock.WaitAsync().ConfigureAwait(false);
try
{
await DisconnectInternalAsync(null).ConfigureAwait(false);
}
finally { _connectionLock.Release(); }
await ApiClient.ConnectAsync("wss://" + _url).ConfigureAwait(false);
await ApiClient.SendIdentityAsync(_userId, _sessionId, _token).ConfigureAwait(false);
} }
private async Task DisconnectAsync(Exception ex)
private async Task OnDisconnectingAsync(Exception ex)
{ {
await _connectionLock.WaitAsync().ConfigureAwait(false);
try
{
await DisconnectInternalAsync(ex).ConfigureAwait(false);
}
finally { _connectionLock.Release(); }
}
private async Task DisconnectInternalAsync(Exception ex)
{
if (ConnectionState == ConnectionState.Disconnected) return;
ConnectionState = ConnectionState.Disconnecting;
await _audioLogger.InfoAsync("Disconnecting").ConfigureAwait(false);

//Signal tasks to complete
try { _cancelTokenSource.Cancel(); } catch { }

//Disconnect from server //Disconnect from server
await ApiClient.DisconnectAsync().ConfigureAwait(false); await ApiClient.DisconnectAsync().ConfigureAwait(false);


@@ -162,17 +112,17 @@ namespace Discord.Audio
await heartbeatTask.ConfigureAwait(false); await heartbeatTask.ConfigureAwait(false);
_heartbeatTask = null; _heartbeatTask = null;


ConnectionState = ConnectionState.Disconnected;
await _audioLogger.InfoAsync("Disconnected").ConfigureAwait(false);
await _disconnectedEvent.InvokeAsync(ex).ConfigureAwait(false);

await Discord.ApiClient.SendVoiceStateUpdateAsync(Guild.Id, null, false, false).ConfigureAwait(false); await Discord.ApiClient.SendVoiceStateUpdateAsync(Guild.Id, null, false, false).ConfigureAwait(false);

long time;
while (_heartbeatTimes.TryDequeue(out time)) { }
_lastMessageTime = 0;
} }


public AudioOutStream CreateOpusStream(int samplesPerFrame, int bufferMillis) public AudioOutStream CreateOpusStream(int samplesPerFrame, int bufferMillis)
{ {
CheckSamplesPerFrame(samplesPerFrame); CheckSamplesPerFrame(samplesPerFrame);
var target = new BufferedAudioTarget(ApiClient, samplesPerFrame, bufferMillis, _cancelTokenSource.Token);
var target = new BufferedAudioTarget(ApiClient, samplesPerFrame, bufferMillis, _connection.CancelToken);
return new RTPWriteStream(target, _secretKey, samplesPerFrame, _ssrc); return new RTPWriteStream(target, _secretKey, samplesPerFrame, _ssrc);
} }
public AudioOutStream CreateDirectOpusStream(int samplesPerFrame) public AudioOutStream CreateDirectOpusStream(int samplesPerFrame)
@@ -184,7 +134,7 @@ namespace Discord.Audio
public AudioOutStream CreatePCMStream(int samplesPerFrame, int channels, int? bitrate, int bufferMillis) public AudioOutStream CreatePCMStream(int samplesPerFrame, int channels, int? bitrate, int bufferMillis)
{ {
CheckSamplesPerFrame(samplesPerFrame); CheckSamplesPerFrame(samplesPerFrame);
var target = new BufferedAudioTarget(ApiClient, samplesPerFrame, bufferMillis, _cancelTokenSource.Token);
var target = new BufferedAudioTarget(ApiClient, samplesPerFrame, bufferMillis, _connection.CancelToken);
return new OpusEncodeStream(target, _secretKey, channels, samplesPerFrame, _ssrc, bitrate); return new OpusEncodeStream(target, _secretKey, channels, samplesPerFrame, _ssrc, bitrate);
} }
public AudioOutStream CreateDirectPCMStream(int samplesPerFrame, int channels, int? bitrate) public AudioOutStream CreateDirectPCMStream(int samplesPerFrame, int channels, int? bitrate)
@@ -202,6 +152,8 @@ namespace Discord.Audio


private async Task ProcessMessageAsync(VoiceOpCode opCode, object payload) private async Task ProcessMessageAsync(VoiceOpCode opCode, object payload)
{ {
_lastMessageTime = Environment.TickCount;

try try
{ {
switch (opCode) switch (opCode)
@@ -216,8 +168,7 @@ namespace Discord.Audio
if (!data.Modes.Contains(DiscordVoiceAPIClient.Mode)) if (!data.Modes.Contains(DiscordVoiceAPIClient.Mode))
throw new InvalidOperationException($"Discord does not support {DiscordVoiceAPIClient.Mode}"); throw new InvalidOperationException($"Discord does not support {DiscordVoiceAPIClient.Mode}");


_heartbeatTime = 0;
_heartbeatTask = RunHeartbeatAsync(data.HeartbeatInterval, _cancelTokenSource.Token);
_heartbeatTask = RunHeartbeatAsync(data.HeartbeatInterval, _connection.CancelToken);
ApiClient.SetUdpEndpoint(_url, data.Port); ApiClient.SetUdpEndpoint(_url, data.Port);
await ApiClient.SendDiscoveryAsync(_ssrc).ConfigureAwait(false); await ApiClient.SendDiscoveryAsync(_ssrc).ConfigureAwait(false);
@@ -234,19 +185,17 @@ namespace Discord.Audio
_secretKey = data.SecretKey; _secretKey = data.SecretKey;
await ApiClient.SendSetSpeaking(true).ConfigureAwait(false); await ApiClient.SendSetSpeaking(true).ConfigureAwait(false);


var _ = _connectTask.TrySetResultAsync(true);
var _ = _connection.CompleteAsync();
} }
break; break;
case VoiceOpCode.HeartbeatAck: case VoiceOpCode.HeartbeatAck:
{ {
await _audioLogger.DebugAsync("Received HeartbeatAck").ConfigureAwait(false); await _audioLogger.DebugAsync("Received HeartbeatAck").ConfigureAwait(false);


var heartbeatTime = _heartbeatTime;
if (heartbeatTime != 0)
long time;
if (_heartbeatTimes.TryDequeue(out time))
{ {
int latency = (int)(Environment.TickCount - _heartbeatTime);
_heartbeatTime = 0;

int latency = (int)(Environment.TickCount - time);
int before = Latency; int before = Latency;
Latency = latency; Latency = latency;


@@ -267,7 +216,7 @@ namespace Discord.Audio
} }
private async Task ProcessPacketAsync(byte[] packet) private async Task ProcessPacketAsync(byte[] packet)
{ {
if (!_connectTask.Task.IsCompleted)
if (!_connection.IsCompleted)
{ {
if (packet.Length == 70) if (packet.Length == 70)
{ {
@@ -291,33 +240,50 @@ namespace Discord.Audio
//Clean this up when Discord's session patch is live //Clean this up when Discord's session patch is live
try try
{ {
await _audioLogger.DebugAsync("Heartbeat Started").ConfigureAwait(false);
while (!cancelToken.IsCancellationRequested) while (!cancelToken.IsCancellationRequested)
{ {
var now = Environment.TickCount;

//Did server respond to our last heartbeat, or are we still receiving messages (long load?)
if (_heartbeatTimes.Count != 0 && (now - _lastMessageTime) > intervalMillis &&
ConnectionState == ConnectionState.Connected)
{
_connection.Error(new Exception("Server missed last heartbeat"));
return;
}
_heartbeatTimes.Enqueue(now);

await Task.Delay(intervalMillis, cancelToken).ConfigureAwait(false); await Task.Delay(intervalMillis, cancelToken).ConfigureAwait(false);


if (_heartbeatTime != 0) //Server never responded to our last heartbeat
try
{ {
if (ConnectionState == ConnectionState.Connected)
{
await _audioLogger.WarningAsync("Server missed last heartbeat").ConfigureAwait(false);
await DisconnectInternalAsync(new Exception("Server missed last heartbeat")).ConfigureAwait(false);
return;
}
await ApiClient.SendHeartbeatAsync().ConfigureAwait(false);
} }
else
_heartbeatTime = Environment.TickCount;
await ApiClient.SendHeartbeatAsync().ConfigureAwait(false);
catch (Exception ex)
{
await _audioLogger.WarningAsync("Heartbeat Errored", ex).ConfigureAwait(false);
}

await Task.Delay(intervalMillis, cancelToken).ConfigureAwait(false);
} }
await _audioLogger.DebugAsync("Heartbeat Stopped").ConfigureAwait(false);
}
catch (OperationCanceledException)
{
await _audioLogger.DebugAsync("Heartbeat Stopped").ConfigureAwait(false);
}
catch (Exception ex)
{
await _audioLogger.ErrorAsync("Heartbeat Errored", ex).ConfigureAwait(false);
} }
catch (OperationCanceledException) { }
} }


internal void Dispose(bool disposing) internal void Dispose(bool disposing)
{ {
if (disposing && !_isDisposed)
if (disposing)
{ {
_isDisposed = true;
DisconnectInternalAsync(null).GetAwaiter().GetResult();
StopAsync().GetAwaiter().GetResult();
ApiClient.Dispose(); ApiClient.Dispose();
} }
} }


+ 6
- 37
src/Discord.Net.WebSocket/DiscordShardedClient.cs View File

@@ -72,7 +72,7 @@ namespace Discord.WebSocket
private static API.DiscordSocketApiClient CreateApiClient(DiscordSocketConfig config) private static API.DiscordSocketApiClient CreateApiClient(DiscordSocketConfig config)
=> new API.DiscordSocketApiClient(config.RestClientProvider, config.WebSocketProvider, DiscordRestConfig.UserAgent); => new API.DiscordSocketApiClient(config.RestClientProvider, config.WebSocketProvider, DiscordRestConfig.UserAgent);


protected override async Task OnLoginAsync(TokenType tokenType, string token)
internal override async Task OnLoginAsync(TokenType tokenType, string token)
{ {
if (_automaticShards) if (_automaticShards)
{ {
@@ -95,7 +95,7 @@ namespace Discord.WebSocket
for (int i = 0; i < _shards.Length; i++) for (int i = 0; i < _shards.Length; i++)
await _shards[i].LoginAsync(tokenType, token, false); await _shards[i].LoginAsync(tokenType, token, false);
} }
protected override async Task OnLogoutAsync()
internal override async Task OnLogoutAsync()
{ {
//Assume threadsafe: already in a connection lock //Assume threadsafe: already in a connection lock
for (int i = 0; i < _shards.Length; i++) for (int i = 0; i < _shards.Length; i++)
@@ -112,42 +112,14 @@ namespace Discord.WebSocket
} }


/// <inheritdoc /> /// <inheritdoc />
public async Task ConnectAsync()
public async Task StartAsync()
{ {
await _connectionLock.WaitAsync().ConfigureAwait(false);
try
{
await ConnectInternalAsync().ConfigureAwait(false);
}
catch
{
await DisconnectInternalAsync().ConfigureAwait(false);
throw;
}
finally { _connectionLock.Release(); }
}
private async Task ConnectInternalAsync()
{
await Task.WhenAll(
_shards.Select(x => x.ConnectAsync())
).ConfigureAwait(false);

CurrentUser = _shards[0].CurrentUser;
await Task.WhenAll(_shards.Select(x => x.StartAsync())).ConfigureAwait(false);
} }
/// <inheritdoc /> /// <inheritdoc />
public async Task DisconnectAsync()
public async Task StopAsync()
{ {
await _connectionLock.WaitAsync().ConfigureAwait(false);
try
{
await DisconnectInternalAsync().ConfigureAwait(false);
}
finally { _connectionLock.Release(); }
}
private async Task DisconnectInternalAsync()
{
for (int i = 0; i < _shards.Length; i++)
await _shards[i].DisconnectAsync();
await Task.WhenAll(_shards.Select(x => x.StopAsync())).ConfigureAwait(false);
} }


public DiscordSocketClient GetShard(int id) public DiscordSocketClient GetShard(int id)
@@ -334,9 +306,6 @@ namespace Discord.WebSocket
} }


//IDiscordClient //IDiscordClient
Task IDiscordClient.ConnectAsync()
=> ConnectAsync();

async Task<IApplication> IDiscordClient.GetApplicationInfoAsync() async Task<IApplication> IDiscordClient.GetApplicationInfoAsync()
=> await GetApplicationInfoAsync().ConfigureAwait(false); => await GetApplicationInfoAsync().ConfigureAwait(false);




+ 10
- 4
src/Discord.Net.WebSocket/DiscordSocketApiClient.cs View File

@@ -28,6 +28,7 @@ namespace Discord.API


private CancellationTokenSource _connectCancelToken; private CancellationTokenSource _connectCancelToken;
private string _gatewayUrl; private string _gatewayUrl;
private bool _isExplicitUrl;
internal IWebSocketClient WebSocketClient { get; } internal IWebSocketClient WebSocketClient { get; }


@@ -38,6 +39,8 @@ namespace Discord.API
: base(restClientProvider, userAgent, defaultRetryMode, serializer) : base(restClientProvider, userAgent, defaultRetryMode, serializer)
{ {
_gatewayUrl = url; _gatewayUrl = url;
if (url != null)
_isExplicitUrl = true;
WebSocketClient = webSocketProvider(); WebSocketClient = webSocketProvider();
//WebSocketClient.SetHeader("user-agent", DiscordConfig.UserAgent); (Causes issues in .NET Framework 4.6+) //WebSocketClient.SetHeader("user-agent", DiscordConfig.UserAgent); (Causes issues in .NET Framework 4.6+)
WebSocketClient.BinaryMessage += async (data, index, count) => WebSocketClient.BinaryMessage += async (data, index, count) =>
@@ -52,7 +55,8 @@ namespace Discord.API
using (var jsonReader = new JsonTextReader(reader)) using (var jsonReader = new JsonTextReader(reader))
{ {
var msg = _serializer.Deserialize<SocketFrame>(jsonReader); var msg = _serializer.Deserialize<SocketFrame>(jsonReader);
await _receivedGatewayEvent.InvokeAsync((GatewayOpCode)msg.Operation, msg.Sequence, msg.Type, msg.Payload).ConfigureAwait(false);
if (msg != null)
await _receivedGatewayEvent.InvokeAsync((GatewayOpCode)msg.Operation, msg.Sequence, msg.Type, msg.Payload).ConfigureAwait(false);
} }
} }
}; };
@@ -62,7 +66,8 @@ namespace Discord.API
using (var jsonReader = new JsonTextReader(reader)) using (var jsonReader = new JsonTextReader(reader))
{ {
var msg = _serializer.Deserialize<SocketFrame>(jsonReader); var msg = _serializer.Deserialize<SocketFrame>(jsonReader);
await _receivedGatewayEvent.InvokeAsync((GatewayOpCode)msg.Operation, msg.Sequence, msg.Type, msg.Payload).ConfigureAwait(false);
if (msg != null)
await _receivedGatewayEvent.InvokeAsync((GatewayOpCode)msg.Operation, msg.Sequence, msg.Type, msg.Payload).ConfigureAwait(false);
} }
}; };
WebSocketClient.Closed += async ex => WebSocketClient.Closed += async ex =>
@@ -107,7 +112,7 @@ namespace Discord.API
if (WebSocketClient != null) if (WebSocketClient != null)
WebSocketClient.SetCancelToken(_connectCancelToken.Token); WebSocketClient.SetCancelToken(_connectCancelToken.Token);


if (_gatewayUrl == null)
if (!_isExplicitUrl)
{ {
var gatewayResponse = await GetGatewayAsync().ConfigureAwait(false); var gatewayResponse = await GetGatewayAsync().ConfigureAwait(false);
_gatewayUrl = $"{gatewayResponse.Url}?v={DiscordConfig.APIVersion}&encoding={DiscordSocketConfig.GatewayEncoding}"; _gatewayUrl = $"{gatewayResponse.Url}?v={DiscordConfig.APIVersion}&encoding={DiscordSocketConfig.GatewayEncoding}";
@@ -118,7 +123,8 @@ namespace Discord.API
} }
catch catch
{ {
_gatewayUrl = null; //Uncache in case the gateway url changed
if (!_isExplicitUrl)
_gatewayUrl = null; //Uncache in case the gateway url changed
await DisconnectInternalAsync().ConfigureAwait(false); await DisconnectInternalAsync().ConfigureAwait(false);
throw; throw;
} }


+ 70
- 210
src/Discord.Net.WebSocket/DiscordSocketClient.cs View File

@@ -17,29 +17,27 @@ using System.Linq;
using System.Threading; using System.Threading;
using System.Threading.Tasks; using System.Threading.Tasks;
using GameModel = Discord.API.Game; using GameModel = Discord.API.Game;
using Discord.Net;


namespace Discord.WebSocket namespace Discord.WebSocket
{ {
public partial class DiscordSocketClient : BaseDiscordClient, IDiscordClient public partial class DiscordSocketClient : BaseDiscordClient, IDiscordClient
{ {
private readonly ConcurrentQueue<ulong> _largeGuilds; private readonly ConcurrentQueue<ulong> _largeGuilds;
private readonly Logger _gatewayLogger;
private readonly JsonSerializer _serializer; private readonly JsonSerializer _serializer;
private readonly SemaphoreSlim _connectionGroupLock; private readonly SemaphoreSlim _connectionGroupLock;
private readonly DiscordSocketClient _parentClient; private readonly DiscordSocketClient _parentClient;
private readonly ConcurrentQueue<long> _heartbeatTimes; private readonly ConcurrentQueue<long> _heartbeatTimes;
private readonly ConnectionManager _connection;
private readonly Logger _gatewayLogger;
private readonly SemaphoreSlim _stateLock;


private string _sessionId; private string _sessionId;
private int _lastSeq; private int _lastSeq;
private ImmutableDictionary<string, RestVoiceRegion> _voiceRegions; private ImmutableDictionary<string, RestVoiceRegion> _voiceRegions;
private TaskCompletionSource<bool> _connectTask;
private CancellationTokenSource _cancelToken, _reconnectCancelToken;
private Task _heartbeatTask, _guildDownloadTask, _reconnectTask;
private Task _heartbeatTask, _guildDownloadTask;
private int _unavailableGuilds; private int _unavailableGuilds;
private long _lastGuildAvailableTime, _lastMessageTime; private long _lastGuildAvailableTime, _lastMessageTime;
private int _nextAudioId; private int _nextAudioId;
private bool _canReconnect;
private DateTimeOffset? _statusSince; private DateTimeOffset? _statusSince;
private RestApplication _applicationInfo; private RestApplication _applicationInfo;
private ConcurrentHashSet<ulong> _downloadUsersFor; private ConcurrentHashSet<ulong> _downloadUsersFor;
@@ -59,7 +57,6 @@ namespace Discord.WebSocket
internal int LargeThreshold { get; private set; } internal int LargeThreshold { get; private set; }
internal AudioMode AudioMode { get; private set; } internal AudioMode AudioMode { get; private set; }
internal ClientState State { get; private set; } internal ClientState State { get; private set; }
internal int ConnectionTimeout { get; private set; }
internal UdpSocketProvider UdpSocketProvider { get; private set; } internal UdpSocketProvider UdpSocketProvider { get; private set; }
internal WebSocketProvider WebSocketProvider { get; private set; } internal WebSocketProvider WebSocketProvider { get; private set; }
internal bool AlwaysDownloadUsers { get; private set; } internal bool AlwaysDownloadUsers { get; private set; }
@@ -90,35 +87,28 @@ namespace Discord.WebSocket
UdpSocketProvider = config.UdpSocketProvider; UdpSocketProvider = config.UdpSocketProvider;
WebSocketProvider = config.WebSocketProvider; WebSocketProvider = config.WebSocketProvider;
AlwaysDownloadUsers = config.AlwaysDownloadUsers; AlwaysDownloadUsers = config.AlwaysDownloadUsers;
ConnectionTimeout = config.ConnectionTimeout;
State = new ClientState(0, 0); State = new ClientState(0, 0);
_downloadUsersFor = new ConcurrentHashSet<ulong>(); _downloadUsersFor = new ConcurrentHashSet<ulong>();
_heartbeatTimes = new ConcurrentQueue<long>(); _heartbeatTimes = new ConcurrentQueue<long>();

_stateLock = new SemaphoreSlim(1, 1);
_gatewayLogger = LogManager.CreateLogger(ShardId == 0 && TotalShards == 1 ? "Gateway" : $"Shard #{ShardId}");
_connection = new ConnectionManager(_stateLock, _gatewayLogger, config.ConnectionTimeout,
OnConnectingAsync, OnDisconnectingAsync, x => ApiClient.Disconnected += x);
_nextAudioId = 1; _nextAudioId = 1;
_gatewayLogger = LogManager.CreateLogger(ShardId == 0 && TotalShards == 1 ? "Gateway" : "Shard #" + ShardId);
_connectionGroupLock = groupLock; _connectionGroupLock = groupLock;
_parentClient = parentClient; _parentClient = parentClient;


_serializer = new JsonSerializer { ContractResolver = new DiscordContractResolver() }; _serializer = new JsonSerializer { ContractResolver = new DiscordContractResolver() };
_serializer.Error += (s, e) => _serializer.Error += (s, e) =>
{ {
_gatewayLogger.WarningAsync(e.ErrorContext.Error).GetAwaiter().GetResult();
_gatewayLogger.WarningAsync("Serializer Error", e.ErrorContext.Error).GetAwaiter().GetResult();
e.ErrorContext.Handled = true; e.ErrorContext.Handled = true;
}; };
ApiClient.SentGatewayMessage += async opCode => await _gatewayLogger.DebugAsync($"Sent {opCode}").ConfigureAwait(false); ApiClient.SentGatewayMessage += async opCode => await _gatewayLogger.DebugAsync($"Sent {opCode}").ConfigureAwait(false);
ApiClient.ReceivedGatewayEvent += ProcessMessageAsync; ApiClient.ReceivedGatewayEvent += ProcessMessageAsync;
ApiClient.Disconnected += async ex =>
{
if (ex != null)
{
await _gatewayLogger.WarningAsync($"Connection Closed", ex).ConfigureAwait(false);
await StartReconnectAsync(ex).ConfigureAwait(false);
}
else
await _gatewayLogger.WarningAsync($"Connection Closed").ConfigureAwait(false);
};


LeftGuild += async g => await _gatewayLogger.InfoAsync($"Left {g.Name}").ConfigureAwait(false); LeftGuild += async g => await _gatewayLogger.InfoAsync($"Left {g.Name}").ConfigureAwait(false);
JoinedGuild += async g => await _gatewayLogger.InfoAsync($"Joined {g.Name}").ConfigureAwait(false); JoinedGuild += async g => await _gatewayLogger.InfoAsync($"Joined {g.Name}").ConfigureAwait(false);
@@ -143,8 +133,16 @@ namespace Discord.WebSocket
} }
private static API.DiscordSocketApiClient CreateApiClient(DiscordSocketConfig config) private static API.DiscordSocketApiClient CreateApiClient(DiscordSocketConfig config)
=> new API.DiscordSocketApiClient(config.RestClientProvider, config.WebSocketProvider, DiscordRestConfig.UserAgent, config.GatewayHost); => new API.DiscordSocketApiClient(config.RestClientProvider, config.WebSocketProvider, DiscordRestConfig.UserAgent, config.GatewayHost);
internal override void Dispose(bool disposing)
{
if (disposing)
{
StopAsync().GetAwaiter().GetResult();
ApiClient.Dispose();
}
}
protected override async Task OnLoginAsync(TokenType tokenType, string token)
internal override async Task OnLoginAsync(TokenType tokenType, string token)
{ {
if (_parentClient == null) if (_parentClient == null)
{ {
@@ -154,92 +152,49 @@ namespace Discord.WebSocket
else else
_voiceRegions = _parentClient._voiceRegions; _voiceRegions = _parentClient._voiceRegions;
} }
protected override async Task OnLogoutAsync()
internal override async Task OnLogoutAsync()
{ {
if (ConnectionState != ConnectionState.Disconnected)
await DisconnectInternalAsync(null, false).ConfigureAwait(false);

await StopAsync().ConfigureAwait(false);
_applicationInfo = null; _applicationInfo = null;
_voiceRegions = ImmutableDictionary.Create<string, RestVoiceRegion>(); _voiceRegions = ImmutableDictionary.Create<string, RestVoiceRegion>();
_downloadUsersFor.Clear(); _downloadUsersFor.Clear();
} }

public async Task StartAsync()
=> await _connection.StartAsync().ConfigureAwait(false);
public async Task StopAsync()
=> await _connection.StopAsync().ConfigureAwait(false);
/// <inheritdoc />
public async Task ConnectAsync()
private async Task OnConnectingAsync()
{ {
await _connectionLock.WaitAsync().ConfigureAwait(false);
try
{
await ConnectInternalAsync(false).ConfigureAwait(false);
}
finally { _connectionLock.Release(); }
}
private async Task ConnectInternalAsync(bool isReconnecting)
{
if (LoginState != LoginState.LoggedIn)
throw new InvalidOperationException("Client is not logged in.");

if (!isReconnecting && _reconnectCancelToken != null && !_reconnectCancelToken.IsCancellationRequested)
_reconnectCancelToken.Cancel();

var state = ConnectionState;
if (state == ConnectionState.Connecting || state == ConnectionState.Connected)
await DisconnectInternalAsync(null, isReconnecting).ConfigureAwait(false);

if (_connectionGroupLock != null) if (_connectionGroupLock != null)
await _connectionGroupLock.WaitAsync().ConfigureAwait(false);
await _connectionGroupLock.WaitAsync(_connection.CancelToken).ConfigureAwait(false);
try try
{ {
_canReconnect = true;
ConnectionState = ConnectionState.Connecting;
await _gatewayLogger.InfoAsync("Connecting").ConfigureAwait(false);
try
{
var connectTask = new TaskCompletionSource<bool>();
_connectTask = connectTask;
_cancelToken = new CancellationTokenSource();

//Abort connection on timeout
var _ = Task.Run(async () =>
{
await Task.Delay(ConnectionTimeout).ConfigureAwait(false);
connectTask.TrySetException(new TimeoutException());
});

await _gatewayLogger.DebugAsync("Connecting ApiClient").ConfigureAwait(false);
await ApiClient.ConnectAsync().ConfigureAwait(false);
await _gatewayLogger.DebugAsync("Raising Event").ConfigureAwait(false);
await _connectedEvent.InvokeAsync().ConfigureAwait(false);

if (_sessionId != null)
{
await _gatewayLogger.DebugAsync("Resuming").ConfigureAwait(false);
await ApiClient.SendResumeAsync(_sessionId, _lastSeq).ConfigureAwait(false);
}
else
{
await _gatewayLogger.DebugAsync("Identifying").ConfigureAwait(false);
await ApiClient.SendIdentifyAsync(shardID: ShardId, totalShards: TotalShards).ConfigureAwait(false);
}

await _connectTask.Task.ConfigureAwait(false);
await _gatewayLogger.DebugAsync("Connecting ApiClient").ConfigureAwait(false);
await ApiClient.ConnectAsync().ConfigureAwait(false);
await _gatewayLogger.DebugAsync("Raising Event").ConfigureAwait(false);
await _connectedEvent.InvokeAsync().ConfigureAwait(false);


await _gatewayLogger.DebugAsync("Sending Status").ConfigureAwait(false);
await SendStatusAsync().ConfigureAwait(false);

await _gatewayLogger.DebugAsync("Raising Event").ConfigureAwait(false);
ConnectionState = ConnectionState.Connected;
await _gatewayLogger.InfoAsync("Connected").ConfigureAwait(false);

await ProcessUserDownloadsAsync(_downloadUsersFor.Select(x => GetGuild(x))
.Where(x => x != null).ToImmutableArray()).ConfigureAwait(false);
if (_sessionId != null)
{
await _gatewayLogger.DebugAsync("Resuming").ConfigureAwait(false);
await ApiClient.SendResumeAsync(_sessionId, _lastSeq).ConfigureAwait(false);
} }
catch (Exception)
else
{ {
await DisconnectInternalAsync(null, isReconnecting).ConfigureAwait(false);
throw;
await _gatewayLogger.DebugAsync("Identifying").ConfigureAwait(false);
await ApiClient.SendIdentifyAsync(shardID: ShardId, totalShards: TotalShards).ConfigureAwait(false);
} }

//Wait for READY
await _connection.WaitAsync().ConfigureAwait(false);

await _gatewayLogger.DebugAsync("Sending Status").ConfigureAwait(false);
await SendStatusAsync().ConfigureAwait(false);

await ProcessUserDownloadsAsync(_downloadUsersFor.Select(x => GetGuild(x))
.Where(x => x != null).ToImmutableArray()).ConfigureAwait(false);
} }
finally finally
{ {
@@ -250,41 +205,11 @@ namespace Discord.WebSocket
} }
} }
} }
/// <inheritdoc />
public async Task DisconnectAsync()
{
if (_connectTask?.TrySetCanceled() ?? false) return;
await _connectionLock.WaitAsync().ConfigureAwait(false);
try
{
await DisconnectInternalAsync(null, false).ConfigureAwait(false);
}
finally { _connectionLock.Release(); }
}
private async Task DisconnectInternalAsync(Exception ex, bool isReconnecting)
private async Task OnDisconnectingAsync(Exception ex)
{ {
if (!isReconnecting)
{
_canReconnect = false;
_sessionId = null;
_lastSeq = 0;

if (_reconnectCancelToken != null && !_reconnectCancelToken.IsCancellationRequested)
_reconnectCancelToken.Cancel();
}

ulong guildId; ulong guildId;


if (ConnectionState == ConnectionState.Disconnected) return;
ConnectionState = ConnectionState.Disconnecting;
await _gatewayLogger.InfoAsync("Disconnecting").ConfigureAwait(false);

await _gatewayLogger.DebugAsync("Cancelling current tasks").ConfigureAwait(false);
//Signal tasks to complete
try { _cancelToken.Cancel(); } catch { }

await _gatewayLogger.DebugAsync("Disconnecting ApiClient").ConfigureAwait(false); await _gatewayLogger.DebugAsync("Disconnecting ApiClient").ConfigureAwait(false);
//Disconnect from server
await ApiClient.DisconnectAsync().ConfigureAwait(false); await ApiClient.DisconnectAsync().ConfigureAwait(false);


//Wait for tasks to complete //Wait for tasks to complete
@@ -294,8 +219,8 @@ namespace Discord.WebSocket
await heartbeatTask.ConfigureAwait(false); await heartbeatTask.ConfigureAwait(false);
_heartbeatTask = null; _heartbeatTask = null;


long times;
while (_heartbeatTimes.TryDequeue(out times)) { }
long time;
while (_heartbeatTimes.TryDequeue(out time)) { }
_lastMessageTime = 0; _lastMessageTime = 0;


await _gatewayLogger.DebugAsync("Waiting for guild downloader").ConfigureAwait(false); await _gatewayLogger.DebugAsync("Waiting for guild downloader").ConfigureAwait(false);
@@ -315,70 +240,6 @@ namespace Discord.WebSocket
if (guild._available) if (guild._available)
await _guildUnavailableEvent.InvokeAsync(guild).ConfigureAwait(false); await _guildUnavailableEvent.InvokeAsync(guild).ConfigureAwait(false);
} }

ConnectionState = ConnectionState.Disconnected;
await _gatewayLogger.InfoAsync("Disconnected").ConfigureAwait(false);

await _disconnectedEvent.InvokeAsync(ex).ConfigureAwait(false);
}

private async Task StartReconnectAsync(Exception ex)
{
if ((ex as WebSocketClosedException)?.CloseCode == 4004) //Bad Token
{
_canReconnect = false;
_connectTask?.TrySetException(ex);
await LogoutAsync().ConfigureAwait(false);
return;
}

await _connectionLock.WaitAsync().ConfigureAwait(false);
try
{
if (!_canReconnect || _reconnectTask != null) return;
_reconnectCancelToken = new CancellationTokenSource();
_reconnectTask = ReconnectInternalAsync(ex, _reconnectCancelToken.Token);
}
finally { _connectionLock.Release(); }
}
private async Task ReconnectInternalAsync(Exception ex, CancellationToken cancelToken)
{
try
{
Random jitter = new Random();
int nextReconnectDelay = 1000;
while (true)
{
await Task.Delay(nextReconnectDelay, cancelToken).ConfigureAwait(false);
nextReconnectDelay = nextReconnectDelay * 2 + jitter.Next(-250, 250);
if (nextReconnectDelay > 60000)
nextReconnectDelay = 60000;

await _connectionLock.WaitAsync().ConfigureAwait(false);
try
{
if (cancelToken.IsCancellationRequested) return;
await ConnectInternalAsync(true).ConfigureAwait(false);
_reconnectTask = null;
return;
}
catch (Exception ex2)
{
await _gatewayLogger.WarningAsync("Reconnect failed", ex2).ConfigureAwait(false);
}
finally { _connectionLock.Release(); }
}
}
catch (OperationCanceledException)
{
await _connectionLock.WaitAsync().ConfigureAwait(false);
try
{
await _gatewayLogger.DebugAsync("Reconnect cancelled").ConfigureAwait(false);
_reconnectTask = null;
}
finally { _connectionLock.Release(); }
}
} }


/// <inheritdoc /> /// <inheritdoc />
@@ -555,7 +416,7 @@ namespace Discord.WebSocket
await _gatewayLogger.DebugAsync("Received Hello").ConfigureAwait(false); await _gatewayLogger.DebugAsync("Received Hello").ConfigureAwait(false);
var data = (payload as JToken).ToObject<HelloEvent>(_serializer); var data = (payload as JToken).ToObject<HelloEvent>(_serializer);


_heartbeatTask = RunHeartbeatAsync(data.HeartbeatInterval, _cancelToken.Token, _gatewayLogger);
_heartbeatTask = RunHeartbeatAsync(data.HeartbeatInterval, _connection.CancelToken);
} }
break; break;
case GatewayOpCode.Heartbeat: case GatewayOpCode.Heartbeat:
@@ -593,9 +454,7 @@ namespace Discord.WebSocket
case GatewayOpCode.Reconnect: case GatewayOpCode.Reconnect:
{ {
await _gatewayLogger.DebugAsync("Received Reconnect").ConfigureAwait(false); await _gatewayLogger.DebugAsync("Received Reconnect").ConfigureAwait(false);
await _gatewayLogger.WarningAsync("Server requested a reconnect").ConfigureAwait(false);

await StartReconnectAsync(new Exception("Server requested a reconnect")).ConfigureAwait(false);
_connection.Error(new Exception("Server requested a reconnect"));
} }
break; break;
case GatewayOpCode.Dispatch: case GatewayOpCode.Dispatch:
@@ -633,8 +492,7 @@ namespace Discord.WebSocket
} }
catch (Exception ex) catch (Exception ex)
{ {
_canReconnect = false;
_connectTask.TrySetException(new Exception("Processing READY failed", ex));
_connection.CriticalError(new Exception("Processing READY failed", ex));
return; return;
} }


@@ -642,11 +500,11 @@ namespace Discord.WebSocket
await SyncGuildsAsync().ConfigureAwait(false); await SyncGuildsAsync().ConfigureAwait(false);


_lastGuildAvailableTime = Environment.TickCount; _lastGuildAvailableTime = Environment.TickCount;
_guildDownloadTask = WaitForGuildsAsync(_cancelToken.Token, _gatewayLogger);
_guildDownloadTask = WaitForGuildsAsync(_connection.CancelToken, _gatewayLogger);


await _readyEvent.InvokeAsync().ConfigureAwait(false); await _readyEvent.InvokeAsync().ConfigureAwait(false);


var _ = _connectTask.TrySetResultAsync(true); //Signal the .Connect() call to complete
var _ = _connection.CompleteAsync();
await _gatewayLogger.InfoAsync("Ready").ConfigureAwait(false); await _gatewayLogger.InfoAsync("Ready").ConfigureAwait(false);
} }
break; break;
@@ -654,7 +512,7 @@ namespace Discord.WebSocket
{ {
await _gatewayLogger.DebugAsync("Received Dispatch (RESUMED)").ConfigureAwait(false); await _gatewayLogger.DebugAsync("Received Dispatch (RESUMED)").ConfigureAwait(false);


var _ = _connectTask.TrySetResultAsync(true); //Signal the .Connect() call to complete
var _ = _connection.CompleteAsync();


//Notify the client that these guilds are available again //Notify the client that these guilds are available again
foreach (var guild in State.Guilds) foreach (var guild in State.Guilds)
@@ -1356,7 +1214,6 @@ namespace Discord.WebSocket
SocketUserMessage cachedMsg = channel.GetCachedMessage(data.MessageId) as SocketUserMessage; SocketUserMessage cachedMsg = channel.GetCachedMessage(data.MessageId) as SocketUserMessage;
var user = await channel.GetUserAsync(data.UserId, CacheMode.CacheOnly); var user = await channel.GetUserAsync(data.UserId, CacheMode.CacheOnly);
SocketReaction reaction = SocketReaction.Create(data, channel, cachedMsg, Optional.Create(user)); SocketReaction reaction = SocketReaction.Create(data, channel, cachedMsg, Optional.Create(user));

if (cachedMsg != null) if (cachedMsg != null)
{ {
cachedMsg.AddReaction(reaction); cachedMsg.AddReaction(reaction);
@@ -1691,11 +1548,11 @@ namespace Discord.WebSocket
} }
} }


private async Task RunHeartbeatAsync(int intervalMillis, CancellationToken cancelToken, Logger logger)
private async Task RunHeartbeatAsync(int intervalMillis, CancellationToken cancelToken)
{ {
try try
{ {
await logger.DebugAsync("Heartbeat Started").ConfigureAwait(false);
await _gatewayLogger.DebugAsync("Heartbeat Started").ConfigureAwait(false);
while (!cancelToken.IsCancellationRequested) while (!cancelToken.IsCancellationRequested)
{ {
var now = Environment.TickCount; var now = Environment.TickCount;
@@ -1705,8 +1562,7 @@ namespace Discord.WebSocket
{ {
if (ConnectionState == ConnectionState.Connected && (_guildDownloadTask?.IsCompleted ?? true)) if (ConnectionState == ConnectionState.Connected && (_guildDownloadTask?.IsCompleted ?? true))
{ {
await _gatewayLogger.WarningAsync("Server missed last heartbeat").ConfigureAwait(false);
await StartReconnectAsync(new Exception("Server missed last heartbeat")).ConfigureAwait(false);
_connection.Error(new Exception("Server missed last heartbeat"));
return; return;
} }
} }
@@ -1718,20 +1574,20 @@ namespace Discord.WebSocket
} }
catch (Exception ex) catch (Exception ex)
{ {
await logger.WarningAsync("Heartbeat Errored", ex).ConfigureAwait(false);
await _gatewayLogger.WarningAsync("Heartbeat Errored", ex).ConfigureAwait(false);
} }


await Task.Delay(intervalMillis, cancelToken).ConfigureAwait(false); await Task.Delay(intervalMillis, cancelToken).ConfigureAwait(false);
} }
await logger.DebugAsync("Heartbeat Stopped").ConfigureAwait(false);
await _gatewayLogger.DebugAsync("Heartbeat Stopped").ConfigureAwait(false);
} }
catch (OperationCanceledException) catch (OperationCanceledException)
{ {
await logger.DebugAsync("Heartbeat Stopped").ConfigureAwait(false);
await _gatewayLogger.DebugAsync("Heartbeat Stopped").ConfigureAwait(false);
} }
catch (Exception ex) catch (Exception ex)
{ {
await logger.ErrorAsync("Heartbeat Errored", ex).ConfigureAwait(false);
await _gatewayLogger.ErrorAsync("Heartbeat Errored", ex).ConfigureAwait(false);
} }
} }
public async Task WaitForGuildsAsync() public async Task WaitForGuildsAsync()
@@ -1805,8 +1661,7 @@ namespace Discord.WebSocket
} }


//IDiscordClient //IDiscordClient
Task IDiscordClient.ConnectAsync()
=> ConnectAsync();
ConnectionState IDiscordClient.ConnectionState => _connection.State;


async Task<IApplication> IDiscordClient.GetApplicationInfoAsync() async Task<IApplication> IDiscordClient.GetApplicationInfoAsync()
=> await GetApplicationInfoAsync().ConfigureAwait(false); => await GetApplicationInfoAsync().ConfigureAwait(false);
@@ -1842,5 +1697,10 @@ namespace Discord.WebSocket
=> Task.FromResult<IReadOnlyCollection<IVoiceRegion>>(VoiceRegions); => Task.FromResult<IReadOnlyCollection<IVoiceRegion>>(VoiceRegions);
Task<IVoiceRegion> IDiscordClient.GetVoiceRegionAsync(string id) Task<IVoiceRegion> IDiscordClient.GetVoiceRegionAsync(string id)
=> Task.FromResult<IVoiceRegion>(GetVoiceRegion(id)); => Task.FromResult<IVoiceRegion>(GetVoiceRegion(id));

async Task IDiscordClient.StartAsync()
=> await StartAsync().ConfigureAwait(false);
async Task IDiscordClient.StopAsync()
=> await StopAsync().ConfigureAwait(false);
} }
} }

+ 1
- 1
src/Discord.Net.WebSocket/DiscordSocketConfig.cs View File

@@ -9,7 +9,7 @@ namespace Discord.WebSocket
{ {
public const string GatewayEncoding = "json"; public const string GatewayEncoding = "json";


/// <summary> Gets or sets the websocket host to connect to. If null, the client will use the /gateway endpoint.
/// <summary> Gets or sets the websocket host to connect to. If null, the client will use the /gateway endpoint. </summary>
public string GatewayHost { get; set; } = null; public string GatewayHost { get; set; } = null;


/// <summary> Gets or sets the time, in milliseconds, to wait for a connection to complete before aborting. </summary> /// <summary> Gets or sets the time, in milliseconds, to wait for a connection to complete before aborting. </summary>


+ 7
- 34
src/Discord.Net.WebSocket/Entities/Guilds/SocketGuild.cs View File

@@ -501,7 +501,7 @@ namespace Discord.WebSocket
_audioConnectPromise?.TrySetCanceledAsync(); //Cancel any previous audio connection _audioConnectPromise?.TrySetCanceledAsync(); //Cancel any previous audio connection
_audioConnectPromise = null; _audioConnectPromise = null;
if (_audioClient != null) if (_audioClient != null)
await _audioClient.DisconnectAsync().ConfigureAwait(false);
await _audioClient.StopAsync().ConfigureAwait(false);
_audioClient = null; _audioClient = null;
} }
internal async Task FinishConnectAudio(int id, string url, string token) internal async Task FinishConnectAudio(int id, string url, string token)
@@ -517,7 +517,6 @@ namespace Discord.WebSocket
var promise = _audioConnectPromise; var promise = _audioConnectPromise;
audioClient.Disconnected += async ex => audioClient.Disconnected += async ex =>
{ {
//If the initial connection hasn't been made yet, reconnecting will lead to deadlocks
if (!promise.Task.IsCompleted) if (!promise.Task.IsCompleted)
{ {
try { audioClient.Dispose(); } catch { } try { audioClient.Dispose(); } catch { }
@@ -528,41 +527,15 @@ namespace Discord.WebSocket
await promise.TrySetCanceledAsync(); await promise.TrySetCanceledAsync();
return; return;
} }

//TODO: Implement reconnect
/*await _audioLock.WaitAsync().ConfigureAwait(false);
try
{
if (AudioClient == audioClient) //Only reconnect if we're still assigned as this guild's audio client
{
if (ex != null)
{
//Reconnect if we still have channel info.
//TODO: Is this threadsafe? Could channel data be deleted before we access it?
var voiceState2 = GetVoiceState(Discord.CurrentUser.Id);
if (voiceState2.HasValue)
{
var voiceChannelId = voiceState2.Value.VoiceChannel?.Id;
if (voiceChannelId != null)
{
await Discord.ApiClient.SendVoiceStateUpdateAsync(Id, voiceChannelId, voiceState2.Value.IsSelfDeafened, voiceState2.Value.IsSelfMuted);
return;
}
}
}
try { audioClient.Dispose(); } catch { }
AudioClient = null;
}
}
finally
{
_audioLock.Release();
}*/
}; };
_audioClient = audioClient; _audioClient = audioClient;
} }
await _audioClient.ConnectAsync(url, Discord.CurrentUser.Id, voiceState.VoiceSessionId, token).ConfigureAwait(false);
await _audioConnectPromise.TrySetResultAsync(_audioClient).ConfigureAwait(false);
_audioClient.Connected += () =>
{
var _ = _audioConnectPromise.TrySetResultAsync(_audioClient);
return Task.Delay(0);
};
await _audioClient.StartAsync(url, Discord.CurrentUser.Id, voiceState.VoiceSessionId, token).ConfigureAwait(false);
} }
catch (OperationCanceledException) catch (OperationCanceledException)
{ {


Loading…
Cancel
Save