diff --git a/src/Discord.Net.Rest/Entities/Channels/ChannelHelper.cs b/src/Discord.Net.Rest/Entities/Channels/ChannelHelper.cs index aa90b2eee..1aef77ac4 100644 --- a/src/Discord.Net.Rest/Entities/Channels/ChannelHelper.cs +++ b/src/Discord.Net.Rest/Entities/Channels/ChannelHelper.cs @@ -109,12 +109,81 @@ namespace Discord.Rest public static IAsyncEnumerable> GetMessagesAsync(IMessageChannel channel, BaseDiscordClient client, ulong? fromMessageId, Direction dir, int limit, RequestOptions options) { - if (dir == Direction.Around) - throw new NotImplementedException(); //TODO: Impl - var guildId = (channel as IGuildChannel)?.GuildId; var guild = guildId != null ? (client as IDiscordClient).GetGuildAsync(guildId.Value, CacheMode.CacheOnly).Result : null; + if (dir == Direction.Around && limit > DiscordConfig.MaxMessagesPerBatch) + { + int around = limit / 2; + return new PagedAsyncEnumerable( + DiscordConfig.MaxMessagesPerBatch, + async (info, ct) => + { + var args = new GetChannelMessagesParams + { + RelativeDirection = Direction.Before, + Limit = info.PageSize + }; + if (info.Position != null) + args.RelativeMessageId = info.Position.Value; + + var models = await client.ApiClient.GetChannelMessagesAsync(channel.Id, args, options).ConfigureAwait(false); + var builder = ImmutableArray.CreateBuilder(); + foreach (var model in models) + { + var author = GetAuthor(client, guild, model.Author.Value, model.WebhookId.ToNullable()); + builder.Add(RestMessage.Create(client, channel, author, model)); + } + return builder.ToImmutable(); + }, + nextPage: (info, lastPage) => + { + if (lastPage.Count != DiscordConfig.MaxMessagesPerBatch) + return false; + if (dir == Direction.Before) + info.Position = lastPage.Min(x => x.Id); + else + info.Position = lastPage.Max(x => x.Id); + return true; + }, + start: fromMessageId + 1, //Needs to include the message itself + count: around + 1 + ).Concat(new PagedAsyncEnumerable( + DiscordConfig.MaxMessagesPerBatch, + async (info, ct) => + { + var args = new GetChannelMessagesParams + { + RelativeDirection = Direction.After, + Limit = info.PageSize + }; + if (info.Position != null) + args.RelativeMessageId = info.Position.Value; + + var models = await client.ApiClient.GetChannelMessagesAsync(channel.Id, args, options).ConfigureAwait(false); + var builder = ImmutableArray.CreateBuilder(); + foreach (var model in models) + { + var author = GetAuthor(client, guild, model.Author.Value, model.WebhookId.ToNullable()); + builder.Add(RestMessage.Create(client, channel, author, model)); + } + return builder.ToImmutable(); + }, + nextPage: (info, lastPage) => + { + if (lastPage.Count != DiscordConfig.MaxMessagesPerBatch) + return false; + if (dir == Direction.Before) + info.Position = lastPage.Min(x => x.Id); + else + info.Position = lastPage.Max(x => x.Id); + return true; + }, + start: fromMessageId, + count: around + )); + } + return new PagedAsyncEnumerable( DiscordConfig.MaxMessagesPerBatch, async (info, ct) => diff --git a/src/Discord.Net.WebSocket/Entities/Channels/SocketChannelHelper.cs b/src/Discord.Net.WebSocket/Entities/Channels/SocketChannelHelper.cs index e6339b6d9..5cfbcc1a8 100644 --- a/src/Discord.Net.WebSocket/Entities/Channels/SocketChannelHelper.cs +++ b/src/Discord.Net.WebSocket/Entities/Channels/SocketChannelHelper.cs @@ -11,23 +11,11 @@ namespace Discord.WebSocket public static IAsyncEnumerable> GetMessagesAsync(ISocketMessageChannel channel, DiscordSocketClient discord, MessageCache messages, ulong? fromMessageId, Direction dir, int limit, CacheMode mode, RequestOptions options) { - if (dir == Direction.Around) - throw new NotImplementedException(); //TODO: Impl - - IReadOnlyCollection cachedMessages = null; - IAsyncEnumerable> result = null; - if (dir == Direction.After && fromMessageId == null) return AsyncEnumerable.Empty>(); - if (dir == Direction.Before || mode == CacheMode.CacheOnly) - { - if (messages != null) //Cache enabled - cachedMessages = messages.GetMany(fromMessageId, dir, limit); - else - cachedMessages = ImmutableArray.Create(); - result = ImmutableArray.Create(cachedMessages).ToAsyncEnumerable>(); - } + var cachedMessages = GetCachedMessages(channel, discord, messages, fromMessageId, dir, limit); + var result = ImmutableArray.Create(cachedMessages).ToAsyncEnumerable>(); if (dir == Direction.Before) { @@ -38,18 +26,35 @@ namespace Discord.WebSocket //Download remaining messages ulong? minId = cachedMessages.Count > 0 ? cachedMessages.Min(x => x.Id) : fromMessageId; var downloadedMessages = ChannelHelper.GetMessagesAsync(channel, discord, minId, dir, limit, options); - return result.Concat(downloadedMessages); + if (cachedMessages.Count != 0) + return result.Concat(downloadedMessages); + else + return downloadedMessages; } - else + else if (dir == Direction.After) + { + limit -= cachedMessages.Count; + if (mode == CacheMode.CacheOnly || limit <= 0) + return result; + + //Download remaining messages + ulong maxId = cachedMessages.Count > 0 ? cachedMessages.Max(x => x.Id) : fromMessageId.Value; + var downloadedMessages = ChannelHelper.GetMessagesAsync(channel, discord, maxId, dir, limit, options); + if (cachedMessages.Count != 0) + return result.Concat(downloadedMessages); + else + return downloadedMessages; + } + else //Direction.Around { - if (mode == CacheMode.CacheOnly) + if (mode == CacheMode.CacheOnly || limit <= cachedMessages.Count) return result; - //Dont use cache in this case + //Cache isn't useful here since Discord will send them anyways return ChannelHelper.GetMessagesAsync(channel, discord, fromMessageId, dir, limit, options); } } - public static IReadOnlyCollection GetCachedMessages(SocketChannel channel, DiscordSocketClient discord, MessageCache messages, + public static IReadOnlyCollection GetCachedMessages(ISocketMessageChannel channel, DiscordSocketClient discord, MessageCache messages, ulong? fromMessageId, Direction dir, int limit) { if (messages != null) //Cache enabled diff --git a/src/Discord.Net.WebSocket/Entities/Messages/MessageCache.cs b/src/Discord.Net.WebSocket/Entities/Messages/MessageCache.cs index 24e46df46..8d0bde94d 100644 --- a/src/Discord.Net.WebSocket/Entities/Messages/MessageCache.cs +++ b/src/Discord.Net.WebSocket/Entities/Messages/MessageCache.cs @@ -56,11 +56,41 @@ namespace Discord.WebSocket cachedMessageIds = _orderedMessages; else if (dir == Direction.Before) cachedMessageIds = _orderedMessages.Where(x => x < fromMessageId.Value); - else + else if (dir == Direction.After) cachedMessageIds = _orderedMessages.Where(x => x > fromMessageId.Value); + else //Direction.Around + { + if (!_messages.TryGetValue(fromMessageId.Value, out SocketMessage msg)) + return ImmutableArray.Empty; + int around = limit / 2; + var before = _orderedMessages + .Where(x => x < fromMessageId.Value) + .Select(x => + { + if (_messages.TryGetValue(x, out SocketMessage msg)) + return msg; + return null; + }) + .Where(x => x != null) + .Take(around); + var after = _orderedMessages + .Where(x => x > fromMessageId.Value) + .Select(x => + { + if (_messages.TryGetValue(x, out SocketMessage msg)) + return msg; + return null; + }) + .Where(x => x != null) + .Take(around); + + return before.Concat(new SocketMessage[] { msg }).Concat(after).ToImmutableArray(); + } if (dir == Direction.Before) cachedMessageIds = cachedMessageIds.Reverse(); + if (dir == Direction.Around) + limit /= 2; return cachedMessageIds .Select(x =>