Browse Source

Several performance/memory improvements. Renamed CachedPublicUser -> CachedGlobalUser.

tags/1.0-rc
RogueException 9 years ago
parent
commit
b38455f427
15 changed files with 348 additions and 249 deletions
  1. +8
    -8
      src/Discord.Net/Data/DefaultDataStore.cs
  2. +4
    -4
      src/Discord.Net/Data/IDataStore.cs
  3. +111
    -66
      src/Discord.Net/DiscordSocketClient.cs
  4. +6
    -5
      src/Discord.Net/Entities/Channels/DMChannel.cs
  5. +32
    -18
      src/Discord.Net/Entities/Channels/GuildChannel.cs
  6. +3
    -2
      src/Discord.Net/Entities/Users/GuildUser.cs
  7. +1
    -3
      src/Discord.Net/Entities/Users/User.cs
  8. +7
    -4
      src/Discord.Net/Entities/WebSocket/CachedDMChannel.cs
  9. +38
    -0
      src/Discord.Net/Entities/WebSocket/CachedDMUser.cs
  10. +39
    -0
      src/Discord.Net/Entities/WebSocket/CachedGlobalUser.cs
  11. +3
    -3
      src/Discord.Net/Entities/WebSocket/CachedGuildUser.cs
  12. +0
    -75
      src/Discord.Net/Entities/WebSocket/CachedPublicUser.cs
  13. +5
    -2
      src/Discord.Net/Entities/WebSocket/CachedTextChannel.cs
  14. +10
    -59
      src/Discord.Net/Entities/WebSocket/MessageCache.cs
  15. +81
    -0
      src/Discord.Net/Entities/WebSocket/MessageManager.cs

+ 8
- 8
src/Discord.Net/Data/DefaultDataStore.cs View File

@@ -15,12 +15,12 @@ namespace Discord.Data
private readonly ConcurrentDictionary<ulong, ICachedChannel> _channels; private readonly ConcurrentDictionary<ulong, ICachedChannel> _channels;
private readonly ConcurrentDictionary<ulong, CachedDMChannel> _dmChannels; private readonly ConcurrentDictionary<ulong, CachedDMChannel> _dmChannels;
private readonly ConcurrentDictionary<ulong, CachedGuild> _guilds; private readonly ConcurrentDictionary<ulong, CachedGuild> _guilds;
private readonly ConcurrentDictionary<ulong, CachedPublicUser> _users;
private readonly ConcurrentDictionary<ulong, CachedGlobalUser> _users;


internal override IReadOnlyCollection<ICachedChannel> Channels => _channels.ToReadOnlyCollection(); internal override IReadOnlyCollection<ICachedChannel> Channels => _channels.ToReadOnlyCollection();
internal override IReadOnlyCollection<CachedDMChannel> DMChannels => _dmChannels.ToReadOnlyCollection(); internal override IReadOnlyCollection<CachedDMChannel> DMChannels => _dmChannels.ToReadOnlyCollection();
internal override IReadOnlyCollection<CachedGuild> Guilds => _guilds.ToReadOnlyCollection(); internal override IReadOnlyCollection<CachedGuild> Guilds => _guilds.ToReadOnlyCollection();
internal override IReadOnlyCollection<CachedPublicUser> Users => _users.ToReadOnlyCollection();
internal override IReadOnlyCollection<CachedGlobalUser> Users => _users.ToReadOnlyCollection();


public DefaultDataStore(int guildCount, int dmChannelCount) public DefaultDataStore(int guildCount, int dmChannelCount)
{ {
@@ -29,7 +29,7 @@ namespace Discord.Data
_channels = new ConcurrentDictionary<ulong, ICachedChannel>(CollectionConcurrencyLevel, (int)(estimatedChannelCount * CollectionMultiplier)); _channels = new ConcurrentDictionary<ulong, ICachedChannel>(CollectionConcurrencyLevel, (int)(estimatedChannelCount * CollectionMultiplier));
_dmChannels = new ConcurrentDictionary<ulong, CachedDMChannel>(CollectionConcurrencyLevel, (int)(dmChannelCount * CollectionMultiplier)); _dmChannels = new ConcurrentDictionary<ulong, CachedDMChannel>(CollectionConcurrencyLevel, (int)(dmChannelCount * CollectionMultiplier));
_guilds = new ConcurrentDictionary<ulong, CachedGuild>(CollectionConcurrencyLevel, (int)(guildCount * CollectionMultiplier)); _guilds = new ConcurrentDictionary<ulong, CachedGuild>(CollectionConcurrencyLevel, (int)(guildCount * CollectionMultiplier));
_users = new ConcurrentDictionary<ulong, CachedPublicUser>(CollectionConcurrencyLevel, (int)(estimatedUsersCount * CollectionMultiplier));
_users = new ConcurrentDictionary<ulong, CachedGlobalUser>(CollectionConcurrencyLevel, (int)(estimatedUsersCount * CollectionMultiplier));
} }


internal override ICachedChannel GetChannel(ulong id) internal override ICachedChannel GetChannel(ulong id)
@@ -94,20 +94,20 @@ namespace Discord.Data
return null; return null;
} }


internal override CachedPublicUser GetUser(ulong id)
internal override CachedGlobalUser GetUser(ulong id)
{ {
CachedPublicUser user;
CachedGlobalUser user;
if (_users.TryGetValue(id, out user)) if (_users.TryGetValue(id, out user))
return user; return user;
return null; return null;
} }
internal override CachedPublicUser GetOrAddUser(ulong id, Func<ulong, CachedPublicUser> userFactory)
internal override CachedGlobalUser GetOrAddUser(ulong id, Func<ulong, CachedGlobalUser> userFactory)
{ {
return _users.GetOrAdd(id, userFactory); return _users.GetOrAdd(id, userFactory);
} }
internal override CachedPublicUser RemoveUser(ulong id)
internal override CachedGlobalUser RemoveUser(ulong id)
{ {
CachedPublicUser user;
CachedGlobalUser user;
if (_users.TryRemove(id, out user)) if (_users.TryRemove(id, out user))
return user; return user;
return null; return null;


+ 4
- 4
src/Discord.Net/Data/IDataStore.cs View File

@@ -8,7 +8,7 @@ namespace Discord.Data
internal abstract IReadOnlyCollection<ICachedChannel> Channels { get; } internal abstract IReadOnlyCollection<ICachedChannel> Channels { get; }
internal abstract IReadOnlyCollection<CachedDMChannel> DMChannels { get; } internal abstract IReadOnlyCollection<CachedDMChannel> DMChannels { get; }
internal abstract IReadOnlyCollection<CachedGuild> Guilds { get; } internal abstract IReadOnlyCollection<CachedGuild> Guilds { get; }
internal abstract IReadOnlyCollection<CachedPublicUser> Users { get; }
internal abstract IReadOnlyCollection<CachedGlobalUser> Users { get; }


internal abstract ICachedChannel GetChannel(ulong id); internal abstract ICachedChannel GetChannel(ulong id);
internal abstract void AddChannel(ICachedChannel channel); internal abstract void AddChannel(ICachedChannel channel);
@@ -22,8 +22,8 @@ namespace Discord.Data
internal abstract void AddGuild(CachedGuild guild); internal abstract void AddGuild(CachedGuild guild);
internal abstract CachedGuild RemoveGuild(ulong id); internal abstract CachedGuild RemoveGuild(ulong id);


internal abstract CachedPublicUser GetUser(ulong id);
internal abstract CachedPublicUser GetOrAddUser(ulong userId, Func<ulong, CachedPublicUser> userFactory);
internal abstract CachedPublicUser RemoveUser(ulong id);
internal abstract CachedGlobalUser GetUser(ulong id);
internal abstract CachedGlobalUser GetOrAddUser(ulong userId, Func<ulong, CachedGlobalUser> userFactory);
internal abstract CachedGlobalUser RemoveUser(ulong id);
} }
} }

