diff --git a/src/Discord.Net.Rest/Entities/Channels/ChannelHelper.cs b/src/Discord.Net.Rest/Entities/Channels/ChannelHelper.cs index aa90b2eee..b424cbe32 100644 --- a/src/Discord.Net.Rest/Entities/Channels/ChannelHelper.cs +++ b/src/Discord.Net.Rest/Entities/Channels/ChannelHelper.cs @@ -109,12 +109,19 @@ namespace Discord.Rest public static IAsyncEnumerable> GetMessagesAsync(IMessageChannel channel, BaseDiscordClient client, ulong? fromMessageId, Direction dir, int limit, RequestOptions options) { - if (dir == Direction.Around) - throw new NotImplementedException(); //TODO: Impl - var guildId = (channel as IGuildChannel)?.GuildId; var guild = guildId != null ? (client as IDiscordClient).GetGuildAsync(guildId.Value, CacheMode.CacheOnly).Result : null; + if (dir == Direction.Around && limit > DiscordConfig.MaxMessagesPerBatch) + { + int around = limit / 2; + if (fromMessageId.HasValue) + return GetMessagesAsync(channel, client, fromMessageId.Value + 1, Direction.Before, around + 1, options) //Need to include the message itself + .Concat(GetMessagesAsync(channel, client, fromMessageId, Direction.After, around, options)); + else //Shouldn't happen since there's no public overload for ulong? and Direction + return GetMessagesAsync(channel, client, null, Direction.Before, around + 1, options); + } + return new PagedAsyncEnumerable( DiscordConfig.MaxMessagesPerBatch, async (info, ct) => diff --git a/src/Discord.Net.WebSocket/Entities/Channels/SocketChannelHelper.cs b/src/Discord.Net.WebSocket/Entities/Channels/SocketChannelHelper.cs index e6339b6d9..5cfbcc1a8 100644 --- a/src/Discord.Net.WebSocket/Entities/Channels/SocketChannelHelper.cs +++ b/src/Discord.Net.WebSocket/Entities/Channels/SocketChannelHelper.cs @@ -11,23 +11,11 @@ namespace Discord.WebSocket public static IAsyncEnumerable> GetMessagesAsync(ISocketMessageChannel channel, DiscordSocketClient discord, MessageCache messages, ulong? fromMessageId, Direction dir, int limit, CacheMode mode, RequestOptions options) { - if (dir == Direction.Around) - throw new NotImplementedException(); //TODO: Impl - - IReadOnlyCollection cachedMessages = null; - IAsyncEnumerable> result = null; - if (dir == Direction.After && fromMessageId == null) return AsyncEnumerable.Empty>(); - if (dir == Direction.Before || mode == CacheMode.CacheOnly) - { - if (messages != null) //Cache enabled - cachedMessages = messages.GetMany(fromMessageId, dir, limit); - else - cachedMessages = ImmutableArray.Create(); - result = ImmutableArray.Create(cachedMessages).ToAsyncEnumerable>(); - } + var cachedMessages = GetCachedMessages(channel, discord, messages, fromMessageId, dir, limit); + var result = ImmutableArray.Create(cachedMessages).ToAsyncEnumerable>(); if (dir == Direction.Before) { @@ -38,18 +26,35 @@ namespace Discord.WebSocket //Download remaining messages ulong? minId = cachedMessages.Count > 0 ? cachedMessages.Min(x => x.Id) : fromMessageId; var downloadedMessages = ChannelHelper.GetMessagesAsync(channel, discord, minId, dir, limit, options); - return result.Concat(downloadedMessages); + if (cachedMessages.Count != 0) + return result.Concat(downloadedMessages); + else + return downloadedMessages; } - else + else if (dir == Direction.After) + { + limit -= cachedMessages.Count; + if (mode == CacheMode.CacheOnly || limit <= 0) + return result; + + //Download remaining messages + ulong maxId = cachedMessages.Count > 0 ? cachedMessages.Max(x => x.Id) : fromMessageId.Value; + var downloadedMessages = ChannelHelper.GetMessagesAsync(channel, discord, maxId, dir, limit, options); + if (cachedMessages.Count != 0) + return result.Concat(downloadedMessages); + else + return downloadedMessages; + } + else //Direction.Around { - if (mode == CacheMode.CacheOnly) + if (mode == CacheMode.CacheOnly || limit <= cachedMessages.Count) return result; - //Dont use cache in this case + //Cache isn't useful here since Discord will send them anyways return ChannelHelper.GetMessagesAsync(channel, discord, fromMessageId, dir, limit, options); } } - public static IReadOnlyCollection GetCachedMessages(SocketChannel channel, DiscordSocketClient discord, MessageCache messages, + public static IReadOnlyCollection GetCachedMessages(ISocketMessageChannel channel, DiscordSocketClient discord, MessageCache messages, ulong? fromMessageId, Direction dir, int limit) { if (messages != null) //Cache enabled diff --git a/src/Discord.Net.WebSocket/Entities/Messages/MessageCache.cs b/src/Discord.Net.WebSocket/Entities/Messages/MessageCache.cs index 24e46df46..6baf56879 100644 --- a/src/Discord.Net.WebSocket/Entities/Messages/MessageCache.cs +++ b/src/Discord.Net.WebSocket/Entities/Messages/MessageCache.cs @@ -56,11 +56,23 @@ namespace Discord.WebSocket cachedMessageIds = _orderedMessages; else if (dir == Direction.Before) cachedMessageIds = _orderedMessages.Where(x => x < fromMessageId.Value); - else + else if (dir == Direction.After) cachedMessageIds = _orderedMessages.Where(x => x > fromMessageId.Value); + else //Direction.Around + { + if (!_messages.TryGetValue(fromMessageId.Value, out SocketMessage msg)) + return ImmutableArray.Empty; + int around = limit / 2; + var before = GetMany(fromMessageId, Direction.Before, around); + var after = GetMany(fromMessageId, Direction.After, around).Reverse(); + + return after.Concat(new SocketMessage[] { msg }).Concat(before).ToImmutableArray(); + } if (dir == Direction.Before) cachedMessageIds = cachedMessageIds.Reverse(); + if (dir == Direction.Around) //Only happens if fromMessageId is null, should only get "around" and itself (+1) + limit = limit / 2 + 1; return cachedMessageIds .Select(x =>