diff --git a/src/Discord.Net/DiscordSimpleClient.Voice.cs b/src/Discord.Net/DiscordSimpleClient.Voice.cs index a62b8b612..1c66b9af8 100644 --- a/src/Discord.Net/DiscordSimpleClient.Voice.cs +++ b/src/Discord.Net/DiscordSimpleClient.Voice.cs @@ -24,22 +24,10 @@ namespace Discord if (channelId == null) throw new ArgumentNullException(nameof(channelId)); await _voiceSocket.Disconnect().ConfigureAwait(false); - _voiceSocket.SetChannel(_voiceServerId, channelId); + + await _voiceSocket.SetChannel(_voiceServerId, channelId).ConfigureAwait(false); _dataSocket.SendJoinVoice(_voiceServerId, channelId); - - CancellationTokenSource tokenSource = new CancellationTokenSource(); - try - { - await Task.Run(() => _voiceSocket.WaitForConnection(tokenSource.Token)) - .Timeout(_config.ConnectionTimeout, tokenSource) - .ConfigureAwait(false); - } - catch (TimeoutException) - { - tokenSource.Cancel(); - await _voiceSocket.Disconnect().ConfigureAwait(false); - throw; - } + await _voiceSocket.WaitForConnection(_config.ConnectionTimeout); } /*async Task IDiscordVoiceClient.Disconnect() diff --git a/src/Discord.Net/WebSockets/Data/DataWebSocket.cs b/src/Discord.Net/WebSockets/Data/DataWebSocket.cs index 5011ed99a..f4d8af7dd 100644 --- a/src/Discord.Net/WebSockets/Data/DataWebSocket.cs +++ b/src/Discord.Net/WebSockets/Data/DataWebSocket.cs @@ -19,7 +19,8 @@ namespace Discord.WebSockets.Data public async Task Login(string token) { - await Connect().ConfigureAwait(false); + await BeginConnect().ConfigureAwait(false); + await Start().ConfigureAwait(false); LoginCommand msg = new LoginCommand(); msg.Payload.Token = token; @@ -29,7 +30,9 @@ namespace Discord.WebSockets.Data private async Task Redirect(string server) { await DisconnectInternal(isUnexpected: false).ConfigureAwait(false); - await Connect().ConfigureAwait(false); + + await BeginConnect().ConfigureAwait(false); + await Start().ConfigureAwait(false); var resumeMsg = new ResumeCommand(); resumeMsg.Payload.SessionId = _sessionId; @@ -87,7 +90,7 @@ namespace Discord.WebSockets.Data } RaiseReceivedEvent(msg.Type, token); if (msg.Type == "READY" || msg.Type == "RESUMED") - CompleteConnect(); + EndConnect(); } break; case 7: //Redirect diff --git a/src/Discord.Net/WebSockets/Voice/VoiceWebSocket.cs b/src/Discord.Net/WebSockets/Voice/VoiceWebSocket.cs index f05b70e40..ffcdfe87b 100644 --- a/src/Discord.Net/WebSockets/Voice/VoiceWebSocket.cs +++ b/src/Discord.Net/WebSockets/Voice/VoiceWebSocket.cs @@ -61,14 +61,16 @@ namespace Discord.WebSockets.Voice _encoder = new OpusEncoder(48000, 1, 20, Opus.Application.Audio); } - public void SetChannel(string serverId, string channelId) + public Task SetChannel(string serverId, string channelId) { _serverId = serverId; _channelId = channelId; + + return base.BeginConnect(); } public async Task Login(string userId, string sessionId, string token, CancellationToken cancelToken) { - if ((WebSocketState)_state != WebSocketState.Disconnected) + if ((WebSocketState)_state == WebSocketState.Connected) { //Adjust the host and tell the system to reconnect await DisconnectInternal(new Exception("Server transfer occurred."), isUnexpected: false); @@ -79,7 +81,7 @@ namespace Discord.WebSockets.Voice _sessionId = sessionId; _token = token; - await Connect().ConfigureAwait(false); + await Start().ConfigureAwait(false); } public async Task Reconnect() { @@ -91,7 +93,7 @@ namespace Discord.WebSockets.Voice { try { - await Connect().ConfigureAwait(false); + await Start().ConfigureAwait(false); break; } catch (OperationCanceledException) { throw; } @@ -245,7 +247,7 @@ namespace Discord.WebSockets.Voice int port = packet[68] | packet[69] << 8; string ip = Encoding.ASCII.GetString(packet, 4, 70 - 6).TrimEnd('\0'); - CompleteConnect(); + EndConnect(); var login2 = new Login2Command(); login2.Payload.Protocol = "udp"; @@ -599,9 +601,20 @@ namespace Discord.WebSockets.Voice { _sendQueueEmptyWait.Wait(_cancelToken); } - public void WaitForConnection(CancellationToken cancelToken) + public Task WaitForConnection(int timeout) { - _connectedEvent.Wait(cancelToken); + return Task.Run(() => + { + try + { + if (!_connectedEvent.Wait(timeout, _cancelToken)) + throw new TimeoutException(); + } + catch (OperationCanceledException ex) + { + ThrowError(); + } + }); } } } diff --git a/src/Discord.Net/WebSockets/WebSocket.WebSocketSharp.cs b/src/Discord.Net/WebSockets/WebSocket.WebSocketSharp.cs index 942239f82..80b5aebbc 100644 --- a/src/Discord.Net/WebSockets/WebSocket.WebSocketSharp.cs +++ b/src/Discord.Net/WebSockets/WebSocket.WebSocketSharp.cs @@ -41,14 +41,14 @@ namespace Discord.WebSockets _webSocket.OnError += async (s, e) => { _parent.RaiseOnLog(LogMessageSeverity.Error, $"Websocket Error: {e.Message}"); - await _parent.DisconnectInternal(e.Exception, isUnexpected: true, skipAwait: true); + await _parent.DisconnectInternal(e.Exception, skipAwait: true); }; _webSocket.OnClose += async (s, e) => { string code = e.WasClean ? e.Code.ToString() : "Unexpected"; string reason = e.Reason != "" ? e.Reason : "No Reason"; Exception ex = new Exception($"Got Close Message ({code}): {reason}"); - await _parent.DisconnectInternal(ex, isUnexpected: !e.WasClean, skipAwait: true); + await _parent.DisconnectInternal(ex, skipAwait: true); }; _webSocket.Log.Output = (e, m) => { }; //Dont let websocket-sharp print to console _webSocket.Connect(); @@ -59,7 +59,12 @@ namespace Discord.WebSockets { string ignored; while (_sendQueue.TryDequeue(out ignored)) { } - _webSocket.Close(); + + var socket = _webSocket; + _webSocket = null; + if (socket != null) + socket.Close(); + return TaskHelper.CompletedTask; } @@ -77,7 +82,7 @@ namespace Discord.WebSockets { try { - while (_webSocket.IsAlive && !cancelToken.IsCancellationRequested) + while (!cancelToken.IsCancellationRequested) { string json; while (_sendQueue.TryDequeue(out json)) diff --git a/src/Discord.Net/WebSockets/WebSocket.cs b/src/Discord.Net/WebSockets/WebSocket.cs index 9ee9a6141..ac38a36e6 100644 --- a/src/Discord.Net/WebSockets/WebSocket.cs +++ b/src/Discord.Net/WebSockets/WebSocket.cs @@ -79,12 +79,9 @@ namespace Discord.WebSockets }; } - protected virtual async Task Connect() + protected async Task BeginConnect() { - if (_state != (int)WebSocketState.Disconnected) - throw new InvalidOperationException("Client is already connected or connecting to the server."); - - try + try { await Disconnect().ConfigureAwait(false); @@ -93,19 +90,34 @@ namespace Discord.WebSockets _cancelTokenSource = new CancellationTokenSource(); _cancelToken = CancellationTokenSource.CreateLinkedTokenSource(_cancelTokenSource.Token, ParentCancelToken.Value).Token; + _state = (int)WebSocketState.Connecting; + } + catch (Exception ex) + { + await DisconnectInternal(ex, isUnexpected: false).ConfigureAwait(false); + throw; + } + } + + protected virtual async Task Start() + { + try + { + if (_state != (int)WebSocketState.Connecting) + throw new InvalidOperationException("Socket is in the wrong state."); + _lastHeartbeat = DateTime.UtcNow; await _engine.Connect(Host, _cancelToken).ConfigureAwait(false); - _state = (int)WebSocketState.Connecting; _runTask = RunTasks(); } catch (Exception ex) { await DisconnectInternal(ex, isUnexpected: false).ConfigureAwait(false); - throw; //Dont handle this exception internally, send up it upwards + throw; } } - protected void CompleteConnect() + protected void EndConnect() { _state = (int)WebSocketState.Connected; _connectedEvent.Set();