+ 111
- 66
src/Discord.Net/DiscordSocketClient.cs View File

@@ -13,7 +13,6 @@ using System.Collections.Immutable;
using System.Linq; using System.Linq;
using System.Threading; using System.Threading;
using System.Threading.Tasks; using System.Threading.Tasks;
using System.Diagnostics;


namespace Discord namespace Discord
{ {
@@ -51,15 +50,17 @@ namespace Discord
private readonly bool _enablePreUpdateEvents; private readonly bool _enablePreUpdateEvents;
private readonly int _largeThreshold; private readonly int _largeThreshold;
private readonly int _totalShards; private readonly int _totalShards;
private ConcurrentHashSet<ulong> _dmChannels;
private string _sessionId; private string _sessionId;
private int _lastSeq; private int _lastSeq;
private ImmutableDictionary<string, VoiceRegion> _voiceRegions; private ImmutableDictionary<string, VoiceRegion> _voiceRegions;
private TaskCompletionSource<bool> _connectTask; private TaskCompletionSource<bool> _connectTask;
private CancellationTokenSource _heartbeatCancelToken;
private Task _heartbeatTask, _reconnectTask;
private CancellationTokenSource _cancelToken;
private Task _heartbeatTask, _guildDownloadTask, _reconnectTask;
private long _heartbeatTime; private long _heartbeatTime;
private bool _isReconnecting; private bool _isReconnecting;
private int _unavailableGuilds;
private long _lastGuildAvailableTime;


/// <summary> Gets the shard if of this client. </summary> /// <summary> Gets the shard if of this client. </summary>
public int ShardId { get; } public int ShardId { get; }
@@ -74,15 +75,7 @@ namespace Discord


internal CachedSelfUser CurrentUser => _currentUser as CachedSelfUser; internal CachedSelfUser CurrentUser => _currentUser as CachedSelfUser;
internal IReadOnlyCollection<CachedGuild> Guilds => DataStore.Guilds; internal IReadOnlyCollection<CachedGuild> Guilds => DataStore.Guilds;
internal IReadOnlyCollection<CachedDMChannel> DMChannels
{
get
{
var dmChannels = _dmChannels;
var store = DataStore;
return dmChannels.Select(x => store.GetChannel(x) as CachedDMChannel).Where(x => x != null).ToReadOnlyCollection(dmChannels);
}
}
internal IReadOnlyCollection<CachedDMChannel> DMChannels => DataStore.DMChannels;
internal IReadOnlyCollection<VoiceRegion> VoiceRegions => _voiceRegions.ToReadOnlyCollection(); internal IReadOnlyCollection<VoiceRegion> VoiceRegions => _voiceRegions.ToReadOnlyCollection();


/// <summary> Creates a new REST/WebSocket discord client. </summary> /// <summary> Creates a new REST/WebSocket discord client. </summary>
@@ -132,7 +125,6 @@ namespace Discord


_voiceRegions = ImmutableDictionary.Create<string, VoiceRegion>(); _voiceRegions = ImmutableDictionary.Create<string, VoiceRegion>();
_largeGuilds = new ConcurrentQueue<ulong>(); _largeGuilds = new ConcurrentQueue<ulong>();
_dmChannels = new ConcurrentHashSet<ulong>();
} }


protected override async Task OnLoginAsync() protected override async Task OnLoginAsync()
@@ -169,10 +161,11 @@ namespace Discord
try try
{ {
_connectTask = new TaskCompletionSource<bool>(); _connectTask = new TaskCompletionSource<bool>();
_heartbeatCancelToken = new CancellationTokenSource();
_cancelToken = new CancellationTokenSource();
await ApiClient.ConnectAsync().ConfigureAwait(false); await ApiClient.ConnectAsync().ConfigureAwait(false);


await _connectTask.Task.ConfigureAwait(false); await _connectTask.Task.ConfigureAwait(false);
ConnectionState = ConnectionState.Connected; ConnectionState = ConnectionState.Connected;
await _gatewayLogger.InfoAsync("Connected"); await _gatewayLogger.InfoAsync("Connected");
} }
@@ -203,9 +196,24 @@ namespace Discord
ConnectionState = ConnectionState.Disconnecting; ConnectionState = ConnectionState.Disconnecting;
await _gatewayLogger.InfoAsync("Disconnecting"); await _gatewayLogger.InfoAsync("Disconnecting");


try { _heartbeatCancelToken.Cancel(); } catch { }
//Signal tasks to complete
try { _cancelToken.Cancel(); } catch { }

//Disconnect from server
await ApiClient.DisconnectAsync().ConfigureAwait(false); await ApiClient.DisconnectAsync().ConfigureAwait(false);
await _heartbeatTask.ConfigureAwait(false);

//Wait for tasks to complete
var heartbeatTask = _heartbeatTask;
if (heartbeatTask != null)
await heartbeatTask.ConfigureAwait(false);
_heartbeatTask = null;

var guildDownloadTask = _guildDownloadTask;
if (guildDownloadTask != null)
await guildDownloadTask.ConfigureAwait(false);
_guildDownloadTask = null;

//Clear large guild queue
while (_largeGuilds.TryDequeue(out guildId)) { } while (_largeGuilds.TryDequeue(out guildId)) { }


