diff --git a/src/Discord.Net.Core/Entities/Interactions/IDiscordInteraction.cs b/src/Discord.Net.Core/Entities/Interactions/IDiscordInteraction.cs
index 66ff6c6d0..0741ce138 100644
--- a/src/Discord.Net.Core/Entities/Interactions/IDiscordInteraction.cs
+++ b/src/Discord.Net.Core/Entities/Interactions/IDiscordInteraction.cs
@@ -62,6 +62,11 @@ namespace Discord
///
string GuildLocale { get; }
+ ///
+ /// Gets whether or not this interaction was executed in a dm channel.
+ ///
+ bool IsDMInteraction { get; }
+
///
/// Responds to an Interaction with type .
///
diff --git a/src/Discord.Net.Interactions/Attributes/Preconditions/RequireContextAttribute.cs b/src/Discord.Net.Interactions/Attributes/Preconditions/RequireContextAttribute.cs
index 2f1b1df0d..9d1cee8d9 100644
--- a/src/Discord.Net.Interactions/Attributes/Preconditions/RequireContextAttribute.cs
+++ b/src/Discord.Net.Interactions/Attributes/Preconditions/RequireContextAttribute.cs
@@ -57,11 +57,9 @@ namespace Discord.Interactions
bool isValid = false;
if ((Contexts & ContextType.Guild) != 0)
- isValid = context.Channel is IGuildChannel;
- if ((Contexts & ContextType.DM) != 0)
- isValid = isValid || context.Channel is IDMChannel;
- if ((Contexts & ContextType.Group) != 0)
- isValid = isValid || context.Channel is IGroupChannel;
+ isValid = !context.Interaction.IsDMInteraction;
+ if ((Contexts & ContextType.DM) != 0 && (Contexts & ContextType.Group) != 0)
+ isValid = context.Interaction.IsDMInteraction;
if (isValid)
return Task.FromResult(PreconditionResult.FromSuccess());
diff --git a/src/Discord.Net.Rest/Entities/Interactions/RestInteraction.cs b/src/Discord.Net.Rest/Entities/Interactions/RestInteraction.cs
index 5894ee264..6069cf0e2 100644
--- a/src/Discord.Net.Rest/Entities/Interactions/RestInteraction.cs
+++ b/src/Discord.Net.Rest/Entities/Interactions/RestInteraction.cs
@@ -61,6 +61,9 @@ namespace Discord.Rest
///
public bool HasResponded { get; protected set; }
+ ///
+ public bool IsDMInteraction { get; private set; }
+
internal RestInteraction(BaseDiscordClient discord, ulong id)
: base(discord, id)
{
@@ -108,6 +111,8 @@ namespace Discord.Rest
internal virtual async Task UpdateAsync(DiscordRestClient discord, Model model)
{
+ IsDMInteraction = !model.GuildId.IsSpecified;
+
Data = model.Data.IsSpecified
? model.Data.Value
: null;
diff --git a/src/Discord.Net.WebSocket/DiscordSocketClient.cs b/src/Discord.Net.WebSocket/DiscordSocketClient.cs
index e7f9b10ee..b0215d9ef 100644
--- a/src/Discord.Net.WebSocket/DiscordSocketClient.cs
+++ b/src/Discord.Net.WebSocket/DiscordSocketClient.cs
@@ -2233,24 +2233,42 @@ namespace Discord.WebSocket
var data = (payload as JToken).ToObject(_serializer);
+ var guild = data.GuildId.IsSpecified ? GetGuild(data.GuildId.Value) : null;
+
+ if (guild != null && !guild.IsSynced)
+ {
+ await UnsyncedGuildAsync(type, guild.Id).ConfigureAwait(false);
+ return;
+ }
+
+ SocketUser user = data.User.IsSpecified
+ ? State.GetOrAddUser(data.User.Value.Id, (_) => SocketGlobalUser.Create(this, State, data.User.Value))
+ : guild.AddOrUpdateUser(data.Member.Value);
+
SocketChannel channel = null;
if(data.ChannelId.IsSpecified)
{
channel = State.GetChannel(data.ChannelId.Value);
+
+ if (channel == null)
+ {
+ if (!data.GuildId.IsSpecified) // assume it is a DM
+ {
+ channel = CreateDMChannel(data.ChannelId.Value, user, State);
+ }
+ else
+ {
+ await UnknownChannelAsync(type, data.ChannelId.Value).ConfigureAwait(false);
+ return;
+ }
+ }
}
else if (data.User.IsSpecified)
{
channel = State.GetDMChannel(data.User.Value.Id);
}
- var guild = (channel as SocketGuildChannel)?.Guild;
- if (guild != null && !guild.IsSynced)
- {
- await UnsyncedGuildAsync(type, guild.Id).ConfigureAwait(false);
- return;
- }
-
- var interaction = SocketInteraction.Create(this, data, channel as ISocketMessageChannel);
+ var interaction = SocketInteraction.Create(this, data, channel as ISocketMessageChannel, user);
await TimedInvokeAsync(_interactionCreatedEvent, nameof(InteractionCreated), interaction).ConfigureAwait(false);
diff --git a/src/Discord.Net.WebSocket/Entities/Interaction/ContextMenuCommands/MessageCommands/SocketMessageCommand.cs b/src/Discord.Net.WebSocket/Entities/Interaction/ContextMenuCommands/MessageCommands/SocketMessageCommand.cs
index fee33f8cb..e4198a183 100644
--- a/src/Discord.Net.WebSocket/Entities/Interaction/ContextMenuCommands/MessageCommands/SocketMessageCommand.cs
+++ b/src/Discord.Net.WebSocket/Entities/Interaction/ContextMenuCommands/MessageCommands/SocketMessageCommand.cs
@@ -13,8 +13,8 @@ namespace Discord.WebSocket
///
public new SocketMessageCommandData Data { get; }
- internal SocketMessageCommand(DiscordSocketClient client, Model model, ISocketMessageChannel channel)
- : base(client, model, channel)
+ internal SocketMessageCommand(DiscordSocketClient client, Model model, ISocketMessageChannel channel, SocketUser user)
+ : base(client, model, channel, user)
{
var dataModel = model.Data.IsSpecified
? (DataModel)model.Data.Value
@@ -27,9 +27,9 @@ namespace Discord.WebSocket
Data = SocketMessageCommandData.Create(client, dataModel, model.Id, guildId);
}
- internal new static SocketInteraction Create(DiscordSocketClient client, Model model, ISocketMessageChannel channel)
+ internal new static SocketInteraction Create(DiscordSocketClient client, Model model, ISocketMessageChannel channel, SocketUser user)
{
- var entity = new SocketMessageCommand(client, model, channel);
+ var entity = new SocketMessageCommand(client, model, channel, user);
entity.Update(model);
return entity;
}
diff --git a/src/Discord.Net.WebSocket/Entities/Interaction/ContextMenuCommands/UserCommands/SocketUserCommand.cs b/src/Discord.Net.WebSocket/Entities/Interaction/ContextMenuCommands/UserCommands/SocketUserCommand.cs
index 75e8ebff9..c33c06f83 100644
--- a/src/Discord.Net.WebSocket/Entities/Interaction/ContextMenuCommands/UserCommands/SocketUserCommand.cs
+++ b/src/Discord.Net.WebSocket/Entities/Interaction/ContextMenuCommands/UserCommands/SocketUserCommand.cs
@@ -13,8 +13,8 @@ namespace Discord.WebSocket
///
public new SocketUserCommandData Data { get; }
- internal SocketUserCommand(DiscordSocketClient client, Model model, ISocketMessageChannel channel)
- : base(client, model, channel)
+ internal SocketUserCommand(DiscordSocketClient client, Model model, ISocketMessageChannel channel, SocketUser user)
+ : base(client, model, channel, user)
{
var dataModel = model.Data.IsSpecified
? (DataModel)model.Data.Value
@@ -27,9 +27,9 @@ namespace Discord.WebSocket
Data = SocketUserCommandData.Create(client, dataModel, model.Id, guildId);
}
- internal new static SocketInteraction Create(DiscordSocketClient client, Model model, ISocketMessageChannel channel)
+ internal new static SocketInteraction Create(DiscordSocketClient client, Model model, ISocketMessageChannel channel, SocketUser user)
{
- var entity = new SocketUserCommand(client, model, channel);
+ var entity = new SocketUserCommand(client, model, channel, user);
entity.Update(model);
return entity;
}
diff --git a/src/Discord.Net.WebSocket/Entities/Interaction/MessageComponents/SocketMessageComponent.cs b/src/Discord.Net.WebSocket/Entities/Interaction/MessageComponents/SocketMessageComponent.cs
index 17a5e0209..b06979381 100644
--- a/src/Discord.Net.WebSocket/Entities/Interaction/MessageComponents/SocketMessageComponent.cs
+++ b/src/Discord.Net.WebSocket/Entities/Interaction/MessageComponents/SocketMessageComponent.cs
@@ -28,8 +28,8 @@ namespace Discord.WebSocket
private object _lock = new object();
public override bool HasResponded { get; internal set; } = false;
- internal SocketMessageComponent(DiscordSocketClient client, Model model, ISocketMessageChannel channel)
- : base(client, model.Id, channel)
+ internal SocketMessageComponent(DiscordSocketClient client, Model model, ISocketMessageChannel channel, SocketUser user)
+ : base(client, model.Id, channel, user)
{
var dataModel = model.Data.IsSpecified
? (DataModel)model.Data.Value
@@ -38,9 +38,9 @@ namespace Discord.WebSocket
Data = new SocketMessageComponentData(dataModel);
}
- internal new static SocketMessageComponent Create(DiscordSocketClient client, Model model, ISocketMessageChannel channel)
+ internal new static SocketMessageComponent Create(DiscordSocketClient client, Model model, ISocketMessageChannel channel, SocketUser user)
{
- var entity = new SocketMessageComponent(client, model, channel);
+ var entity = new SocketMessageComponent(client, model, channel, user);
entity.Update(model);
return entity;
}
diff --git a/src/Discord.Net.WebSocket/Entities/Interaction/Modals/SocketModal.cs b/src/Discord.Net.WebSocket/Entities/Interaction/Modals/SocketModal.cs
index 197882dae..e59bed6ec 100644
--- a/src/Discord.Net.WebSocket/Entities/Interaction/Modals/SocketModal.cs
+++ b/src/Discord.Net.WebSocket/Entities/Interaction/Modals/SocketModal.cs
@@ -22,8 +22,8 @@ namespace Discord.WebSocket
///
public new SocketModalData Data { get; set; }
- internal SocketModal(DiscordSocketClient client, ModelBase model, ISocketMessageChannel channel)
- : base(client, model.Id, channel)
+ internal SocketModal(DiscordSocketClient client, ModelBase model, ISocketMessageChannel channel, SocketUser user)
+ : base(client, model.Id, channel, user)
{
var dataModel = model.Data.IsSpecified
? (DataModel)model.Data.Value
@@ -32,9 +32,9 @@ namespace Discord.WebSocket
Data = new SocketModalData(dataModel);
}
- internal new static SocketModal Create(DiscordSocketClient client, ModelBase model, ISocketMessageChannel channel)
+ internal new static SocketModal Create(DiscordSocketClient client, ModelBase model, ISocketMessageChannel channel, SocketUser user)
{
- var entity = new SocketModal(client, model, channel);
+ var entity = new SocketModal(client, model, channel, user);
entity.Update(model);
return entity;
}
diff --git a/src/Discord.Net.WebSocket/Entities/Interaction/SlashCommands/SocketAutocompleteInteraction.cs b/src/Discord.Net.WebSocket/Entities/Interaction/SlashCommands/SocketAutocompleteInteraction.cs
index d4cdc9cc1..449f074f5 100644
--- a/src/Discord.Net.WebSocket/Entities/Interaction/SlashCommands/SocketAutocompleteInteraction.cs
+++ b/src/Discord.Net.WebSocket/Entities/Interaction/SlashCommands/SocketAutocompleteInteraction.cs
@@ -21,8 +21,8 @@ namespace Discord.WebSocket
public override bool HasResponded { get; internal set; }
private object _lock = new object();
- internal SocketAutocompleteInteraction(DiscordSocketClient client, Model model, ISocketMessageChannel channel)
- : base(client, model.Id, channel)
+ internal SocketAutocompleteInteraction(DiscordSocketClient client, Model model, ISocketMessageChannel channel, SocketUser user)
+ : base(client, model.Id, channel, user)
{
var dataModel = model.Data.IsSpecified
? (DataModel)model.Data.Value
@@ -32,9 +32,9 @@ namespace Discord.WebSocket
Data = new SocketAutocompleteInteractionData(dataModel);
}
- internal new static SocketAutocompleteInteraction Create(DiscordSocketClient client, Model model, ISocketMessageChannel channel)
+ internal new static SocketAutocompleteInteraction Create(DiscordSocketClient client, Model model, ISocketMessageChannel channel, SocketUser user)
{
- var entity = new SocketAutocompleteInteraction(client, model, channel);
+ var entity = new SocketAutocompleteInteraction(client, model, channel, user);
entity.Update(model);
return entity;
}
diff --git a/src/Discord.Net.WebSocket/Entities/Interaction/SlashCommands/SocketSlashCommand.cs b/src/Discord.Net.WebSocket/Entities/Interaction/SlashCommands/SocketSlashCommand.cs
index 5934a3864..5f7e72ba0 100644
--- a/src/Discord.Net.WebSocket/Entities/Interaction/SlashCommands/SocketSlashCommand.cs
+++ b/src/Discord.Net.WebSocket/Entities/Interaction/SlashCommands/SocketSlashCommand.cs
@@ -13,8 +13,8 @@ namespace Discord.WebSocket
///
public new SocketSlashCommandData Data { get; }
- internal SocketSlashCommand(DiscordSocketClient client, Model model, ISocketMessageChannel channel)
- : base(client, model, channel)
+ internal SocketSlashCommand(DiscordSocketClient client, Model model, ISocketMessageChannel channel, SocketUser user)
+ : base(client, model, channel, user)
{
var dataModel = model.Data.IsSpecified
? (DataModel)model.Data.Value
@@ -27,9 +27,9 @@ namespace Discord.WebSocket
Data = SocketSlashCommandData.Create(client, dataModel, guildId);
}
- internal new static SocketInteraction Create(DiscordSocketClient client, Model model, ISocketMessageChannel channel)
+ internal new static SocketInteraction Create(DiscordSocketClient client, Model model, ISocketMessageChannel channel, SocketUser user)
{
- var entity = new SocketSlashCommand(client, model, channel);
+ var entity = new SocketSlashCommand(client, model, channel, user);
entity.Update(model);
return entity;
}
diff --git a/src/Discord.Net.WebSocket/Entities/Interaction/SocketBaseCommand/SocketCommandBase.cs b/src/Discord.Net.WebSocket/Entities/Interaction/SocketBaseCommand/SocketCommandBase.cs
index bc3ece20c..6cf74ce17 100644
--- a/src/Discord.Net.WebSocket/Entities/Interaction/SocketBaseCommand/SocketCommandBase.cs
+++ b/src/Discord.Net.WebSocket/Entities/Interaction/SocketBaseCommand/SocketCommandBase.cs
@@ -35,8 +35,8 @@ namespace Discord.WebSocket
private object _lock = new object();
- internal SocketCommandBase(DiscordSocketClient client, Model model, ISocketMessageChannel channel)
- : base(client, model.Id, channel)
+ internal SocketCommandBase(DiscordSocketClient client, Model model, ISocketMessageChannel channel, SocketUser user)
+ : base(client, model.Id, channel, user)
{
var dataModel = model.Data.IsSpecified
? (DataModel)model.Data.Value
@@ -49,9 +49,9 @@ namespace Discord.WebSocket
Data = SocketCommandBaseData.Create(client, dataModel, model.Id, guildId);
}
- internal new static SocketInteraction Create(DiscordSocketClient client, Model model, ISocketMessageChannel channel)
+ internal new static SocketInteraction Create(DiscordSocketClient client, Model model, ISocketMessageChannel channel, SocketUser user)
{
- var entity = new SocketCommandBase(client, model, channel);
+ var entity = new SocketCommandBase(client, model, channel, user);
entity.Update(model);
return entity;
}
diff --git a/src/Discord.Net.WebSocket/Entities/Interaction/SocketInteraction.cs b/src/Discord.Net.WebSocket/Entities/Interaction/SocketInteraction.cs
index 1c3563ab0..83b458e0c 100644
--- a/src/Discord.Net.WebSocket/Entities/Interaction/SocketInteraction.cs
+++ b/src/Discord.Net.WebSocket/Entities/Interaction/SocketInteraction.cs
@@ -5,6 +5,7 @@ using Model = Discord.API.Interaction;
using DataModel = Discord.API.ApplicationCommandInteractionData;
using System.IO;
using System.Collections.Generic;
+using Discord.Net;
namespace Discord.WebSocket
{
@@ -72,17 +73,23 @@ namespace Discord.WebSocket
public bool IsValidToken
=> InteractionHelper.CanRespondOrFollowup(this);
- internal SocketInteraction(DiscordSocketClient client, ulong id, ISocketMessageChannel channel)
+ ///
+ public bool IsDMInteraction { get; private set; }
+
+ private ulong? _channelId;
+
+ internal SocketInteraction(DiscordSocketClient client, ulong id, ISocketMessageChannel channel, SocketUser user)
: base(client, id)
{
Channel = channel;
+ User = user;
CreatedAt = client.UseInteractionSnowflakeDate
? SnowflakeUtils.FromSnowflake(Id)
: DateTime.UtcNow;
}
- internal static SocketInteraction Create(DiscordSocketClient client, Model model, ISocketMessageChannel channel)
+ internal static SocketInteraction Create(DiscordSocketClient client, Model model, ISocketMessageChannel channel, SocketUser user)
{
if (model.Type == InteractionType.ApplicationCommand)
{
@@ -95,27 +102,31 @@ namespace Discord.WebSocket
return dataModel.Type switch
{
- ApplicationCommandType.Slash => SocketSlashCommand.Create(client, model, channel),
- ApplicationCommandType.Message => SocketMessageCommand.Create(client, model, channel),
- ApplicationCommandType.User => SocketUserCommand.Create(client, model, channel),
+ ApplicationCommandType.Slash => SocketSlashCommand.Create(client, model, channel, user),
+ ApplicationCommandType.Message => SocketMessageCommand.Create(client, model, channel, user),
+ ApplicationCommandType.User => SocketUserCommand.Create(client, model, channel, user),
_ => null
};
}
if (model.Type == InteractionType.MessageComponent)
- return SocketMessageComponent.Create(client, model, channel);
+ return SocketMessageComponent.Create(client, model, channel, user);
if (model.Type == InteractionType.ApplicationCommandAutocomplete)
- return SocketAutocompleteInteraction.Create(client, model, channel);
+ return SocketAutocompleteInteraction.Create(client, model, channel, user);
if (model.Type == InteractionType.ModalSubmit)
- return SocketModal.Create(client, model, channel);
+ return SocketModal.Create(client, model, channel, user);
return null;
}
internal virtual void Update(Model model)
{
+ IsDMInteraction = !model.GuildId.IsSpecified;
+
+ _channelId = model.ChannelId.ToNullable();
+
Data = model.Data.IsSpecified
? model.Data.Value
: null;
@@ -123,18 +134,6 @@ namespace Discord.WebSocket
Version = model.Version;
Type = model.Type;
- if (User == null)
- {
- if (model.Member.IsSpecified && model.GuildId.IsSpecified)
- {
- User = SocketGuildUser.Create(Discord.State.GetGuild(model.GuildId.Value), Discord.State, model.Member.Value);
- }
- else
- {
- User = SocketGlobalUser.Create(Discord, Discord.State, model.User.Value);
- }
- }
-
UserLocale = model.UserLocale.IsSpecified
? model.UserLocale.Value
: null;
@@ -399,6 +398,28 @@ namespace Discord.WebSocket
public abstract Task RespondWithModalAsync(Modal modal, RequestOptions options = null);
#endregion
+ ///
+ /// Attepts to get the channel this interaction was executed in.
+ ///
+ /// The request options for this request.
+ ///
+ /// A task that represents the asynchronous operation of fetching the channel.
+ ///
+ public async ValueTask GetChannelAsync(RequestOptions options = null)
+ {
+ if (Channel != null)
+ return Channel;
+
+ if (!_channelId.HasValue)
+ return null;
+
+ try
+ {
+ return (IMessageChannel)await Discord.GetChannelAsync(_channelId.Value, options).ConfigureAwait(false);
+ }
+ catch(HttpException ex) when (ex.DiscordCode == DiscordErrorCode.MissingPermissions) { return null; } // bot can't view that channel, return null instead of throwing.
+ }
+
#region IDiscordInteraction
///
IUser IDiscordInteraction.User => User;