| @@ -109,12 +109,81 @@ namespace Discord.Rest | |||
| public static IAsyncEnumerable<IReadOnlyCollection<RestMessage>> GetMessagesAsync(IMessageChannel channel, BaseDiscordClient client, | |||
| ulong? fromMessageId, Direction dir, int limit, RequestOptions options) | |||
| { | |||
| if (dir == Direction.Around) | |||
| throw new NotImplementedException(); //TODO: Impl | |||
| var guildId = (channel as IGuildChannel)?.GuildId; | |||
| var guild = guildId != null ? (client as IDiscordClient).GetGuildAsync(guildId.Value, CacheMode.CacheOnly).Result : null; | |||
| if (dir == Direction.Around && limit > DiscordConfig.MaxMessagesPerBatch) | |||
| { | |||
| int around = limit / 2; | |||
| return new PagedAsyncEnumerable<RestMessage>( | |||
| DiscordConfig.MaxMessagesPerBatch, | |||
| async (info, ct) => | |||
| { | |||
| var args = new GetChannelMessagesParams | |||
| { | |||
| RelativeDirection = Direction.Before, | |||
| Limit = info.PageSize | |||
| }; | |||
| if (info.Position != null) | |||
| args.RelativeMessageId = info.Position.Value; | |||
| var models = await client.ApiClient.GetChannelMessagesAsync(channel.Id, args, options).ConfigureAwait(false); | |||
| var builder = ImmutableArray.CreateBuilder<RestMessage>(); | |||
| foreach (var model in models) | |||
| { | |||
| var author = GetAuthor(client, guild, model.Author.Value, model.WebhookId.ToNullable()); | |||
| builder.Add(RestMessage.Create(client, channel, author, model)); | |||
| } | |||
| return builder.ToImmutable(); | |||
| }, | |||
| nextPage: (info, lastPage) => | |||
| { | |||
| if (lastPage.Count != DiscordConfig.MaxMessagesPerBatch) | |||
| return false; | |||
| if (dir == Direction.Before) | |||
| info.Position = lastPage.Min(x => x.Id); | |||
| else | |||
| info.Position = lastPage.Max(x => x.Id); | |||
| return true; | |||
| }, | |||
| start: fromMessageId + 1, //Needs to include the message itself | |||
| count: around + 1 | |||
| ).Concat(new PagedAsyncEnumerable<RestMessage>( | |||
| DiscordConfig.MaxMessagesPerBatch, | |||
| async (info, ct) => | |||
| { | |||
| var args = new GetChannelMessagesParams | |||
| { | |||
| RelativeDirection = Direction.After, | |||
| Limit = info.PageSize | |||
| }; | |||
| if (info.Position != null) | |||
| args.RelativeMessageId = info.Position.Value; | |||
| var models = await client.ApiClient.GetChannelMessagesAsync(channel.Id, args, options).ConfigureAwait(false); | |||
| var builder = ImmutableArray.CreateBuilder<RestMessage>(); | |||
| foreach (var model in models) | |||
| { | |||
| var author = GetAuthor(client, guild, model.Author.Value, model.WebhookId.ToNullable()); | |||
| builder.Add(RestMessage.Create(client, channel, author, model)); | |||
| } | |||
| return builder.ToImmutable(); | |||
| }, | |||
| nextPage: (info, lastPage) => | |||
| { | |||
| if (lastPage.Count != DiscordConfig.MaxMessagesPerBatch) | |||
| return false; | |||
| if (dir == Direction.Before) | |||
| info.Position = lastPage.Min(x => x.Id); | |||
| else | |||
| info.Position = lastPage.Max(x => x.Id); | |||
| return true; | |||
| }, | |||
| start: fromMessageId, | |||
| count: around | |||
| )); | |||
| } | |||
| return new PagedAsyncEnumerable<RestMessage>( | |||
| DiscordConfig.MaxMessagesPerBatch, | |||
| async (info, ct) => | |||
| @@ -11,23 +11,11 @@ namespace Discord.WebSocket | |||
| public static IAsyncEnumerable<IReadOnlyCollection<IMessage>> GetMessagesAsync(ISocketMessageChannel channel, DiscordSocketClient discord, MessageCache messages, | |||
| ulong? fromMessageId, Direction dir, int limit, CacheMode mode, RequestOptions options) | |||
| { | |||
| if (dir == Direction.Around) | |||
| throw new NotImplementedException(); //TODO: Impl | |||
| IReadOnlyCollection<SocketMessage> cachedMessages = null; | |||
| IAsyncEnumerable<IReadOnlyCollection<IMessage>> result = null; | |||
| if (dir == Direction.After && fromMessageId == null) | |||
| return AsyncEnumerable.Empty<IReadOnlyCollection<IMessage>>(); | |||
| if (dir == Direction.Before || mode == CacheMode.CacheOnly) | |||
| { | |||
| if (messages != null) //Cache enabled | |||
| cachedMessages = messages.GetMany(fromMessageId, dir, limit); | |||
| else | |||
| cachedMessages = ImmutableArray.Create<SocketMessage>(); | |||
| result = ImmutableArray.Create(cachedMessages).ToAsyncEnumerable<IReadOnlyCollection<IMessage>>(); | |||
| } | |||
| var cachedMessages = GetCachedMessages(channel, discord, messages, fromMessageId, dir, limit); | |||
| var result = ImmutableArray.Create(cachedMessages).ToAsyncEnumerable<IReadOnlyCollection<IMessage>>(); | |||
| if (dir == Direction.Before) | |||
| { | |||
| @@ -38,18 +26,35 @@ namespace Discord.WebSocket | |||
| //Download remaining messages | |||
| ulong? minId = cachedMessages.Count > 0 ? cachedMessages.Min(x => x.Id) : fromMessageId; | |||
| var downloadedMessages = ChannelHelper.GetMessagesAsync(channel, discord, minId, dir, limit, options); | |||
| return result.Concat(downloadedMessages); | |||
| if (cachedMessages.Count != 0) | |||
| return result.Concat(downloadedMessages); | |||
| else | |||
| return downloadedMessages; | |||
| } | |||
| else | |||
| else if (dir == Direction.After) | |||
| { | |||
| limit -= cachedMessages.Count; | |||
| if (mode == CacheMode.CacheOnly || limit <= 0) | |||
| return result; | |||
| //Download remaining messages | |||
| ulong maxId = cachedMessages.Count > 0 ? cachedMessages.Max(x => x.Id) : fromMessageId.Value; | |||
| var downloadedMessages = ChannelHelper.GetMessagesAsync(channel, discord, maxId, dir, limit, options); | |||
| if (cachedMessages.Count != 0) | |||
| return result.Concat(downloadedMessages); | |||
| else | |||
| return downloadedMessages; | |||
| } | |||
| else //Direction.Around | |||
| { | |||
| if (mode == CacheMode.CacheOnly) | |||
| if (mode == CacheMode.CacheOnly || limit <= cachedMessages.Count) | |||
| return result; | |||
| //Dont use cache in this case | |||
| //Cache isn't useful here since Discord will send them anyways | |||
| return ChannelHelper.GetMessagesAsync(channel, discord, fromMessageId, dir, limit, options); | |||
| } | |||
| } | |||
| public static IReadOnlyCollection<SocketMessage> GetCachedMessages(SocketChannel channel, DiscordSocketClient discord, MessageCache messages, | |||
| public static IReadOnlyCollection<SocketMessage> GetCachedMessages(ISocketMessageChannel channel, DiscordSocketClient discord, MessageCache messages, | |||
| ulong? fromMessageId, Direction dir, int limit) | |||
| { | |||
| if (messages != null) //Cache enabled | |||
| @@ -56,11 +56,41 @@ namespace Discord.WebSocket | |||
| cachedMessageIds = _orderedMessages; | |||
| else if (dir == Direction.Before) | |||
| cachedMessageIds = _orderedMessages.Where(x => x < fromMessageId.Value); | |||
| else | |||
| else if (dir == Direction.After) | |||
| cachedMessageIds = _orderedMessages.Where(x => x > fromMessageId.Value); | |||
| else //Direction.Around | |||
| { | |||
| if (!_messages.TryGetValue(fromMessageId.Value, out SocketMessage msg)) | |||
| return ImmutableArray<SocketMessage>.Empty; | |||
| int around = limit / 2; | |||
| var before = _orderedMessages | |||
| .Where(x => x < fromMessageId.Value) | |||
| .Select(x => | |||
| { | |||
| if (_messages.TryGetValue(x, out SocketMessage msg)) | |||
| return msg; | |||
| return null; | |||
| }) | |||
| .Where(x => x != null) | |||
| .Take(around); | |||
| var after = _orderedMessages | |||
| .Where(x => x > fromMessageId.Value) | |||
| .Select(x => | |||
| { | |||
| if (_messages.TryGetValue(x, out SocketMessage msg)) | |||
| return msg; | |||
| return null; | |||
| }) | |||
| .Where(x => x != null) | |||
| .Take(around); | |||
| return before.Concat(new SocketMessage[] { msg }).Concat(after).ToImmutableArray(); | |||
| } | |||
| if (dir == Direction.Before) | |||
| cachedMessageIds = cachedMessageIds.Reverse(); | |||
| if (dir == Direction.Around) | |||
| limit /= 2; | |||
| return cachedMessageIds | |||
| .Select(x => | |||