ConnectionState = ConnectionState.Disconnected; ConnectionState = ConnectionState.Disconnected;
@@ -216,22 +224,21 @@ namespace Discord
private async Task StartReconnectAsync() private async Task StartReconnectAsync()
{ {
//TODO: Is this thread-safe? //TODO: Is this thread-safe?
while (true)
await _log.InfoAsync("Debug", "Trying to reconnect...").ConfigureAwait(false);
if (_reconnectTask != null) return;

await _connectionLock.WaitAsync().ConfigureAwait(false);
try
{ {
if (_reconnectTask != null) return; if (_reconnectTask != null) return;

await _connectionLock.WaitAsync().ConfigureAwait(false);
try
{
if (_reconnectTask != null) return;
_isReconnecting = true;
_reconnectTask = ReconnectInternalAsync();
}
finally { _connectionLock.Release(); }
_isReconnecting = true;
_reconnectTask = ReconnectInternalAsync();
} }
finally { _connectionLock.Release(); }
} }
private async Task ReconnectInternalAsync() private async Task ReconnectInternalAsync()
{ {
await _log.InfoAsync("Debug", "Reconnecting...").ConfigureAwait(false);
try try
{ {
int nextReconnectDelay = 1000; int nextReconnectDelay = 1000;
@@ -255,13 +262,18 @@ namespace Discord
catch (Exception ex) catch (Exception ex)
{ {
await _gatewayLogger.WarningAsync("Reconnect failed", ex).ConfigureAwait(false); await _gatewayLogger.WarningAsync("Reconnect failed", ex).ConfigureAwait(false);
} }
}
}
} }
finally finally
{ {
_isReconnecting = false;
_reconnectTask = null;

await _connectionLock.WaitAsync().ConfigureAwait(false);
try
{
_isReconnecting = false;
_reconnectTask = null;
}
finally { _connectionLock.Release(); }
} }
} }


@@ -318,20 +330,22 @@ namespace Discord
{ {
return Task.FromResult<IReadOnlyCollection<IDMChannel>>(DMChannels); return Task.FromResult<IReadOnlyCollection<IDMChannel>>(DMChannels);
} }
internal CachedDMChannel AddDMChannel(API.Channel model, DataStore dataStore, ConcurrentHashSet<ulong> dmChannels)
internal CachedDMChannel AddDMChannel(API.Channel model, DataStore dataStore)
{ {
var recipient = GetOrAddUser(model.Recipient.Value, dataStore); var recipient = GetOrAddUser(model.Recipient.Value, dataStore);
var channel = recipient.AddDMChannel(this, model);
dataStore.AddChannel(channel);
dmChannels.TryAdd(model.Id);
var channel = new CachedDMChannel(this, new CachedDMUser(recipient), model);
recipient.AddRef();
dataStore.AddDMChannel(channel);
return channel; return channel;
} }
internal CachedDMChannel RemoveDMChannel(ulong id) internal CachedDMChannel RemoveDMChannel(ulong id)
{ {
var dmChannel = DataStore.RemoveChannel(id) as CachedDMChannel;
var recipient = dmChannel.Recipient;
recipient.RemoveDMChannel(id);
_dmChannels.TryRemove(id);
var dmChannel = DataStore.RemoveDMChannel(id);
if (dmChannel != null)
{
var recipient = dmChannel.Recipient;
recipient.User.RemoveRef(this);
}
return dmChannel; return dmChannel;
} }


@@ -345,13 +359,13 @@ namespace Discord
{ {
return Task.FromResult<IUser>(DataStore.Users.Where(x => x.Discriminator == discriminator && x.Username == username).FirstOrDefault()); return Task.FromResult<IUser>(DataStore.Users.Where(x => x.Discriminator == discriminator && x.Username == username).FirstOrDefault());
} }
internal CachedPublicUser GetOrAddUser(API.User model, DataStore dataStore)
internal CachedGlobalUser GetOrAddUser(API.User model, DataStore dataStore)
{ {
var user = dataStore.GetOrAddUser(model.Id, _ => new CachedPublicUser(model));
var user = dataStore.GetOrAddUser(model.Id, _ => new CachedGlobalUser(model));
user.AddRef(); user.AddRef();
return user; return user;
} }
internal CachedPublicUser RemoveUser(ulong id)
internal CachedGlobalUser RemoveUser(ulong id)
{ {
return DataStore.RemoveUser(id); return DataStore.RemoveUser(id);
} }
@@ -425,7 +439,7 @@ namespace Discord
else else
await ApiClient.SendIdentifyAsync().ConfigureAwait(false); await ApiClient.SendIdentifyAsync().ConfigureAwait(false);
_heartbeatTime = 0; _heartbeatTime = 0;
_heartbeatTask = RunHeartbeatAsync(data.HeartbeatInterval, _heartbeatCancelToken.Token);
_heartbeatTask = RunHeartbeatAsync(data.HeartbeatInterval, _cancelToken.Token);
} }
break; break;
case GatewayOpCode.Heartbeat: case GatewayOpCode.Heartbeat:
@@ -439,12 +453,16 @@ namespace Discord
{ {
await _gatewayLogger.DebugAsync("Received HeartbeatAck").ConfigureAwait(false); await _gatewayLogger.DebugAsync("Received HeartbeatAck").ConfigureAwait(false);


var latency = (int)(Environment.TickCount - _heartbeatTime);
_heartbeatTime = 0;
await _gatewayLogger.DebugAsync($"Latency = {latency} ms").ConfigureAwait(false);
Latency = latency;
var heartbeatTime = _heartbeatTime;
if (heartbeatTime != 0)
{
var latency = (int)(Environment.TickCount - _heartbeatTime);
_heartbeatTime = 0;
await _gatewayLogger.VerboseAsync($"Latency = {latency} ms").ConfigureAwait(false);
Latency = latency;


await LatencyUpdated.RaiseAsync(latency).ConfigureAwait(false);
await LatencyUpdated.RaiseAsync(latency).ConfigureAwait(false);
}
} }
break; break;
case GatewayOpCode.InvalidSession: case GatewayOpCode.InvalidSession:
@@ -475,21 +493,29 @@ namespace Discord
var data = (payload as JToken).ToObject<ReadyEvent>(_serializer); var data = (payload as JToken).ToObject<ReadyEvent>(_serializer);
var dataStore = _dataStoreProvider(ShardId, _totalShards, data.Guilds.Length, data.PrivateChannels.Length); var dataStore = _dataStoreProvider(ShardId, _totalShards, data.Guilds.Length, data.PrivateChannels.Length);
var dmChannels = new ConcurrentHashSet<ulong>();


var currentUser = new CachedSelfUser(this, data.User); var currentUser = new CachedSelfUser(this, data.User);
int unavailableGuilds = 0;
//dataStore.GetOrAddUser(data.User.Id, _ => currentUser); //dataStore.GetOrAddUser(data.User.Id, _ => currentUser);


for (int i = 0; i < data.Guilds.Length; i++) for (int i = 0; i < data.Guilds.Length; i++)
AddGuild(data.Guilds[i], dataStore);
{
var model = data.Guilds[i];
AddGuild(model, dataStore);
if (model.Unavailable == true)
unavailableGuilds++;
}
for (int i = 0; i < data.PrivateChannels.Length; i++) for (int i = 0; i < data.PrivateChannels.Length; i++)
AddDMChannel(data.PrivateChannels[i], dataStore, dmChannels);
AddDMChannel(data.PrivateChannels[i], dataStore);


