Browse Source

General WebSocket improvements

tags/docs-0.9
Brandon Smith 9 years ago
parent
commit
02e718a96e
1 changed files with 72 additions and 47 deletions
  1. +72
    -47
      src/Discord.Net/DiscordWebSocket.cs

+ 72
- 47
src/Discord.Net/DiscordWebSocket.cs View File

@@ -15,51 +15,59 @@ namespace Discord
{
private const int ReceiveChunkSize = 4096;
private const int SendChunkSize = 4096;
private const int ReadyTimeout = 5000; //Max time in milliseconds between connecting to Discord and receiving a READY event
private const int ReadyTimeout = 2500; //Max time in milliseconds between connecting to Discord and receiving a READY event

private volatile ClientWebSocket _webSocket;
private volatile CancellationTokenSource _cancelToken;
private volatile CancellationTokenSource _disconnectToken;
private volatile Task _tasks;
private ConcurrentQueue<byte[]> _sendQueue;
private int _heartbeatInterval;
private DateTime _lastHeartbeat;
private AutoResetEvent _connectWaitOnLogin, _connectWaitOnLogin2;
private ManualResetEventSlim _connectWaitOnLogin, _connectWaitOnLogin2;
private bool _isConnected;

public DiscordWebSocket()
{
_connectWaitOnLogin = new ManualResetEventSlim(false);
_connectWaitOnLogin2 = new ManualResetEventSlim(false);
_sendQueue = new ConcurrentQueue<byte[]>();
}

public async Task ConnectAsync(string url, bool autoLogin)
{
await DisconnectAsync();

_connectWaitOnLogin = new AutoResetEvent(false);
_connectWaitOnLogin2 = new AutoResetEvent(false);
_sendQueue = new ConcurrentQueue<byte[]>();
_disconnectToken = new CancellationTokenSource();
var cancelToken = _disconnectToken.Token;

_webSocket = new ClientWebSocket();
_webSocket.Options.KeepAliveInterval = TimeSpan.Zero;

_cancelToken = new CancellationTokenSource();
var cancelToken = _cancelToken.Token;

await _webSocket.ConnectAsync(new Uri(url), cancelToken);

_tasks = Task.WhenAll(
await Task.Factory.StartNew(ReceiveAsync, cancelToken, TaskCreationOptions.LongRunning, TaskScheduler.Default),
await Task.Factory.StartNew(SendAsync, cancelToken, TaskCreationOptions.LongRunning, TaskScheduler.Default)
).ContinueWith(x =>
await Task.Factory.StartNew(SendAsync, cancelToken, TaskCreationOptions.LongRunning, TaskScheduler.Default))
.ContinueWith(x =>
{
//Do not clean up until both tasks have ended
_heartbeatInterval = 0;
_lastHeartbeat = DateTime.MinValue;
_webSocket.Dispose();
_webSocket = null;
_cancelToken.Dispose();
_cancelToken = null;
_disconnectToken.Dispose();
_disconnectToken = null;
_tasks = null;

//Clear send queue
byte[] ignored;
while (_sendQueue.TryDequeue(out ignored)) { }

if (_isConnected)
{
_isConnected = false;
RaiseDisconnected();
}
}
});

if (autoLogin)
@@ -67,6 +75,11 @@ namespace Discord
}
public void Login()
{
var cancelToken = _disconnectToken.Token;

_connectWaitOnLogin.Reset();
_connectWaitOnLogin2.Reset();

WebSocketCommands.Login msg = new WebSocketCommands.Login();
msg.Payload.Token = Http.Token;
msg.Payload.Properties["$os"] = "";
@@ -74,11 +87,18 @@ namespace Discord
msg.Payload.Properties["$device"] = "Discord.Net";
msg.Payload.Properties["$referrer"] = "";
msg.Payload.Properties["$referring_domain"] = "";
SendMessage(msg, _cancelToken.Token);
SendMessage(msg, _disconnectToken.Token);

if (!_connectWaitOnLogin.WaitOne(ReadyTimeout)) //Pre-Event
throw new Exception("No reply from Discord server");
_connectWaitOnLogin2.WaitOne(); //Post-Event
try
{
if (!_connectWaitOnLogin.Wait(ReadyTimeout, cancelToken)) //Waiting on READY message
throw new Exception("No reply from Discord server");
}
catch (OperationCanceledException)
{
throw new InvalidOperationException("Bad Token");
}
_connectWaitOnLogin2.Wait(cancelToken); //Waiting on READY handler

_isConnected = true;
RaiseConnected();
@@ -87,14 +107,14 @@ namespace Discord
{
if (_tasks != null)
{
_cancelToken.Cancel();
_disconnectToken.Cancel();
await _tasks;
}
}

private async Task ReceiveAsync()
{
var cancelToken = _cancelToken.Token;
var cancelToken = _disconnectToken.Token;
var buffer = new byte[ReceiveChunkSize];
var builder = new StringBuilder();

@@ -105,7 +125,7 @@ namespace Discord
WebSocketReceiveResult result;
do
{
result = await _webSocket.ReceiveAsync(new ArraySegment<byte>(buffer), _cancelToken.Token);
result = await _webSocket.ReceiveAsync(new ArraySegment<byte>(buffer), _disconnectToken.Token);

if (result.MessageType == WebSocketMessageType.Close)
{
@@ -126,8 +146,8 @@ namespace Discord
{
var payload = (msg.Payload as JToken).ToObject<WebSocketEvents.Ready>();
_heartbeatInterval = payload.HeartbeatInterval;
SendMessage(new WebSocketCommands.UpdateStatus(), cancelToken);
SendMessage(new WebSocketCommands.KeepAlive(), cancelToken);
QueueMessage(new WebSocketCommands.UpdateStatus());
QueueMessage(new WebSocketCommands.KeepAlive());
_connectWaitOnLogin.Set(); //Pre-Event
}
RaiseGotEvent(msg.Type, msg.Payload as JToken);
@@ -143,12 +163,12 @@ namespace Discord
}
}
catch { }
finally { _cancelToken.Cancel(); }
finally { _disconnectToken.Cancel(); }
}

private async Task SendAsync()
{
var cancelToken = _cancelToken.Token;
var cancelToken = _disconnectToken.Token;
try
{
byte[] bytes;
@@ -159,41 +179,46 @@ namespace Discord
DateTime now = DateTime.UtcNow;
if ((now - _lastHeartbeat).TotalMilliseconds > _heartbeatInterval)
{
SendMessage(new WebSocketCommands.KeepAlive(), cancelToken);
await SendMessage(new WebSocketCommands.KeepAlive(), cancelToken);
_lastHeartbeat = now;
}
}
while (_sendQueue.TryDequeue(out bytes))
{
var frameCount = (int)Math.Ceiling((double)bytes.Length / SendChunkSize);

int offset = 0;
for (var i = 0; i < frameCount; i++, offset += SendChunkSize)
{
bool isLast = i == (frameCount - 1);

int count;
if (isLast)
count = bytes.Length - (i * SendChunkSize);
else
count = SendChunkSize;

await _webSocket.SendAsync(new ArraySegment<byte>(bytes, offset, count), WebSocketMessageType.Text, isLast, cancelToken);
}
}
await SendMessage(bytes, cancelToken);
await Task.Delay(100);
}
}
catch { }
finally { _cancelToken.Cancel(); }
finally { _disconnectToken.Cancel(); }
}

private void SendMessage(object frame, CancellationToken token)
private void QueueMessage(object message)
{
var bytes = Encoding.UTF8.GetBytes(JsonConvert.SerializeObject(frame));
var bytes = Encoding.UTF8.GetBytes(JsonConvert.SerializeObject(message));
_sendQueue.Enqueue(bytes);
}

private Task SendMessage(object message, CancellationToken cancelToken)
=> SendMessage(Encoding.UTF8.GetBytes(JsonConvert.SerializeObject(message)), cancelToken);
private async Task SendMessage(byte[] message, CancellationToken cancelToken)
{
var frameCount = (int)Math.Ceiling((double)message.Length / SendChunkSize);

int offset = 0;
for (var i = 0; i < frameCount; i++, offset += SendChunkSize)
{
bool isLast = i == (frameCount - 1);

int count;
if (isLast)
count = message.Length - (i * SendChunkSize);
else
count = SendChunkSize;

await _webSocket.SendAsync(new ArraySegment<byte>(message, offset, count), WebSocketMessageType.Text, isLast, cancelToken);
}
}

#region IDisposable Support
private bool _isDisposed = false;



Loading…
Cancel
Save