diff --git a/src/Discord.Net/DiscordClient.Events.cs b/src/Discord.Net/DiscordClient.Events.cs index 8704ec856..f8b1da5eb 100644 --- a/src/Discord.Net/DiscordClient.Events.cs +++ b/src/Discord.Net/DiscordClient.Events.cs @@ -22,6 +22,12 @@ namespace Discord VoiceWebSocket, } + public class DisconnectedEventArgs : EventArgs + { + public readonly bool WasUnexpected; + public readonly Exception Error; + internal DisconnectedEventArgs(bool wasUnexpected, Exception error) { WasUnexpected = wasUnexpected; Error = error; } + } public sealed class LogMessageEventArgs : EventArgs { public LogMessageSeverity Severity { get; } @@ -30,6 +36,7 @@ namespace Discord internal LogMessageEventArgs(LogMessageSeverity severity, LogMessageSource source, string msg) { Severity = severity; Source = source; Message = msg; } } + public sealed class ServerEventArgs : EventArgs { public Server Server { get; } @@ -136,11 +143,11 @@ namespace Discord if (Connected != null) Connected(this, EventArgs.Empty); } - public event EventHandler Disconnected; - private void RaiseDisconnected() + public event EventHandler Disconnected; + private void RaiseDisconnected(DisconnectedEventArgs e) { if (Disconnected != null) - Disconnected(this, EventArgs.Empty); + Disconnected(this, e); } public event EventHandler LogMessage; internal void RaiseOnLog(LogMessageSeverity severity, LogMessageSource source, string message) @@ -308,11 +315,11 @@ namespace Discord if (VoiceConnected != null) VoiceConnected(this, EventArgs.Empty); } - public event EventHandler VoiceDisconnected; - private void RaiseVoiceDisconnected() + public event EventHandler VoiceDisconnected; + private void RaiseVoiceDisconnected(DisconnectedEventArgs e) { if (VoiceDisconnected != null) - VoiceDisconnected(this, EventArgs.Empty); + VoiceDisconnected(this, e); } /*public event EventHandler VoiceServerChanged; private void RaiseVoiceServerUpdated(Server server, string endpoint) diff --git a/src/Discord.Net/DiscordClient.Voice.cs b/src/Discord.Net/DiscordClient.Voice.cs index a0f1d84ff..e64f9972a 100644 --- a/src/Discord.Net/DiscordClient.Voice.cs +++ b/src/Discord.Net/DiscordClient.Voice.cs @@ -38,7 +38,10 @@ namespace Discord public void SendVoicePCM(byte[] data, int count) { CheckReady(checkVoice: true); + if (data == null) throw new ArgumentException(nameof(data)); + if (count < 0) throw new ArgumentOutOfRangeException(nameof(count)); if (count == 0) return; + _voiceSocket.SendPCMFrames(data, count); } @@ -46,6 +49,7 @@ namespace Discord public void ClearVoicePCM() { CheckReady(checkVoice: true); + _voiceSocket.ClearPCMFrames(); } @@ -53,6 +57,7 @@ namespace Discord public async Task WaitVoice() { CheckReady(checkVoice: true); + _voiceSocket.Wait(); await TaskHelper.CompletedTask.ConfigureAwait(false); } diff --git a/src/Discord.Net/DiscordClient.cs b/src/Discord.Net/DiscordClient.cs index ca781778c..416ac6759 100644 --- a/src/Discord.Net/DiscordClient.cs +++ b/src/Discord.Net/DiscordClient.cs @@ -35,6 +35,7 @@ namespace Discord private Task _runTask; protected ExceptionDispatchInfo _disconnectReason; private bool _wasDisconnectUnexpected; + private string _token; /// Returns the id of the current logged-in user. public string CurrentUserId => _currentUserId; @@ -86,6 +87,7 @@ namespace Discord _config.Lock(); _state = (int)DiscordClientState.Disconnected; + _cancelToken = new CancellationToken(true); _disconnectedEvent = new ManualResetEvent(true); _connectedEvent = new ManualResetEventSlim(false); _rand = new Random(); @@ -93,8 +95,13 @@ namespace Discord _api = new DiscordAPIClient(_config.LogLevel); _dataSocket = new DataWebSocket(this); _dataSocket.Connected += (s, e) => { if (_state == (int)DiscordClientState.Connecting) CompleteConnect(); }; + _dataSocket.Disconnected += async (s, e) => { RaiseDisconnected(e); if (e.WasUnexpected) await Connect(_token); /*await _dataSocket.Reconnect(_cancelToken);*/ }; if (_config.EnableVoice) + { _voiceSocket = new VoiceWebSocket(this); + _voiceSocket.Connected += (s, e) => RaiseVoiceConnected(); + _voiceSocket.Disconnected += async (s, e) => { RaiseVoiceDisconnected(e); if (e.WasUnexpected) await _voiceSocket.Reconnect(_cancelToken); }; + } _channels = new Channels(this); _members = new Members(this); @@ -108,10 +115,6 @@ namespace Discord _voiceSocket.LogMessage += (s, e) => RaiseOnLog(e.Severity, LogMessageSource.VoiceWebSocket, e.Message); if (_config.LogLevel >= LogMessageSeverity.Info) { - Connected += (s, e) => RaiseOnLog(LogMessageSeverity.Info, LogMessageSource.Client, "Connected"); - Disconnected += (s, e) => RaiseOnLog(LogMessageSeverity.Info, LogMessageSource.Client, "Disconnected"); - VoiceConnected += (s, e) => RaiseOnLog(LogMessageSeverity.Info, LogMessageSource.Client, $"Voice Connected"); - VoiceDisconnected += (s, e) => RaiseOnLog(LogMessageSeverity.Info, LogMessageSource.Client, $"Voice Disconnected"); _dataSocket.Connected += (s, e) => RaiseOnLog(LogMessageSeverity.Info, LogMessageSource.DataWebSocket, "Connected"); _dataSocket.Disconnected += (s, e) => RaiseOnLog(LogMessageSeverity.Info, LogMessageSource.DataWebSocket, "Disconnected"); //_dataSocket.ReceivedEvent += (s, e) => RaiseOnLog(LogMessageSeverity.Info, LogMessageSource.DataWebSocket, $"Received {e.Type}"); @@ -604,6 +607,7 @@ namespace Discord } //_state = (int)DiscordClientState.Connected; + _token = token; return token; } catch @@ -616,23 +620,24 @@ namespace Discord { _state = (int)WebSocketState.Connected; _connectedEvent.Set(); + RaiseConnected(); } /// Disconnects from the Discord server, canceling any pending requests. public Task Disconnect() => DisconnectInternal(new Exception("Disconnect was requested by user."), isUnexpected: false); - protected async Task DisconnectInternal(Exception ex, bool isUnexpected = true, bool skipAwait = false) + protected Task DisconnectInternal(Exception ex, bool isUnexpected = true, bool skipAwait = false) { int oldState; bool hasWriterLock; //If in either connecting or connected state, get a lock by being the first to switch to disconnecting oldState = Interlocked.CompareExchange(ref _state, (int)DiscordClientState.Disconnecting, (int)DiscordClientState.Connecting); - if (oldState == (int)DiscordClientState.Disconnected) return; //Already disconnected + if (oldState == (int)DiscordClientState.Disconnected) return TaskHelper.CompletedTask; //Already disconnected hasWriterLock = oldState == (int)DiscordClientState.Connecting; //Caused state change if (!hasWriterLock) { oldState = Interlocked.CompareExchange(ref _state, (int)DiscordClientState.Disconnecting, (int)DiscordClientState.Connected); - if (oldState == (int)DiscordClientState.Disconnected) return; //Already disconnected + if (oldState == (int)DiscordClientState.Disconnected) return TaskHelper.CompletedTask; //Already disconnected hasWriterLock = oldState == (int)DiscordClientState.Connected; //Caused state change } @@ -644,18 +649,9 @@ namespace Discord } if (!skipAwait) - { - Task task = _runTask; - if (task != null) - await task.ConfigureAwait(false); - } - - if (hasWriterLock) - { - _state = (int)DiscordClientState.Disconnected; - _disconnectedEvent.Set(); - _connectedEvent.Reset(); - } + return _runTask ?? TaskHelper.CompletedTask; + else + return TaskHelper.CompletedTask; } private async Task RunTasks() @@ -672,13 +668,14 @@ namespace Discord } catch (Exception ex) { await DisconnectInternal(ex, skipAwait: true).ConfigureAwait(false); } - await Cleanup().ConfigureAwait(false); + bool wasUnexpected = _wasDisconnectUnexpected; + _wasDisconnectUnexpected = false; + + await Cleanup(wasUnexpected).ConfigureAwait(false); _runTask = null; } - private async Task Cleanup() + private async Task Cleanup(bool wasUnexpected) { - _disconnectedEvent.Set(); - await _dataSocket.Disconnect().ConfigureAwait(false); if (_config.EnableVoice) await _voiceSocket.Disconnect().ConfigureAwait(false); @@ -695,6 +692,14 @@ namespace Discord _currentUser = null; _currentUserId = null; + _token = null; + + if (!wasUnexpected) + { + _state = (int)DiscordClientState.Disconnected; + _disconnectedEvent.Set(); + } + _connectedEvent.Reset(); } //Helpers diff --git a/src/Discord.Net/Net/WebSockets/VoiceWebSocket.cs b/src/Discord.Net/Net/WebSockets/VoiceWebSocket.cs index 8c13f2bde..38b64f8f5 100644 --- a/src/Discord.Net/Net/WebSockets/VoiceWebSocket.cs +++ b/src/Discord.Net/Net/WebSockets/VoiceWebSocket.cs @@ -16,7 +16,10 @@ namespace Discord.Net.WebSockets { internal partial class VoiceWebSocket : WebSocket { - private readonly int _targetAudioBufferLength; + private const string EncryptedMode = "xsalsa20_poly1305"; + private const string UnencryptedMode = "plain"; + + private readonly int _targetAudioBufferLength; private ManualResetEventSlim _connectWaitOnLogin; private uint _ssrc; private readonly Random _rand = new Random(); @@ -26,11 +29,9 @@ namespace Discord.Net.WebSockets private ManualResetEventSlim _sendQueueWait, _sendQueueEmptyWait; private UdpClient _udp; private IPEndPoint _endpoint; - private bool _isReady, _isClearing; + private bool _isClearing, _isEncrypted; private byte[] _secretKey; - private string _myIp; private ushort _sequence; - private string _mode; private byte[] _encodingBuffer; private string _serverId, _userId, _sessionId, _token; @@ -52,24 +53,32 @@ namespace Discord.Net.WebSockets _targetAudioBufferLength = client.Config.VoiceBufferLength / 20; //20 ms frames } - public Task Login(string host, string serverId, string userId, string sessionId, string token, CancellationToken cancelToken) + public async Task Login(string host, string serverId, string userId, string sessionId, string token, CancellationToken cancelToken) { + if (_serverId == serverId && _userId == userId && _sessionId == sessionId && _token == token) + { + //Adjust the host and tell the system to reconnect + _host = host; + await DisconnectInternal(new Exception("Server transfer occurred.")); + return; + } + _serverId = serverId; _userId = userId; _sessionId = sessionId; _token = token; - return base.Connect(host, cancelToken); + await Connect(host, cancelToken); } protected override Task[] Run() { + _isClearing = false; + _udp = new UdpClient(new IPEndPoint(IPAddress.Any, 0)); #if !DNX451 _udp.AllowNatTraversal(true); #endif - _isReady = false; - _isClearing = false; VoiceCommands.Login msg = new VoiceCommands.Login(); msg.Payload.ServerId = _serverId; @@ -93,19 +102,24 @@ namespace Discord.Net.WebSockets #endif }.Concat(base.Run()).ToArray(); } - protected override Task Cleanup() + protected override Task Cleanup(bool wasUnexpected) { - ClearPCMFrames(); - _udp = null; - _serverId = null; - _userId = null; - _sessionId = null; - _token = null; #if USE_THREAD _sendThread.Join(); _sendThread = null; #endif - return base.Cleanup(); + + ClearPCMFrames(); + if (!wasUnexpected) + { + _serverId = null; + _userId = null; + _sessionId = null; + _token = null; + } + _udp = null; + + return base.Cleanup(wasUnexpected); } private async Task ReceiveVoiceAsync() @@ -153,14 +167,16 @@ namespace Discord.Net.WebSockets byte[] packet; try { - while (!cancelToken.IsCancellationRequested && !_isReady) + while (!cancelToken.IsCancellationRequested && _state != (int)WebSocketState.Connected) + { #if USE_THREAD Thread.Sleep(1); #else await Task.Delay(1); #endif + } - if (cancelToken.IsCancellationRequested) + if (cancelToken.IsCancellationRequested) return; uint timestamp = 0; @@ -251,15 +267,15 @@ namespace Discord.Net.WebSockets { case 2: //READY { - if (!_isReady) + if (_state != (int)WebSocketState.Connected) { var payload = (msg.Payload as JToken).ToObject(); _heartbeatInterval = payload.HeartbeatInterval; _ssrc = payload.SSRC; _endpoint = new IPEndPoint((await Dns.GetHostAddressesAsync(_host.Replace("wss://", "")).ConfigureAwait(false)).FirstOrDefault(), payload.Port); //_mode = payload.Modes.LastOrDefault(); - _mode = "plain"; - _udp.Connect(_endpoint); + _isEncrypted = !payload.Modes.Contains("plain"); + _udp.Connect(_endpoint); _sequence = (ushort)_rand.Next(0, ushort.MaxValue); //No thread issue here because SendAsync doesn't start until _isReady is true @@ -297,9 +313,8 @@ namespace Discord.Net.WebSockets { byte[] buffer = msg.Buffer; int length = msg.Buffer.Length; - if (!_isReady) + if (_state != (int)WebSocketState.Connected) { - _isReady = true; if (length != 70) { if (_logLevel >= LogMessageSeverity.Warning) @@ -308,15 +323,15 @@ namespace Discord.Net.WebSockets } int port = buffer[68] | buffer[69] << 8; + string ip = Encoding.ASCII.GetString(buffer, 4, 70 - 6).TrimEnd('\0'); - _myIp = Encoding.ASCII.GetString(buffer, 4, 70 - 6).TrimEnd('\0'); + CompleteConnect(); - _isReady = true; var login2 = new VoiceCommands.Login2(); login2.Payload.Protocol = "udp"; - login2.Payload.SocketData.Address = _myIp; - login2.Payload.SocketData.Mode = _mode; - login2.Payload.SocketData.Port = port; + login2.Payload.SocketData.Address = ip; + login2.Payload.SocketData.Mode = _isEncrypted ? EncryptedMode : UnencryptedMode; + login2.Payload.SocketData.Port = port; QueueMessage(login2); } else @@ -377,8 +392,8 @@ namespace Discord.Net.WebSockets buffer = newBuffer; }*/ - if (_logLevel >= LogMessageSeverity.Verbose) - RaiseOnLog(LogMessageSeverity.Verbose, $"Received {buffer.Length - 12} bytes."); + if (_logLevel >= LogMessageSeverity.Debug) + RaiseOnLog(LogMessageSeverity.Debug, $"Received {buffer.Length - 12} bytes."); //TODO: Use Voice Data } } @@ -386,12 +401,6 @@ namespace Discord.Net.WebSockets public void SendPCMFrames(byte[] data, int bytes) { - var cancelToken = _cancelToken; - if (!_isReady || cancelToken == null) - throw new InvalidOperationException("Not connected to a voice server."); - if (bytes == 0) - return; - int frameSize = _encoder.FrameSize; int frames = bytes / frameSize; int expectedBytes = frames * frameSize; @@ -431,7 +440,7 @@ namespace Discord.Net.WebSockets int encodedLength = _encoder.EncodeFrame(data, pos, _encodingBuffer); //TODO: Handle Encryption - if (_mode == "xsalsa20_poly1305") + if (_isEncrypted) { } @@ -448,16 +457,16 @@ namespace Discord.Net.WebSockets } } - if (_logLevel >= LogMessageSeverity.Verbose) - RaiseOnLog(LogMessageSeverity.Verbose, $"Queued {bytes} bytes for voice output."); + if (_logLevel >= LogMessageSeverity.Debug) + RaiseOnLog(LogMessageSeverity.Debug, $"Queued {bytes} bytes for voice output."); } public void ClearPCMFrames() { _isClearing = true; byte[] ignored; while (_sendQueue.TryDequeue(out ignored)) { } - if (_logLevel >= LogMessageSeverity.Verbose) - RaiseOnLog(LogMessageSeverity.Verbose, "Cleared the voice buffer."); + if (_logLevel >= LogMessageSeverity.Debug) + RaiseOnLog(LogMessageSeverity.Debug, "Cleared the voice buffer."); _isClearing = false; } diff --git a/src/Discord.Net/Net/WebSockets/WebSocket.Events.cs b/src/Discord.Net/Net/WebSockets/WebSocket.Events.cs index dd91f81ce..70fab7836 100644 --- a/src/Discord.Net/Net/WebSockets/WebSocket.Events.cs +++ b/src/Discord.Net/Net/WebSockets/WebSocket.Events.cs @@ -2,13 +2,6 @@ namespace Discord.Net.WebSockets { - public class DisconnectedEventArgs : EventArgs - { - public readonly bool WasUnexpected; - public readonly Exception Error; - internal DisconnectedEventArgs(bool wasUnexpected, Exception error) { WasUnexpected = wasUnexpected; Error = error; } - } - internal partial class WebSocket { public event EventHandler Connected; diff --git a/src/Discord.Net/Net/WebSockets/WebSocket.cs b/src/Discord.Net/Net/WebSockets/WebSocket.cs index 446705a37..4644c396a 100644 --- a/src/Discord.Net/Net/WebSockets/WebSocket.cs +++ b/src/Discord.Net/Net/WebSockets/WebSocket.cs @@ -38,12 +38,14 @@ namespace Discord.Net.WebSockets protected readonly DiscordClient _client; protected readonly LogMessageSeverity _logLevel; - protected int _state; protected string _host; protected int _loginTimeout, _heartbeatInterval; private DateTime _lastHeartbeat; private Task _runTask; + public WebSocketState State => (WebSocketState)_state; + protected int _state; + protected ExceptionDispatchInfo _disconnectReason; private bool _wasDisconnectUnexpected; @@ -56,17 +58,45 @@ namespace Discord.Net.WebSockets _client = client; _logLevel = client.Config.LogLevel; _loginTimeout = client.Config.ConnectionTimeout; + _cancelToken = new CancellationToken(true); + _engine = new BuiltInWebSocketEngine(client.Config.WebSocketInterval); - _engine.ProcessMessage += (s, e) => + _engine.ProcessMessage += async (s, e) => { if (_logLevel >= LogMessageSeverity.Debug) RaiseOnLog(LogMessageSeverity.Debug, $"In: " + e.Message); - ProcessMessage(e.Message); + await ProcessMessage(e.Message); }; } + public async Task Reconnect(CancellationToken cancelToken) + { + try + { + await Task.Delay(_client.Config.ReconnectDelay, cancelToken).ConfigureAwait(false); + while (!cancelToken.IsCancellationRequested) + { + try + { + await Connect(_host, cancelToken).ConfigureAwait(false); + break; + } + catch (OperationCanceledException) { throw; } + catch (Exception ex) + { + RaiseOnLog(LogMessageSeverity.Error, $"DataSocket reconnect failed: {ex.GetBaseException().Message}"); + //Net is down? We can keep trying to reconnect until the user runs Disconnect() + await Task.Delay(_client.Config.FailedReconnectDelay, cancelToken).ConfigureAwait(false); + } + } + } + catch (OperationCanceledException) { } + } protected virtual async Task Connect(string host, CancellationToken cancelToken) { + if (_state != (int)WebSocketState.Disconnected) + throw new InvalidOperationException("Client is already connected or connecting to the server."); + try { await Disconnect().ConfigureAwait(false); @@ -97,19 +127,19 @@ namespace Discord.Net.WebSockets => Connect(_host, _cancelToken);*/ public Task Disconnect() => DisconnectInternal(new Exception("Disconnect was requested by user."), isUnexpected: false); - protected async Task DisconnectInternal(Exception ex, bool isUnexpected = true, bool skipAwait = false) + protected Task DisconnectInternal(Exception ex, bool isUnexpected = true, bool skipAwait = false) { int oldState; bool hasWriterLock; //If in either connecting or connected state, get a lock by being the first to switch to disconnecting oldState = Interlocked.CompareExchange(ref _state, (int)WebSocketState.Disconnecting, (int)WebSocketState.Connecting); - if (oldState == (int)WebSocketState.Disconnected) return; //Already disconnected + if (oldState == (int)WebSocketState.Disconnected) return TaskHelper.CompletedTask; //Already disconnected hasWriterLock = oldState == (int)WebSocketState.Connecting; //Caused state change if (!hasWriterLock) { oldState = Interlocked.CompareExchange(ref _state, (int)WebSocketState.Disconnecting, (int)WebSocketState.Connected); - if (oldState == (int)WebSocketState.Disconnected) return; //Already disconnected + if (oldState == (int)WebSocketState.Disconnected) return TaskHelper.CompletedTask; //Already disconnected hasWriterLock = oldState == (int)WebSocketState.Connected; //Caused state change } @@ -121,17 +151,9 @@ namespace Discord.Net.WebSockets } if (!skipAwait) - { - Task task = _runTask; - if (task != null) - await task.ConfigureAwait(false); - } - - if (hasWriterLock) - { - _state = (int)WebSocketState.Disconnected; - RaiseDisconnected(isUnexpected, ex); - } + return _runTask ?? TaskHelper.CompletedTask; + else + return TaskHelper.CompletedTask; } protected virtual async Task RunTasks() @@ -146,9 +168,8 @@ namespace Discord.Net.WebSockets bool wasUnexpected = _wasDisconnectUnexpected; _wasDisconnectUnexpected = false; - - await _engine.Disconnect().ConfigureAwait(false); - await Cleanup().ConfigureAwait(false); + + await Cleanup(wasUnexpected).ConfigureAwait(false); _runTask = null; } protected virtual Task[] Run() @@ -158,10 +179,12 @@ namespace Discord.Net.WebSockets .Concat(new Task[] { HeartbeatAsync(cancelToken) }) .ToArray(); } - protected virtual Task Cleanup() + protected virtual Task Cleanup(bool wasUnexpected) { _cancelTokenSource = null; - return TaskHelper.CompletedTask; + _state = (int)WebSocketState.Disconnected; + RaiseDisconnected(wasUnexpected, _disconnectReason?.SourceException); + return _engine.Disconnect(); } protected abstract Task ProcessMessage(string json);