diff --git a/src/Discord.Net/DiscordClient.cs b/src/Discord.Net/DiscordClient.cs index 75cdec7cb..ca781778c 100644 --- a/src/Discord.Net/DiscordClient.cs +++ b/src/Discord.Net/DiscordClient.cs @@ -75,8 +75,9 @@ namespace Discord public Users Users => _users; private readonly Users _users; - public CancellationToken CancelToken => _cancelToken.Token; - private CancellationTokenSource _cancelToken; + public CancellationToken CancelToken => _cancelToken; + private CancellationTokenSource _cancelTokenSource; + private CancellationToken _cancelToken; /// Initializes a new instance of the DiscordClient class. public DiscordClient(DiscordClientConfig config = null) @@ -509,7 +510,7 @@ namespace Discord if (_config.EnableVoice) { string host = "wss://" + data.Endpoint.Split(':')[0]; - await _voiceSocket.Login(host, data.GuildId, _currentUserId, _dataSocket.SessionId, data.Token).ConfigureAwait(false); + await _voiceSocket.Login(host, data.GuildId, _currentUserId, _dataSocket.SessionId, data.Token, _cancelToken).ConfigureAwait(false); } } break; @@ -576,7 +577,8 @@ namespace Discord try { _disconnectedEvent.Reset(); - _cancelToken = new CancellationTokenSource(); + _cancelTokenSource = new CancellationTokenSource(); + _cancelToken = _cancelTokenSource.Token; _state = (int)DiscordClientState.Connecting; _api.Token = token; @@ -584,14 +586,14 @@ namespace Discord if (_config.LogLevel >= LogMessageSeverity.Verbose) RaiseOnLog(LogMessageSeverity.Verbose, LogMessageSource.Authentication, $"Websocket endpoint: {url}"); - await _dataSocket.Login(url, token).ConfigureAwait(false); + await _dataSocket.Login(url, token, _cancelToken).ConfigureAwait(false); _runTask = RunTasks(); try { //Cancel if either Disconnect is called, data socket errors or timeout is reached - var cancelToken = CancellationTokenSource.CreateLinkedTokenSource(_cancelToken.Token, _dataSocket.CancelToken).Token; + var cancelToken = CancellationTokenSource.CreateLinkedTokenSource(_cancelToken, _dataSocket.CancelToken).Token; if (!_connectedEvent.Wait(_config.ConnectionTimeout, cancelToken)) throw new Exception("Operation timed out."); } @@ -638,7 +640,7 @@ namespace Discord { _wasDisconnectUnexpected = isUnexpected; _disconnectReason = ExceptionDispatchInfo.Capture(ex); - _cancelToken.Cancel(); + _cancelTokenSource.Cancel(); } if (!skipAwait) @@ -730,7 +732,7 @@ namespace Discord //Experimental private Task MessageQueueLoop() { - var cancelToken = _cancelToken.Token; + var cancelToken = _cancelToken; int interval = _config.MessageQueueInterval; return Task.Run(async () => diff --git a/src/Discord.Net/Net/WebSockets/DataWebSocket.cs b/src/Discord.Net/Net/WebSockets/DataWebSocket.cs index f8cb3db3e..883401dcb 100644 --- a/src/Discord.Net/Net/WebSockets/DataWebSocket.cs +++ b/src/Discord.Net/Net/WebSockets/DataWebSocket.cs @@ -1,6 +1,7 @@ using Newtonsoft.Json; using Newtonsoft.Json.Linq; using System; +using System.Threading; using System.Threading.Tasks; namespace Discord.Net.WebSockets @@ -18,9 +19,9 @@ namespace Discord.Net.WebSockets { } - public async Task Login(string host, string token) + public async Task Login(string host, string token, CancellationToken cancelToken) { - await base.Connect(host); + await base.Connect(host, cancelToken); Commands.Login msg = new Commands.Login(); msg.Payload.Token = token; diff --git a/src/Discord.Net/Net/WebSockets/VoiceWebSocket.cs b/src/Discord.Net/Net/WebSockets/VoiceWebSocket.cs index 9c484b433..8c13f2bde 100644 --- a/src/Discord.Net/Net/WebSockets/VoiceWebSocket.cs +++ b/src/Discord.Net/Net/WebSockets/VoiceWebSocket.cs @@ -52,14 +52,14 @@ namespace Discord.Net.WebSockets _targetAudioBufferLength = client.Config.VoiceBufferLength / 20; //20 ms frames } - public Task Login(string host, string serverId, string userId, string sessionId, string token) + public Task Login(string host, string serverId, string userId, string sessionId, string token, CancellationToken cancelToken) { _serverId = serverId; _userId = userId; _sessionId = sessionId; _token = token; - return base.Connect(host); + return base.Connect(host, cancelToken); } protected override Task[] Run() @@ -110,8 +110,7 @@ namespace Discord.Net.WebSockets private async Task ReceiveVoiceAsync() { - var cancelSource = _cancelToken; - var cancelToken = cancelSource.Token; + var cancelToken = _cancelToken; await Task.Run(async () => { @@ -145,8 +144,8 @@ namespace Discord.Net.WebSockets #else private Task SendVoiceAsync() { - var cancelSource = _cancelToken; - var cancelToken = cancelSource.Token; + var cancelToken = _cancelToken; + return Task.Run(async () => { #endif @@ -239,7 +238,7 @@ namespace Discord.Net.WebSockets //Closes the UDP socket when _disconnectToken is triggered, since UDPClient doesn't allow passing a canceltoken private Task WatcherAsync() { - var cancelToken = _cancelToken.Token; + var cancelToken = _cancelToken; return cancelToken.Wait() .ContinueWith(_ => _udp.Close()); } @@ -387,7 +386,7 @@ namespace Discord.Net.WebSockets public void SendPCMFrames(byte[] data, int bytes) { - var cancelToken = _cancelToken.Token; + var cancelToken = _cancelToken; if (!_isReady || cancelToken == null) throw new InvalidOperationException("Not connected to a voice server."); if (bytes == 0) @@ -441,7 +440,7 @@ namespace Discord.Net.WebSockets Buffer.BlockCopy(_encodingBuffer, 0, payload, 0, encodedLength); //Wait until the queue has a spot open - _sendQueueWait.Wait(_cancelToken.Token); + _sendQueueWait.Wait(_cancelToken); _sendQueue.Enqueue(payload); if (_sendQueue.Count >= _targetAudioBufferLength) _sendQueueWait.Reset(); diff --git a/src/Discord.Net/Net/WebSockets/WebSocket.cs b/src/Discord.Net/Net/WebSockets/WebSocket.cs index 793908a7a..446705a37 100644 --- a/src/Discord.Net/Net/WebSockets/WebSocket.cs +++ b/src/Discord.Net/Net/WebSockets/WebSocket.cs @@ -47,8 +47,9 @@ namespace Discord.Net.WebSockets protected ExceptionDispatchInfo _disconnectReason; private bool _wasDisconnectUnexpected; - public CancellationToken CancelToken => _cancelToken.Token; - protected CancellationTokenSource _cancelToken; + public CancellationToken CancelToken => _cancelToken; + private CancellationTokenSource _cancelTokenSource; + protected CancellationToken _cancelToken; public WebSocket(DiscordClient client) { @@ -64,20 +65,18 @@ namespace Discord.Net.WebSockets }; } - protected virtual async Task Connect(string host) + protected virtual async Task Connect(string host, CancellationToken cancelToken) { - if (_state != (int)WebSocketState.Disconnected) - throw new InvalidOperationException("Client is already connected or connecting to the server."); - try { await Disconnect().ConfigureAwait(false); _state = (int)WebSocketState.Connecting; - _cancelToken = new CancellationTokenSource(); + _cancelTokenSource = new CancellationTokenSource(); + _cancelToken = CancellationTokenSource.CreateLinkedTokenSource(_cancelTokenSource.Token, cancelToken).Token; - await _engine.Connect(host, _cancelToken.Token).ConfigureAwait(false); + await _engine.Connect(host, _cancelToken).ConfigureAwait(false); _host = host; _lastHeartbeat = DateTime.UtcNow; @@ -94,8 +93,8 @@ namespace Discord.Net.WebSockets _state = (int)WebSocketState.Connected; RaiseConnected(); } - public Task Reconnect() - => Connect(_host); + /*public Task Reconnect(CancellationToken cancelToken) + => Connect(_host, _cancelToken);*/ public Task Disconnect() => DisconnectInternal(new Exception("Disconnect was requested by user."), isUnexpected: false); protected async Task DisconnectInternal(Exception ex, bool isUnexpected = true, bool skipAwait = false) @@ -118,7 +117,7 @@ namespace Discord.Net.WebSockets { _wasDisconnectUnexpected = isUnexpected; _disconnectReason = ExceptionDispatchInfo.Capture(ex); - _cancelToken.Cancel(); + _cancelTokenSource.Cancel(); } if (!skipAwait) @@ -154,12 +153,16 @@ namespace Discord.Net.WebSockets } protected virtual Task[] Run() { - var cancelToken = _cancelToken.Token; + var cancelToken = _cancelToken; return _engine.RunTasks(cancelToken) .Concat(new Task[] { HeartbeatAsync(cancelToken) }) .ToArray(); } - protected virtual Task Cleanup() { return TaskHelper.CompletedTask; } + protected virtual Task Cleanup() + { + _cancelTokenSource = null; + return TaskHelper.CompletedTask; + } protected abstract Task ProcessMessage(string json); protected abstract object GetKeepAlive();