_sessionId = data.SessionId; _sessionId = data.SessionId;
_currentUser = currentUser; _currentUser = currentUser;
_dmChannels = dmChannels;
_unavailableGuilds = unavailableGuilds;
_lastGuildAvailableTime = Environment.TickCount;
DataStore = dataStore; DataStore = dataStore;


_guildDownloadTask = WaitForGuildsAsync(_cancelToken.Token);

await Ready.RaiseAsync().ConfigureAwait(false); await Ready.RaiseAsync().ConfigureAwait(false);


_connectTask.TrySetResult(true); //Signal the .Connect() call to complete _connectTask.TrySetResult(true); //Signal the .Connect() call to complete
@@ -503,7 +529,10 @@ namespace Discord
var data = (payload as JToken).ToObject<ExtendedGuild>(_serializer); var data = (payload as JToken).ToObject<ExtendedGuild>(_serializer);


if (data.Unavailable == false) if (data.Unavailable == false)
{
type = "GUILD_AVAILABLE"; type = "GUILD_AVAILABLE";
_lastGuildAvailableTime = Environment.TickCount;
}
await _gatewayLogger.DebugAsync($"Received Dispatch ({type})").ConfigureAwait(false); await _gatewayLogger.DebugAsync($"Received Dispatch ({type})").ConfigureAwait(false);


CachedGuild guild; CachedGuild guild;
@@ -511,6 +540,7 @@ namespace Discord
{ {
guild = AddGuild(data, DataStore); guild = AddGuild(data, DataStore);
await JoinedGuild.RaiseAsync(guild).ConfigureAwait(false); await JoinedGuild.RaiseAsync(guild).ConfigureAwait(false);
await _gatewayLogger.InfoAsync($"Joined {data.Name}").ConfigureAwait(false);
} }
else else
{ {
@@ -526,7 +556,7 @@ namespace Discord


if (data.Unavailable != true) if (data.Unavailable != true)
{ {
await _gatewayLogger.InfoAsync($"Connected to {data.Name}").ConfigureAwait(false);
await _gatewayLogger.VerboseAsync($"Connected to {data.Name}").ConfigureAwait(false);
await GuildAvailable.RaiseAsync(guild).ConfigureAwait(false); await GuildAvailable.RaiseAsync(guild).ConfigureAwait(false);
} }
} }
@@ -564,7 +594,7 @@ namespace Discord
member.User.RemoveRef(this); member.User.RemoveRef(this);


await GuildUnavailable.RaiseAsync(guild).ConfigureAwait(false); await GuildUnavailable.RaiseAsync(guild).ConfigureAwait(false);
await _gatewayLogger.InfoAsync($"Disconnected from {data.Name}").ConfigureAwait(false);
await _gatewayLogger.VerboseAsync($"Disconnected from {data.Name}").ConfigureAwait(false);
if (data.Unavailable != true) if (data.Unavailable != true)
{ {
await LeftGuild.RaiseAsync(guild).ConfigureAwait(false); await LeftGuild.RaiseAsync(guild).ConfigureAwait(false);
@@ -587,7 +617,7 @@ namespace Discord


var data = (payload as JToken).ToObject<API.Channel>(_serializer); var data = (payload as JToken).ToObject<API.Channel>(_serializer);
ICachedChannel channel = null; ICachedChannel channel = null;
if (data.GuildId.IsSpecified)
if (!data.IsPrivate)
{ {
var guild = DataStore.GetGuild(data.GuildId.Value); var guild = DataStore.GetGuild(data.GuildId.Value);
if (guild != null) if (guild != null)
@@ -599,7 +629,7 @@ namespace Discord
} }
} }
else else
channel = AddDMChannel(data, DataStore, _dmChannels);
channel = AddDMChannel(data, DataStore);
if (channel != null) if (channel != null)
await ChannelCreated.RaiseAsync(channel).ConfigureAwait(false); await ChannelCreated.RaiseAsync(channel).ConfigureAwait(false);
} }
@@ -629,7 +659,7 @@ namespace Discord


ICachedChannel channel = null; ICachedChannel channel = null;
var data = (payload as JToken).ToObject<API.Channel>(_serializer); var data = (payload as JToken).ToObject<API.Channel>(_serializer);
if (data.GuildId.IsSpecified)
if (!data.IsPrivate)
{ {
var guild = DataStore.GetGuild(data.GuildId.Value); var guild = DataStore.GetGuild(data.GuildId.Value);
if (guild != null) if (guild != null)
@@ -975,9 +1005,9 @@ namespace Discord
} }
else else
{ {
var user = DataStore.GetUser(data.User.Id);
if (user == null)
user.Update(data, UpdateSource.WebSocket);
var channel = DataStore.GetDMChannel(data.User.Id);
if (channel != null)
channel.Recipient.Update(data, UpdateSource.WebSocket);
} }
} }
break; break;
@@ -1095,22 +1125,37 @@ namespace Discord
{ {
try try
{ {
var state = ConnectionState;
while (state == ConnectionState.Connecting || state == ConnectionState.Connected)
while (!cancelToken.IsCancellationRequested)
{ {
await Task.Delay(intervalMillis, cancelToken).ConfigureAwait(false); await Task.Delay(intervalMillis, cancelToken).ConfigureAwait(false);


if (_heartbeatTime != 0) //Server never responded to our last heartbeat if (_heartbeatTime != 0) //Server never responded to our last heartbeat
{ {
await _gatewayLogger.WarningAsync("Server missed last heartbeat").ConfigureAwait(false);
await StartReconnectAsync().ConfigureAwait(false);
return;
if (ConnectionState == ConnectionState.Connected && (_guildDownloadTask?.IsCompleted ?? false))
{
await _gatewayLogger.WarningAsync("Server missed last heartbeat").ConfigureAwait(false);
await StartReconnectAsync().ConfigureAwait(false);
return;
}
} }
_heartbeatTime = Environment.TickCount;
else
_heartbeatTime = Environment.TickCount;
await ApiClient.SendHeartbeatAsync(_lastSeq).ConfigureAwait(false); await ApiClient.SendHeartbeatAsync(_lastSeq).ConfigureAwait(false);
} }
} }
catch (OperationCanceledException) { } catch (OperationCanceledException) { }
} }

private async Task WaitForGuildsAsync(CancellationToken cancelToken)
{
while ((_unavailableGuilds > 0) || (Environment.TickCount - _lastGuildAvailableTime > 2000))
await Task.Delay(500, cancelToken).ConfigureAwait(false);
}
public async Task WaitForGuildsAsync()
{
var downloadTask = _guildDownloadTask;
if (downloadTask != null)
await _guildDownloadTask.ConfigureAwait(false);
}
} }
} }

