diff --git a/src/Discord.Net/DiscordWebSocket.cs b/src/Discord.Net/DiscordWebSocket.cs index 7d68ea742..0325c659a 100644 --- a/src/Discord.Net/DiscordWebSocket.cs +++ b/src/Discord.Net/DiscordWebSocket.cs @@ -15,51 +15,59 @@ namespace Discord { private const int ReceiveChunkSize = 4096; private const int SendChunkSize = 4096; - private const int ReadyTimeout = 5000; //Max time in milliseconds between connecting to Discord and receiving a READY event + private const int ReadyTimeout = 2500; //Max time in milliseconds between connecting to Discord and receiving a READY event private volatile ClientWebSocket _webSocket; - private volatile CancellationTokenSource _cancelToken; + private volatile CancellationTokenSource _disconnectToken; private volatile Task _tasks; private ConcurrentQueue _sendQueue; private int _heartbeatInterval; private DateTime _lastHeartbeat; - private AutoResetEvent _connectWaitOnLogin, _connectWaitOnLogin2; + private ManualResetEventSlim _connectWaitOnLogin, _connectWaitOnLogin2; private bool _isConnected; + public DiscordWebSocket() + { + _connectWaitOnLogin = new ManualResetEventSlim(false); + _connectWaitOnLogin2 = new ManualResetEventSlim(false); + + _sendQueue = new ConcurrentQueue(); + } + public async Task ConnectAsync(string url, bool autoLogin) { await DisconnectAsync(); - _connectWaitOnLogin = new AutoResetEvent(false); - _connectWaitOnLogin2 = new AutoResetEvent(false); - _sendQueue = new ConcurrentQueue(); + _disconnectToken = new CancellationTokenSource(); + var cancelToken = _disconnectToken.Token; _webSocket = new ClientWebSocket(); _webSocket.Options.KeepAliveInterval = TimeSpan.Zero; - - _cancelToken = new CancellationTokenSource(); - var cancelToken = _cancelToken.Token; - await _webSocket.ConnectAsync(new Uri(url), cancelToken); + _tasks = Task.WhenAll( await Task.Factory.StartNew(ReceiveAsync, cancelToken, TaskCreationOptions.LongRunning, TaskScheduler.Default), - await Task.Factory.StartNew(SendAsync, cancelToken, TaskCreationOptions.LongRunning, TaskScheduler.Default) - ).ContinueWith(x => + await Task.Factory.StartNew(SendAsync, cancelToken, TaskCreationOptions.LongRunning, TaskScheduler.Default)) + .ContinueWith(x => { //Do not clean up until both tasks have ended _heartbeatInterval = 0; _lastHeartbeat = DateTime.MinValue; _webSocket.Dispose(); _webSocket = null; - _cancelToken.Dispose(); - _cancelToken = null; + _disconnectToken.Dispose(); + _disconnectToken = null; _tasks = null; + //Clear send queue + byte[] ignored; + while (_sendQueue.TryDequeue(out ignored)) { } + if (_isConnected) { _isConnected = false; RaiseDisconnected(); - } + } }); if (autoLogin) @@ -67,6 +75,11 @@ namespace Discord } public void Login() { + var cancelToken = _disconnectToken.Token; + + _connectWaitOnLogin.Reset(); + _connectWaitOnLogin2.Reset(); + WebSocketCommands.Login msg = new WebSocketCommands.Login(); msg.Payload.Token = Http.Token; msg.Payload.Properties["$os"] = ""; @@ -74,11 +87,18 @@ namespace Discord msg.Payload.Properties["$device"] = "Discord.Net"; msg.Payload.Properties["$referrer"] = ""; msg.Payload.Properties["$referring_domain"] = ""; - SendMessage(msg, _cancelToken.Token); + SendMessage(msg, _disconnectToken.Token); - if (!_connectWaitOnLogin.WaitOne(ReadyTimeout)) //Pre-Event - throw new Exception("No reply from Discord server"); - _connectWaitOnLogin2.WaitOne(); //Post-Event + try + { + if (!_connectWaitOnLogin.Wait(ReadyTimeout, cancelToken)) //Waiting on READY message + throw new Exception("No reply from Discord server"); + } + catch (OperationCanceledException) + { + throw new InvalidOperationException("Bad Token"); + } + _connectWaitOnLogin2.Wait(cancelToken); //Waiting on READY handler _isConnected = true; RaiseConnected(); @@ -87,14 +107,14 @@ namespace Discord { if (_tasks != null) { - _cancelToken.Cancel(); + _disconnectToken.Cancel(); await _tasks; } } private async Task ReceiveAsync() { - var cancelToken = _cancelToken.Token; + var cancelToken = _disconnectToken.Token; var buffer = new byte[ReceiveChunkSize]; var builder = new StringBuilder(); @@ -105,7 +125,7 @@ namespace Discord WebSocketReceiveResult result; do { - result = await _webSocket.ReceiveAsync(new ArraySegment(buffer), _cancelToken.Token); + result = await _webSocket.ReceiveAsync(new ArraySegment(buffer), _disconnectToken.Token); if (result.MessageType == WebSocketMessageType.Close) { @@ -126,8 +146,8 @@ namespace Discord { var payload = (msg.Payload as JToken).ToObject(); _heartbeatInterval = payload.HeartbeatInterval; - SendMessage(new WebSocketCommands.UpdateStatus(), cancelToken); - SendMessage(new WebSocketCommands.KeepAlive(), cancelToken); + QueueMessage(new WebSocketCommands.UpdateStatus()); + QueueMessage(new WebSocketCommands.KeepAlive()); _connectWaitOnLogin.Set(); //Pre-Event } RaiseGotEvent(msg.Type, msg.Payload as JToken); @@ -143,12 +163,12 @@ namespace Discord } } catch { } - finally { _cancelToken.Cancel(); } + finally { _disconnectToken.Cancel(); } } private async Task SendAsync() { - var cancelToken = _cancelToken.Token; + var cancelToken = _disconnectToken.Token; try { byte[] bytes; @@ -159,41 +179,46 @@ namespace Discord DateTime now = DateTime.UtcNow; if ((now - _lastHeartbeat).TotalMilliseconds > _heartbeatInterval) { - SendMessage(new WebSocketCommands.KeepAlive(), cancelToken); + await SendMessage(new WebSocketCommands.KeepAlive(), cancelToken); _lastHeartbeat = now; } } while (_sendQueue.TryDequeue(out bytes)) - { - var frameCount = (int)Math.Ceiling((double)bytes.Length / SendChunkSize); - - int offset = 0; - for (var i = 0; i < frameCount; i++, offset += SendChunkSize) - { - bool isLast = i == (frameCount - 1); - - int count; - if (isLast) - count = bytes.Length - (i * SendChunkSize); - else - count = SendChunkSize; - - await _webSocket.SendAsync(new ArraySegment(bytes, offset, count), WebSocketMessageType.Text, isLast, cancelToken); - } - } + await SendMessage(bytes, cancelToken); await Task.Delay(100); } } catch { } - finally { _cancelToken.Cancel(); } + finally { _disconnectToken.Cancel(); } } - private void SendMessage(object frame, CancellationToken token) + private void QueueMessage(object message) { - var bytes = Encoding.UTF8.GetBytes(JsonConvert.SerializeObject(frame)); + var bytes = Encoding.UTF8.GetBytes(JsonConvert.SerializeObject(message)); _sendQueue.Enqueue(bytes); } + private Task SendMessage(object message, CancellationToken cancelToken) + => SendMessage(Encoding.UTF8.GetBytes(JsonConvert.SerializeObject(message)), cancelToken); + private async Task SendMessage(byte[] message, CancellationToken cancelToken) + { + var frameCount = (int)Math.Ceiling((double)message.Length / SendChunkSize); + + int offset = 0; + for (var i = 0; i < frameCount; i++, offset += SendChunkSize) + { + bool isLast = i == (frameCount - 1); + + int count; + if (isLast) + count = message.Length - (i * SendChunkSize); + else + count = SendChunkSize; + + await _webSocket.SendAsync(new ArraySegment(message, offset, count), WebSocketMessageType.Text, isLast, cancelToken); + } + } + #region IDisposable Support private bool _isDisposed = false;