@@ -1,5 +1,4 @@
using Discord.API;
using Discord.API.Gateway;
using Discord.API.Gateway;
using Discord.Data;
using Discord.Extensions;
using Discord.Logging;
@@ -11,19 +10,23 @@ using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Diagnostics;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
namespace Discord
{
//TODO: Remove unnecessary `as` casts
//TODO: Add docstrings
//TODO: Add event docstrings
//TODO: Add reconnect logic (+ensure the heartbeat task shuts down)
//TODO: Add resume logic
public class DiscordSocketClient : DiscordClient, IDiscordClient
{
public event Func<Task> Connected, Disconnected;
public event Func<Task> Ready;
//public event Func<Channel> VoiceConnected, VoiceDisconnected;
/* public event Func<IChannel, Task> ChannelCreated, ChannelDestroyed;
public event Func<IChannel, Task> ChannelCreated, ChannelDestroyed;
public event Func<IChannel, IChannel, Task> ChannelUpdated;
public event Func<IMessage, Task> MessageReceived, MessageDeleted;
public event Func<IMessage, IMessage, Task> MessageUpdated;
@@ -34,7 +37,8 @@ namespace Discord
public event Func<IUser, Task> UserJoined, UserLeft, UserBanned, UserUnbanned;
public event Func<IUser, IUser, Task> UserUpdated;
public event Func<ISelfUser, ISelfUser, Task> CurrentUserUpdated;
public event Func<IChannel, IUser, Task> UserIsTyping;*/
public event Func<IChannel, IUser, Task> UserIsTyping;
public event Func<int, Task> LatencyUpdated;
private readonly ConcurrentQueue<ulong> _largeGuilds;
private readonly Logger _gatewayLogger;
@@ -44,13 +48,21 @@ namespace Discord
private readonly bool _enablePreUpdateEvents;
private readonly int _largeThreshold;
private readonly int _totalShards;
private ImmutableDictionary<string, VoiceRegion> _voiceRegions;
private string _sessionId;
private int _lastSeq;
private ImmutableDictionary<string, VoiceRegion> _voiceRegions;
private TaskCompletionSource<bool> _connectTask;
private CancellationTokenSource _heartbeatCancelToken;
private Task _heartbeatTask;
private long _heartbeatTime;
/// <summary> Gets the shard if of this client. </summary>
public int ShardId { get; }
/// <summary> Gets the current connection state of this client. </summary>
public ConnectionState ConnectionState { get; private set; }
public IWebSocketClient GatewaySocket { get; private set; }
/// <summary> Gets the estimated round-trip latency to the gateway server. </summary>
public int Latency { get; private set; }
internal IWebSocketClient GatewaySocket { get; private set; }
internal int MessageCacheSize { get; private set; }
//internal bool UsePermissionCache { get; private set; }
internal DataStore DataStore { get; private set; }
@@ -61,7 +73,7 @@ namespace Discord
get
{
var guilds = DataStore.Guilds;
return guilds.Select(x => x as CachedGuild). ToReadOnlyCollection(guilds);
return guilds.ToReadOnlyCollection(guilds);
}
}
internal IReadOnlyCollection<CachedDMChannel> DMChannels
@@ -69,13 +81,15 @@ namespace Discord
get
{
var users = DataStore.Users;
return users.Select(x => ( x as CachedPublicUser) .DMChannel).Where(x => x != null).ToReadOnlyCollection(users);
return users.Select(x => x.DMChannel).Where(x => x != null).ToReadOnlyCollection(users);
}
}
internal IReadOnlyCollection<VoiceRegion> VoiceRegions => _voiceRegions.ToReadOnlyCollection();
/// <summary> Creates a new discord client using the REST and WebSocket APIs. </summary>
public DiscordSocketClient()
: this(new DiscordSocketConfig()) { }
/// <summary> Creates a new discord client using the REST and WebSocket APIs. </summary>
public DiscordSocketClient(DiscordSocketConfig config)
: base(config)
{
@@ -117,6 +131,7 @@ namespace Discord
_voiceRegions = ImmutableDictionary.Create<string, VoiceRegion>();
}
/// <inheritdoc />
public async Task Connect()
{
await _connectionLock.WaitAsync().ConfigureAwait(false);
@@ -135,6 +150,7 @@ namespace Discord
try
{
_connectTask = new TaskCompletionSource<bool>();
_heartbeatCancelToken = new CancellationTokenSource();
await ApiClient.Connect().ConfigureAwait(false);
await _connectTask.Task.ConfigureAwait(false);
@@ -148,6 +164,7 @@ namespace Discord
await Connected.Raise().ConfigureAwait(false);
}
/// <inheritdoc />
public async Task Disconnect()
{
await _connectionLock.WaitAsync().ConfigureAwait(false);
@@ -165,13 +182,15 @@ namespace Discord
ConnectionState = ConnectionState.Disconnecting;
await ApiClient.Disconnect().ConfigureAwait(false);
await _heartbeatTask.ConfigureAwait(false);
while (_largeGuilds.TryDequeue(out guildId)) { }
ConnectionState = ConnectionState.Disconnected;
await Disconnected.Raise().ConfigureAwait(false);
}
/// <inheritdoc />
public override Task<IVoiceRegion> GetVoiceRegion(string id)
{
VoiceRegion region;
@@ -180,6 +199,7 @@ namespace Discord
return Task.FromResult<IVoiceRegion>(null);
}
/// <inheritdoc />
public override Task<IGuild> GetGuild(ulong id)
{
return Task.FromResult<IGuild>(DataStore.GetGuild(id));
@@ -192,7 +212,7 @@ namespace Discord
if (model.Unavailable != true)
{
for (int i = 0; i < model.Channels.Length; i++)
AddCachedChannel(model.Channels[i], dataStore);
AddCachedChannel(guild, model.Channels[i], dataStore);
}
dataStore.AddGuild(guild);
if (model.Large)
@@ -203,7 +223,7 @@ namespace Discord
{
dataStore = dataStore ?? DataStore;
var guild = dataStore.RemoveGuild(id) as CachedGuild ;
var guild = dataStore.RemoveGuild(id);
foreach (var channel in guild.Channels)
guild.RemoveCachedChannel(channel.Id);
foreach (var user in guild.Members)
@@ -211,25 +231,25 @@ namespace Discord
return guild;
}
/// <inheritdoc />
public override Task<IChannel> GetChannel(ulong id)
{
return Task.FromResult<IChannel>(DataStore.GetChannel(id));
}
internal ICachedChannel AddCachedChannel(API.Channel model, DataStore dataStore = null)
internal ICachedGuild Channel AddCachedChannel(CachedGuild guild, API.Channel model, DataStore dataStore = null)
{
dataStore = dataStore ?? DataStore;
ICachedChannel channel;
if (model.IsPrivate)
{
var recipient = AddCachedUser(model.Recipient, dataStore);
channel = recipient.SetDMChannel(model);
}
else
{
var guild = dataStore.GetGuild(model.GuildId.Value);
channel = guild.AddCachedChannel(model);
}
var channel = guild.AddCachedChannel(model);
dataStore.AddChannel(channel);
return channel;
}
internal CachedDMChannel AddCachedDMChannel(API.Channel model, DataStore dataStore = null)
{
dataStore = dataStore ?? DataStore;
var recipient = AddCachedUser(model.Recipient, dataStore);
var channel = recipient.AddDMChannel(model);
dataStore.AddChannel(channel);
return channel;
}
@@ -237,8 +257,8 @@ namespace Discord
{
dataStore = dataStore ?? DataStore;
//TODO: C#7
var channel = DataStore.RemoveChannel(id) as ICachedChannel ;
//TODO: C#7 Typeswitch Candidate
var channel = DataStore.RemoveChannel(id);
var guildChannel = channel as ICachedGuildChannel;
if (guildChannel != null)
@@ -258,10 +278,12 @@ namespace Discord
return null;
}
/// <inheritdoc />
public override Task<IUser> GetUser(ulong id)
{
return Task.FromResult<IUser>(DataStore.GetUser(id));
}
/// <inheritdoc />
public override Task<IUser> GetUser(string username, string discriminator)
{
return Task.FromResult<IUser>(DataStore.Users.Where(x => x.Discriminator == discriminator && x.Username == username).FirstOrDefault());
@@ -270,7 +292,7 @@ namespace Discord
{
dataStore = dataStore ?? DataStore;
var user = dataStore.GetOrAddUser(model.Id, _ => new CachedPublicUser(this, model)) as CachedPublicUser ;
var user = dataStore.GetOrAddUser(model.Id, _ => new CachedPublicUser(this, model));
user.AddRef();
return user;
}
@@ -278,22 +300,34 @@ namespace Discord
{
dataStore = dataStore ?? DataStore;
var user = dataStore.GetUser(id) as CachedPublicUser ;
var user = dataStore.GetUser(id);
user.RemoveRef();
return user;
}
private async Task ProcessMessage(GatewayOpCode opCode, string type, JToken payload)
private async Task ProcessMessage(GatewayOpCode opCode, int? seq, string type, object payload)
{
if (seq != null)
_lastSeq = seq.Value;
try
{
switch (opCode)
{
case GatewayOpCode.Hello:
{
var data = payload.ToObject<HelloEvent>(_serializer);
var data = ( payload as JToken) .ToObject<HelloEvent>(_serializer);
await ApiClient.SendIdentify().ConfigureAwait(false);
_heartbeatTask = RunHeartbeat(data.HeartbeatInterval, _heartbeatCancelToken.Token);
}
break;
case GatewayOpCode.HeartbeatAck:
{
var latency = (int)(Environment.TickCount - _heartbeatTime);
await _gatewayLogger.Debug($"Latency: {latency} ms").ConfigureAwait(false);
Latency = latency;
await LatencyUpdated.Raise(latency).ConfigureAwait(false);
}
break;
case GatewayOpCode.Dispatch:
@@ -303,15 +337,15 @@ namespace Discord
case "READY":
{
//TODO: Make downloading large guilds optional
var data = payload.ToObject<ReadyEvent>(_serializer);
var data = ( payload as JToken) .ToObject<ReadyEvent>(_serializer);
var dataStore = _dataStoreProvider(ShardId, _totalShards, data.Guilds.Length, data.PrivateChannels.Length);
_currentUser = new CachedSelfUser(this,data.User);
_currentUser = new CachedSelfUser(this, data.User);
for (int i = 0; i < data.Guilds.Length; i++)
AddCachedGuild(data.Guilds[i], dataStore);
for (int i = 0; i < data.PrivateChannels.Length; i++)
AddCachedChannel(data.PrivateChannels[i], dataStore);
AddCachedDM Channel(data.PrivateChannels[i], dataStore);
_sessionId = data.SessionId;
DataStore = dataStore;
@@ -323,9 +357,9 @@ namespace Discord
break;
//Guilds
/* case "GUILD_CREATE":
case "GUILD_CREATE":
{
var data = payload.ToObject<ExtendedGuild>(_serializer);
var data = ( payload as JToken) .ToObject<ExtendedGuild>(_serializer);
var guild = new CachedGuild(this, data);
DataStore.AddGuild(guild);
@@ -342,12 +376,12 @@ namespace Discord
break;
case "GUILD_UPDATE":
{
var data = payload.ToObject<API.Guild>(_serializer);
var data = ( payload as JToken) .ToObject<API.Guild>(_serializer);
var guild = DataStore.GetGuild(data.Id);
if (guild != null)
{
var before = _enablePreUpdateEvents ? guild.Clone() : null;
guild.Update(data);
guild.Update(data, UpdateSource.WebSocket );
await GuildUpdated.Raise(before, guild);
}
else
@@ -356,7 +390,7 @@ namespace Discord
break;
case "GUILD_DELETE":
{
var data = payload.ToObject<ExtendedGuild>(_serializer);
var data = ( payload as JToken) .ToObject<ExtendedGuild>(_serializer);
var guild = DataStore.RemoveGuild(data.Id);
if (guild != null)
{
@@ -375,34 +409,34 @@ namespace Discord
//Channels
case "CHANNEL_CREATE":
{
var data = payload.ToObject<API.Channel>(_serializer);
var data = ( payload as JToken) .ToObject<API.Channel>(_serializer);
IChannel channel = null;
ICachedC hannel channel = null;
if (data.GuildId != null)
{
var guild = GetCached Guild(data.GuildId.Value);
var guild = DataStore.Get Guild(data.GuildId.Value);
if (guild != null)
channel = guild.AddCachedChannel(data.Id, true);
{
channel = guild.AddCachedChannel(data);
DataStore.AddChannel(channel);
}
else
await _gatewayLogger.Warning("CHANNEL_CREATE referenced an unknown guild.");
}
else
channel = AddCachedPrivateChannel(data.Id, data.Recipient.Id );
channel = AddCachedDMChannel(data );
if (channel != null)
{
channel.Update(data);
await ChannelCreated.Raise(channel);
}
}
break;
case "CHANNEL_UPDATE":
{
var data = payload.ToObject<API.Channel>(_serializer);
var channel = DataStore.GetChannel(data.Id) as Channel ;
var data = ( payload as JToken) .ToObject<API.Channel>(_serializer);
var channel = DataStore.GetChannel(data.Id);
if (channel != null)
{
var before = _enablePreUpdateEvents ? channel.Clone() : null;
channel.Update(data);
channel.Update(data, UpdateSource.WebSocket );
await ChannelUpdated.Raise(before, channel);
}
else
@@ -411,7 +445,7 @@ namespace Discord
break;
case "CHANNEL_DELETE":
{
var data = payload.ToObject<API.Channel>(_serializer);
var data = ( payload as JToken) .ToObject<API.Channel>(_serializer);
var channel = RemoveCachedChannel(data.Id);
if (channel != null)
await ChannelDestroyed.Raise(channel);
@@ -421,9 +455,9 @@ namespace Discord
break;
//Members
case "GUILD_MEMBER_ADD":
/* case "GUILD_MEMBER_ADD":
{
var data = payload.ToObject<API.GuildMember>(_serializer);
var data = ( payload as JToken) .ToObject<API.GuildMember>(_serializer);
var guild = GetGuild(data.GuildId.Value);
if (guild != null)
{
@@ -438,7 +472,7 @@ namespace Discord
break;
case "GUILD_MEMBER_UPDATE":
{
var data = payload.ToObject<API.GuildMember>(_serializer);
var data = ( payload as JToken) .ToObject<API.GuildMember>(_serializer);
var guild = GetGuild(data.GuildId.Value);
if (guild != null)
{
@@ -458,7 +492,7 @@ namespace Discord
break;
case "GUILD_MEMBER_REMOVE":
{
var data = payload.ToObject<API.GuildMember>(_serializer);
var data = ( payload as JToken) .ToObject<API.GuildMember>(_serializer);
var guild = GetGuild(data.GuildId.Value);
if (guild != null)
{
@@ -479,7 +513,7 @@ namespace Discord
break;
case "GUILD_MEMBERS_CHUNK":
{
var data = payload.ToObject<GuildMembersChunkEvent>(_serializer);
var data = ( payload as JToken) .ToObject<GuildMembersChunkEvent>(_serializer);
var guild = GetCachedGuild(data.GuildId);
if (guild != null)
{
@@ -498,9 +532,9 @@ namespace Discord
break;
//Roles
case "GUILD_ROLE_CREATE":
/* case "GUILD_ROLE_CREATE":
{
var data = payload.ToObject<GuildRoleCreateEvent>(_serializer);
var data = ( payload as JToken) .ToObject<GuildRoleCreateEvent>(_serializer);
var guild = GetCachedGuild(data.GuildId);
if (guild != null)
{
@@ -514,7 +548,7 @@ namespace Discord
break;
case "GUILD_ROLE_UPDATE":
{
var data = payload.ToObject<GuildRoleUpdateEvent>(_serializer);
var data = ( payload as JToken) .ToObject<GuildRoleUpdateEvent>(_serializer);
var guild = GetCachedGuild(data.GuildId);
if (guild != null)
{
@@ -534,8 +568,8 @@ namespace Discord
break;
case "GUILD_ROLE_DELETE":
{
var data = payload.ToObject<GuildRoleDeleteEvent>(_serializer);
var guild = DataStore.GetGuild(data.GuildId) as CachedGuild ;
var data = ( payload as JToken) .ToObject<GuildRoleDeleteEvent>(_serializer);
var guild = DataStore.GetGuild(data.GuildId);
if (guild != null)
{
var role = guild.RemoveRole(data.RoleId);
@@ -552,7 +586,7 @@ namespace Discord
//Bans
case "GUILD_BAN_ADD":
{
var data = payload.ToObject<GuildBanEvent>(_serializer);
var data = ( payload as JToken) .ToObject<GuildBanEvent>(_serializer);
var guild = GetCachedGuild(data.GuildId);
if (guild != null)
await UserBanned.Raise(new User(this, data));
@@ -574,8 +608,7 @@ namespace Discord
//Messages
case "MESSAGE_CREATE":
{
var data = payload.ToObject<API.Message>(_serializer);
var data = (payload as JToken).ToObject<API.Message>(_serializer);
var channel = DataStore.GetChannel(data.ChannelId);
if (channel != null)
{
@@ -599,7 +632,7 @@ namespace Discord
break;
case "MESSAGE_UPDATE":
{
var data = payload.ToObject<API.Message>(_serializer);
var data = ( payload as JToken) .ToObject<API.Message>(_serializer);
var channel = GetCachedChannel(data.ChannelId);
if (channel != null)
{
@@ -614,7 +647,7 @@ namespace Discord
break;
case "MESSAGE_DELETE":
{
var data = payload.ToObject<API.Message>(_serializer);
var data = ( payload as JToken) .ToObject<API.Message>(_serializer);
var channel = GetCachedChannel(data.ChannelId);
if (channel != null)
{
@@ -629,7 +662,7 @@ namespace Discord
//Statuses
case "PRESENCE_UPDATE":
{
var data = payload.ToObject<API.Presence>(_serializer);
var data = ( payload as JToken) .ToObject<API.Presence>(_serializer);
User user;
Guild guild;
if (data.GuildId == null)
@@ -664,7 +697,7 @@ namespace Discord
break;
case "TYPING_START":
{
var data = payload.ToObject<TypingStartEvent>(_serializer);
var data = ( payload as JToken) .ToObject<TypingStartEvent>(_serializer);
var channel = GetCachedChannel(data.ChannelId);
if (channel != null)
{
@@ -683,7 +716,7 @@ namespace Discord
//Voice
case "VOICE_STATE_UPDATE":
{
var data = payload.ToObject<API.VoiceState>(_serializer);
var data = ( payload as JToken) .ToObject<API.VoiceState>(_serializer);
var guild = GetGuild(data.GuildId);
if (guild != null)
{
@@ -708,7 +741,7 @@ namespace Discord
//Settings
case "USER_UPDATE":
{
var data = payload.ToObject<SelfUser>(_serializer);
var data = ( payload as JToken) .ToObject<SelfUser>(_serializer);
if (data.Id == CurrentUser.Id)
{
var before = _enablePreUpdateEvents ? CurrentUser.Clone() : null;
@@ -746,5 +779,17 @@ namespace Discord
}
await _gatewayLogger.Debug($"Received {opCode}{(type != null ? $" ({type})" : "")}").ConfigureAwait(false);
}
private async Task RunHeartbeat(int intervalMillis, CancellationToken cancelToken)
{
var state = ConnectionState;
while (state == ConnectionState.Connecting || state == ConnectionState.Connected)
{
//if (_heartbeatTime != 0) //TODO: Connection lost, reconnect
_heartbeatTime = Environment.TickCount;
await ApiClient.SendHeartbeat(_lastSeq).ConfigureAwait(false);
await Task.Delay(intervalMillis, cancelToken).ConfigureAwait(false);
}
}
}
}