+ 6
- 5
src/Discord.Net/Entities/Channels/DMChannel.cs View File

@@ -14,11 +14,11 @@ namespace Discord
internal class DMChannel : SnowflakeEntity, IDMChannel internal class DMChannel : SnowflakeEntity, IDMChannel
{ {
public override DiscordClient Discord { get; } public override DiscordClient Discord { get; }
public User Recipient { get; private set; }
public IUser Recipient { get; private set; }


public virtual IReadOnlyCollection<IMessage> CachedMessages => ImmutableArray.Create<IMessage>(); public virtual IReadOnlyCollection<IMessage> CachedMessages => ImmutableArray.Create<IMessage>();


public DMChannel(DiscordClient discord, User recipient, Model model)
public DMChannel(DiscordClient discord, IUser recipient, Model model)
: base(model.Id) : base(model.Id)
{ {
Discord = discord; Discord = discord;
@@ -30,7 +30,9 @@ namespace Discord
{ {
if (source == UpdateSource.Rest && IsAttached) return; if (source == UpdateSource.Rest && IsAttached) return;
Recipient.Update(model.Recipient.Value, UpdateSource.Rest);
//TODO: Is this cast okay?
if (Recipient is User)
(Recipient as User).Update(model.Recipient.Value, source);
} }


public async Task UpdateAsync() public async Task UpdateAsync()
@@ -119,8 +121,7 @@ namespace Discord
public override string ToString() => '@' + Recipient.ToString(); public override string ToString() => '@' + Recipient.ToString();
private string DebuggerDisplay => $"@{Recipient} ({Id}, DM)"; private string DebuggerDisplay => $"@{Recipient} ({Id}, DM)";

IUser IDMChannel.Recipient => Recipient;
IMessage IMessageChannel.GetCachedMessage(ulong id) => null; IMessage IMessageChannel.GetCachedMessage(ulong id) => null;
} }
} }

+ 32
- 18
src/Discord.Net/Entities/Channels/GuildChannel.cs View File

@@ -1,7 +1,5 @@
using Discord.API.Rest; using Discord.API.Rest;
using Discord.Extensions;
using System; using System;
using System.Collections.Concurrent;
using System.Collections.Generic; using System.Collections.Generic;
using System.Collections.Immutable; using System.Collections.Immutable;
using System.Diagnostics; using System.Diagnostics;
@@ -14,7 +12,7 @@ namespace Discord
[DebuggerDisplay(@"{DebuggerDisplay,nq}")] [DebuggerDisplay(@"{DebuggerDisplay,nq}")]
internal abstract class GuildChannel : SnowflakeEntity, IGuildChannel internal abstract class GuildChannel : SnowflakeEntity, IGuildChannel
{ {
private ConcurrentDictionary<ulong, Overwrite> _overwrites;
private List<Overwrite> _overwrites; //TODO: Is maintaining a list here too expensive? Is this threadsafe?


public string Name { get; private set; } public string Name { get; private set; }
public int Position { get; private set; } public int Position { get; private set; }
@@ -38,9 +36,9 @@ namespace Discord
Position = model.Position.Value; Position = model.Position.Value;


var overwrites = model.PermissionOverwrites.Value; var overwrites = model.PermissionOverwrites.Value;
var newOverwrites = new ConcurrentDictionary<ulong, Overwrite>();
var newOverwrites = new List<Overwrite>(overwrites.Length);
for (int i = 0; i < overwrites.Length; i++) for (int i = 0; i < overwrites.Length; i++)
newOverwrites[overwrites[i].TargetId] = new Overwrite(overwrites[i]);
newOverwrites.Add(new Overwrite(overwrites[i]));
_overwrites = newOverwrites; _overwrites = newOverwrites;
} }


@@ -89,16 +87,20 @@ namespace Discord


public OverwritePermissions? GetPermissionOverwrite(IUser user) public OverwritePermissions? GetPermissionOverwrite(IUser user)
{ {
Overwrite value;
if (_overwrites.TryGetValue(user.Id, out value))
return value.Permissions;
for (int i = 0; i < _overwrites.Count; i++)
{
if (_overwrites[i].TargetId == user.Id)
return _overwrites[i].Permissions;
}
return null; return null;
} }
public OverwritePermissions? GetPermissionOverwrite(IRole role) public OverwritePermissions? GetPermissionOverwrite(IRole role)
{ {
Overwrite value;
if (_overwrites.TryGetValue(role.Id, out value))
return value.Permissions;
for (int i = 0; i < _overwrites.Count; i++)
{
if (_overwrites[i].TargetId == role.Id)
return _overwrites[i].Permissions;
}
return null; return null;
} }
@@ -106,34 +108,46 @@ namespace Discord
{ {
var args = new ModifyChannelPermissionsParams { Allow = perms.AllowValue, Deny = perms.DenyValue }; var args = new ModifyChannelPermissionsParams { Allow = perms.AllowValue, Deny = perms.DenyValue };
await Discord.ApiClient.ModifyChannelPermissionsAsync(Id, user.Id, args).ConfigureAwait(false); await Discord.ApiClient.ModifyChannelPermissionsAsync(Id, user.Id, args).ConfigureAwait(false);
_overwrites[user.Id] = new Overwrite(new API.Overwrite { Allow = perms.AllowValue, Deny = perms.DenyValue, TargetId = user.Id, TargetType = PermissionTarget.User });
_overwrites.Add(new Overwrite(new API.Overwrite { Allow = perms.AllowValue, Deny = perms.DenyValue, TargetId = user.Id, TargetType = PermissionTarget.User }));
} }
public async Task AddPermissionOverwriteAsync(IRole role, OverwritePermissions perms) public async Task AddPermissionOverwriteAsync(IRole role, OverwritePermissions perms)
{ {
var args = new ModifyChannelPermissionsParams { Allow = perms.AllowValue, Deny = perms.DenyValue }; var args = new ModifyChannelPermissionsParams { Allow = perms.AllowValue, Deny = perms.DenyValue };
await Discord.ApiClient.ModifyChannelPermissionsAsync(Id, role.Id, args).ConfigureAwait(false); await Discord.ApiClient.ModifyChannelPermissionsAsync(Id, role.Id, args).ConfigureAwait(false);
_overwrites[role.Id] = new Overwrite(new API.Overwrite { Allow = perms.AllowValue, Deny = perms.DenyValue, TargetId = role.Id, TargetType = PermissionTarget.Role });
_overwrites.Add(new Overwrite(new API.Overwrite { Allow = perms.AllowValue, Deny = perms.DenyValue, TargetId = role.Id, TargetType = PermissionTarget.Role }));
} }
public async Task RemovePermissionOverwriteAsync(IUser user) public async Task RemovePermissionOverwriteAsync(IUser user)
{ {
await Discord.ApiClient.DeleteChannelPermissionAsync(Id, user.Id).ConfigureAwait(false); await Discord.ApiClient.DeleteChannelPermissionAsync(Id, user.Id).ConfigureAwait(false);


Overwrite value;
_overwrites.TryRemove(user.Id, out value);
for (int i = 0; i < _overwrites.Count; i++)
{
if (_overwrites[i].TargetId == user.Id)
{
_overwrites.RemoveAt(i);
return;
}
}
} }
public async Task RemovePermissionOverwriteAsync(IRole role) public async Task RemovePermissionOverwriteAsync(IRole role)
{ {
await Discord.ApiClient.DeleteChannelPermissionAsync(Id, role.Id).ConfigureAwait(false); await Discord.ApiClient.DeleteChannelPermissionAsync(Id, role.Id).ConfigureAwait(false);


Overwrite value;
_overwrites.TryRemove(role.Id, out value);
for (int i = 0; i < _overwrites.Count; i++)
{
if (_overwrites[i].TargetId == role.Id)
{
_overwrites.RemoveAt(i);
return;
}
}
} }
public override string ToString() => Name; public override string ToString() => Name;
private string DebuggerDisplay => $"{Name} ({Id})"; private string DebuggerDisplay => $"{Name} ({Id})";
IGuild IGuildChannel.Guild => Guild; IGuild IGuildChannel.Guild => Guild;
IReadOnlyCollection<Overwrite> IGuildChannel.PermissionOverwrites => _overwrites.ToReadOnlyCollection();
IReadOnlyCollection<Overwrite> IGuildChannel.PermissionOverwrites => _overwrites.AsReadOnly();


