@@ -13,7 +13,6 @@ using System.Collections.Immutable;
using System.Linq;
using System.Linq;
using System.Threading;
using System.Threading;
using System.Threading.Tasks;
using System.Threading.Tasks;
using System.Diagnostics;
namespace Discord
namespace Discord
{
{
@@ -51,15 +50,17 @@ namespace Discord
private readonly bool _enablePreUpdateEvents;
private readonly bool _enablePreUpdateEvents;
private readonly int _largeThreshold;
private readonly int _largeThreshold;
private readonly int _totalShards;
private readonly int _totalShards;
private ConcurrentHashSet<ulong> _dmChannels;
private string _sessionId;
private string _sessionId;
private int _lastSeq;
private int _lastSeq;
private ImmutableDictionary<string, VoiceRegion> _voiceRegions;
private ImmutableDictionary<string, VoiceRegion> _voiceRegions;
private TaskCompletionSource<bool> _connectTask;
private TaskCompletionSource<bool> _connectTask;
private CancellationTokenSource _heartbeatC ancelToken;
private Task _heartbeatTask, _reconnectTask;
private CancellationTokenSource _c ancelToken;
private Task _heartbeatTask, _guildDownloadTask, _ reconnectTask;
private long _heartbeatTime;
private long _heartbeatTime;
private bool _isReconnecting;
private bool _isReconnecting;
private int _unavailableGuilds;
private long _lastGuildAvailableTime;
/// <summary> Gets the shard if of this client. </summary>
/// <summary> Gets the shard if of this client. </summary>
public int ShardId { get; }
public int ShardId { get; }
@@ -74,15 +75,7 @@ namespace Discord
internal CachedSelfUser CurrentUser => _currentUser as CachedSelfUser;
internal CachedSelfUser CurrentUser => _currentUser as CachedSelfUser;
internal IReadOnlyCollection<CachedGuild> Guilds => DataStore.Guilds;
internal IReadOnlyCollection<CachedGuild> Guilds => DataStore.Guilds;
internal IReadOnlyCollection<CachedDMChannel> DMChannels
{
get
{
var dmChannels = _dmChannels;
var store = DataStore;
return dmChannels.Select(x => store.GetChannel(x) as CachedDMChannel).Where(x => x != null).ToReadOnlyCollection(dmChannels);
}
}
internal IReadOnlyCollection<CachedDMChannel> DMChannels => DataStore.DMChannels;
internal IReadOnlyCollection<VoiceRegion> VoiceRegions => _voiceRegions.ToReadOnlyCollection();
internal IReadOnlyCollection<VoiceRegion> VoiceRegions => _voiceRegions.ToReadOnlyCollection();
/// <summary> Creates a new REST/WebSocket discord client. </summary>
/// <summary> Creates a new REST/WebSocket discord client. </summary>
@@ -132,7 +125,6 @@ namespace Discord
_voiceRegions = ImmutableDictionary.Create<string, VoiceRegion>();
_voiceRegions = ImmutableDictionary.Create<string, VoiceRegion>();
_largeGuilds = new ConcurrentQueue<ulong>();
_largeGuilds = new ConcurrentQueue<ulong>();
_dmChannels = new ConcurrentHashSet<ulong>();
}
}
protected override async Task OnLoginAsync()
protected override async Task OnLoginAsync()
@@ -169,10 +161,11 @@ namespace Discord
try
try
{
{
_connectTask = new TaskCompletionSource<bool>();
_connectTask = new TaskCompletionSource<bool>();
_heartbeatC ancelToken = new CancellationTokenSource();
_c ancelToken = new CancellationTokenSource();
await ApiClient.ConnectAsync().ConfigureAwait(false);
await ApiClient.ConnectAsync().ConfigureAwait(false);
await _connectTask.Task.ConfigureAwait(false);
await _connectTask.Task.ConfigureAwait(false);
ConnectionState = ConnectionState.Connected;
ConnectionState = ConnectionState.Connected;
await _gatewayLogger.InfoAsync("Connected");
await _gatewayLogger.InfoAsync("Connected");
}
}
@@ -203,9 +196,24 @@ namespace Discord
ConnectionState = ConnectionState.Disconnecting;
ConnectionState = ConnectionState.Disconnecting;
await _gatewayLogger.InfoAsync("Disconnecting");
await _gatewayLogger.InfoAsync("Disconnecting");
try { _heartbeatCancelToken.Cancel(); } catch { }
//Signal tasks to complete
try { _cancelToken.Cancel(); } catch { }
//Disconnect from server
await ApiClient.DisconnectAsync().ConfigureAwait(false);
await ApiClient.DisconnectAsync().ConfigureAwait(false);
await _heartbeatTask.ConfigureAwait(false);
//Wait for tasks to complete
var heartbeatTask = _heartbeatTask;
if (heartbeatTask != null)
await heartbeatTask.ConfigureAwait(false);
_heartbeatTask = null;
var guildDownloadTask = _guildDownloadTask;
if (guildDownloadTask != null)
await guildDownloadTask.ConfigureAwait(false);
_guildDownloadTask = null;
//Clear large guild queue
while (_largeGuilds.TryDequeue(out guildId)) { }
while (_largeGuilds.TryDequeue(out guildId)) { }
ConnectionState = ConnectionState.Disconnected;
ConnectionState = ConnectionState.Disconnected;
@@ -216,22 +224,21 @@ namespace Discord
private async Task StartReconnectAsync()
private async Task StartReconnectAsync()
{
{
//TODO: Is this thread-safe?
//TODO: Is this thread-safe?
while (true)
await _log.InfoAsync("Debug", "Trying to reconnect...").ConfigureAwait(false);
if (_reconnectTask != null) return;
await _connectionLock.WaitAsync().ConfigureAwait(false);
try
{
{
if (_reconnectTask != null) return;
if (_reconnectTask != null) return;
await _connectionLock.WaitAsync().ConfigureAwait(false);
try
{
if (_reconnectTask != null) return;
_isReconnecting = true;
_reconnectTask = ReconnectInternalAsync();
}
finally { _connectionLock.Release(); }
_isReconnecting = true;
_reconnectTask = ReconnectInternalAsync();
}
}
finally { _connectionLock.Release(); }
}
}
private async Task ReconnectInternalAsync()
private async Task ReconnectInternalAsync()
{
{
await _log.InfoAsync("Debug", "Reconnecting...").ConfigureAwait(false);
try
try
{
{
int nextReconnectDelay = 1000;
int nextReconnectDelay = 1000;
@@ -255,13 +262,18 @@ namespace Discord
catch (Exception ex)
catch (Exception ex)
{
{
await _gatewayLogger.WarningAsync("Reconnect failed", ex).ConfigureAwait(false);
await _gatewayLogger.WarningAsync("Reconnect failed", ex).ConfigureAwait(false);
} }
}
}
}
}
finally
finally
{
{
_isReconnecting = false;
_reconnectTask = null;
await _connectionLock.WaitAsync().ConfigureAwait(false);
try
{
_isReconnecting = false;
_reconnectTask = null;
}
finally { _connectionLock.Release(); }
}
}
}
}
@@ -318,20 +330,22 @@ namespace Discord
{
{
return Task.FromResult<IReadOnlyCollection<IDMChannel>>(DMChannels);
return Task.FromResult<IReadOnlyCollection<IDMChannel>>(DMChannels);
}
}
internal CachedDMChannel AddDMChannel(API.Channel model, DataStore dataStore, ConcurrentHashSet<ulong> dmChannels )
internal CachedDMChannel AddDMChannel(API.Channel model, DataStore dataStore)
{
{
var recipient = GetOrAddUser(model.Recipient.Value, dataStore);
var recipient = GetOrAddUser(model.Recipient.Value, dataStore);
var channel = recipient.AddDMChannel(this , model);
dataStore.AddChannel(channel );
dmChannels.TryAdd(model.Id );
var channel = new CachedDMChannel(this, new CachedDMUser(recipient) , model);
recipient.AddRef( );
dataStore.AddDMChannel(channel );
return channel;
return channel;
}
}
internal CachedDMChannel RemoveDMChannel(ulong id)
internal CachedDMChannel RemoveDMChannel(ulong id)
{
{
var dmChannel = DataStore.RemoveChannel(id) as CachedDMChannel;
var recipient = dmChannel.Recipient;
recipient.RemoveDMChannel(id);
_dmChannels.TryRemove(id);
var dmChannel = DataStore.RemoveDMChannel(id);
if (dmChannel != null)
{
var recipient = dmChannel.Recipient;
recipient.User.RemoveRef(this);
}
return dmChannel;
return dmChannel;
}
}
@@ -345,13 +359,13 @@ namespace Discord
{
{
return Task.FromResult<IUser>(DataStore.Users.Where(x => x.Discriminator == discriminator && x.Username == username).FirstOrDefault());
return Task.FromResult<IUser>(DataStore.Users.Where(x => x.Discriminator == discriminator && x.Username == username).FirstOrDefault());
}
}
internal CachedPublic User GetOrAddUser(API.User model, DataStore dataStore)
internal CachedGlobal User GetOrAddUser(API.User model, DataStore dataStore)
{
{
var user = dataStore.GetOrAddUser(model.Id, _ => new CachedPublic User(model));
var user = dataStore.GetOrAddUser(model.Id, _ => new CachedGlobal User(model));
user.AddRef();
user.AddRef();
return user;
return user;
}
}
internal CachedPublic User RemoveUser(ulong id)
internal CachedGlobal User RemoveUser(ulong id)
{
{
return DataStore.RemoveUser(id);
return DataStore.RemoveUser(id);
}
}
@@ -425,7 +439,7 @@ namespace Discord
else
else
await ApiClient.SendIdentifyAsync().ConfigureAwait(false);
await ApiClient.SendIdentifyAsync().ConfigureAwait(false);
_heartbeatTime = 0;
_heartbeatTime = 0;
_heartbeatTask = RunHeartbeatAsync(data.HeartbeatInterval, _heartbeatC ancelToken.Token);
_heartbeatTask = RunHeartbeatAsync(data.HeartbeatInterval, _c ancelToken.Token);
}
}
break;
break;
case GatewayOpCode.Heartbeat:
case GatewayOpCode.Heartbeat:
@@ -439,12 +453,16 @@ namespace Discord
{
{
await _gatewayLogger.DebugAsync("Received HeartbeatAck").ConfigureAwait(false);
await _gatewayLogger.DebugAsync("Received HeartbeatAck").ConfigureAwait(false);
var latency = (int)(Environment.TickCount - _heartbeatTime);
_heartbeatTime = 0;
await _gatewayLogger.DebugAsync($"Latency = {latency} ms").ConfigureAwait(false);
Latency = latency;
var heartbeatTime = _heartbeatTime;
if (heartbeatTime != 0)
{
var latency = (int)(Environment.TickCount - _heartbeatTime);
_heartbeatTime = 0;
await _gatewayLogger.VerboseAsync($"Latency = {latency} ms").ConfigureAwait(false);
Latency = latency;
await LatencyUpdated.RaiseAsync(latency).ConfigureAwait(false);
await LatencyUpdated.RaiseAsync(latency).ConfigureAwait(false);
}
}
}
break;
break;
case GatewayOpCode.InvalidSession:
case GatewayOpCode.InvalidSession:
@@ -475,21 +493,29 @@ namespace Discord
var data = (payload as JToken).ToObject<ReadyEvent>(_serializer);
var data = (payload as JToken).ToObject<ReadyEvent>(_serializer);
var dataStore = _dataStoreProvider(ShardId, _totalShards, data.Guilds.Length, data.PrivateChannels.Length);
var dataStore = _dataStoreProvider(ShardId, _totalShards, data.Guilds.Length, data.PrivateChannels.Length);
var dmChannels = new ConcurrentHashSet<ulong>();
var currentUser = new CachedSelfUser(this, data.User);
var currentUser = new CachedSelfUser(this, data.User);
int unavailableGuilds = 0;
//dataStore.GetOrAddUser(data.User.Id, _ => currentUser);
//dataStore.GetOrAddUser(data.User.Id, _ => currentUser);
for (int i = 0; i < data.Guilds.Length; i++)
for (int i = 0; i < data.Guilds.Length; i++)
AddGuild(data.Guilds[i], dataStore);
{
var model = data.Guilds[i];
AddGuild(model, dataStore);
if (model.Unavailable == true)
unavailableGuilds++;
}
for (int i = 0; i < data.PrivateChannels.Length; i++)
for (int i = 0; i < data.PrivateChannels.Length; i++)
AddDMChannel(data.PrivateChannels[i], dataStore, dmChannels);
AddDMChannel(data.PrivateChannels[i], dataStore);
_sessionId = data.SessionId;
_sessionId = data.SessionId;
_currentUser = currentUser;
_currentUser = currentUser;
_dmChannels = dmChannels;
_unavailableGuilds = unavailableGuilds;
_lastGuildAvailableTime = Environment.TickCount;
DataStore = dataStore;
DataStore = dataStore;
_guildDownloadTask = WaitForGuildsAsync(_cancelToken.Token);
await Ready.RaiseAsync().ConfigureAwait(false);
await Ready.RaiseAsync().ConfigureAwait(false);
_connectTask.TrySetResult(true); //Signal the .Connect() call to complete
_connectTask.TrySetResult(true); //Signal the .Connect() call to complete
@@ -503,7 +529,10 @@ namespace Discord
var data = (payload as JToken).ToObject<ExtendedGuild>(_serializer);
var data = (payload as JToken).ToObject<ExtendedGuild>(_serializer);
if (data.Unavailable == false)
if (data.Unavailable == false)
{
type = "GUILD_AVAILABLE";
type = "GUILD_AVAILABLE";
_lastGuildAvailableTime = Environment.TickCount;
}
await _gatewayLogger.DebugAsync($"Received Dispatch ({type})").ConfigureAwait(false);
await _gatewayLogger.DebugAsync($"Received Dispatch ({type})").ConfigureAwait(false);
CachedGuild guild;
CachedGuild guild;
@@ -511,6 +540,7 @@ namespace Discord
{
{
guild = AddGuild(data, DataStore);
guild = AddGuild(data, DataStore);
await JoinedGuild.RaiseAsync(guild).ConfigureAwait(false);
await JoinedGuild.RaiseAsync(guild).ConfigureAwait(false);
await _gatewayLogger.InfoAsync($"Joined {data.Name}").ConfigureAwait(false);
}
}
else
else
{
{
@@ -526,7 +556,7 @@ namespace Discord
if (data.Unavailable != true)
if (data.Unavailable != true)
{
{
await _gatewayLogger.Info Async($"Connected to {data.Name}").ConfigureAwait(false);
await _gatewayLogger.Verbose Async($"Connected to {data.Name}").ConfigureAwait(false);
await GuildAvailable.RaiseAsync(guild).ConfigureAwait(false);
await GuildAvailable.RaiseAsync(guild).ConfigureAwait(false);
}
}
}
}
@@ -564,7 +594,7 @@ namespace Discord
member.User.RemoveRef(this);
member.User.RemoveRef(this);
await GuildUnavailable.RaiseAsync(guild).ConfigureAwait(false);
await GuildUnavailable.RaiseAsync(guild).ConfigureAwait(false);
await _gatewayLogger.Info Async($"Disconnected from {data.Name}").ConfigureAwait(false);
await _gatewayLogger.Verbose Async($"Disconnected from {data.Name}").ConfigureAwait(false);
if (data.Unavailable != true)
if (data.Unavailable != true)
{
{
await LeftGuild.RaiseAsync(guild).ConfigureAwait(false);
await LeftGuild.RaiseAsync(guild).ConfigureAwait(false);
@@ -587,7 +617,7 @@ namespace Discord
var data = (payload as JToken).ToObject<API.Channel>(_serializer);
var data = (payload as JToken).ToObject<API.Channel>(_serializer);
ICachedChannel channel = null;
ICachedChannel channel = null;
if (data.GuildId.IsSpecified )
if (!data.IsPrivate )
{
{
var guild = DataStore.GetGuild(data.GuildId.Value);
var guild = DataStore.GetGuild(data.GuildId.Value);
if (guild != null)
if (guild != null)
@@ -599,7 +629,7 @@ namespace Discord
}
}
}
}
else
else
channel = AddDMChannel(data, DataStore, _dmChannels );
channel = AddDMChannel(data, DataStore);
if (channel != null)
if (channel != null)
await ChannelCreated.RaiseAsync(channel).ConfigureAwait(false);
await ChannelCreated.RaiseAsync(channel).ConfigureAwait(false);
}
}
@@ -629,7 +659,7 @@ namespace Discord
ICachedChannel channel = null;
ICachedChannel channel = null;
var data = (payload as JToken).ToObject<API.Channel>(_serializer);
var data = (payload as JToken).ToObject<API.Channel>(_serializer);
if (data.GuildId.IsSpecified )
if (!data.IsPrivate )
{
{
var guild = DataStore.GetGuild(data.GuildId.Value);
var guild = DataStore.GetGuild(data.GuildId.Value);
if (guild != null)
if (guild != null)
@@ -975,9 +1005,9 @@ namespace Discord
}
}
else
else
{
{
var user = DataStore.GetUser (data.User.Id);
if (user = = null)
user .Update(data, UpdateSource.WebSocket);
var channel = DataStore.GetDMChannel (data.User.Id);
if (channel ! = null)
channel.Recipient .Update(data, UpdateSource.WebSocket);
}
}
}
}
break;
break;
@@ -1095,22 +1125,37 @@ namespace Discord
{
{
try
try
{
{
var state = ConnectionState;
while (state == ConnectionState.Connecting || state == ConnectionState.Connected)
while (!cancelToken.IsCancellationRequested)
{
{
await Task.Delay(intervalMillis, cancelToken).ConfigureAwait(false);
await Task.Delay(intervalMillis, cancelToken).ConfigureAwait(false);
if (_heartbeatTime != 0) //Server never responded to our last heartbeat
if (_heartbeatTime != 0) //Server never responded to our last heartbeat
{
{
await _gatewayLogger.WarningAsync("Server missed last heartbeat").ConfigureAwait(false);
await StartReconnectAsync().ConfigureAwait(false);
return;
if (ConnectionState == ConnectionState.Connected && (_guildDownloadTask?.IsCompleted ?? false))
{
await _gatewayLogger.WarningAsync("Server missed last heartbeat").ConfigureAwait(false);
await StartReconnectAsync().ConfigureAwait(false);
return;
}
}
}
_heartbeatTime = Environment.TickCount;
else
_heartbeatTime = Environment.TickCount;
await ApiClient.SendHeartbeatAsync(_lastSeq).ConfigureAwait(false);
await ApiClient.SendHeartbeatAsync(_lastSeq).ConfigureAwait(false);
}
}
}
}
catch (OperationCanceledException) { }
catch (OperationCanceledException) { }
}
}
private async Task WaitForGuildsAsync(CancellationToken cancelToken)
{
while ((_unavailableGuilds > 0) || (Environment.TickCount - _lastGuildAvailableTime > 2000))
await Task.Delay(500, cancelToken).ConfigureAwait(false);
}
public async Task WaitForGuildsAsync()
{
var downloadTask = _guildDownloadTask;
if (downloadTask != null)
await _guildDownloadTask.ConfigureAwait(false);
}
}
}
}
}