| @@ -109,12 +109,81 @@ namespace Discord.Rest | |||||
| public static IAsyncEnumerable<IReadOnlyCollection<RestMessage>> GetMessagesAsync(IMessageChannel channel, BaseDiscordClient client, | public static IAsyncEnumerable<IReadOnlyCollection<RestMessage>> GetMessagesAsync(IMessageChannel channel, BaseDiscordClient client, | ||||
| ulong? fromMessageId, Direction dir, int limit, RequestOptions options) | ulong? fromMessageId, Direction dir, int limit, RequestOptions options) | ||||
| { | { | ||||
| if (dir == Direction.Around) | |||||
| throw new NotImplementedException(); //TODO: Impl | |||||
| var guildId = (channel as IGuildChannel)?.GuildId; | var guildId = (channel as IGuildChannel)?.GuildId; | ||||
| var guild = guildId != null ? (client as IDiscordClient).GetGuildAsync(guildId.Value, CacheMode.CacheOnly).Result : null; | var guild = guildId != null ? (client as IDiscordClient).GetGuildAsync(guildId.Value, CacheMode.CacheOnly).Result : null; | ||||
| if (dir == Direction.Around && limit > DiscordConfig.MaxMessagesPerBatch) | |||||
| { | |||||
| int around = limit / 2; | |||||
| return new PagedAsyncEnumerable<RestMessage>( | |||||
| DiscordConfig.MaxMessagesPerBatch, | |||||
| async (info, ct) => | |||||
| { | |||||
| var args = new GetChannelMessagesParams | |||||
| { | |||||
| RelativeDirection = Direction.Before, | |||||
| Limit = info.PageSize | |||||
| }; | |||||
| if (info.Position != null) | |||||
| args.RelativeMessageId = info.Position.Value; | |||||
| var models = await client.ApiClient.GetChannelMessagesAsync(channel.Id, args, options).ConfigureAwait(false); | |||||
| var builder = ImmutableArray.CreateBuilder<RestMessage>(); | |||||
| foreach (var model in models) | |||||
| { | |||||
| var author = GetAuthor(client, guild, model.Author.Value, model.WebhookId.ToNullable()); | |||||
| builder.Add(RestMessage.Create(client, channel, author, model)); | |||||
| } | |||||
| return builder.ToImmutable(); | |||||
| }, | |||||
| nextPage: (info, lastPage) => | |||||
| { | |||||
| if (lastPage.Count != DiscordConfig.MaxMessagesPerBatch) | |||||
| return false; | |||||
| if (dir == Direction.Before) | |||||
| info.Position = lastPage.Min(x => x.Id); | |||||
| else | |||||
| info.Position = lastPage.Max(x => x.Id); | |||||
| return true; | |||||
| }, | |||||
| start: fromMessageId + 1, //Needs to include the message itself | |||||
| count: around + 1 | |||||
| ).Concat(new PagedAsyncEnumerable<RestMessage>( | |||||
| DiscordConfig.MaxMessagesPerBatch, | |||||
| async (info, ct) => | |||||
| { | |||||
| var args = new GetChannelMessagesParams | |||||
| { | |||||
| RelativeDirection = Direction.After, | |||||
| Limit = info.PageSize | |||||
| }; | |||||
| if (info.Position != null) | |||||
| args.RelativeMessageId = info.Position.Value; | |||||
| var models = await client.ApiClient.GetChannelMessagesAsync(channel.Id, args, options).ConfigureAwait(false); | |||||
| var builder = ImmutableArray.CreateBuilder<RestMessage>(); | |||||
| foreach (var model in models) | |||||
| { | |||||
| var author = GetAuthor(client, guild, model.Author.Value, model.WebhookId.ToNullable()); | |||||
| builder.Add(RestMessage.Create(client, channel, author, model)); | |||||
| } | |||||
| return builder.ToImmutable(); | |||||
| }, | |||||
| nextPage: (info, lastPage) => | |||||
| { | |||||
| if (lastPage.Count != DiscordConfig.MaxMessagesPerBatch) | |||||
| return false; | |||||
| if (dir == Direction.Before) | |||||
| info.Position = lastPage.Min(x => x.Id); | |||||
| else | |||||
| info.Position = lastPage.Max(x => x.Id); | |||||
| return true; | |||||
| }, | |||||
| start: fromMessageId, | |||||
| count: around | |||||
| )); | |||||
| } | |||||
| return new PagedAsyncEnumerable<RestMessage>( | return new PagedAsyncEnumerable<RestMessage>( | ||||
| DiscordConfig.MaxMessagesPerBatch, | DiscordConfig.MaxMessagesPerBatch, | ||||
| async (info, ct) => | async (info, ct) => | ||||
| @@ -11,23 +11,11 @@ namespace Discord.WebSocket | |||||
| public static IAsyncEnumerable<IReadOnlyCollection<IMessage>> GetMessagesAsync(ISocketMessageChannel channel, DiscordSocketClient discord, MessageCache messages, | public static IAsyncEnumerable<IReadOnlyCollection<IMessage>> GetMessagesAsync(ISocketMessageChannel channel, DiscordSocketClient discord, MessageCache messages, | ||||
| ulong? fromMessageId, Direction dir, int limit, CacheMode mode, RequestOptions options) | ulong? fromMessageId, Direction dir, int limit, CacheMode mode, RequestOptions options) | ||||
| { | { | ||||
| if (dir == Direction.Around) | |||||
| throw new NotImplementedException(); //TODO: Impl | |||||
| IReadOnlyCollection<SocketMessage> cachedMessages = null; | |||||
| IAsyncEnumerable<IReadOnlyCollection<IMessage>> result = null; | |||||
| if (dir == Direction.After && fromMessageId == null) | if (dir == Direction.After && fromMessageId == null) | ||||
| return AsyncEnumerable.Empty<IReadOnlyCollection<IMessage>>(); | return AsyncEnumerable.Empty<IReadOnlyCollection<IMessage>>(); | ||||
| if (dir == Direction.Before || mode == CacheMode.CacheOnly) | |||||
| { | |||||
| if (messages != null) //Cache enabled | |||||
| cachedMessages = messages.GetMany(fromMessageId, dir, limit); | |||||
| else | |||||
| cachedMessages = ImmutableArray.Create<SocketMessage>(); | |||||
| result = ImmutableArray.Create(cachedMessages).ToAsyncEnumerable<IReadOnlyCollection<IMessage>>(); | |||||
| } | |||||
| var cachedMessages = GetCachedMessages(channel, discord, messages, fromMessageId, dir, limit); | |||||
| var result = ImmutableArray.Create(cachedMessages).ToAsyncEnumerable<IReadOnlyCollection<IMessage>>(); | |||||
| if (dir == Direction.Before) | if (dir == Direction.Before) | ||||
| { | { | ||||
| @@ -38,18 +26,35 @@ namespace Discord.WebSocket | |||||
| //Download remaining messages | //Download remaining messages | ||||
| ulong? minId = cachedMessages.Count > 0 ? cachedMessages.Min(x => x.Id) : fromMessageId; | ulong? minId = cachedMessages.Count > 0 ? cachedMessages.Min(x => x.Id) : fromMessageId; | ||||
| var downloadedMessages = ChannelHelper.GetMessagesAsync(channel, discord, minId, dir, limit, options); | var downloadedMessages = ChannelHelper.GetMessagesAsync(channel, discord, minId, dir, limit, options); | ||||
| return result.Concat(downloadedMessages); | |||||
| if (cachedMessages.Count != 0) | |||||
| return result.Concat(downloadedMessages); | |||||
| else | |||||
| return downloadedMessages; | |||||
| } | } | ||||
| else | |||||
| else if (dir == Direction.After) | |||||
| { | |||||
| limit -= cachedMessages.Count; | |||||
| if (mode == CacheMode.CacheOnly || limit <= 0) | |||||
| return result; | |||||
| //Download remaining messages | |||||
| ulong maxId = cachedMessages.Count > 0 ? cachedMessages.Max(x => x.Id) : fromMessageId.Value; | |||||
| var downloadedMessages = ChannelHelper.GetMessagesAsync(channel, discord, maxId, dir, limit, options); | |||||
| if (cachedMessages.Count != 0) | |||||
| return result.Concat(downloadedMessages); | |||||
| else | |||||
| return downloadedMessages; | |||||
| } | |||||
| else //Direction.Around | |||||
| { | { | ||||
| if (mode == CacheMode.CacheOnly) | |||||
| if (mode == CacheMode.CacheOnly || limit <= cachedMessages.Count) | |||||
| return result; | return result; | ||||
| //Dont use cache in this case | |||||
| //Cache isn't useful here since Discord will send them anyways | |||||
| return ChannelHelper.GetMessagesAsync(channel, discord, fromMessageId, dir, limit, options); | return ChannelHelper.GetMessagesAsync(channel, discord, fromMessageId, dir, limit, options); | ||||
| } | } | ||||
| } | } | ||||
| public static IReadOnlyCollection<SocketMessage> GetCachedMessages(SocketChannel channel, DiscordSocketClient discord, MessageCache messages, | |||||
| public static IReadOnlyCollection<SocketMessage> GetCachedMessages(ISocketMessageChannel channel, DiscordSocketClient discord, MessageCache messages, | |||||
| ulong? fromMessageId, Direction dir, int limit) | ulong? fromMessageId, Direction dir, int limit) | ||||
| { | { | ||||
| if (messages != null) //Cache enabled | if (messages != null) //Cache enabled | ||||
| @@ -56,11 +56,41 @@ namespace Discord.WebSocket | |||||
| cachedMessageIds = _orderedMessages; | cachedMessageIds = _orderedMessages; | ||||
| else if (dir == Direction.Before) | else if (dir == Direction.Before) | ||||
| cachedMessageIds = _orderedMessages.Where(x => x < fromMessageId.Value); | cachedMessageIds = _orderedMessages.Where(x => x < fromMessageId.Value); | ||||
| else | |||||
| else if (dir == Direction.After) | |||||
| cachedMessageIds = _orderedMessages.Where(x => x > fromMessageId.Value); | cachedMessageIds = _orderedMessages.Where(x => x > fromMessageId.Value); | ||||
| else //Direction.Around | |||||
| { | |||||
| if (!_messages.TryGetValue(fromMessageId.Value, out SocketMessage msg)) | |||||
| return ImmutableArray<SocketMessage>.Empty; | |||||
| int around = limit / 2; | |||||
| var before = _orderedMessages | |||||
| .Where(x => x < fromMessageId.Value) | |||||
| .Select(x => | |||||
| { | |||||
| if (_messages.TryGetValue(x, out SocketMessage msg)) | |||||
| return msg; | |||||
| return null; | |||||
| }) | |||||
| .Where(x => x != null) | |||||
| .Take(around); | |||||
| var after = _orderedMessages | |||||
| .Where(x => x > fromMessageId.Value) | |||||
| .Select(x => | |||||
| { | |||||
| if (_messages.TryGetValue(x, out SocketMessage msg)) | |||||
| return msg; | |||||
| return null; | |||||
| }) | |||||
| .Where(x => x != null) | |||||
| .Take(around); | |||||
| return before.Concat(new SocketMessage[] { msg }).Concat(after).ToImmutableArray(); | |||||
| } | |||||
| if (dir == Direction.Before) | if (dir == Direction.Before) | ||||
| cachedMessageIds = cachedMessageIds.Reverse(); | cachedMessageIds = cachedMessageIds.Reverse(); | ||||
| if (dir == Direction.Around) | |||||
| limit /= 2; | |||||
| return cachedMessageIds | return cachedMessageIds | ||||
| .Select(x => | .Select(x => | ||||