async Task<IUser> IChannel.GetUserAsync(ulong id) => await GetUserAsync(id).ConfigureAwait(false); async Task<IUser> IChannel.GetUserAsync(ulong id) => await GetUserAsync(id).ConfigureAwait(false);
async Task<IReadOnlyCollection<IUser>> IChannel.GetUsersAsync() => await GetUsersAsync().ConfigureAwait(false); async Task<IReadOnlyCollection<IUser>> IChannel.GetUsersAsync() => await GetUsersAsync().ConfigureAwait(false);


+ 3
- 2
src/Discord.Net/Entities/Users/GuildUser.cs View File

@@ -33,8 +33,9 @@ namespace Discord
public bool IsBot => User.IsBot; public bool IsBot => User.IsBot;
public string Mention => User.Mention; public string Mention => User.Mention;
public string Username => User.Username; public string Username => User.Username;
public virtual UserStatus Status => User.Status;
public virtual Game Game => User.Game;

public virtual UserStatus Status => UserStatus.Unknown;
public virtual Game Game => null;


public DiscordClient Discord => Guild.Discord; public DiscordClient Discord => Guild.Discord;
public DateTimeOffset? JoinedAt => DateTimeUtils.FromTicks(_joinedAtTicks); public DateTimeOffset? JoinedAt => DateTimeUtils.FromTicks(_joinedAtTicks);


+ 1
- 3
src/Discord.Net/Entities/Users/User.cs View File

@@ -1,7 +1,5 @@
using Discord.API.Rest;
using System;
using System;
using System.Diagnostics; using System.Diagnostics;
using System.Threading.Tasks;
using Model = Discord.API.User; using Model = Discord.API.User;


namespace Discord namespace Discord


+ 7
- 4
src/Discord.Net/Entities/WebSocket/CachedDMChannel.cs View File

