Browse Source

Fix channel being null in DMs on Interactions (#2098)

tags/3.3.1
Quin Lynch GitHub 3 years ago
parent
commit
7e1b8c9db0
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 108 additions and 61 deletions
  1. +5
    -0
      src/Discord.Net.Core/Entities/Interactions/IDiscordInteraction.cs
  2. +3
    -5
      src/Discord.Net.Interactions/Attributes/Preconditions/RequireContextAttribute.cs
  3. +5
    -0
      src/Discord.Net.Rest/Entities/Interactions/RestInteraction.cs
  4. +26
    -8
      src/Discord.Net.WebSocket/DiscordSocketClient.cs
  5. +4
    -4
      src/Discord.Net.WebSocket/Entities/Interaction/ContextMenuCommands/MessageCommands/SocketMessageCommand.cs
  6. +4
    -4
      src/Discord.Net.WebSocket/Entities/Interaction/ContextMenuCommands/UserCommands/SocketUserCommand.cs
  7. +4
    -4
      src/Discord.Net.WebSocket/Entities/Interaction/MessageComponents/SocketMessageComponent.cs
  8. +4
    -4
      src/Discord.Net.WebSocket/Entities/Interaction/Modals/SocketModal.cs
  9. +4
    -4
      src/Discord.Net.WebSocket/Entities/Interaction/SlashCommands/SocketAutocompleteInteraction.cs
  10. +4
    -4
      src/Discord.Net.WebSocket/Entities/Interaction/SlashCommands/SocketSlashCommand.cs
  11. +4
    -4
      src/Discord.Net.WebSocket/Entities/Interaction/SocketBaseCommand/SocketCommandBase.cs
  12. +41
    -20
      src/Discord.Net.WebSocket/Entities/Interaction/SocketInteraction.cs

+ 5
- 0
src/Discord.Net.Core/Entities/Interactions/IDiscordInteraction.cs View File

@@ -62,6 +62,11 @@ namespace Discord
/// </remarks> /// </remarks>
string GuildLocale { get; } string GuildLocale { get; }


/// <summary>
/// Gets whether or not this interaction was executed in a dm channel.
/// </summary>
bool IsDMInteraction { get; }

/// <summary> /// <summary>
/// Responds to an Interaction with type <see cref="InteractionResponseType.ChannelMessageWithSource"/>. /// Responds to an Interaction with type <see cref="InteractionResponseType.ChannelMessageWithSource"/>.
/// </summary> /// </summary>


+ 3
- 5
src/Discord.Net.Interactions/Attributes/Preconditions/RequireContextAttribute.cs View File

@@ -57,11 +57,9 @@ namespace Discord.Interactions
bool isValid = false; bool isValid = false;


if ((Contexts & ContextType.Guild) != 0) if ((Contexts & ContextType.Guild) != 0)
isValid = context.Channel is IGuildChannel;
if ((Contexts & ContextType.DM) != 0)
isValid = isValid || context.Channel is IDMChannel;
if ((Contexts & ContextType.Group) != 0)
isValid = isValid || context.Channel is IGroupChannel;
isValid = !context.Interaction.IsDMInteraction;
if ((Contexts & ContextType.DM) != 0 && (Contexts & ContextType.Group) != 0)
isValid = context.Interaction.IsDMInteraction;


if (isValid) if (isValid)
return Task.FromResult(PreconditionResult.FromSuccess()); return Task.FromResult(PreconditionResult.FromSuccess());


+ 5
- 0
src/Discord.Net.Rest/Entities/Interactions/RestInteraction.cs View File

@@ -61,6 +61,9 @@ namespace Discord.Rest
/// <inheritdoc/> /// <inheritdoc/>
public bool HasResponded { get; protected set; } public bool HasResponded { get; protected set; }


/// <inheritdoc/>
public bool IsDMInteraction { get; private set; }

internal RestInteraction(BaseDiscordClient discord, ulong id) internal RestInteraction(BaseDiscordClient discord, ulong id)
: base(discord, id) : base(discord, id)
{ {
@@ -108,6 +111,8 @@ namespace Discord.Rest


internal virtual async Task UpdateAsync(DiscordRestClient discord, Model model) internal virtual async Task UpdateAsync(DiscordRestClient discord, Model model)
{ {
IsDMInteraction = !model.GuildId.IsSpecified;

Data = model.Data.IsSpecified Data = model.Data.IsSpecified
? model.Data.Value ? model.Data.Value
: null; : null;


+ 26
- 8
src/Discord.Net.WebSocket/DiscordSocketClient.cs View File

@@ -2233,24 +2233,42 @@ namespace Discord.WebSocket


var data = (payload as JToken).ToObject<API.Interaction>(_serializer); var data = (payload as JToken).ToObject<API.Interaction>(_serializer);


var guild = data.GuildId.IsSpecified ? GetGuild(data.GuildId.Value) : null;

if (guild != null && !guild.IsSynced)
{
await UnsyncedGuildAsync(type, guild.Id).ConfigureAwait(false);
return;
}

SocketUser user = data.User.IsSpecified
? State.GetOrAddUser(data.User.Value.Id, (_) => SocketGlobalUser.Create(this, State, data.User.Value))
: guild.AddOrUpdateUser(data.Member.Value);

SocketChannel channel = null; SocketChannel channel = null;
if(data.ChannelId.IsSpecified) if(data.ChannelId.IsSpecified)
{ {
channel = State.GetChannel(data.ChannelId.Value); channel = State.GetChannel(data.ChannelId.Value);

if (channel == null)
{
if (!data.GuildId.IsSpecified) // assume it is a DM
{
channel = CreateDMChannel(data.ChannelId.Value, user, State);
}
else
{
await UnknownChannelAsync(type, data.ChannelId.Value).ConfigureAwait(false);
return;
}
}
} }
else if (data.User.IsSpecified) else if (data.User.IsSpecified)
{ {
channel = State.GetDMChannel(data.User.Value.Id); channel = State.GetDMChannel(data.User.Value.Id);
} }


var guild = (channel as SocketGuildChannel)?.Guild;
if (guild != null && !guild.IsSynced)
{
await UnsyncedGuildAsync(type, guild.Id).ConfigureAwait(false);
return;
}

var interaction = SocketInteraction.Create(this, data, channel as ISocketMessageChannel);
var interaction = SocketInteraction.Create(this, data, channel as ISocketMessageChannel, user);


await TimedInvokeAsync(_interactionCreatedEvent, nameof(InteractionCreated), interaction).ConfigureAwait(false); await TimedInvokeAsync(_interactionCreatedEvent, nameof(InteractionCreated), interaction).ConfigureAwait(false);




+ 4
- 4
src/Discord.Net.WebSocket/Entities/Interaction/ContextMenuCommands/MessageCommands/SocketMessageCommand.cs View File

@@ -13,8 +13,8 @@ namespace Discord.WebSocket
/// </summary> /// </summary>
public new SocketMessageCommandData Data { get; } public new SocketMessageCommandData Data { get; }


internal SocketMessageCommand(DiscordSocketClient client, Model model, ISocketMessageChannel channel)
: base(client, model, channel)
internal SocketMessageCommand(DiscordSocketClient client, Model model, ISocketMessageChannel channel, SocketUser user)
: base(client, model, channel, user)
{ {
var dataModel = model.Data.IsSpecified var dataModel = model.Data.IsSpecified
? (DataModel)model.Data.Value ? (DataModel)model.Data.Value
@@ -27,9 +27,9 @@ namespace Discord.WebSocket
Data = SocketMessageCommandData.Create(client, dataModel, model.Id, guildId); Data = SocketMessageCommandData.Create(client, dataModel, model.Id, guildId);
} }


internal new static SocketInteraction Create(DiscordSocketClient client, Model model, ISocketMessageChannel channel)
internal new static SocketInteraction Create(DiscordSocketClient client, Model model, ISocketMessageChannel channel, SocketUser user)
{ {
var entity = new SocketMessageCommand(client, model, channel);
var entity = new SocketMessageCommand(client, model, channel, user);
entity.Update(model); entity.Update(model);
return entity; return entity;
} }


+ 4
- 4
src/Discord.Net.WebSocket/Entities/Interaction/ContextMenuCommands/UserCommands/SocketUserCommand.cs View File

@@ -13,8 +13,8 @@ namespace Discord.WebSocket
/// </summary> /// </summary>
public new SocketUserCommandData Data { get; } public new SocketUserCommandData Data { get; }


internal SocketUserCommand(DiscordSocketClient client, Model model, ISocketMessageChannel channel)
: base(client, model, channel)
internal SocketUserCommand(DiscordSocketClient client, Model model, ISocketMessageChannel channel, SocketUser user)
: base(client, model, channel, user)
{ {
var dataModel = model.Data.IsSpecified var dataModel = model.Data.IsSpecified
? (DataModel)model.Data.Value ? (DataModel)model.Data.Value
@@ -27,9 +27,9 @@ namespace Discord.WebSocket
Data = SocketUserCommandData.Create(client, dataModel, model.Id, guildId); Data = SocketUserCommandData.Create(client, dataModel, model.Id, guildId);
} }


internal new static SocketInteraction Create(DiscordSocketClient client, Model model, ISocketMessageChannel channel)
internal new static SocketInteraction Create(DiscordSocketClient client, Model model, ISocketMessageChannel channel, SocketUser user)
{ {
var entity = new SocketUserCommand(client, model, channel);
var entity = new SocketUserCommand(client, model, channel, user);
entity.Update(model); entity.Update(model);
return entity; return entity;
} }


+ 4
- 4
src/Discord.Net.WebSocket/Entities/Interaction/MessageComponents/SocketMessageComponent.cs View File

@@ -28,8 +28,8 @@ namespace Discord.WebSocket
private object _lock = new object(); private object _lock = new object();
public override bool HasResponded { get; internal set; } = false; public override bool HasResponded { get; internal set; } = false;


internal SocketMessageComponent(DiscordSocketClient client, Model model, ISocketMessageChannel channel)
: base(client, model.Id, channel)
internal SocketMessageComponent(DiscordSocketClient client, Model model, ISocketMessageChannel channel, SocketUser user)
: base(client, model.Id, channel, user)
{ {
var dataModel = model.Data.IsSpecified var dataModel = model.Data.IsSpecified
? (DataModel)model.Data.Value ? (DataModel)model.Data.Value
@@ -38,9 +38,9 @@ namespace Discord.WebSocket
Data = new SocketMessageComponentData(dataModel); Data = new SocketMessageComponentData(dataModel);
} }


internal new static SocketMessageComponent Create(DiscordSocketClient client, Model model, ISocketMessageChannel channel)
internal new static SocketMessageComponent Create(DiscordSocketClient client, Model model, ISocketMessageChannel channel, SocketUser user)
{ {
var entity = new SocketMessageComponent(client, model, channel);
var entity = new SocketMessageComponent(client, model, channel, user);
entity.Update(model); entity.Update(model);
return entity; return entity;
} }


+ 4
- 4
src/Discord.Net.WebSocket/Entities/Interaction/Modals/SocketModal.cs View File

@@ -22,8 +22,8 @@ namespace Discord.WebSocket
/// <value></value> /// <value></value>
public new SocketModalData Data { get; set; } public new SocketModalData Data { get; set; }


internal SocketModal(DiscordSocketClient client, ModelBase model, ISocketMessageChannel channel)
: base(client, model.Id, channel)
internal SocketModal(DiscordSocketClient client, ModelBase model, ISocketMessageChannel channel, SocketUser user)
: base(client, model.Id, channel, user)
{ {
var dataModel = model.Data.IsSpecified var dataModel = model.Data.IsSpecified
? (DataModel)model.Data.Value ? (DataModel)model.Data.Value
@@ -32,9 +32,9 @@ namespace Discord.WebSocket
Data = new SocketModalData(dataModel); Data = new SocketModalData(dataModel);
} }


internal new static SocketModal Create(DiscordSocketClient client, ModelBase model, ISocketMessageChannel channel)
internal new static SocketModal Create(DiscordSocketClient client, ModelBase model, ISocketMessageChannel channel, SocketUser user)
{ {
var entity = new SocketModal(client, model, channel);
var entity = new SocketModal(client, model, channel, user);
entity.Update(model); entity.Update(model);
return entity; return entity;
} }


+ 4
- 4
src/Discord.Net.WebSocket/Entities/Interaction/SlashCommands/SocketAutocompleteInteraction.cs View File

@@ -21,8 +21,8 @@ namespace Discord.WebSocket
public override bool HasResponded { get; internal set; } public override bool HasResponded { get; internal set; }
private object _lock = new object(); private object _lock = new object();


internal SocketAutocompleteInteraction(DiscordSocketClient client, Model model, ISocketMessageChannel channel)
: base(client, model.Id, channel)
internal SocketAutocompleteInteraction(DiscordSocketClient client, Model model, ISocketMessageChannel channel, SocketUser user)
: base(client, model.Id, channel, user)
{ {
var dataModel = model.Data.IsSpecified var dataModel = model.Data.IsSpecified
? (DataModel)model.Data.Value ? (DataModel)model.Data.Value
@@ -32,9 +32,9 @@ namespace Discord.WebSocket
Data = new SocketAutocompleteInteractionData(dataModel); Data = new SocketAutocompleteInteractionData(dataModel);
} }


internal new static SocketAutocompleteInteraction Create(DiscordSocketClient client, Model model, ISocketMessageChannel channel)
internal new static SocketAutocompleteInteraction Create(DiscordSocketClient client, Model model, ISocketMessageChannel channel, SocketUser user)
{ {
var entity = new SocketAutocompleteInteraction(client, model, channel);
var entity = new SocketAutocompleteInteraction(client, model, channel, user);
entity.Update(model); entity.Update(model);
return entity; return entity;
} }


+ 4
- 4
src/Discord.Net.WebSocket/Entities/Interaction/SlashCommands/SocketSlashCommand.cs View File

@@ -13,8 +13,8 @@ namespace Discord.WebSocket
/// </summary> /// </summary>
public new SocketSlashCommandData Data { get; } public new SocketSlashCommandData Data { get; }


internal SocketSlashCommand(DiscordSocketClient client, Model model, ISocketMessageChannel channel)
: base(client, model, channel)
internal SocketSlashCommand(DiscordSocketClient client, Model model, ISocketMessageChannel channel, SocketUser user)
: base(client, model, channel, user)
{ {
var dataModel = model.Data.IsSpecified var dataModel = model.Data.IsSpecified
? (DataModel)model.Data.Value ? (DataModel)model.Data.Value
@@ -27,9 +27,9 @@ namespace Discord.WebSocket
Data = SocketSlashCommandData.Create(client, dataModel, guildId); Data = SocketSlashCommandData.Create(client, dataModel, guildId);
} }


internal new static SocketInteraction Create(DiscordSocketClient client, Model model, ISocketMessageChannel channel)
internal new static SocketInteraction Create(DiscordSocketClient client, Model model, ISocketMessageChannel channel, SocketUser user)
{ {
var entity = new SocketSlashCommand(client, model, channel);
var entity = new SocketSlashCommand(client, model, channel, user);
entity.Update(model); entity.Update(model);
return entity; return entity;
} }


+ 4
- 4
src/Discord.Net.WebSocket/Entities/Interaction/SocketBaseCommand/SocketCommandBase.cs View File

@@ -35,8 +35,8 @@ namespace Discord.WebSocket


private object _lock = new object(); private object _lock = new object();


internal SocketCommandBase(DiscordSocketClient client, Model model, ISocketMessageChannel channel)
: base(client, model.Id, channel)
internal SocketCommandBase(DiscordSocketClient client, Model model, ISocketMessageChannel channel, SocketUser user)
: base(client, model.Id, channel, user)
{ {
var dataModel = model.Data.IsSpecified var dataModel = model.Data.IsSpecified
? (DataModel)model.Data.Value ? (DataModel)model.Data.Value
@@ -49,9 +49,9 @@ namespace Discord.WebSocket
Data = SocketCommandBaseData.Create(client, dataModel, model.Id, guildId); Data = SocketCommandBaseData.Create(client, dataModel, model.Id, guildId);
} }


internal new static SocketInteraction Create(DiscordSocketClient client, Model model, ISocketMessageChannel channel)
internal new static SocketInteraction Create(DiscordSocketClient client, Model model, ISocketMessageChannel channel, SocketUser user)
{ {
var entity = new SocketCommandBase(client, model, channel);
var entity = new SocketCommandBase(client, model, channel, user);
entity.Update(model); entity.Update(model);
return entity; return entity;
} }


+ 41
- 20
src/Discord.Net.WebSocket/Entities/Interaction/SocketInteraction.cs View File

@@ -5,6 +5,7 @@ using Model = Discord.API.Interaction;
using DataModel = Discord.API.ApplicationCommandInteractionData; using DataModel = Discord.API.ApplicationCommandInteractionData;
using System.IO; using System.IO;
using System.Collections.Generic; using System.Collections.Generic;
using Discord.Net;


namespace Discord.WebSocket namespace Discord.WebSocket
{ {
@@ -72,17 +73,23 @@ namespace Discord.WebSocket
public bool IsValidToken public bool IsValidToken
=> InteractionHelper.CanRespondOrFollowup(this); => InteractionHelper.CanRespondOrFollowup(this);


internal SocketInteraction(DiscordSocketClient client, ulong id, ISocketMessageChannel channel)
/// <inheritdoc/>
public bool IsDMInteraction { get; private set; }

private ulong? _channelId;

internal SocketInteraction(DiscordSocketClient client, ulong id, ISocketMessageChannel channel, SocketUser user)
: base(client, id) : base(client, id)
{ {
Channel = channel; Channel = channel;
User = user;


CreatedAt = client.UseInteractionSnowflakeDate CreatedAt = client.UseInteractionSnowflakeDate
? SnowflakeUtils.FromSnowflake(Id) ? SnowflakeUtils.FromSnowflake(Id)
: DateTime.UtcNow; : DateTime.UtcNow;
} }


internal static SocketInteraction Create(DiscordSocketClient client, Model model, ISocketMessageChannel channel)
internal static SocketInteraction Create(DiscordSocketClient client, Model model, ISocketMessageChannel channel, SocketUser user)
{ {
if (model.Type == InteractionType.ApplicationCommand) if (model.Type == InteractionType.ApplicationCommand)
{ {
@@ -95,27 +102,31 @@ namespace Discord.WebSocket


return dataModel.Type switch return dataModel.Type switch
{ {
ApplicationCommandType.Slash => SocketSlashCommand.Create(client, model, channel),
ApplicationCommandType.Message => SocketMessageCommand.Create(client, model, channel),
ApplicationCommandType.User => SocketUserCommand.Create(client, model, channel),
ApplicationCommandType.Slash => SocketSlashCommand.Create(client, model, channel, user),
ApplicationCommandType.Message => SocketMessageCommand.Create(client, model, channel, user),
ApplicationCommandType.User => SocketUserCommand.Create(client, model, channel, user),
_ => null _ => null
}; };
} }


if (model.Type == InteractionType.MessageComponent) if (model.Type == InteractionType.MessageComponent)
return SocketMessageComponent.Create(client, model, channel);
return SocketMessageComponent.Create(client, model, channel, user);


if (model.Type == InteractionType.ApplicationCommandAutocomplete) if (model.Type == InteractionType.ApplicationCommandAutocomplete)
return SocketAutocompleteInteraction.Create(client, model, channel);
return SocketAutocompleteInteraction.Create(client, model, channel, user);


if (model.Type == InteractionType.ModalSubmit) if (model.Type == InteractionType.ModalSubmit)
return SocketModal.Create(client, model, channel);
return SocketModal.Create(client, model, channel, user);


return null; return null;
} }


internal virtual void Update(Model model) internal virtual void Update(Model model)
{ {
IsDMInteraction = !model.GuildId.IsSpecified;

_channelId = model.ChannelId.ToNullable();

Data = model.Data.IsSpecified Data = model.Data.IsSpecified
? model.Data.Value ? model.Data.Value
: null; : null;
@@ -123,18 +134,6 @@ namespace Discord.WebSocket
Version = model.Version; Version = model.Version;
Type = model.Type; Type = model.Type;


if (User == null)
{
if (model.Member.IsSpecified && model.GuildId.IsSpecified)
{
User = SocketGuildUser.Create(Discord.State.GetGuild(model.GuildId.Value), Discord.State, model.Member.Value);
}
else
{
User = SocketGlobalUser.Create(Discord, Discord.State, model.User.Value);
}
}

UserLocale = model.UserLocale.IsSpecified UserLocale = model.UserLocale.IsSpecified
? model.UserLocale.Value ? model.UserLocale.Value
: null; : null;
@@ -399,6 +398,28 @@ namespace Discord.WebSocket
public abstract Task RespondWithModalAsync(Modal modal, RequestOptions options = null); public abstract Task RespondWithModalAsync(Modal modal, RequestOptions options = null);
#endregion #endregion


/// <summary>
/// Attepts to get the channel this interaction was executed in.
/// </summary>
/// <param name="options">The request options for this <see langword="async"/> request.</param>
/// <returns>
/// A task that represents the asynchronous operation of fetching the channel.
/// </returns>
public async ValueTask<IMessageChannel> GetChannelAsync(RequestOptions options = null)
{
if (Channel != null)
return Channel;

if (!_channelId.HasValue)
return null;

try
{
return (IMessageChannel)await Discord.GetChannelAsync(_channelId.Value, options).ConfigureAwait(false);
}
catch(HttpException ex) when (ex.DiscordCode == DiscordErrorCode.MissingPermissions) { return null; } // bot can't view that channel, return null instead of throwing.
}

#region IDiscordInteraction #region IDiscordInteraction
/// <inheritdoc/> /// <inheritdoc/>
IUser IDiscordInteraction.User => User; IUser IDiscordInteraction.User => User;


Loading…
Cancel
Save