From 7e1b8c9db06ef6458f487f317f6cd4dcfaa63a24 Mon Sep 17 00:00:00 2001 From: Quin Lynch <49576606+quinchs@users.noreply.github.com> Date: Fri, 11 Feb 2022 21:43:32 -0400 Subject: [PATCH] Fix channel being null in DMs on Interactions (#2098) --- .../Interactions/IDiscordInteraction.cs | 5 ++ .../Preconditions/RequireContextAttribute.cs | 8 +-- .../Entities/Interactions/RestInteraction.cs | 5 ++ .../DiscordSocketClient.cs | 34 ++++++++--- .../MessageCommands/SocketMessageCommand.cs | 8 +-- .../UserCommands/SocketUserCommand.cs | 8 +-- .../SocketMessageComponent.cs | 8 +-- .../Interaction/Modals/SocketModal.cs | 8 +-- .../SocketAutocompleteInteraction.cs | 8 +-- .../SlashCommands/SocketSlashCommand.cs | 8 +-- .../SocketBaseCommand/SocketCommandBase.cs | 8 +-- .../Entities/Interaction/SocketInteraction.cs | 61 +++++++++++++------ 12 files changed, 108 insertions(+), 61 deletions(-) diff --git a/src/Discord.Net.Core/Entities/Interactions/IDiscordInteraction.cs b/src/Discord.Net.Core/Entities/Interactions/IDiscordInteraction.cs index 66ff6c6d0..0741ce138 100644 --- a/src/Discord.Net.Core/Entities/Interactions/IDiscordInteraction.cs +++ b/src/Discord.Net.Core/Entities/Interactions/IDiscordInteraction.cs @@ -62,6 +62,11 @@ namespace Discord /// string GuildLocale { get; } + /// + /// Gets whether or not this interaction was executed in a dm channel. + /// + bool IsDMInteraction { get; } + /// /// Responds to an Interaction with type . /// diff --git a/src/Discord.Net.Interactions/Attributes/Preconditions/RequireContextAttribute.cs b/src/Discord.Net.Interactions/Attributes/Preconditions/RequireContextAttribute.cs index 2f1b1df0d..9d1cee8d9 100644 --- a/src/Discord.Net.Interactions/Attributes/Preconditions/RequireContextAttribute.cs +++ b/src/Discord.Net.Interactions/Attributes/Preconditions/RequireContextAttribute.cs @@ -57,11 +57,9 @@ namespace Discord.Interactions bool isValid = false; if ((Contexts & ContextType.Guild) != 0) - isValid = context.Channel is IGuildChannel; - if ((Contexts & ContextType.DM) != 0) - isValid = isValid || context.Channel is IDMChannel; - if ((Contexts & ContextType.Group) != 0) - isValid = isValid || context.Channel is IGroupChannel; + isValid = !context.Interaction.IsDMInteraction; + if ((Contexts & ContextType.DM) != 0 && (Contexts & ContextType.Group) != 0) + isValid = context.Interaction.IsDMInteraction; if (isValid) return Task.FromResult(PreconditionResult.FromSuccess()); diff --git a/src/Discord.Net.Rest/Entities/Interactions/RestInteraction.cs b/src/Discord.Net.Rest/Entities/Interactions/RestInteraction.cs index 5894ee264..6069cf0e2 100644 --- a/src/Discord.Net.Rest/Entities/Interactions/RestInteraction.cs +++ b/src/Discord.Net.Rest/Entities/Interactions/RestInteraction.cs @@ -61,6 +61,9 @@ namespace Discord.Rest /// public bool HasResponded { get; protected set; } + /// + public bool IsDMInteraction { get; private set; } + internal RestInteraction(BaseDiscordClient discord, ulong id) : base(discord, id) { @@ -108,6 +111,8 @@ namespace Discord.Rest internal virtual async Task UpdateAsync(DiscordRestClient discord, Model model) { + IsDMInteraction = !model.GuildId.IsSpecified; + Data = model.Data.IsSpecified ? model.Data.Value : null; diff --git a/src/Discord.Net.WebSocket/DiscordSocketClient.cs b/src/Discord.Net.WebSocket/DiscordSocketClient.cs index e7f9b10ee..b0215d9ef 100644 --- a/src/Discord.Net.WebSocket/DiscordSocketClient.cs +++ b/src/Discord.Net.WebSocket/DiscordSocketClient.cs @@ -2233,24 +2233,42 @@ namespace Discord.WebSocket var data = (payload as JToken).ToObject(_serializer); + var guild = data.GuildId.IsSpecified ? GetGuild(data.GuildId.Value) : null; + + if (guild != null && !guild.IsSynced) + { + await UnsyncedGuildAsync(type, guild.Id).ConfigureAwait(false); + return; + } + + SocketUser user = data.User.IsSpecified + ? State.GetOrAddUser(data.User.Value.Id, (_) => SocketGlobalUser.Create(this, State, data.User.Value)) + : guild.AddOrUpdateUser(data.Member.Value); + SocketChannel channel = null; if(data.ChannelId.IsSpecified) { channel = State.GetChannel(data.ChannelId.Value); + + if (channel == null) + { + if (!data.GuildId.IsSpecified) // assume it is a DM + { + channel = CreateDMChannel(data.ChannelId.Value, user, State); + } + else + { + await UnknownChannelAsync(type, data.ChannelId.Value).ConfigureAwait(false); + return; + } + } } else if (data.User.IsSpecified) { channel = State.GetDMChannel(data.User.Value.Id); } - var guild = (channel as SocketGuildChannel)?.Guild; - if (guild != null && !guild.IsSynced) - { - await UnsyncedGuildAsync(type, guild.Id).ConfigureAwait(false); - return; - } - - var interaction = SocketInteraction.Create(this, data, channel as ISocketMessageChannel); + var interaction = SocketInteraction.Create(this, data, channel as ISocketMessageChannel, user); await TimedInvokeAsync(_interactionCreatedEvent, nameof(InteractionCreated), interaction).ConfigureAwait(false); diff --git a/src/Discord.Net.WebSocket/Entities/Interaction/ContextMenuCommands/MessageCommands/SocketMessageCommand.cs b/src/Discord.Net.WebSocket/Entities/Interaction/ContextMenuCommands/MessageCommands/SocketMessageCommand.cs index fee33f8cb..e4198a183 100644 --- a/src/Discord.Net.WebSocket/Entities/Interaction/ContextMenuCommands/MessageCommands/SocketMessageCommand.cs +++ b/src/Discord.Net.WebSocket/Entities/Interaction/ContextMenuCommands/MessageCommands/SocketMessageCommand.cs @@ -13,8 +13,8 @@ namespace Discord.WebSocket /// public new SocketMessageCommandData Data { get; } - internal SocketMessageCommand(DiscordSocketClient client, Model model, ISocketMessageChannel channel) - : base(client, model, channel) + internal SocketMessageCommand(DiscordSocketClient client, Model model, ISocketMessageChannel channel, SocketUser user) + : base(client, model, channel, user) { var dataModel = model.Data.IsSpecified ? (DataModel)model.Data.Value @@ -27,9 +27,9 @@ namespace Discord.WebSocket Data = SocketMessageCommandData.Create(client, dataModel, model.Id, guildId); } - internal new static SocketInteraction Create(DiscordSocketClient client, Model model, ISocketMessageChannel channel) + internal new static SocketInteraction Create(DiscordSocketClient client, Model model, ISocketMessageChannel channel, SocketUser user) { - var entity = new SocketMessageCommand(client, model, channel); + var entity = new SocketMessageCommand(client, model, channel, user); entity.Update(model); return entity; } diff --git a/src/Discord.Net.WebSocket/Entities/Interaction/ContextMenuCommands/UserCommands/SocketUserCommand.cs b/src/Discord.Net.WebSocket/Entities/Interaction/ContextMenuCommands/UserCommands/SocketUserCommand.cs index 75e8ebff9..c33c06f83 100644 --- a/src/Discord.Net.WebSocket/Entities/Interaction/ContextMenuCommands/UserCommands/SocketUserCommand.cs +++ b/src/Discord.Net.WebSocket/Entities/Interaction/ContextMenuCommands/UserCommands/SocketUserCommand.cs @@ -13,8 +13,8 @@ namespace Discord.WebSocket /// public new SocketUserCommandData Data { get; } - internal SocketUserCommand(DiscordSocketClient client, Model model, ISocketMessageChannel channel) - : base(client, model, channel) + internal SocketUserCommand(DiscordSocketClient client, Model model, ISocketMessageChannel channel, SocketUser user) + : base(client, model, channel, user) { var dataModel = model.Data.IsSpecified ? (DataModel)model.Data.Value @@ -27,9 +27,9 @@ namespace Discord.WebSocket Data = SocketUserCommandData.Create(client, dataModel, model.Id, guildId); } - internal new static SocketInteraction Create(DiscordSocketClient client, Model model, ISocketMessageChannel channel) + internal new static SocketInteraction Create(DiscordSocketClient client, Model model, ISocketMessageChannel channel, SocketUser user) { - var entity = new SocketUserCommand(client, model, channel); + var entity = new SocketUserCommand(client, model, channel, user); entity.Update(model); return entity; } diff --git a/src/Discord.Net.WebSocket/Entities/Interaction/MessageComponents/SocketMessageComponent.cs b/src/Discord.Net.WebSocket/Entities/Interaction/MessageComponents/SocketMessageComponent.cs index 17a5e0209..b06979381 100644 --- a/src/Discord.Net.WebSocket/Entities/Interaction/MessageComponents/SocketMessageComponent.cs +++ b/src/Discord.Net.WebSocket/Entities/Interaction/MessageComponents/SocketMessageComponent.cs @@ -28,8 +28,8 @@ namespace Discord.WebSocket private object _lock = new object(); public override bool HasResponded { get; internal set; } = false; - internal SocketMessageComponent(DiscordSocketClient client, Model model, ISocketMessageChannel channel) - : base(client, model.Id, channel) + internal SocketMessageComponent(DiscordSocketClient client, Model model, ISocketMessageChannel channel, SocketUser user) + : base(client, model.Id, channel, user) { var dataModel = model.Data.IsSpecified ? (DataModel)model.Data.Value @@ -38,9 +38,9 @@ namespace Discord.WebSocket Data = new SocketMessageComponentData(dataModel); } - internal new static SocketMessageComponent Create(DiscordSocketClient client, Model model, ISocketMessageChannel channel) + internal new static SocketMessageComponent Create(DiscordSocketClient client, Model model, ISocketMessageChannel channel, SocketUser user) { - var entity = new SocketMessageComponent(client, model, channel); + var entity = new SocketMessageComponent(client, model, channel, user); entity.Update(model); return entity; } diff --git a/src/Discord.Net.WebSocket/Entities/Interaction/Modals/SocketModal.cs b/src/Discord.Net.WebSocket/Entities/Interaction/Modals/SocketModal.cs index 197882dae..e59bed6ec 100644 --- a/src/Discord.Net.WebSocket/Entities/Interaction/Modals/SocketModal.cs +++ b/src/Discord.Net.WebSocket/Entities/Interaction/Modals/SocketModal.cs @@ -22,8 +22,8 @@ namespace Discord.WebSocket /// public new SocketModalData Data { get; set; } - internal SocketModal(DiscordSocketClient client, ModelBase model, ISocketMessageChannel channel) - : base(client, model.Id, channel) + internal SocketModal(DiscordSocketClient client, ModelBase model, ISocketMessageChannel channel, SocketUser user) + : base(client, model.Id, channel, user) { var dataModel = model.Data.IsSpecified ? (DataModel)model.Data.Value @@ -32,9 +32,9 @@ namespace Discord.WebSocket Data = new SocketModalData(dataModel); } - internal new static SocketModal Create(DiscordSocketClient client, ModelBase model, ISocketMessageChannel channel) + internal new static SocketModal Create(DiscordSocketClient client, ModelBase model, ISocketMessageChannel channel, SocketUser user) { - var entity = new SocketModal(client, model, channel); + var entity = new SocketModal(client, model, channel, user); entity.Update(model); return entity; } diff --git a/src/Discord.Net.WebSocket/Entities/Interaction/SlashCommands/SocketAutocompleteInteraction.cs b/src/Discord.Net.WebSocket/Entities/Interaction/SlashCommands/SocketAutocompleteInteraction.cs index d4cdc9cc1..449f074f5 100644 --- a/src/Discord.Net.WebSocket/Entities/Interaction/SlashCommands/SocketAutocompleteInteraction.cs +++ b/src/Discord.Net.WebSocket/Entities/Interaction/SlashCommands/SocketAutocompleteInteraction.cs @@ -21,8 +21,8 @@ namespace Discord.WebSocket public override bool HasResponded { get; internal set; } private object _lock = new object(); - internal SocketAutocompleteInteraction(DiscordSocketClient client, Model model, ISocketMessageChannel channel) - : base(client, model.Id, channel) + internal SocketAutocompleteInteraction(DiscordSocketClient client, Model model, ISocketMessageChannel channel, SocketUser user) + : base(client, model.Id, channel, user) { var dataModel = model.Data.IsSpecified ? (DataModel)model.Data.Value @@ -32,9 +32,9 @@ namespace Discord.WebSocket Data = new SocketAutocompleteInteractionData(dataModel); } - internal new static SocketAutocompleteInteraction Create(DiscordSocketClient client, Model model, ISocketMessageChannel channel) + internal new static SocketAutocompleteInteraction Create(DiscordSocketClient client, Model model, ISocketMessageChannel channel, SocketUser user) { - var entity = new SocketAutocompleteInteraction(client, model, channel); + var entity = new SocketAutocompleteInteraction(client, model, channel, user); entity.Update(model); return entity; } diff --git a/src/Discord.Net.WebSocket/Entities/Interaction/SlashCommands/SocketSlashCommand.cs b/src/Discord.Net.WebSocket/Entities/Interaction/SlashCommands/SocketSlashCommand.cs index 5934a3864..5f7e72ba0 100644 --- a/src/Discord.Net.WebSocket/Entities/Interaction/SlashCommands/SocketSlashCommand.cs +++ b/src/Discord.Net.WebSocket/Entities/Interaction/SlashCommands/SocketSlashCommand.cs @@ -13,8 +13,8 @@ namespace Discord.WebSocket /// public new SocketSlashCommandData Data { get; } - internal SocketSlashCommand(DiscordSocketClient client, Model model, ISocketMessageChannel channel) - : base(client, model, channel) + internal SocketSlashCommand(DiscordSocketClient client, Model model, ISocketMessageChannel channel, SocketUser user) + : base(client, model, channel, user) { var dataModel = model.Data.IsSpecified ? (DataModel)model.Data.Value @@ -27,9 +27,9 @@ namespace Discord.WebSocket Data = SocketSlashCommandData.Create(client, dataModel, guildId); } - internal new static SocketInteraction Create(DiscordSocketClient client, Model model, ISocketMessageChannel channel) + internal new static SocketInteraction Create(DiscordSocketClient client, Model model, ISocketMessageChannel channel, SocketUser user) { - var entity = new SocketSlashCommand(client, model, channel); + var entity = new SocketSlashCommand(client, model, channel, user); entity.Update(model); return entity; } diff --git a/src/Discord.Net.WebSocket/Entities/Interaction/SocketBaseCommand/SocketCommandBase.cs b/src/Discord.Net.WebSocket/Entities/Interaction/SocketBaseCommand/SocketCommandBase.cs index bc3ece20c..6cf74ce17 100644 --- a/src/Discord.Net.WebSocket/Entities/Interaction/SocketBaseCommand/SocketCommandBase.cs +++ b/src/Discord.Net.WebSocket/Entities/Interaction/SocketBaseCommand/SocketCommandBase.cs @@ -35,8 +35,8 @@ namespace Discord.WebSocket private object _lock = new object(); - internal SocketCommandBase(DiscordSocketClient client, Model model, ISocketMessageChannel channel) - : base(client, model.Id, channel) + internal SocketCommandBase(DiscordSocketClient client, Model model, ISocketMessageChannel channel, SocketUser user) + : base(client, model.Id, channel, user) { var dataModel = model.Data.IsSpecified ? (DataModel)model.Data.Value @@ -49,9 +49,9 @@ namespace Discord.WebSocket Data = SocketCommandBaseData.Create(client, dataModel, model.Id, guildId); } - internal new static SocketInteraction Create(DiscordSocketClient client, Model model, ISocketMessageChannel channel) + internal new static SocketInteraction Create(DiscordSocketClient client, Model model, ISocketMessageChannel channel, SocketUser user) { - var entity = new SocketCommandBase(client, model, channel); + var entity = new SocketCommandBase(client, model, channel, user); entity.Update(model); return entity; } diff --git a/src/Discord.Net.WebSocket/Entities/Interaction/SocketInteraction.cs b/src/Discord.Net.WebSocket/Entities/Interaction/SocketInteraction.cs index 1c3563ab0..83b458e0c 100644 --- a/src/Discord.Net.WebSocket/Entities/Interaction/SocketInteraction.cs +++ b/src/Discord.Net.WebSocket/Entities/Interaction/SocketInteraction.cs @@ -5,6 +5,7 @@ using Model = Discord.API.Interaction; using DataModel = Discord.API.ApplicationCommandInteractionData; using System.IO; using System.Collections.Generic; +using Discord.Net; namespace Discord.WebSocket { @@ -72,17 +73,23 @@ namespace Discord.WebSocket public bool IsValidToken => InteractionHelper.CanRespondOrFollowup(this); - internal SocketInteraction(DiscordSocketClient client, ulong id, ISocketMessageChannel channel) + /// + public bool IsDMInteraction { get; private set; } + + private ulong? _channelId; + + internal SocketInteraction(DiscordSocketClient client, ulong id, ISocketMessageChannel channel, SocketUser user) : base(client, id) { Channel = channel; + User = user; CreatedAt = client.UseInteractionSnowflakeDate ? SnowflakeUtils.FromSnowflake(Id) : DateTime.UtcNow; } - internal static SocketInteraction Create(DiscordSocketClient client, Model model, ISocketMessageChannel channel) + internal static SocketInteraction Create(DiscordSocketClient client, Model model, ISocketMessageChannel channel, SocketUser user) { if (model.Type == InteractionType.ApplicationCommand) { @@ -95,27 +102,31 @@ namespace Discord.WebSocket return dataModel.Type switch { - ApplicationCommandType.Slash => SocketSlashCommand.Create(client, model, channel), - ApplicationCommandType.Message => SocketMessageCommand.Create(client, model, channel), - ApplicationCommandType.User => SocketUserCommand.Create(client, model, channel), + ApplicationCommandType.Slash => SocketSlashCommand.Create(client, model, channel, user), + ApplicationCommandType.Message => SocketMessageCommand.Create(client, model, channel, user), + ApplicationCommandType.User => SocketUserCommand.Create(client, model, channel, user), _ => null }; } if (model.Type == InteractionType.MessageComponent) - return SocketMessageComponent.Create(client, model, channel); + return SocketMessageComponent.Create(client, model, channel, user); if (model.Type == InteractionType.ApplicationCommandAutocomplete) - return SocketAutocompleteInteraction.Create(client, model, channel); + return SocketAutocompleteInteraction.Create(client, model, channel, user); if (model.Type == InteractionType.ModalSubmit) - return SocketModal.Create(client, model, channel); + return SocketModal.Create(client, model, channel, user); return null; } internal virtual void Update(Model model) { + IsDMInteraction = !model.GuildId.IsSpecified; + + _channelId = model.ChannelId.ToNullable(); + Data = model.Data.IsSpecified ? model.Data.Value : null; @@ -123,18 +134,6 @@ namespace Discord.WebSocket Version = model.Version; Type = model.Type; - if (User == null) - { - if (model.Member.IsSpecified && model.GuildId.IsSpecified) - { - User = SocketGuildUser.Create(Discord.State.GetGuild(model.GuildId.Value), Discord.State, model.Member.Value); - } - else - { - User = SocketGlobalUser.Create(Discord, Discord.State, model.User.Value); - } - } - UserLocale = model.UserLocale.IsSpecified ? model.UserLocale.Value : null; @@ -399,6 +398,28 @@ namespace Discord.WebSocket public abstract Task RespondWithModalAsync(Modal modal, RequestOptions options = null); #endregion + /// + /// Attepts to get the channel this interaction was executed in. + /// + /// The request options for this request. + /// + /// A task that represents the asynchronous operation of fetching the channel. + /// + public async ValueTask GetChannelAsync(RequestOptions options = null) + { + if (Channel != null) + return Channel; + + if (!_channelId.HasValue) + return null; + + try + { + return (IMessageChannel)await Discord.GetChannelAsync(_channelId.Value, options).ConfigureAwait(false); + } + catch(HttpException ex) when (ex.DiscordCode == DiscordErrorCode.MissingPermissions) { return null; } // bot can't view that channel, return null instead of throwing. + } + #region IDiscordInteraction /// IUser IDiscordInteraction.User => User;