@@ -9,16 +9,19 @@ namespace Discord
{ {
internal class CachedDMChannel : DMChannel, IDMChannel, ICachedChannel, ICachedMessageChannel internal class CachedDMChannel : DMChannel, IDMChannel, ICachedChannel, ICachedMessageChannel
{ {
private readonly MessageCache _messages;
private readonly MessageManager _messages;


public new DiscordSocketClient Discord => base.Discord as DiscordSocketClient; public new DiscordSocketClient Discord => base.Discord as DiscordSocketClient;
public new CachedPublicUser Recipient => base.Recipient as CachedPublicUser;
public new CachedDMUser Recipient => base.Recipient as CachedDMUser;
public IReadOnlyCollection<ICachedUser> Members => ImmutableArray.Create<ICachedUser>(Discord.CurrentUser, Recipient); public IReadOnlyCollection<ICachedUser> Members => ImmutableArray.Create<ICachedUser>(Discord.CurrentUser, Recipient);


public CachedDMChannel(DiscordSocketClient discord, CachedPublicUser recipient, Model model)
public CachedDMChannel(DiscordSocketClient discord, CachedDMUser recipient, Model model)
: base(discord, recipient, model) : base(discord, recipient, model)
{ {
_messages = new MessageCache(Discord, this);
if (Discord.MessageCacheSize > 0)
_messages = new MessageCache(Discord, this);
else
_messages = new MessageManager(Discord, this);
} }


public override Task<IUser> GetUserAsync(ulong id) => Task.FromResult<IUser>(GetUser(id)); public override Task<IUser> GetUserAsync(ulong id) => Task.FromResult<IUser>(GetUser(id));


+ 38
- 0
src/Discord.Net/Entities/WebSocket/CachedDMUser.cs View File

@@ -0,0 +1,38 @@
using System;
using PresenceModel = Discord.API.Presence;

namespace Discord
{
internal class CachedDMUser : ICachedUser
{
public CachedGlobalUser User { get; }

public Game Game { get; private set; }
public UserStatus Status { get; private set; }

public DiscordSocketClient Discord => User.Discord;
public ulong Id => User.Id;
public string AvatarUrl => User.AvatarUrl;
public DateTimeOffset CreatedAt => User.CreatedAt;
public string Discriminator => User.Discriminator;
public bool IsAttached => User.IsAttached;
public bool IsBot => User.IsBot;
public string Mention => User.Mention;
public string Username => User.Username;

public CachedDMUser(CachedGlobalUser user)
{
User = user;
}

public void Update(PresenceModel model, UpdateSource source)
{
Status = model.Status;
Game = model.Game != null ? new Game(model.Game) : null;
}

public CachedDMUser Clone() => MemberwiseClone() as CachedDMUser;
ICachedUser ICachedUser.Clone() => Clone();
}
}

+ 39
- 0
src/Discord.Net/Entities/WebSocket/CachedGlobalUser.cs View File

@@ -0,0 +1,39 @@
using System;
using Model = Discord.API.User;

namespace Discord
{
internal class CachedGlobalUser : User, ICachedUser
{
private ushort _references;

public new DiscordSocketClient Discord { get { throw new NotSupportedException(); } }
public override UserStatus Status => UserStatus.Unknown;// _status;
public override Game Game => null; //_game;

public CachedGlobalUser(Model model)
: base(model)
{
}

public void AddRef()
{
checked
{
lock (this)
_references++;
}
}
public void RemoveRef(DiscordSocketClient discord)
{
lock (this)
{
if (--_references == 0)
discord.RemoveUser(Id);
}
}

public CachedGlobalUser Clone() => MemberwiseClone() as CachedGlobalUser;
ICachedUser ICachedUser.Clone() => Clone();
}
}

+ 3
- 3
src/Discord.Net/Entities/WebSocket/CachedGuildUser.cs View File

@@ -10,7 +10,7 @@ namespace Discord


public new DiscordSocketClient Discord => base.Discord as DiscordSocketClient; public new DiscordSocketClient Discord => base.Discord as DiscordSocketClient;
public new CachedGuild Guild => base.Guild as CachedGuild; public new CachedGuild Guild => base.Guild as CachedGuild;
public new CachedPublicUser User => base.User as CachedPublicUser;
public new CachedGlobalUser User => base.User as CachedGlobalUser;


public override Game Game => _game; public override Game Game => _game;
public override UserStatus Status => _status; public override UserStatus Status => _status;
@@ -21,11 +21,11 @@ namespace Discord
public bool IsSuppressed => VoiceState?.IsSuppressed ?? false; public bool IsSuppressed => VoiceState?.IsSuppressed ?? false;
public CachedVoiceChannel VoiceChannel => VoiceState?.VoiceChannel; public CachedVoiceChannel VoiceChannel => VoiceState?.VoiceChannel;


public CachedGuildUser(CachedGuild guild, CachedPublicUser user, Model model)
public CachedGuildUser(CachedGuild guild, CachedGlobalUser user, Model model)
: base(guild, user, model) : base(guild, user, model)
{ {
} }
public CachedGuildUser(CachedGuild guild, CachedPublicUser user, PresenceModel model)
public CachedGuildUser(CachedGuild guild, CachedGlobalUser user, PresenceModel model)
: base(guild, user, model) : base(guild, user, model)
{ {
} }


+ 0
- 75
src/Discord.Net/Entities/WebSocket/CachedPublicUser.cs View File

@@ -1,75 +0,0 @@
using ChannelModel = Discord.API.Channel;
using Model = Discord.API.User;
using PresenceModel = Discord.API.Presence;

namespace Discord
{
internal class CachedPublicUser : User, ICachedUser
{
//TODO: Fix removed game/status (add CachedDMUser?)
private int _references;
//private Game? _game;
//private UserStatus _status;

public CachedDMChannel DMChannel { get; private set; }

public new DiscordSocketClient Discord => base.Discord as DiscordSocketClient;
public override UserStatus Status => UserStatus.Unknown;// _status;
public override Game Game => null; //_game;

public CachedPublicUser(Model model)
: base(model)
{
}

public CachedDMChannel AddDMChannel(DiscordSocketClient discord, ChannelModel model)
{
lock (this)
{
var channel = new CachedDMChannel(discord, this, model);
DMChannel = channel;
return channel;
}
}
public CachedDMChannel RemoveDMChannel(ulong id)
{
lock (this)
{
var channel = DMChannel;
if (channel.Id == id)
{
DMChannel = null;
return channel;
}
return null;
}
}

public void Update(PresenceModel model, UpdateSource source)
{
if (source == UpdateSource.Rest) return;

//var game = model.Game != null ? new Game(model.Game) : (Game)null;

//_status = model.Status;
//_game = game;
}

public void AddRef()
{
lock (this)
_references++;
}
public void RemoveRef(DiscordSocketClient discord)
{
lock (this)
{
if (--_references == 0 && DMChannel == null)
discord.RemoveUser(Id);
}
}

public CachedPublicUser Clone() => MemberwiseClone() as CachedPublicUser;
ICachedUser ICachedUser.Clone() => Clone();
}
}

+ 5
- 2
src/Discord.Net/Entities/WebSocket/CachedTextChannel.cs View File

@@ -9,7 +9,7 @@ namespace Discord
{ {
internal class CachedTextChannel : TextChannel, ICachedGuildChannel, ICachedMessageChannel internal class CachedTextChannel : TextChannel, ICachedGuildChannel, ICachedMessageChannel
{ {
private readonly MessageCache _messages;
private readonly MessageManager _messages;


public new DiscordSocketClient Discord => base.Discord as DiscordSocketClient; public new DiscordSocketClient Discord => base.Discord as DiscordSocketClient;
public new CachedGuild Guild => base.Guild as CachedGuild; public new CachedGuild Guild => base.Guild as CachedGuild;
@@ -20,7 +20,10 @@ namespace Discord
public CachedTextChannel(CachedGuild guild, Model model) public CachedTextChannel(CachedGuild guild, Model model)
: base(guild, model) : base(guild, model)
{ {
_messages = new MessageCache(Discord, this);
if (Discord.MessageCacheSize > 0)
_messages = new MessageCache(Discord, this);
else
_messages = new MessageManager(Discord, this);
} }


public override Task<IGuildUser> GetUserAsync(ulong id) => Task.FromResult<IGuildUser>(GetUser(id)); public override Task<IGuildUser> GetUserAsync(ulong id) => Task.FromResult<IGuildUser>(GetUser(id));


+ 10
- 59
src/Discord.Net/Entities/WebSocket/MessageCache.cs View File

@@ -1,5 +1,4 @@
using Discord.API.Rest;
using Discord.Extensions;
using Discord.Extensions;
using System; using System;
using System.Collections.Concurrent; using System.Collections.Concurrent;
using System.Collections.Generic; using System.Collections.Generic;
@@ -9,26 +8,23 @@ using System.Threading.Tasks;


namespace Discord namespace Discord
{ {
internal class MessageCache
internal class MessageCache : MessageManager
{ {
private readonly DiscordSocketClient _discord;
private readonly ICachedMessageChannel _channel;
private readonly ConcurrentDictionary<ulong, CachedMessage> _messages; private readonly ConcurrentDictionary<ulong, CachedMessage> _messages;
private readonly ConcurrentQueue<ulong> _orderedMessages; private readonly ConcurrentQueue<ulong> _orderedMessages;
private readonly int _size; private readonly int _size;


public IReadOnlyCollection<CachedMessage> Messages => _messages.ToReadOnlyCollection();
public override IReadOnlyCollection<CachedMessage> Messages => _messages.ToReadOnlyCollection();


public MessageCache(DiscordSocketClient discord, ICachedMessageChannel channel) public MessageCache(DiscordSocketClient discord, ICachedMessageChannel channel)
: base(discord, channel)
{ {
_discord = discord;
_channel = channel;
_size = discord.MessageCacheSize; _size = discord.MessageCacheSize;
_messages = new ConcurrentDictionary<ulong, CachedMessage>(1, (int)(_size * 1.05)); _messages = new ConcurrentDictionary<ulong, CachedMessage>(1, (int)(_size * 1.05));
_orderedMessages = new ConcurrentQueue<ulong>(); _orderedMessages = new ConcurrentQueue<ulong>();
} }


public void Add(CachedMessage message)
public override void Add(CachedMessage message)
{ {
if (_messages.TryAdd(message.Id, message)) if (_messages.TryAdd(message.Id, message))
{ {
@@ -41,21 +37,21 @@ namespace Discord
} }
} }


public CachedMessage Remove(ulong id)
public override CachedMessage Remove(ulong id)
{ {
CachedMessage msg; CachedMessage msg;
_messages.TryRemove(id, out msg); _messages.TryRemove(id, out msg);
return msg; return msg;
} }


public CachedMessage Get(ulong id)
public override CachedMessage Get(ulong id)
{ {
CachedMessage result; CachedMessage result;
if (_messages.TryGetValue(id, out result)) if (_messages.TryGetValue(id, out result))
return result; return result;
return null; return null;
} }
public IImmutableList<CachedMessage> GetMany(ulong? fromMessageId, Direction dir, int limit = DiscordConfig.MaxMessagesPerBatch)
public override IImmutableList<CachedMessage> GetMany(ulong? fromMessageId, Direction dir, int limit = DiscordConfig.MaxMessagesPerBatch)
{ {
if (limit < 0) throw new ArgumentOutOfRangeException(nameof(limit)); if (limit < 0) throw new ArgumentOutOfRangeException(nameof(limit));
if (limit == 0) return ImmutableArray<CachedMessage>.Empty; if (limit == 0) return ImmutableArray<CachedMessage>.Empty;
@@ -81,57 +77,12 @@ namespace Discord
.ToImmutableArray(); .ToImmutableArray();
} }


public async Task<CachedMessage> DownloadAsync(ulong id)
public override async Task<CachedMessage> DownloadAsync(ulong id)
{ {
var msg = Get(id); var msg = Get(id);
if (msg != null) if (msg != null)
return msg; return msg;
var model = await _discord.ApiClient.GetChannelMessageAsync(_channel.Id, id).ConfigureAwait(false);
if (model != null)
return new CachedMessage(_channel, new User(model.Author.Value), model);
return null;
}
public async Task<IReadOnlyCollection<CachedMessage>> DownloadAsync(ulong? fromId, Direction dir, int limit)
{
//TODO: Test heavily, especially the ordering of messages
if (limit < 0) throw new ArgumentOutOfRangeException(nameof(limit));
if (limit == 0) return ImmutableArray<CachedMessage>.Empty;

var cachedMessages = GetMany(fromId, dir, limit);
if (cachedMessages.Count == limit)
return cachedMessages;
else if (cachedMessages.Count > limit)
return cachedMessages.Skip(cachedMessages.Count - limit).ToImmutableArray();
else
{
Optional<ulong> relativeId;
if (cachedMessages.Count == 0)
relativeId = fromId ?? new Optional<ulong>();
else
relativeId = dir == Direction.Before ? cachedMessages[0].Id : cachedMessages[cachedMessages.Count - 1].Id;
var args = new GetChannelMessagesParams
{
Limit = limit - cachedMessages.Count,
RelativeDirection = dir,
RelativeMessageId = relativeId
};
var downloadedMessages = await _discord.ApiClient.GetChannelMessagesAsync(_channel.Id, args).ConfigureAwait(false);

var guild = (_channel as ICachedGuildChannel).Guild;
return cachedMessages.Concat(downloadedMessages.Select(x =>
{
IUser user = _channel.GetUser(x.Author.Value.Id, true);
if (user == null)
{
var newUser = new User(x.Author.Value);
if (guild != null)
user = new GuildUser(guild, newUser);
else
user = newUser;
}
return new CachedMessage(_channel, user, x);
})).ToImmutableArray();
}
return await base.DownloadAsync(id).ConfigureAwait(false);
} }
} }
} }

+ 81
- 0
src/Discord.Net/Entities/WebSocket/MessageManager.cs View File

@@ -0,0 +1,81 @@
using Discord.API.Rest;
using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Linq;
using System.Threading.Tasks;

namespace Discord
{
internal class MessageManager
{
private readonly DiscordSocketClient _discord;
private readonly ICachedMessageChannel _channel;

public virtual IReadOnlyCollection<CachedMessage> Messages
=> ImmutableArray.Create<CachedMessage>();

public MessageManager(DiscordSocketClient discord, ICachedMessageChannel channel)
{
_discord = discord;
_channel = channel;
}

public virtual void Add(CachedMessage message) { }
public virtual CachedMessage Remove(ulong id) => null;
public virtual CachedMessage Get(ulong id) => null;

public virtual IImmutableList<CachedMessage> GetMany(ulong? fromMessageId, Direction dir, int limit = DiscordConfig.MaxMessagesPerBatch)
=> ImmutableArray.Create<CachedMessage>();

public virtual async Task<CachedMessage> DownloadAsync(ulong id)
{
var model = await _discord.ApiClient.GetChannelMessageAsync(_channel.Id, id).ConfigureAwait(false);
if (model != null)
return new CachedMessage(_channel, new User(model.Author.Value), model);
return null;
}
public async Task<IReadOnlyCollection<CachedMessage>> DownloadAsync(ulong? fromId, Direction dir, int limit)
{
//TODO: Test heavily, especially the ordering of messages
if (limit < 0) throw new ArgumentOutOfRangeException(nameof(limit));
if (limit == 0) return ImmutableArray<CachedMessage>.Empty;

var cachedMessages = GetMany(fromId, dir, limit);
if (cachedMessages.Count == limit)
return cachedMessages;
else if (cachedMessages.Count > limit)
return cachedMessages.Skip(cachedMessages.Count - limit).ToImmutableArray();
else
{
Optional<ulong> relativeId;
if (cachedMessages.Count == 0)
relativeId = fromId ?? new Optional<ulong>();
else
relativeId = dir == Direction.Before ? cachedMessages[0].Id : cachedMessages[cachedMessages.Count - 1].Id;
var args = new GetChannelMessagesParams
{
Limit = limit - cachedMessages.Count,
RelativeDirection = dir,
RelativeMessageId = relativeId
};
var downloadedMessages = await _discord.ApiClient.GetChannelMessagesAsync(_channel.Id, args).ConfigureAwait(false);

var guild = (_channel as ICachedGuildChannel).Guild;
return cachedMessages.Concat(downloadedMessages.Select(x =>
{
IUser user = _channel.GetUser(x.Author.Value.Id, true);
if (user == null)
{
var newUser = new User(x.Author.Value);
if (guild != null)
user = new GuildUser(guild, newUser);
else
user = newUser;
}
return new CachedMessage(_channel, user, x);
})).ToImmutableArray();
}
}
}
}

Loading…
Cancel
Save