diff --git a/Discord.Net.sln b/Discord.Net.sln index cac6c9064..daf902b96 100644 --- a/Discord.Net.sln +++ b/Discord.Net.sln @@ -1,6 +1,6 @@ Microsoft Visual Studio Solution File, Format Version 12.00 # Visual Studio 15 -VisualStudioVersion = 15.0.27004.2009 +VisualStudioVersion = 15.0.27130.0 MinimumVisualStudioVersion = 10.0.40219.1 Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Discord.Net.Core", "src\Discord.Net.Core\Discord.Net.Core.csproj", "{91E9E7BD-75C9-4E98-84AA-2C271922E5C2}" EndProject @@ -22,6 +22,8 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Discord.Net.Tests", "test\D EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Discord.Net.Webhook", "src\Discord.Net.Webhook\Discord.Net.Webhook.csproj", "{9AFAB80E-D2D3-4EDB-B58C-BACA78D1EA30}" EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Discord.Net.Analyzers", "src\Discord.Net.Analyzers\Discord.Net.Analyzers.csproj", "{BBA8E7FB-C834-40DC-822F-B112CB7F0140}" +EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU @@ -116,6 +118,18 @@ Global {9AFAB80E-D2D3-4EDB-B58C-BACA78D1EA30}.Release|x64.Build.0 = Release|Any CPU {9AFAB80E-D2D3-4EDB-B58C-BACA78D1EA30}.Release|x86.ActiveCfg = Release|Any CPU {9AFAB80E-D2D3-4EDB-B58C-BACA78D1EA30}.Release|x86.Build.0 = Release|Any CPU + {BBA8E7FB-C834-40DC-822F-B112CB7F0140}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {BBA8E7FB-C834-40DC-822F-B112CB7F0140}.Debug|Any CPU.Build.0 = Debug|Any CPU + {BBA8E7FB-C834-40DC-822F-B112CB7F0140}.Debug|x64.ActiveCfg = Debug|Any CPU + {BBA8E7FB-C834-40DC-822F-B112CB7F0140}.Debug|x64.Build.0 = Debug|Any CPU + {BBA8E7FB-C834-40DC-822F-B112CB7F0140}.Debug|x86.ActiveCfg = Debug|Any CPU + {BBA8E7FB-C834-40DC-822F-B112CB7F0140}.Debug|x86.Build.0 = Debug|Any CPU + {BBA8E7FB-C834-40DC-822F-B112CB7F0140}.Release|Any CPU.ActiveCfg = Release|Any CPU + {BBA8E7FB-C834-40DC-822F-B112CB7F0140}.Release|Any CPU.Build.0 = Release|Any CPU + {BBA8E7FB-C834-40DC-822F-B112CB7F0140}.Release|x64.ActiveCfg = Release|Any CPU + {BBA8E7FB-C834-40DC-822F-B112CB7F0140}.Release|x64.Build.0 = Release|Any CPU + {BBA8E7FB-C834-40DC-822F-B112CB7F0140}.Release|x86.ActiveCfg = Release|Any CPU + {BBA8E7FB-C834-40DC-822F-B112CB7F0140}.Release|x86.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE @@ -126,6 +140,7 @@ Global {688FD1D8-7F01-4539-B2E9-F473C5D699C7} = {288C363D-A636-4EAE-9AC1-4698B641B26E} {6BDEEC08-417B-459F-9CA3-FF8BAB18CAC7} = {B0657AAE-DCC5-4FBF-8E5D-1FB578CF3012} {9AFAB80E-D2D3-4EDB-B58C-BACA78D1EA30} = {CC3D4B1C-9DE0-448B-8AE7-F3F1F3EC5C3A} + {BBA8E7FB-C834-40DC-822F-B112CB7F0140} = {CC3D4B1C-9DE0-448B-8AE7-F3F1F3EC5C3A} EndGlobalSection GlobalSection(ExtensibilityGlobals) = postSolution SolutionGuid = {D2404771-EEC8-45F2-9D71-F3373F6C1495} diff --git a/Discord.Net.targets b/Discord.Net.targets index 3f623c619..958b2053f 100644 --- a/Discord.Net.targets +++ b/Discord.Net.targets @@ -1,7 +1,7 @@ 2.0.0 - beta + beta2 RogueException discord;discordapp https://github.com/RogueException/Discord.Net diff --git a/appveyor.yml b/appveyor.yml index 393485fee..54b9a1251 100644 --- a/appveyor.yml +++ b/appveyor.yml @@ -29,6 +29,7 @@ after_build: - ps: dotnet pack "src\Discord.Net.Commands\Discord.Net.Commands.csproj" -c "Release" -o "../../artifacts" --no-build /p:BuildNumber="$Env:BUILD" /p:IsTagBuild="$Env:APPVEYOR_REPO_TAG" - ps: dotnet pack "src\Discord.Net.Webhook\Discord.Net.Webhook.csproj" -c "Release" -o "../../artifacts" --no-build /p:BuildNumber="$Env:BUILD" /p:IsTagBuild="$Env:APPVEYOR_REPO_TAG" - ps: dotnet pack "src\Discord.Net.Providers.WS4Net\Discord.Net.Providers.WS4Net.csproj" -c "Release" -o "../../artifacts" --no-build /p:BuildNumber="$Env:BUILD" /p:IsTagBuild="$Env:APPVEYOR_REPO_TAG" +- ps: dotnet pack "src\Discord.Net.Analyzers\Discord.Net.Analyzers.csproj" -c "Release" -o "../../artifacts" --no-build /p:BuildNumber="$Env:BUILD" /p:IsTagBuild="$Env:APPVEYOR_REPO_TAG" - ps: >- if ($Env:APPVEYOR_REPO_TAG -eq "true") { nuget pack src\Discord.Net\Discord.Net.nuspec -OutputDirectory "artifacts" -properties suffix="" diff --git a/docs/README.md b/docs/README.md new file mode 100644 index 000000000..a672330d4 --- /dev/null +++ b/docs/README.md @@ -0,0 +1,16 @@ +# Instructions for Building Documentation + +The documentation for the Discord.NET library uses [DocFX][docfx-main]. [Instructions for installing this tool can be found here.][docfx-installing] + +1. Navigate to the root of the repository. +2. (Optional) If you intend to target a specific version, ensure that you +have the correct version checked out. +3. Build the library. Run `dotnet build` in the root of this repository. + Ensure that the build passes without errors. +4. Build the docs using `docfx .\docs\docfx.json`. Add the `--serve` parameter +to preview the site locally. Some elements of the page may appear incorrect +when not hosted by a server. + - Remarks: According to the docfx website, this tool does work on Linux under mono. + +[docfx-main]: https://dotnet.github.io/docfx/ +[docfx-installing]: https://dotnet.github.io/docfx/tutorial/docfx_getting_started.html diff --git a/docs/docfx.json b/docs/docfx.json index 3c0b0611e..50ae39092 100644 --- a/docs/docfx.json +++ b/docs/docfx.json @@ -67,8 +67,8 @@ "default" ], "globalMetadata": { - "_appFooter": "Discord.Net (c) 2015-2017" + "_appFooter": "Discord.Net (c) 2015-2018 2.0.0-beta" }, "noLangKeyword": false } -} \ No newline at end of file +} diff --git a/docs/guides/commands/samples/module.cs b/docs/guides/commands/samples/module.cs index 5014619da..1e3555501 100644 --- a/docs/guides/commands/samples/module.cs +++ b/docs/guides/commands/samples/module.cs @@ -3,7 +3,7 @@ public class Info : ModuleBase { // ~say hello -> hello [Command("say")] - [Summary("Echos a message.")] + [Summary("Echoes a message.")] public async Task SayAsync([Remainder] [Summary("The text to echo")] string echo) { // ReplyAsync is a method on ModuleBase @@ -38,4 +38,4 @@ public class Sample : ModuleBase var userInfo = user ?? Context.Client.CurrentUser; await ReplyAsync($"{userInfo.Username}#{userInfo.Discriminator}"); } -} \ No newline at end of file +} diff --git a/docs/guides/getting_started/samples/intro/structure.cs b/docs/guides/getting_started/samples/intro/structure.cs index bdfc12b67..a9a018c3a 100644 --- a/docs/guides/getting_started/samples/intro/structure.cs +++ b/docs/guides/getting_started/samples/intro/structure.cs @@ -19,10 +19,10 @@ class Program private readonly DiscordSocketClient _client; - // Keep the CommandService and IServiceCollection around for use with commands. + // Keep the CommandService and DI container around for use with commands. // These two types require you install the Discord.Net.Commands package. - private readonly IServiceCollection _map = new ServiceCollection(); - private readonly CommandService _commands = new CommandService(); + private readonly CommandService _commands; + private readonly IServiceProvider _services; private Program() { @@ -41,14 +41,45 @@ class Program // add the `using` at the top, and uncomment this line: //WebSocketProvider = WS4NetProvider.Instance }); + + _commands = new CommandService(new CommandServiceConfig + { + // Again, log level: + LogLevel = LogSeverity.Info, + + // There's a few more properties you can set, + // for example, case-insensitive commands. + CaseSensitiveCommands = false, + }); + // Subscribe the logging handler to both the client and the CommandService. - _client.Log += Logger; - _commands.Log += Logger; + _client.Log += Log; + _commands.Log += Log; + + // Setup your DI container. + _services = ConfigureServices(), + + } + + // If any services require the client, or the CommandService, or something else you keep on hand, + // pass them as parameters into this method as needed. + // If this method is getting pretty long, you can seperate it out into another file using partials. + private static IServiceProvider ConfigureServices() + { + var map = new ServiceCollection() + // Repeat this for all the service classes + // and other dependencies that your commands might need. + .AddSingleton(new SomeServiceClass()); + + // When all your required services are in the collection, build the container. + // Tip: There's an overload taking in a 'validateScopes' bool to make sure + // you haven't made any mistakes in your dependency graph. + return map.BuildServiceProvider(); } // Example of a logging handler. This can be re-used by addons // that ask for a Func. - private static Task Logger(LogMessage message) + private static Task Log(LogMessage message) { switch (message.Severity) { @@ -92,24 +123,15 @@ class Program await Task.Delay(Timeout.Infinite); } - private IServiceProvider _services; - private async Task InitCommands() { - // Repeat this for all the service classes - // and other dependencies that your commands might need. - _map.AddSingleton(new SomeServiceClass()); - - // When all your required services are in the collection, build the container. - // Tip: There's an overload taking in a 'validateScopes' bool to make sure - // you haven't made any mistakes in your dependency graph. - _services = _map.BuildServiceProvider(); - // Either search the program and add all Module classes that can be found. // Module classes MUST be marked 'public' or they will be ignored. - await _commands.AddModulesAsync(Assembly.GetEntryAssembly()); + // You also need to pass your 'IServiceProvider' instance now, + // so make sure that's done before you get here. + await _commands.AddModulesAsync(Assembly.GetEntryAssembly(), _services); // Or add Modules manually if you prefer to be a little more explicit: - await _commands.AddModuleAsync(); + await _commands.AddModuleAsync(_services); // Note that the first one is 'Modules' (plural) and the second is 'Module' (singular). // Subscribe a handler to see if a message invokes a command. @@ -123,8 +145,6 @@ class Program if (msg == null) return; // We don't want the bot to respond to itself or other bots. - // NOTE: Selfbots should invert this first check and remove the second - // as they should ONLY be allowed to respond to messages from the same account. if (msg.Author.Id == _client.CurrentUser.Id || msg.Author.IsBot) return; // Create a number to track where the prefix ends and the command begins @@ -140,10 +160,12 @@ class Program // Execute the command. (result does not indicate a return value, // rather an object stating if the command executed successfully). - var result = await _commands.ExecuteAsync(context, pos, _services); + var result = await _commands.ExecuteAsync(context, pos); // Uncomment the following lines if you want the bot - // to send a message if it failed (not advised for most situations). + // to send a message if it failed. + // This does not catch errors from commands with 'RunMode.Async', + // subscribe a handler for '_commands.CommandExecuted' to see those. //if (!result.IsSuccess && result.Error != CommandError.UnknownCommand) // await msg.Channel.SendMessageAsync(result.ErrorReason); } diff --git a/src/Discord.Net.Analyzers/Discord.Net.Analyzers.csproj b/src/Discord.Net.Analyzers/Discord.Net.Analyzers.csproj new file mode 100644 index 000000000..8ab398ff5 --- /dev/null +++ b/src/Discord.Net.Analyzers/Discord.Net.Analyzers.csproj @@ -0,0 +1,15 @@ + + + + Discord.Net.Analyzers + Discord.Analyzers + A Discord.Net extension adding support for design-time analysis of the API usage. + netstandard1.3 + + + + + + + + diff --git a/src/Discord.Net.Analyzers/GuildAccessAnalyzer.cs b/src/Discord.Net.Analyzers/GuildAccessAnalyzer.cs new file mode 100644 index 000000000..0760d019f --- /dev/null +++ b/src/Discord.Net.Analyzers/GuildAccessAnalyzer.cs @@ -0,0 +1,70 @@ +using System; +using System.Collections.Immutable; +using System.Linq; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; +using Microsoft.CodeAnalysis.CSharp.Syntax; +using Microsoft.CodeAnalysis.Diagnostics; +using Discord.Commands; + +namespace Discord.Analyzers +{ + [DiagnosticAnalyzer(LanguageNames.CSharp)] + public sealed class GuildAccessAnalyzer : DiagnosticAnalyzer + { + private const string DiagnosticId = "DNET0001"; + private const string Title = "Limit command to Guild contexts."; + private const string MessageFormat = "Command method '{0}' is accessing 'Context.Guild' but is not restricted to Guild contexts."; + private const string Description = "Accessing 'Context.Guild' in a command without limiting the command to run only in guilds."; + private const string Category = "API Usage"; + + private static DiagnosticDescriptor Rule = new DiagnosticDescriptor(DiagnosticId, Title, MessageFormat, Category, DiagnosticSeverity.Warning, isEnabledByDefault: true, description: Description); + + public override ImmutableArray SupportedDiagnostics => ImmutableArray.Create(Rule); + + public override void Initialize(AnalysisContext context) + { + context.RegisterSyntaxNodeAction(AnalyzeMemberAccess, SyntaxKind.SimpleMemberAccessExpression); + } + + private static void AnalyzeMemberAccess(SyntaxNodeAnalysisContext context) + { + // Bail out if the accessed member isn't named 'Guild' + var memberAccessSymbol = context.SemanticModel.GetSymbolInfo(context.Node).Symbol; + if (memberAccessSymbol.Name != "Guild") + return; + + // Bail out if it happens to be 'ContextType.Guild' in the '[RequireContext]' argument + if (context.Node.Parent is AttributeArgumentSyntax) + return; + + // Bail out if the containing class doesn't derive from 'ModuleBase' + var typeNode = context.Node.FirstAncestorOrSelf(); + var typeSymbol = context.SemanticModel.GetDeclaredSymbol(typeNode); + if (!typeSymbol.DerivesFromModuleBase()) + return; + + // Bail out if the containing method isn't marked with '[Command]' + var methodNode = context.Node.FirstAncestorOrSelf(); + var methodSymbol = context.SemanticModel.GetDeclaredSymbol(methodNode); + var methodAttributes = methodSymbol.GetAttributes(); + if (!methodAttributes.Any(a => a.AttributeClass.Name == nameof(CommandAttribute))) + return; + + // Is the '[RequireContext]' attribute not applied to either the + // method or the class, or its argument isn't 'ContextType.Guild'? + var ctxAttribute = methodAttributes.SingleOrDefault(_attributeDataPredicate) + ?? typeSymbol.GetAttributes().SingleOrDefault(_attributeDataPredicate); + + if (ctxAttribute == null || ctxAttribute.ConstructorArguments.Any(arg => !arg.Value.Equals((int)ContextType.Guild))) + { + // Report the diagnostic + var diagnostic = Diagnostic.Create(Rule, context.Node.GetLocation(), methodSymbol.Name); + context.ReportDiagnostic(diagnostic); + } + } + + private static readonly Func _attributeDataPredicate = + (a => a.AttributeClass.Name == nameof(RequireContextAttribute)); + } +} diff --git a/src/Discord.Net.Analyzers/SymbolExtensions.cs b/src/Discord.Net.Analyzers/SymbolExtensions.cs new file mode 100644 index 000000000..680de66b5 --- /dev/null +++ b/src/Discord.Net.Analyzers/SymbolExtensions.cs @@ -0,0 +1,21 @@ +using System; +using Microsoft.CodeAnalysis; +using Discord.Commands; + +namespace Discord.Analyzers +{ + internal static class SymbolExtensions + { + private static readonly string _moduleBaseName = typeof(ModuleBase<>).Name; + + public static bool DerivesFromModuleBase(this ITypeSymbol symbol) + { + for (var bType = symbol.BaseType; bType != null; bType = bType.BaseType) + { + if (bType.MetadataName == _moduleBaseName) + return true; + } + return false; + } + } +} diff --git a/src/Discord.Net.Analyzers/docs/DNET0001.md b/src/Discord.Net.Analyzers/docs/DNET0001.md new file mode 100644 index 000000000..0c1b8098f --- /dev/null +++ b/src/Discord.Net.Analyzers/docs/DNET0001.md @@ -0,0 +1,30 @@ +# DNET0001 + + + + + + + + + + + + + + +
TypeNameGuildAccessAnalyzer
CheckIdDNET0001
CategoryAPI Usage
+ +## Cause + +A method identified as a command is accessing `Context.Guild` without the requisite precondition. + +## Rule description + +The value of `Context.Guild` is `null` if a command is invoked in a DM channel. Attempting to access +guild properties in such a case will result in a `NullReferenceException` at runtime. +This exception is entirely avoidable by using the library's provided preconditions. + +## How to fix violations + +Add the precondition `[RequireContext(ContextType.Guild)]` to the command or module class. diff --git a/src/Discord.Net.Commands/Attributes/AliasAttribute.cs b/src/Discord.Net.Commands/Attributes/AliasAttribute.cs index 6e115bd60..6cd0abbb7 100644 --- a/src/Discord.Net.Commands/Attributes/AliasAttribute.cs +++ b/src/Discord.Net.Commands/Attributes/AliasAttribute.cs @@ -3,7 +3,7 @@ using System; namespace Discord.Commands { /// Provides aliases for a command. - [AttributeUsage(AttributeTargets.Class | AttributeTargets.Method)] + [AttributeUsage(AttributeTargets.Class | AttributeTargets.Method, AllowMultiple = false, Inherited = true)] public class AliasAttribute : Attribute { /// The aliases which have been defined for the command. diff --git a/src/Discord.Net.Commands/Attributes/CommandAttribute.cs b/src/Discord.Net.Commands/Attributes/CommandAttribute.cs index 5ae6092eb..5f8e9ceaf 100644 --- a/src/Discord.Net.Commands/Attributes/CommandAttribute.cs +++ b/src/Discord.Net.Commands/Attributes/CommandAttribute.cs @@ -1,8 +1,8 @@ -using System; +using System; namespace Discord.Commands { - [AttributeUsage(AttributeTargets.Method)] + [AttributeUsage(AttributeTargets.Method, AllowMultiple = false, Inherited = true)] public class CommandAttribute : Attribute { public string Text { get; } diff --git a/src/Discord.Net.Commands/Attributes/DontAutoLoadAttribute.cs b/src/Discord.Net.Commands/Attributes/DontAutoLoadAttribute.cs index d6a1c646e..cc23a6d15 100644 --- a/src/Discord.Net.Commands/Attributes/DontAutoLoadAttribute.cs +++ b/src/Discord.Net.Commands/Attributes/DontAutoLoadAttribute.cs @@ -1,8 +1,8 @@ -using System; +using System; namespace Discord.Commands { - [AttributeUsage(AttributeTargets.Class)] + [AttributeUsage(AttributeTargets.Class, AllowMultiple = false, Inherited = true)] public class DontAutoLoadAttribute : Attribute { } diff --git a/src/Discord.Net.Commands/Attributes/DontInjectAttribute.cs b/src/Discord.Net.Commands/Attributes/DontInjectAttribute.cs index bd966e129..c982d93a1 100644 --- a/src/Discord.Net.Commands/Attributes/DontInjectAttribute.cs +++ b/src/Discord.Net.Commands/Attributes/DontInjectAttribute.cs @@ -2,7 +2,7 @@ using System; namespace Discord.Commands { - [AttributeUsage(AttributeTargets.Property)] + [AttributeUsage(AttributeTargets.Property, AllowMultiple = false, Inherited = true)] public class DontInjectAttribute : Attribute { } diff --git a/src/Discord.Net.Commands/Attributes/GroupAttribute.cs b/src/Discord.Net.Commands/Attributes/GroupAttribute.cs index 105d256ec..b1760d149 100644 --- a/src/Discord.Net.Commands/Attributes/GroupAttribute.cs +++ b/src/Discord.Net.Commands/Attributes/GroupAttribute.cs @@ -1,8 +1,8 @@ -using System; +using System; namespace Discord.Commands { - [AttributeUsage(AttributeTargets.Class)] + [AttributeUsage(AttributeTargets.Class, AllowMultiple = false, Inherited = true)] public class GroupAttribute : Attribute { public string Prefix { get; } diff --git a/src/Discord.Net.Commands/Attributes/NameAttribute.cs b/src/Discord.Net.Commands/Attributes/NameAttribute.cs index 0a5156fee..4a4b2bfed 100644 --- a/src/Discord.Net.Commands/Attributes/NameAttribute.cs +++ b/src/Discord.Net.Commands/Attributes/NameAttribute.cs @@ -3,7 +3,7 @@ using System; namespace Discord.Commands { // Override public name of command/module - [AttributeUsage(AttributeTargets.Method | AttributeTargets.Class | AttributeTargets.Parameter)] + [AttributeUsage(AttributeTargets.Method | AttributeTargets.Class | AttributeTargets.Parameter, AllowMultiple = false, Inherited = true)] public class NameAttribute : Attribute { public string Text { get; } diff --git a/src/Discord.Net.Commands/Attributes/OverrideTypeReaderAttribute.cs b/src/Discord.Net.Commands/Attributes/OverrideTypeReaderAttribute.cs index 37f685c95..44ab6d214 100644 --- a/src/Discord.Net.Commands/Attributes/OverrideTypeReaderAttribute.cs +++ b/src/Discord.Net.Commands/Attributes/OverrideTypeReaderAttribute.cs @@ -4,7 +4,7 @@ using System.Reflection; namespace Discord.Commands { - [AttributeUsage(AttributeTargets.Parameter)] + [AttributeUsage(AttributeTargets.Parameter, AllowMultiple = false, Inherited = true)] public class OverrideTypeReaderAttribute : Attribute { private static readonly TypeInfo _typeReaderTypeInfo = typeof(TypeReader).GetTypeInfo(); @@ -19,4 +19,4 @@ namespace Discord.Commands TypeReader = overridenTypeReader; } } -} \ No newline at end of file +} diff --git a/src/Discord.Net.Commands/Attributes/ParameterPreconditionAttribute.cs b/src/Discord.Net.Commands/Attributes/ParameterPreconditionAttribute.cs index 209822583..3c5e8cf92 100644 --- a/src/Discord.Net.Commands/Attributes/ParameterPreconditionAttribute.cs +++ b/src/Discord.Net.Commands/Attributes/ParameterPreconditionAttribute.cs @@ -1,6 +1,5 @@ using System; using System.Threading.Tasks; -using Microsoft.Extensions.DependencyInjection; namespace Discord.Commands { @@ -9,4 +8,4 @@ namespace Discord.Commands { public abstract Task CheckPermissionsAsync(ICommandContext context, ParameterInfo parameter, object value, IServiceProvider services); } -} \ No newline at end of file +} diff --git a/src/Discord.Net.Commands/Attributes/Preconditions/RequireContextAttribute.cs b/src/Discord.Net.Commands/Attributes/Preconditions/RequireContextAttribute.cs index 5fa0fb1b9..90af035e4 100644 --- a/src/Discord.Net.Commands/Attributes/Preconditions/RequireContextAttribute.cs +++ b/src/Discord.Net.Commands/Attributes/Preconditions/RequireContextAttribute.cs @@ -1,4 +1,4 @@ -using System; +using System; using System.Threading.Tasks; using Microsoft.Extensions.DependencyInjection; @@ -15,7 +15,7 @@ namespace Discord.Commands /// /// Require that the command be invoked in a specified context. /// - [AttributeUsage(AttributeTargets.Class | AttributeTargets.Method, AllowMultiple = false, Inherited = true)] + [AttributeUsage(AttributeTargets.Class | AttributeTargets.Method, AllowMultiple = true, Inherited = true)] public class RequireContextAttribute : PreconditionAttribute { public ContextType Contexts { get; } diff --git a/src/Discord.Net.Commands/Attributes/Preconditions/RequireNsfwAttribute.cs b/src/Discord.Net.Commands/Attributes/Preconditions/RequireNsfwAttribute.cs index c8e3bfa82..273c764bd 100644 --- a/src/Discord.Net.Commands/Attributes/Preconditions/RequireNsfwAttribute.cs +++ b/src/Discord.Net.Commands/Attributes/Preconditions/RequireNsfwAttribute.cs @@ -1,4 +1,4 @@ -using System; +using System; using System.Threading.Tasks; namespace Discord.Commands @@ -6,7 +6,7 @@ namespace Discord.Commands /// /// Require that the command is invoked in a channel marked NSFW /// - [AttributeUsage(AttributeTargets.Class | AttributeTargets.Method, AllowMultiple = false, Inherited = true)] + [AttributeUsage(AttributeTargets.Class | AttributeTargets.Method, AllowMultiple = true, Inherited = true)] public class RequireNsfwAttribute : PreconditionAttribute { public override Task CheckPermissionsAsync(ICommandContext context, CommandInfo command, IServiceProvider services) diff --git a/src/Discord.Net.Commands/Attributes/Preconditions/RequireOwnerAttribute.cs b/src/Discord.Net.Commands/Attributes/Preconditions/RequireOwnerAttribute.cs index e370aeec4..93e3cbe18 100644 --- a/src/Discord.Net.Commands/Attributes/Preconditions/RequireOwnerAttribute.cs +++ b/src/Discord.Net.Commands/Attributes/Preconditions/RequireOwnerAttribute.cs @@ -1,7 +1,5 @@ -#pragma warning disable CS0618 using System; using System.Threading.Tasks; -using Microsoft.Extensions.DependencyInjection; namespace Discord.Commands { @@ -9,7 +7,7 @@ namespace Discord.Commands /// Require that the command is invoked by the owner of the bot. /// /// This precondition will only work if the bot is a bot account. - [AttributeUsage(AttributeTargets.Class | AttributeTargets.Method, AllowMultiple = false, Inherited = true)] + [AttributeUsage(AttributeTargets.Class | AttributeTargets.Method, AllowMultiple = true, Inherited = true)] public class RequireOwnerAttribute : PreconditionAttribute { public override async Task CheckPermissionsAsync(ICommandContext context, CommandInfo command, IServiceProvider services) @@ -21,10 +19,6 @@ namespace Discord.Commands if (context.User.Id != application.Owner.Id) return PreconditionResult.FromError("Command can only be run by the owner of the bot"); return PreconditionResult.FromSuccess(); - case TokenType.User: - if (context.User.Id != context.Client.CurrentUser.Id) - return PreconditionResult.FromError("Command can only be run by the owner of the bot"); - return PreconditionResult.FromSuccess(); default: return PreconditionResult.FromError($"{nameof(RequireOwnerAttribute)} is not supported by this {nameof(TokenType)}."); } diff --git a/src/Discord.Net.Commands/Attributes/PriorityAttribute.cs b/src/Discord.Net.Commands/Attributes/PriorityAttribute.cs index 5120bb7d1..353e96e41 100644 --- a/src/Discord.Net.Commands/Attributes/PriorityAttribute.cs +++ b/src/Discord.Net.Commands/Attributes/PriorityAttribute.cs @@ -3,7 +3,7 @@ using System; namespace Discord.Commands { /// Sets priority of commands - [AttributeUsage(AttributeTargets.Method)] + [AttributeUsage(AttributeTargets.Method, AllowMultiple = false, Inherited = true)] public class PriorityAttribute : Attribute { /// The priority which has been set for the command diff --git a/src/Discord.Net.Commands/Attributes/RemainderAttribute.cs b/src/Discord.Net.Commands/Attributes/RemainderAttribute.cs index 4aa16bebb..56938f167 100644 --- a/src/Discord.Net.Commands/Attributes/RemainderAttribute.cs +++ b/src/Discord.Net.Commands/Attributes/RemainderAttribute.cs @@ -1,8 +1,8 @@ -using System; +using System; namespace Discord.Commands { - [AttributeUsage(AttributeTargets.Parameter)] + [AttributeUsage(AttributeTargets.Parameter, AllowMultiple = false, Inherited = true)] public class RemainderAttribute : Attribute { } diff --git a/src/Discord.Net.Commands/Attributes/RemarksAttribute.cs b/src/Discord.Net.Commands/Attributes/RemarksAttribute.cs index 44db18a79..c11f790a7 100644 --- a/src/Discord.Net.Commands/Attributes/RemarksAttribute.cs +++ b/src/Discord.Net.Commands/Attributes/RemarksAttribute.cs @@ -1,9 +1,9 @@ -using System; +using System; namespace Discord.Commands { // Extension of the Cosmetic Summary, for Groups, Commands, and Parameters - [AttributeUsage(AttributeTargets.Method | AttributeTargets.Class)] + [AttributeUsage(AttributeTargets.Method | AttributeTargets.Class, AllowMultiple = false, Inherited = true)] public class RemarksAttribute : Attribute { public string Text { get; } diff --git a/src/Discord.Net.Commands/Attributes/SummaryAttribute.cs b/src/Discord.Net.Commands/Attributes/SummaryAttribute.cs index 46d52f3d9..641163408 100644 --- a/src/Discord.Net.Commands/Attributes/SummaryAttribute.cs +++ b/src/Discord.Net.Commands/Attributes/SummaryAttribute.cs @@ -1,9 +1,9 @@ -using System; +using System; namespace Discord.Commands { // Cosmetic Summary, for Groups and Commands - [AttributeUsage(AttributeTargets.Method | AttributeTargets.Class | AttributeTargets.Parameter)] + [AttributeUsage(AttributeTargets.Method | AttributeTargets.Class | AttributeTargets.Parameter, AllowMultiple = false, Inherited = true)] public class SummaryAttribute : Attribute { public string Text { get; } diff --git a/src/Discord.Net.Commands/Builders/ModuleBuilder.cs b/src/Discord.Net.Commands/Builders/ModuleBuilder.cs index 0a33c9e26..1809c2c63 100644 --- a/src/Discord.Net.Commands/Builders/ModuleBuilder.cs +++ b/src/Discord.Net.Commands/Builders/ModuleBuilder.cs @@ -1,5 +1,6 @@ -using System; +using System; using System.Collections.Generic; +using System.Reflection; using System.Threading.Tasks; using Microsoft.Extensions.DependencyInjection; @@ -18,6 +19,7 @@ namespace Discord.Commands.Builders public string Name { get; set; } public string Summary { get; set; } public string Remarks { get; set; } + public string Group { get; set; } public IReadOnlyList Commands => _commands; public IReadOnlyList Modules => _submodules; @@ -25,6 +27,8 @@ namespace Discord.Commands.Builders public IReadOnlyList Attributes => _attributes; public IReadOnlyList Aliases => _aliases; + internal TypeInfo TypeInfo { get; set; } + //Automatic internal ModuleBuilder(CommandService service, ModuleBuilder parent) { @@ -111,17 +115,23 @@ namespace Discord.Commands.Builders return this; } - private ModuleInfo BuildImpl(CommandService service, ModuleInfo parent = null) + private ModuleInfo BuildImpl(CommandService service, IServiceProvider services, ModuleInfo parent = null) { //Default name to first alias if (Name == null) Name = _aliases[0]; - return new ModuleInfo(this, service, parent); + if (TypeInfo != null) + { + var moduleInstance = ReflectionUtils.CreateObject(TypeInfo, service, services); + moduleInstance.OnModuleBuilding(service, this); + } + + return new ModuleInfo(this, service, services, parent); } - public ModuleInfo Build(CommandService service) => BuildImpl(service); + public ModuleInfo Build(CommandService service, IServiceProvider services) => BuildImpl(service, services); - internal ModuleInfo Build(CommandService service, ModuleInfo parent) => BuildImpl(service, parent); + internal ModuleInfo Build(CommandService service, IServiceProvider services, ModuleInfo parent) => BuildImpl(service, services, parent); } } diff --git a/src/Discord.Net.Commands/Builders/ModuleClassBuilder.cs b/src/Discord.Net.Commands/Builders/ModuleClassBuilder.cs index 5a3a1f25a..1dd66cc77 100644 --- a/src/Discord.Net.Commands/Builders/ModuleClassBuilder.cs +++ b/src/Discord.Net.Commands/Builders/ModuleClassBuilder.cs @@ -42,14 +42,13 @@ namespace Discord.Commands } - public static Task> BuildAsync(CommandService service, params TypeInfo[] validTypes) => BuildAsync(validTypes, service); - public static async Task> BuildAsync(IEnumerable validTypes, CommandService service) + public static Task> BuildAsync(CommandService service, IServiceProvider services, params TypeInfo[] validTypes) => BuildAsync(validTypes, service, services); + public static async Task> BuildAsync(IEnumerable validTypes, CommandService service, IServiceProvider services) { /*if (!validTypes.Any()) throw new InvalidOperationException("Could not find any valid modules from the given selection");*/ - var topLevelGroups = validTypes.Where(x => x.DeclaringType == null); - var subGroups = validTypes.Intersect(topLevelGroups); + var topLevelGroups = validTypes.Where(x => x.DeclaringType == null || !IsValidModuleDefinition(x.DeclaringType.GetTypeInfo())); var builtTypes = new List(); @@ -63,11 +62,11 @@ namespace Discord.Commands var module = new ModuleBuilder(service, null); - BuildModule(module, typeInfo, service); - BuildSubTypes(module, typeInfo.DeclaredNestedTypes, builtTypes, service); + BuildModule(module, typeInfo, service, services); + BuildSubTypes(module, typeInfo.DeclaredNestedTypes, builtTypes, service, services); builtTypes.Add(typeInfo); - result[typeInfo.AsType()] = module.Build(service); + result[typeInfo.AsType()] = module.Build(service, services); } await service._cmdLogger.DebugAsync($"Successfully built {builtTypes.Count} modules.").ConfigureAwait(false); @@ -75,7 +74,7 @@ namespace Discord.Commands return result; } - private static void BuildSubTypes(ModuleBuilder builder, IEnumerable subTypes, List builtTypes, CommandService service) + private static void BuildSubTypes(ModuleBuilder builder, IEnumerable subTypes, List builtTypes, CommandService service, IServiceProvider services) { foreach (var typeInfo in subTypes) { @@ -87,17 +86,18 @@ namespace Discord.Commands builder.AddModule((module) => { - BuildModule(module, typeInfo, service); - BuildSubTypes(module, typeInfo.DeclaredNestedTypes, builtTypes, service); + BuildModule(module, typeInfo, service, services); + BuildSubTypes(module, typeInfo.DeclaredNestedTypes, builtTypes, service, services); }); builtTypes.Add(typeInfo); } } - private static void BuildModule(ModuleBuilder builder, TypeInfo typeInfo, CommandService service) + private static void BuildModule(ModuleBuilder builder, TypeInfo typeInfo, CommandService service, IServiceProvider services) { var attributes = typeInfo.GetCustomAttributes(); + builder.TypeInfo = typeInfo; foreach (var attribute in attributes) { @@ -117,6 +117,7 @@ namespace Discord.Commands break; case GroupAttribute group: builder.Name = builder.Name ?? group.Prefix; + builder.Group = group.Prefix; builder.AddAliases(group.Prefix); break; case PreconditionAttribute precondition: @@ -140,12 +141,12 @@ namespace Discord.Commands { builder.AddCommand((command) => { - BuildCommand(command, typeInfo, method, service); + BuildCommand(command, typeInfo, method, service, services); }); } } - private static void BuildCommand(CommandBuilder builder, TypeInfo typeInfo, MethodInfo method, CommandService service) + private static void BuildCommand(CommandBuilder builder, TypeInfo typeInfo, MethodInfo method, CommandService service, IServiceProvider serviceprovider) { var attributes = method.GetCustomAttributes(); @@ -191,7 +192,7 @@ namespace Discord.Commands { builder.AddParameter((parameter) => { - BuildParameter(parameter, paramInfo, pos++, count, service); + BuildParameter(parameter, paramInfo, pos++, count, service, serviceprovider); }); } @@ -227,7 +228,7 @@ namespace Discord.Commands builder.Callback = ExecuteCallback; } - private static void BuildParameter(ParameterBuilder builder, System.Reflection.ParameterInfo paramInfo, int position, int count, CommandService service) + private static void BuildParameter(ParameterBuilder builder, System.Reflection.ParameterInfo paramInfo, int position, int count, CommandService service, IServiceProvider services) { var attributes = paramInfo.GetCustomAttributes(); var paramType = paramInfo.ParameterType; @@ -245,7 +246,7 @@ namespace Discord.Commands builder.Summary = summary.Text; break; case OverrideTypeReaderAttribute typeReader: - builder.TypeReader = GetTypeReader(service, paramType, typeReader.TypeReader); + builder.TypeReader = GetTypeReader(service, paramType, typeReader.TypeReader, services); break; case ParamArrayAttribute _: builder.IsMultiple = true; @@ -273,19 +274,12 @@ namespace Discord.Commands if (builder.TypeReader == null) { - var readers = service.GetTypeReaders(paramType); - TypeReader reader = null; - - if (readers != null) - reader = readers.FirstOrDefault().Value; - else - reader = service.GetDefaultTypeReader(paramType); - - builder.TypeReader = reader; + builder.TypeReader = service.GetDefaultTypeReader(paramType) + ?? service.GetTypeReaders(paramType)?.FirstOrDefault().Value; } } - private static TypeReader GetTypeReader(CommandService service, Type paramType, Type typeReaderType) + private static TypeReader GetTypeReader(CommandService service, Type paramType, Type typeReaderType, IServiceProvider services) { var readers = service.GetTypeReaders(paramType); TypeReader reader = null; @@ -296,8 +290,8 @@ namespace Discord.Commands } //We dont have a cached type reader, create one - reader = ReflectionUtils.CreateObject(typeReaderType.GetTypeInfo(), service, EmptyServiceProvider.Instance); - service.AddTypeReader(paramType, reader); + reader = ReflectionUtils.CreateObject(typeReaderType.GetTypeInfo(), service, services); + service.AddTypeReader(paramType, reader, false); return reader; } @@ -305,7 +299,8 @@ namespace Discord.Commands private static bool IsValidModuleDefinition(TypeInfo typeInfo) { return _moduleTypeInfo.IsAssignableFrom(typeInfo) && - !typeInfo.IsAbstract; + !typeInfo.IsAbstract && + !typeInfo.ContainsGenericParameters; } private static bool IsValidCommandDefinition(MethodInfo methodInfo) diff --git a/src/Discord.Net.Commands/CommandService.cs b/src/Discord.Net.Commands/CommandService.cs index 8e7dab898..c880dd454 100644 --- a/src/Discord.Net.Commands/CommandService.cs +++ b/src/Discord.Net.Commands/CommandService.cs @@ -1,5 +1,3 @@ -using Discord.Commands.Builders; -using Discord.Logging; using System; using System.Collections.Concurrent; using System.Collections.Generic; @@ -8,6 +6,9 @@ using System.Linq; using System.Reflection; using System.Threading; using System.Threading.Tasks; +using Microsoft.Extensions.DependencyInjection; +using Discord.Commands.Builders; +using Discord.Logging; namespace Discord.Commands { @@ -85,7 +86,8 @@ namespace Discord.Commands var builder = new ModuleBuilder(this, null, primaryAlias); buildFunc(builder); - var module = builder.Build(this); + var module = builder.Build(this, null); + return LoadModuleInternal(module); } finally @@ -93,9 +95,18 @@ namespace Discord.Commands _moduleLock.Release(); } } - public Task AddModuleAsync() => AddModuleAsync(typeof(T)); - public async Task AddModuleAsync(Type type) + + /// + /// Add a command module from a type + /// + /// The type of module + /// An IServiceProvider for your dependency injection solution, if using one - otherwise, pass null + /// A built module + public Task AddModuleAsync(IServiceProvider services) => AddModuleAsync(typeof(T), services); + public async Task AddModuleAsync(Type type, IServiceProvider services) { + services = services ?? EmptyServiceProvider.Instance; + await _moduleLock.WaitAsync().ConfigureAwait(false); try { @@ -104,7 +115,7 @@ namespace Discord.Commands if (_typedModuleDefs.ContainsKey(type)) throw new ArgumentException($"This module has already been added."); - var module = (await ModuleClassBuilder.BuildAsync(this, typeInfo).ConfigureAwait(false)).FirstOrDefault(); + var module = (await ModuleClassBuilder.BuildAsync(this, services, typeInfo).ConfigureAwait(false)).FirstOrDefault(); if (module.Value == default(ModuleInfo)) throw new InvalidOperationException($"Could not build the module {type.FullName}, did you pass an invalid type?"); @@ -118,13 +129,21 @@ namespace Discord.Commands _moduleLock.Release(); } } - public async Task> AddModulesAsync(Assembly assembly) + /// + /// Add command modules from an assembly + /// + /// The assembly containing command modules + /// An IServiceProvider for your dependency injection solution, if using one - otherwise, pass null + /// A collection of built modules + public async Task> AddModulesAsync(Assembly assembly, IServiceProvider services) { + services = services ?? EmptyServiceProvider.Instance; + await _moduleLock.WaitAsync().ConfigureAwait(false); try { var types = await ModuleClassBuilder.SearchAsync(assembly, this).ConfigureAwait(false); - var moduleDefs = await ModuleClassBuilder.BuildAsync(types, this).ConfigureAwait(false); + var moduleDefs = await ModuleClassBuilder.BuildAsync(types, this, services).ConfigureAwait(false); foreach (var info in moduleDefs) { @@ -196,10 +215,11 @@ namespace Discord.Commands return true; } - //Type Readers + //Type Readers /// /// Adds a custom to this for the supplied object type. /// If is a , a will also be added. + /// If a default exists for , a warning will be logged and the default will be replaced. /// /// The object type to be read by the . /// An instance of the to be added. @@ -207,24 +227,61 @@ namespace Discord.Commands => AddTypeReader(typeof(T), reader); /// /// Adds a custom to this for the supplied object type. - /// If is a , a for the value type will also be added. + /// If is a , a for the value type will also be added. + /// If a default exists for , a warning will be logged and the default will be replaced. /// /// A instance for the type to be read. /// An instance of the to be added. public void AddTypeReader(Type type, TypeReader reader) { - var readers = _typeReaders.GetOrAdd(type, x => new ConcurrentDictionary()); - readers[reader.GetType()] = reader; + if (_defaultTypeReaders.ContainsKey(type)) + _ = _cmdLogger.WarningAsync($"The default TypeReader for {type.FullName} was replaced by {reader.GetType().FullName}." + + $"To suppress this message, use AddTypeReader(reader, true)."); + AddTypeReader(type, reader, true); + } + /// + /// Adds a custom to this for the supplied object type. + /// If is a , a will also be added. + /// + /// The object type to be read by the . + /// An instance of the to be added. + /// If should replace the default for if one exists. + public void AddTypeReader(TypeReader reader, bool replaceDefault) + => AddTypeReader(typeof(T), reader, replaceDefault); + /// + /// Adds a custom to this for the supplied object type. + /// If is a , a for the value type will also be added. + /// + /// A instance for the type to be read. + /// An instance of the to be added. + /// If should replace the default for if one exists. + public void AddTypeReader(Type type, TypeReader reader, bool replaceDefault) + { + if (replaceDefault && _defaultTypeReaders.ContainsKey(type)) + { + _defaultTypeReaders.AddOrUpdate(type, reader, (k, v) => reader); + if (type.GetTypeInfo().IsValueType) + { + var nullableType = typeof(Nullable<>).MakeGenericType(type); + var nullableReader = NullableTypeReader.Create(type, reader); + _defaultTypeReaders.AddOrUpdate(nullableType, nullableReader, (k, v) => nullableReader); + } + } + else + { + var readers = _typeReaders.GetOrAdd(type, x => new ConcurrentDictionary()); + readers[reader.GetType()] = reader; - if (type.GetTypeInfo().IsValueType) - AddNullableTypeReader(type, reader); + if (type.GetTypeInfo().IsValueType) + AddNullableTypeReader(type, reader); + } } internal void AddNullableTypeReader(Type valueType, TypeReader valueTypeReader) { var readers = _typeReaders.GetOrAdd(typeof(Nullable<>).MakeGenericType(valueType), x => new ConcurrentDictionary()); var nullableReader = NullableTypeReader.Create(valueType, valueTypeReader); readers[nullableReader.GetType()] = nullableReader; - } + } internal IDictionary GetTypeReaders(Type type) { if (_typeReaders.TryGetValue(type, out var definedTypeReaders)) @@ -272,9 +329,9 @@ namespace Discord.Commands return SearchResult.FromError(CommandError.UnknownCommand, "Unknown command."); } - public Task ExecuteAsync(ICommandContext context, int argPos, IServiceProvider services = null, MultiMatchHandling multiMatchHandling = MultiMatchHandling.Exception) + public Task ExecuteAsync(ICommandContext context, int argPos, IServiceProvider services, MultiMatchHandling multiMatchHandling = MultiMatchHandling.Exception) => ExecuteAsync(context, context.Message.Content.Substring(argPos), services, multiMatchHandling); - public async Task ExecuteAsync(ICommandContext context, string input, IServiceProvider services = null, MultiMatchHandling multiMatchHandling = MultiMatchHandling.Exception) + public async Task ExecuteAsync(ICommandContext context, string input, IServiceProvider services, MultiMatchHandling multiMatchHandling = MultiMatchHandling.Exception) { services = services ?? EmptyServiceProvider.Instance; @@ -330,7 +387,7 @@ namespace Discord.Commands float CalculateScore(CommandMatch match, ParseResult parseResult) { float argValuesScore = 0, paramValuesScore = 0; - + if (match.Command.Parameters.Count > 0) { var argValuesSum = parseResult.ArgValues?.Sum(x => x.Values.OrderByDescending(y => y.Score).FirstOrDefault().Score) ?? 0; diff --git a/src/Discord.Net.Commands/CommandServiceConfig.cs b/src/Discord.Net.Commands/CommandServiceConfig.cs index 7fdbe368b..77c5b2262 100644 --- a/src/Discord.Net.Commands/CommandServiceConfig.cs +++ b/src/Discord.Net.Commands/CommandServiceConfig.cs @@ -1,4 +1,6 @@ -namespace Discord.Commands +using System; + +namespace Discord.Commands { public class CommandServiceConfig { @@ -18,5 +20,11 @@ /// Determines whether extra parameters should be ignored. public bool IgnoreExtraArgs { get; set; } = false; + + ///// Gets or sets the to use. + //public IServiceProvider ServiceProvider { get; set; } = null; + + ///// Gets or sets a factory function for the to use. + //public Func ServiceProviderFactory { get; set; } = null; } } diff --git a/src/Discord.Net.Commands/IModuleBase.cs b/src/Discord.Net.Commands/IModuleBase.cs index 479724ae3..3b641ec5f 100644 --- a/src/Discord.Net.Commands/IModuleBase.cs +++ b/src/Discord.Net.Commands/IModuleBase.cs @@ -1,4 +1,6 @@ -namespace Discord.Commands +using Discord.Commands.Builders; + +namespace Discord.Commands { internal interface IModuleBase { @@ -7,5 +9,7 @@ void BeforeExecute(CommandInfo command); void AfterExecute(CommandInfo command); + + void OnModuleBuilding(CommandService commandService, ModuleBuilder builder); } } diff --git a/src/Discord.Net.Commands/Info/CommandInfo.cs b/src/Discord.Net.Commands/Info/CommandInfo.cs index f0d406e8d..6e74c8abc 100644 --- a/src/Discord.Net.Commands/Info/CommandInfo.cs +++ b/src/Discord.Net.Commands/Info/CommandInfo.cs @@ -165,11 +165,11 @@ namespace Discord.Commands switch (RunMode) { case RunMode.Sync: //Always sync - return await ExecuteAsyncInternalAsync(context, args, services).ConfigureAwait(false); + return await ExecuteInternalAsync(context, args, services).ConfigureAwait(false); case RunMode.Async: //Always async var t2 = Task.Run(async () => { - await ExecuteAsyncInternalAsync(context, args, services).ConfigureAwait(false); + await ExecuteInternalAsync(context, args, services).ConfigureAwait(false); }); break; } @@ -181,7 +181,7 @@ namespace Discord.Commands } } - private async Task ExecuteAsyncInternalAsync(ICommandContext context, object[] args, IServiceProvider services) + private async Task ExecuteInternalAsync(ICommandContext context, object[] args, IServiceProvider services) { await Module.Service._cmdLogger.DebugAsync($"Executing {GetLogText(context)}").ConfigureAwait(false); try diff --git a/src/Discord.Net.Commands/Info/ModuleInfo.cs b/src/Discord.Net.Commands/Info/ModuleInfo.cs index 97b90bf4e..5a7f9208e 100644 --- a/src/Discord.Net.Commands/Info/ModuleInfo.cs +++ b/src/Discord.Net.Commands/Info/ModuleInfo.cs @@ -2,7 +2,7 @@ using System; using System.Linq; using System.Collections.Generic; using System.Collections.Immutable; - +using System.Reflection; using Discord.Commands.Builders; namespace Discord.Commands @@ -13,6 +13,7 @@ namespace Discord.Commands public string Name { get; } public string Summary { get; } public string Remarks { get; } + public string Group { get; } public IReadOnlyList Aliases { get; } public IReadOnlyList Commands { get; } @@ -22,21 +23,26 @@ namespace Discord.Commands public ModuleInfo Parent { get; } public bool IsSubmodule => Parent != null; - internal ModuleInfo(ModuleBuilder builder, CommandService service, ModuleInfo parent = null) + //public TypeInfo TypeInfo { get; } + + internal ModuleInfo(ModuleBuilder builder, CommandService service, IServiceProvider services, ModuleInfo parent = null) { Service = service; Name = builder.Name; Summary = builder.Summary; Remarks = builder.Remarks; + Group = builder.Group; Parent = parent; + //TypeInfo = builder.TypeInfo; + Aliases = BuildAliases(builder, service).ToImmutableArray(); Commands = builder.Commands.Select(x => x.Build(this, service)).ToImmutableArray(); Preconditions = BuildPreconditions(builder).ToImmutableArray(); Attributes = BuildAttributes(builder).ToImmutableArray(); - Submodules = BuildSubmodules(builder, service).ToImmutableArray(); + Submodules = BuildSubmodules(builder, service, services).ToImmutableArray(); } private static IEnumerable BuildAliases(ModuleBuilder builder, CommandService service) @@ -66,12 +72,12 @@ namespace Discord.Commands return result; } - private List BuildSubmodules(ModuleBuilder parent, CommandService service) + private List BuildSubmodules(ModuleBuilder parent, CommandService service, IServiceProvider services) { var result = new List(); foreach (var submodule in parent.Modules) - result.Add(submodule.Build(service, this)); + result.Add(submodule.Build(service, services, this)); return result; } diff --git a/src/Discord.Net.Commands/ModuleBase.cs b/src/Discord.Net.Commands/ModuleBase.cs index f51656e40..3e6fbbd9b 100644 --- a/src/Discord.Net.Commands/ModuleBase.cs +++ b/src/Discord.Net.Commands/ModuleBase.cs @@ -1,5 +1,6 @@ -using System; +using System; using System.Threading.Tasks; +using Discord.Commands.Builders; namespace Discord.Commands { @@ -10,7 +11,13 @@ namespace Discord.Commands { public T Context { get; private set; } - protected virtual async Task ReplyAsync(string message, bool isTTS = false, Embed embed = null, RequestOptions options = null) + /// + /// Sends a message to the source channel + /// + /// Contents of the message; optional only if is specified + /// Specifies if Discord should read this message aloud using TTS + /// An embed to be displayed alongside the message + protected virtual async Task ReplyAsync(string message = null, bool isTTS = false, Embed embed = null, RequestOptions options = null) { return await Context.Channel.SendMessageAsync(message, isTTS, embed, options).ConfigureAwait(false); } @@ -23,15 +30,18 @@ namespace Discord.Commands { } + protected virtual void OnModuleBuilding(CommandService commandService, ModuleBuilder builder) + { + } + //IModuleBase void IModuleBase.SetContext(ICommandContext context) { var newValue = context as T; Context = newValue ?? throw new InvalidOperationException($"Invalid context type. Expected {typeof(T).Name}, got {context.GetType().Name}"); } - void IModuleBase.BeforeExecute(CommandInfo command) => BeforeExecute(command); - void IModuleBase.AfterExecute(CommandInfo command) => AfterExecute(command); + void IModuleBase.OnModuleBuilding(CommandService commandService, ModuleBuilder builder) => OnModuleBuilding(commandService, builder); } } diff --git a/src/Discord.Net.Commands/Utilities/ReflectionUtils.cs b/src/Discord.Net.Commands/Utilities/ReflectionUtils.cs index ab88f66ae..30dd7c36b 100644 --- a/src/Discord.Net.Commands/Utilities/ReflectionUtils.cs +++ b/src/Discord.Net.Commands/Utilities/ReflectionUtils.cs @@ -1,4 +1,4 @@ -using System; +using System; using System.Collections.Generic; using System.Linq; using System.Reflection; diff --git a/src/Discord.Net.Core/CDN.cs b/src/Discord.Net.Core/CDN.cs index 070b965ee..52b9a2878 100644 --- a/src/Discord.Net.Core/CDN.cs +++ b/src/Discord.Net.Core/CDN.cs @@ -1,4 +1,4 @@ -using System; +using System; namespace Discord { @@ -13,6 +13,10 @@ namespace Discord string extension = FormatToExtension(format, avatarId); return $"{DiscordConfig.CDNUrl}avatars/{userId}/{avatarId}.{extension}?size={size}"; } + public static string GetDefaultUserAvatarUrl(ushort discriminator) + { + return $"{DiscordConfig.CDNUrl}embed/avatars/{discriminator % 5}.png"; + } public static string GetGuildIconUrl(ulong guildId, string iconId) => iconId != null ? $"{DiscordConfig.CDNUrl}icons/{guildId}/{iconId}.jpg" : null; public static string GetGuildSplashUrl(ulong guildId, string splashId) @@ -28,17 +32,25 @@ namespace Discord return $"{DiscordConfig.CDNUrl}app-assets/{appId}/{assetId}.{extension}?size={size}"; } + public static string GetSpotifyAlbumArtUrl(string albumArtId) + => $"https://i.scdn.co/image/{albumArtId}"; + private static string FormatToExtension(ImageFormat format, string imageId) { if (format == ImageFormat.Auto) format = imageId.StartsWith("a_") ? ImageFormat.Gif : ImageFormat.Png; switch (format) { - case ImageFormat.Gif: return "gif"; - case ImageFormat.Jpeg: return "jpeg"; - case ImageFormat.Png: return "png"; - case ImageFormat.WebP: return "webp"; - default: throw new ArgumentException(nameof(format)); + case ImageFormat.Gif: + return "gif"; + case ImageFormat.Jpeg: + return "jpeg"; + case ImageFormat.Png: + return "png"; + case ImageFormat.WebP: + return "webp"; + default: + throw new ArgumentException(nameof(format)); } } } diff --git a/src/Discord.Net.Core/Entities/Activities/ActivityType.cs b/src/Discord.Net.Core/Entities/Activities/ActivityType.cs new file mode 100644 index 000000000..c7db7b247 --- /dev/null +++ b/src/Discord.Net.Core/Entities/Activities/ActivityType.cs @@ -0,0 +1,10 @@ +namespace Discord +{ + public enum ActivityType + { + Playing = 0, + Streaming = 1, + Listening = 2, + Watching = 3 + } +} diff --git a/src/Discord.Net.Core/Entities/Activities/Game.cs b/src/Discord.Net.Core/Entities/Activities/Game.cs index f2b7e8eb6..179ad4eaa 100644 --- a/src/Discord.Net.Core/Entities/Activities/Game.cs +++ b/src/Discord.Net.Core/Entities/Activities/Game.cs @@ -1,4 +1,4 @@ -using System.Diagnostics; +using System.Diagnostics; namespace Discord { @@ -6,11 +6,13 @@ namespace Discord public class Game : IActivity { public string Name { get; internal set; } + public ActivityType Type { get; internal set; } internal Game() { } - public Game(string name) + public Game(string name, ActivityType type = ActivityType.Playing) { Name = name; + Type = type; } public override string ToString() => Name; diff --git a/src/Discord.Net.Core/Entities/Activities/GameAsset.cs b/src/Discord.Net.Core/Entities/Activities/GameAsset.cs index 385f37214..02c29ba41 100644 --- a/src/Discord.Net.Core/Entities/Activities/GameAsset.cs +++ b/src/Discord.Net.Core/Entities/Activities/GameAsset.cs @@ -1,15 +1,15 @@ -namespace Discord +namespace Discord { public class GameAsset { internal GameAsset() { } - internal ulong ApplicationId { get; set; } + internal ulong? ApplicationId { get; set; } public string Text { get; internal set; } public string ImageId { get; internal set; } public string GetImageUrl(ImageFormat format = ImageFormat.Auto, ushort size = 128) - => CDN.GetRichAssetUrl(ApplicationId, ImageId, size, format); + => ApplicationId.HasValue ? CDN.GetRichAssetUrl(ApplicationId.Value, ImageId, size, format) : null; } -} \ No newline at end of file +} diff --git a/src/Discord.Net.Core/Entities/Activities/GameParty.cs b/src/Discord.Net.Core/Entities/Activities/GameParty.cs index dbfe5b6ce..54e6deef4 100644 --- a/src/Discord.Net.Core/Entities/Activities/GameParty.cs +++ b/src/Discord.Net.Core/Entities/Activities/GameParty.cs @@ -1,11 +1,11 @@ -namespace Discord +namespace Discord { public class GameParty { internal GameParty() { } public string Id { get; internal set; } - public int Members { get; internal set; } - public int Capacity { get; internal set; } + public long Members { get; internal set; } + public long Capacity { get; internal set; } } -} \ No newline at end of file +} diff --git a/src/Discord.Net.Core/Entities/Activities/IActivity.cs b/src/Discord.Net.Core/Entities/Activities/IActivity.cs index 0dcf34273..1f158217d 100644 --- a/src/Discord.Net.Core/Entities/Activities/IActivity.cs +++ b/src/Discord.Net.Core/Entities/Activities/IActivity.cs @@ -1,13 +1,8 @@ -using System; -using System.Collections.Generic; -using System.Linq; -using System.Text; -using System.Threading.Tasks; - -namespace Discord +namespace Discord { public interface IActivity { string Name { get; } + ActivityType Type { get; } } } diff --git a/src/Discord.Net.Core/Entities/Activities/RichGame.cs b/src/Discord.Net.Core/Entities/Activities/RichGame.cs index e66eac1d2..fc3f68cf0 100644 --- a/src/Discord.Net.Core/Entities/Activities/RichGame.cs +++ b/src/Discord.Net.Core/Entities/Activities/RichGame.cs @@ -1,4 +1,4 @@ -using System.Diagnostics; +using System.Diagnostics; namespace Discord { @@ -7,8 +7,8 @@ namespace Discord { internal RichGame() { } - public string Details { get; internal set;} - public string State { get; internal set;} + public string Details { get; internal set; } + public string State { get; internal set; } public ulong ApplicationId { get; internal set; } public GameAsset SmallAsset { get; internal set; } public GameAsset LargeAsset { get; internal set; } @@ -19,4 +19,4 @@ namespace Discord public override string ToString() => Name; private string DebuggerDisplay => $"{Name} (Rich)"; } -} \ No newline at end of file +} diff --git a/src/Discord.Net.Core/Entities/Activities/SpotifyGame.cs b/src/Discord.Net.Core/Entities/Activities/SpotifyGame.cs new file mode 100644 index 000000000..a20384242 --- /dev/null +++ b/src/Discord.Net.Core/Entities/Activities/SpotifyGame.cs @@ -0,0 +1,23 @@ +using System; +using System.Collections.Generic; +using System.Diagnostics; + +namespace Discord +{ + [DebuggerDisplay(@"{DebuggerDisplay,nq}")] + public class SpotifyGame : Game + { + public IEnumerable Artists { get; internal set; } + public string AlbumArt { get; internal set; } + public string AlbumTitle { get; internal set; } + public string TrackTitle { get; internal set; } + public string SyncId { get; internal set; } + public string SessionId { get; internal set; } + public TimeSpan? Duration { get; internal set; } + + internal SpotifyGame() { } + + public override string ToString() => Name; + private string DebuggerDisplay => $"{Name} (Spotify)"; + } +} diff --git a/src/Discord.Net.Core/Entities/Activities/StreamingGame.cs b/src/Discord.Net.Core/Entities/Activities/StreamingGame.cs index 140024272..afbc24cd9 100644 --- a/src/Discord.Net.Core/Entities/Activities/StreamingGame.cs +++ b/src/Discord.Net.Core/Entities/Activities/StreamingGame.cs @@ -6,15 +6,14 @@ namespace Discord public class StreamingGame : Game { public string Url { get; internal set; } - public StreamType StreamType { get; internal set; } - public StreamingGame(string name, string url, StreamType streamType) + public StreamingGame(string name, string url) { Name = name; Url = url; - StreamType = streamType; + Type = ActivityType.Streaming; } - + public override string ToString() => Name; private string DebuggerDisplay => $"{Name} ({Url})"; } diff --git a/src/Discord.Net.Core/Entities/Channels/IMessageChannel.cs b/src/Discord.Net.Core/Entities/Channels/IMessageChannel.cs index a1df5b4c7..5a6e5df59 100644 --- a/src/Discord.Net.Core/Entities/Channels/IMessageChannel.cs +++ b/src/Discord.Net.Core/Entities/Channels/IMessageChannel.cs @@ -1,4 +1,4 @@ -using System; +using System; using System.Collections.Generic; using System.IO; using System.Threading.Tasks; @@ -11,10 +11,10 @@ namespace Discord Task SendMessageAsync(string text, bool isTTS = false, Embed embed = null, RequestOptions options = null); #if FILESYSTEM /// Sends a file to this text channel, with an optional caption. - Task SendFileAsync(string filePath, string text = null, bool isTTS = false, RequestOptions options = null); + Task SendFileAsync(string filePath, string text = null, bool isTTS = false, Embed embed = null, RequestOptions options = null); #endif /// Sends a file to this text channel, with an optional caption. - Task SendFileAsync(Stream stream, string filename, string text = null, bool isTTS = false, RequestOptions options = null); + Task SendFileAsync(Stream stream, string filename, string text = null, bool isTTS = false, Embed embed = null, RequestOptions options = null); /// Gets a message from this message channel with the given id, or null if not found. Task GetMessageAsync(ulong id, CacheMode mode = CacheMode.AllowDownload, RequestOptions options = null); diff --git a/src/Discord.Net.Rest/Entities/Messages/EmbedBuilder.cs b/src/Discord.Net.Core/Entities/Messages/EmbedBuilder.cs similarity index 80% rename from src/Discord.Net.Rest/Entities/Messages/EmbedBuilder.cs rename to src/Discord.Net.Core/Entities/Messages/EmbedBuilder.cs index f5663cea3..62834ebf3 100644 --- a/src/Discord.Net.Rest/Entities/Messages/EmbedBuilder.cs +++ b/src/Discord.Net.Core/Entities/Messages/EmbedBuilder.cs @@ -1,89 +1,105 @@ -using System; +using System; using System.Collections.Generic; using System.Collections.Immutable; +using System.Linq; namespace Discord { public class EmbedBuilder { - private readonly Embed _embed; + private string _title; + private string _description; + private string _url; + private EmbedImage? _image; + private EmbedThumbnail? _thumbnail; + private List _fields; public const int MaxFieldCount = 25; public const int MaxTitleLength = 256; public const int MaxDescriptionLength = 2048; - public const int MaxEmbedLength = 6000; // user bot limit is 2000, but we don't validate that here. + public const int MaxEmbedLength = 6000; public EmbedBuilder() { - _embed = new Embed(EmbedType.Rich); Fields = new List(); } public string Title { - get => _embed.Title; + get => _title; set { if (value?.Length > MaxTitleLength) throw new ArgumentException($"Title length must be less than or equal to {MaxTitleLength}.", nameof(Title)); - _embed.Title = value; + _title = value; } } - public string Description { - get => _embed.Description; + get => _description; set { if (value?.Length > MaxDescriptionLength) throw new ArgumentException($"Description length must be less than or equal to {MaxDescriptionLength}.", nameof(Description)); - _embed.Description = value; + _description = value; } } public string Url { - get => _embed.Url; + get => _url; set { if (!value.IsNullOrUri()) throw new ArgumentException("Url must be a well-formed URI", nameof(Url)); - _embed.Url = value; + _url = value; } } public string ThumbnailUrl { - get => _embed.Thumbnail?.Url; + get => _thumbnail?.Url; set { if (!value.IsNullOrUri()) throw new ArgumentException("Url must be a well-formed URI", nameof(ThumbnailUrl)); - _embed.Thumbnail = new EmbedThumbnail(value, null, null, null); + _thumbnail = new EmbedThumbnail(value, null, null, null); } } public string ImageUrl { - get => _embed.Image?.Url; + get => _image?.Url; set { if (!value.IsNullOrUri()) throw new ArgumentException("Url must be a well-formed URI", nameof(ImageUrl)); - _embed.Image = new EmbedImage(value, null, null, null); + _image = new EmbedImage(value, null, null, null); } } - public DateTimeOffset? Timestamp { get => _embed.Timestamp; set { _embed.Timestamp = value; } } - public Color? Color { get => _embed.Color; set { _embed.Color = value; } } - - public EmbedAuthorBuilder Author { get; set; } - public EmbedFooterBuilder Footer { get; set; } - private List _fields; public List Fields { get => _fields; set { - if (value == null) throw new ArgumentNullException("Cannot set an embed builder's fields collection to null", nameof(Fields)); if (value.Count > MaxFieldCount) throw new ArgumentException($"Field count must be less than or equal to {MaxFieldCount}.", nameof(Fields)); _fields = value; } } + public DateTimeOffset? Timestamp { get; set; } + public Color? Color { get; set; } + public EmbedAuthorBuilder Author { get; set; } + public EmbedFooterBuilder Footer { get; set; } + + public int Length + { + get + { + int titleLength = Title?.Length ?? 0; + int authorLength = Author?.Name?.Length ?? 0; + int descriptionLength = Description?.Length ?? 0; + int footerLength = Footer?.Text?.Length ?? 0; + int fieldSum = Fields.Sum(f => f.Name.Length + f.Value.ToString().Length); + + return titleLength + authorLength + descriptionLength + footerLength + fieldSum; + } + } + public EmbedBuilder WithTitle(string title) { Title = title; @@ -180,7 +196,6 @@ namespace Discord AddField(field); return this; } - public EmbedBuilder AddField(EmbedFieldBuilder field) { if (Fields.Count >= MaxFieldCount) @@ -195,63 +210,54 @@ namespace Discord { var field = new EmbedFieldBuilder(); action(field); - this.AddField(field); + AddField(field); return this; } public Embed Build() { - _embed.Footer = Footer?.Build(); - _embed.Author = Author?.Build(); + if (Length > MaxEmbedLength) + throw new InvalidOperationException($"Total embed length must be less than or equal to {MaxEmbedLength}"); + var fields = ImmutableArray.CreateBuilder(Fields.Count); for (int i = 0; i < Fields.Count; i++) fields.Add(Fields[i].Build()); - _embed.Fields = fields.ToImmutable(); - if (_embed.Length > MaxEmbedLength) - { - throw new InvalidOperationException($"Total embed length must be less than or equal to {MaxEmbedLength}"); - } - - return _embed; + return new Embed(EmbedType.Rich, Title, Description, Url, Timestamp, Color, _image, null, Author?.Build(), Footer?.Build(), null, _thumbnail, fields.ToImmutable()); } } public class EmbedFieldBuilder { + private string _name; + private string _value; private EmbedField _field; - public const int MaxFieldNameLength = 256; public const int MaxFieldValueLength = 1024; public string Name { - get => _field.Name; + get => _name; set { if (string.IsNullOrWhiteSpace(value)) throw new ArgumentException($"Field name must not be null, empty or entirely whitespace.", nameof(Name)); if (value.Length > MaxFieldNameLength) throw new ArgumentException($"Field name length must be less than or equal to {MaxFieldNameLength}.", nameof(Name)); - _field.Name = value; + _name = value; } } public object Value { - get => _field.Value; + get => _value; set { var stringValue = value?.ToString(); if (string.IsNullOrEmpty(stringValue)) throw new ArgumentException($"Field value must not be null or empty.", nameof(Value)); if (stringValue.Length > MaxFieldValueLength) throw new ArgumentException($"Field value length must be less than or equal to {MaxFieldValueLength}.", nameof(Value)); - _field.Value = stringValue; + _value = stringValue; } } - public bool IsInline { get => _field.Inline; set { _field.Inline = value; } } - - public EmbedFieldBuilder() - { - _field = new EmbedField(); - } + public bool IsInline { get; set; } public EmbedFieldBuilder WithName(string name) { @@ -270,48 +276,44 @@ namespace Discord } public EmbedField Build() - => _field; + => new EmbedField(Name, Value.ToString(), IsInline); } public class EmbedAuthorBuilder { - private EmbedAuthor _author; - + private string _name; + private string _url; + private string _iconUrl; public const int MaxAuthorNameLength = 256; public string Name { - get => _author.Name; + get => _name; set { if (value?.Length > MaxAuthorNameLength) throw new ArgumentException($"Author name length must be less than or equal to {MaxAuthorNameLength}.", nameof(Name)); - _author.Name = value; + _name = value; } } public string Url { - get => _author.Url; + get => _url; set { if (!value.IsNullOrUri()) throw new ArgumentException("Url must be a well-formed URI", nameof(Url)); - _author.Url = value; + _url = value; } } public string IconUrl { - get => _author.IconUrl; + get => _iconUrl; set { if (!value.IsNullOrUri()) throw new ArgumentException("Url must be a well-formed URI", nameof(IconUrl)); - _author.IconUrl = value; + _iconUrl = value; } } - public EmbedAuthorBuilder() - { - _author = new EmbedAuthor(); - } - public EmbedAuthorBuilder WithName(string name) { Name = name; @@ -329,39 +331,35 @@ namespace Discord } public EmbedAuthor Build() - => _author; + => new EmbedAuthor(Name, Url, IconUrl, null); } public class EmbedFooterBuilder { - private EmbedFooter _footer; + private string _text; + private string _iconUrl; public const int MaxFooterTextLength = 2048; public string Text { - get => _footer.Text; + get => _text; set { if (value?.Length > MaxFooterTextLength) throw new ArgumentException($"Footer text length must be less than or equal to {MaxFooterTextLength}.", nameof(Text)); - _footer.Text = value; + _text = value; } } public string IconUrl { - get => _footer.IconUrl; + get => _iconUrl; set { if (!value.IsNullOrUri()) throw new ArgumentException("Url must be a well-formed URI", nameof(IconUrl)); - _footer.IconUrl = value; + _iconUrl = value; } } - public EmbedFooterBuilder() - { - _footer = new EmbedFooter(); - } - public EmbedFooterBuilder WithText(string text) { Text = text; @@ -374,6 +372,6 @@ namespace Discord } public EmbedFooter Build() - => _footer; + => new EmbedFooter(Text, IconUrl, null); } } diff --git a/src/Discord.Net.Core/Entities/Messages/EmbedType.cs b/src/Discord.Net.Core/Entities/Messages/EmbedType.cs index 469e968a5..5bb2653e2 100644 --- a/src/Discord.Net.Core/Entities/Messages/EmbedType.cs +++ b/src/Discord.Net.Core/Entities/Messages/EmbedType.cs @@ -1,13 +1,15 @@ -namespace Discord +namespace Discord { public enum EmbedType { + Unknown = -1, Rich, Link, Video, Image, Gifv, Article, - Tweet + Tweet, + Html, } } diff --git a/src/Discord.Net.Core/Entities/Permissions/ChannelPermissions.cs b/src/Discord.Net.Core/Entities/Permissions/ChannelPermissions.cs index 1a8aad53c..ef10ee106 100644 --- a/src/Discord.Net.Core/Entities/Permissions/ChannelPermissions.cs +++ b/src/Discord.Net.Core/Entities/Permissions/ChannelPermissions.cs @@ -1,4 +1,4 @@ -using System; +using System; using System.Collections.Generic; using System.Diagnostics; @@ -13,6 +13,8 @@ namespace Discord public static readonly ChannelPermissions Text = new ChannelPermissions(0b01100_0000000_1111111110001_010001); /// Gets a ChannelPermissions that grants all permissions for voice channels. public static readonly ChannelPermissions Voice = new ChannelPermissions(0b00100_1111110_0000000000000_010001); + /// Gets a ChannelPermissions that grants all permissions for category channels. + public static readonly ChannelPermissions Category = new ChannelPermissions(0b01100_1111110_1111111110001_010001); /// Gets a ChannelPermissions that grants all permissions for direct message channels. public static readonly ChannelPermissions DM = new ChannelPermissions(0b00000_1000110_1011100110000_000000); /// Gets a ChannelPermissions that grants all permissions for group channels. @@ -24,6 +26,7 @@ namespace Discord { case ITextChannel _: return Text; case IVoiceChannel _: return Voice; + case ICategoryChannel _: return Category; case IDMChannel _: return DM; case IGroupChannel _: return Group; default: throw new ArgumentException("Unknown channel type", nameof(channel)); @@ -157,4 +160,4 @@ namespace Discord public override string ToString() => RawValue.ToString(); private string DebuggerDisplay => $"{string.Join(", ", ToList())}"; } -} \ No newline at end of file +} diff --git a/src/Discord.Net.Core/Entities/Users/IUser.cs b/src/Discord.Net.Core/Entities/Users/IUser.cs index e3f270f6f..c5cce7a25 100644 --- a/src/Discord.Net.Core/Entities/Users/IUser.cs +++ b/src/Discord.Net.Core/Entities/Users/IUser.cs @@ -8,6 +8,8 @@ namespace Discord string AvatarId { get; } /// Gets the url to this user's avatar. string GetAvatarUrl(ImageFormat format = ImageFormat.Auto, ushort size = 128); + /// Gets the url to this user's default avatar. + string GetDefaultAvatarUrl(); /// Gets the per-username unique id for this user. string Discriminator { get; } /// Gets the per-username unique id for this user. diff --git a/src/Discord.Net.Core/Entities/Users/StreamType.cs b/src/Discord.Net.Core/Entities/Users/StreamType.cs deleted file mode 100644 index 7622e3d6e..000000000 --- a/src/Discord.Net.Core/Entities/Users/StreamType.cs +++ /dev/null @@ -1,8 +0,0 @@ -namespace Discord -{ - public enum StreamType - { - NotStreaming = 0, - Twitch = 1 - } -} diff --git a/src/Discord.Net.Core/Extensions/AsyncEnumerableExtensions.cs b/src/Discord.Net.Core/Extensions/AsyncEnumerableExtensions.cs index 345154f1d..dd16d2943 100644 --- a/src/Discord.Net.Core/Extensions/AsyncEnumerableExtensions.cs +++ b/src/Discord.Net.Core/Extensions/AsyncEnumerableExtensions.cs @@ -1,6 +1,5 @@ -using System.Collections.Generic; +using System.Collections.Generic; using System.Linq; -using System.Threading; using System.Threading.Tasks; namespace Discord @@ -20,45 +19,7 @@ namespace Discord public static IAsyncEnumerable Flatten(this IAsyncEnumerable> source) { - return new PagedCollectionEnumerator(source); - } - - internal class PagedCollectionEnumerator : IAsyncEnumerator, IAsyncEnumerable - { - readonly IAsyncEnumerator> _source; - IEnumerator _enumerator; - - public IAsyncEnumerator GetEnumerator() => this; - - internal PagedCollectionEnumerator(IAsyncEnumerable> source) - { - _source = source.GetEnumerator(); - } - - public T Current => _enumerator.Current; - - public void Dispose() - { - _enumerator?.Dispose(); - _source.Dispose(); - } - - public async Task MoveNext(CancellationToken cancellationToken) - { - cancellationToken.ThrowIfCancellationRequested(); - - if(!_enumerator?.MoveNext() ?? true) - { - if (!await _source.MoveNext(cancellationToken).ConfigureAwait(false)) - return false; - - _enumerator?.Dispose(); - _enumerator = _source.Current.GetEnumerator(); - return _enumerator.MoveNext(); - } - - return true; - } + return source.SelectMany(enumerable => enumerable.ToAsyncEnumerable()); } } } diff --git a/src/Discord.Net.Rest/Extensions/EmbedBuilderExtensions.cs b/src/Discord.Net.Core/Extensions/EmbedBuilderExtensions.cs similarity index 100% rename from src/Discord.Net.Rest/Extensions/EmbedBuilderExtensions.cs rename to src/Discord.Net.Core/Extensions/EmbedBuilderExtensions.cs diff --git a/src/Discord.Net.Core/Extensions/UserExtensions.cs b/src/Discord.Net.Core/Extensions/UserExtensions.cs index eac25391e..d3e968e39 100644 --- a/src/Discord.Net.Core/Extensions/UserExtensions.cs +++ b/src/Discord.Net.Core/Extensions/UserExtensions.cs @@ -1,4 +1,4 @@ -using System.Threading.Tasks; +using System.Threading.Tasks; using System.IO; namespace Discord @@ -8,10 +8,10 @@ namespace Discord /// /// Sends a message to the user via DM. /// - public static async Task SendMessageAsync(this IUser user, - string text, + public static async Task SendMessageAsync(this IUser user, + string text, bool isTTS = false, - Embed embed = null, + Embed embed = null, RequestOptions options = null) { return await (await user.GetOrCreateDMChannelAsync().ConfigureAwait(false)).SendMessageAsync(text, isTTS, embed, options).ConfigureAwait(false); @@ -25,24 +25,29 @@ namespace Discord string filename, string text = null, bool isTTS = false, + Embed embed = null, RequestOptions options = null ) { - return await (await user.GetOrCreateDMChannelAsync().ConfigureAwait(false)).SendFileAsync(stream, filename, text, isTTS, options).ConfigureAwait(false); + return await (await user.GetOrCreateDMChannelAsync().ConfigureAwait(false)).SendFileAsync(stream, filename, text, isTTS, embed, options).ConfigureAwait(false); } #if FILESYSTEM /// /// Sends a file to the user via DM. /// - public static async Task SendFileAsync(this IUser user, - string filePath, - string text = null, - bool isTTS = false, + public static async Task SendFileAsync(this IUser user, + string filePath, + string text = null, + bool isTTS = false, + Embed embed = null, RequestOptions options = null) { - return await (await user.GetOrCreateDMChannelAsync().ConfigureAwait(false)).SendFileAsync(filePath, text, isTTS, options).ConfigureAwait(false); + return await (await user.GetOrCreateDMChannelAsync().ConfigureAwait(false)).SendFileAsync(filePath, text, isTTS, embed, options).ConfigureAwait(false); } #endif + + public static Task BanAsync(this IGuildUser user, int pruneDays = 0, string reason = null, RequestOptions options = null) + => user.Guild.AddBanAsync(user, pruneDays, reason, options); } } diff --git a/src/Discord.Net.Core/IDiscordClient.cs b/src/Discord.Net.Core/IDiscordClient.cs index 9abb959b5..a383c37da 100644 --- a/src/Discord.Net.Core/IDiscordClient.cs +++ b/src/Discord.Net.Core/IDiscordClient.cs @@ -36,5 +36,7 @@ namespace Discord Task GetVoiceRegionAsync(string id, RequestOptions options = null); Task GetWebhookAsync(ulong id, RequestOptions options = null); + + Task GetRecommendedShardCountAsync(RequestOptions options = null); } } diff --git a/src/Discord.Net.Core/Net/HttpException.cs b/src/Discord.Net.Core/Net/HttpException.cs index 1c872245c..d0ee65b23 100644 --- a/src/Discord.Net.Core/Net/HttpException.cs +++ b/src/Discord.Net.Core/Net/HttpException.cs @@ -1,4 +1,4 @@ -using System; +using System; using System.Net; namespace Discord.Net @@ -8,11 +8,13 @@ namespace Discord.Net public HttpStatusCode HttpCode { get; } public int? DiscordCode { get; } public string Reason { get; } + public IRequest Request { get; } - public HttpException(HttpStatusCode httpCode, int? discordCode = null, string reason = null) + public HttpException(HttpStatusCode httpCode, IRequest request, int? discordCode = null, string reason = null) : base(CreateMessage(httpCode, discordCode, reason)) { HttpCode = httpCode; + Request = request; DiscordCode = discordCode; Reason = reason; } diff --git a/src/Discord.Net.Core/Net/IRequest.cs b/src/Discord.Net.Core/Net/IRequest.cs new file mode 100644 index 000000000..d3c708dd5 --- /dev/null +++ b/src/Discord.Net.Core/Net/IRequest.cs @@ -0,0 +1,10 @@ +using System; + +namespace Discord.Net +{ + public interface IRequest + { + DateTimeOffset? TimeoutAt { get; } + RequestOptions Options { get; } + } +} diff --git a/src/Discord.Net.Core/Net/RateLimitedException.cs b/src/Discord.Net.Core/Net/RateLimitedException.cs index e8572f911..2d34d7bc2 100644 --- a/src/Discord.Net.Core/Net/RateLimitedException.cs +++ b/src/Discord.Net.Core/Net/RateLimitedException.cs @@ -1,12 +1,15 @@ -using System; +using System; namespace Discord.Net { public class RateLimitedException : TimeoutException { - public RateLimitedException() + public IRequest Request { get; } + + public RateLimitedException(IRequest request) : base("You are being rate limited.") { + Request = request; } } } diff --git a/src/Discord.Net.Core/TokenType.cs b/src/Discord.Net.Core/TokenType.cs index c351b1c19..62181420a 100644 --- a/src/Discord.Net.Core/TokenType.cs +++ b/src/Discord.Net.Core/TokenType.cs @@ -1,10 +1,10 @@ -using System; +using System; namespace Discord { public enum TokenType { - [Obsolete("User logins are being deprecated and may result in a ToS strike against your account - please see https://github.com/RogueException/Discord.Net/issues/827")] + [Obsolete("User logins are deprecated and may result in a ToS strike against your account - please see https://github.com/RogueException/Discord.Net/issues/827", error: true)] User, Bearer, Bot, diff --git a/src/Discord.Net.Core/Utils/Comparers.cs b/src/Discord.Net.Core/Utils/Comparers.cs new file mode 100644 index 000000000..d7641e897 --- /dev/null +++ b/src/Discord.Net.Core/Utils/Comparers.cs @@ -0,0 +1,45 @@ +using System; +using System.Collections.Generic; + +namespace Discord +{ + public static class DiscordComparers + { + // TODO: simplify with '??=' slated for C# 8.0 + public static IEqualityComparer UserComparer => _userComparer ?? (_userComparer = new EntityEqualityComparer()); + public static IEqualityComparer GuildComparer => _guildComparer ?? (_guildComparer = new EntityEqualityComparer()); + public static IEqualityComparer ChannelComparer => _channelComparer ?? (_channelComparer = new EntityEqualityComparer()); + public static IEqualityComparer RoleComparer => _roleComparer ?? (_roleComparer = new EntityEqualityComparer()); + public static IEqualityComparer MessageComparer => _messageComparer ?? (_messageComparer = new EntityEqualityComparer()); + + private static IEqualityComparer _userComparer; + private static IEqualityComparer _guildComparer; + private static IEqualityComparer _channelComparer; + private static IEqualityComparer _roleComparer; + private static IEqualityComparer _messageComparer; + + private sealed class EntityEqualityComparer : EqualityComparer + where TEntity : IEntity + where TId : IEquatable + { + public override bool Equals(TEntity x, TEntity y) + { + bool xNull = x == null; + bool yNull = y == null; + + if (xNull && yNull) + return true; + + if (xNull ^ yNull) + return false; + + return x.Id.Equals(y.Id); + } + + public override int GetHashCode(TEntity obj) + { + return obj?.Id.GetHashCode() ?? 0; + } + } + } +} diff --git a/src/Discord.Net.Core/Utils/Permissions.cs b/src/Discord.Net.Core/Utils/Permissions.cs index 367926dd1..04e6784c3 100644 --- a/src/Discord.Net.Core/Utils/Permissions.cs +++ b/src/Discord.Net.Core/Utils/Permissions.cs @@ -80,7 +80,7 @@ namespace Discord } [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static bool HasFlag(ulong value, ulong flag) => (value & flag) != 0; + private static bool HasFlag(ulong value, ulong flag) => (value & flag) == flag; [MethodImpl(MethodImplOptions.AggressiveInlining)] public static void SetFlag(ref ulong value, ulong flag) => value |= flag; [MethodImpl(MethodImplOptions.AggressiveInlining)] @@ -133,9 +133,10 @@ namespace Discord ulong deniedPermissions = 0UL, allowedPermissions = 0UL; foreach (var roleId in user.RoleIds) { - if (roleId != guild.EveryoneRole.Id) + IRole role = null; + if (roleId != guild.EveryoneRole.Id && (role = guild.GetRole(roleId)) != null) { - perms = channel.GetPermissionOverwrite(guild.GetRole(roleId)); + perms = channel.GetPermissionOverwrite(role); if (perms != null) { allowedPermissions |= perms.Value.AllowValue; @@ -160,10 +161,10 @@ namespace Discord else if (!GetValue(resolvedPermissions, ChannelPermission.SendMessages)) { //No send permissions on a text channel removes all send-related permissions - resolvedPermissions &= ~(1UL << (int)ChannelPermission.SendTTSMessages); - resolvedPermissions &= ~(1UL << (int)ChannelPermission.MentionEveryone); - resolvedPermissions &= ~(1UL << (int)ChannelPermission.EmbedLinks); - resolvedPermissions &= ~(1UL << (int)ChannelPermission.AttachFiles); + resolvedPermissions &= ~(ulong)ChannelPermission.SendTTSMessages; + resolvedPermissions &= ~(ulong)ChannelPermission.MentionEveryone; + resolvedPermissions &= ~(ulong)ChannelPermission.EmbedLinks; + resolvedPermissions &= ~(ulong)ChannelPermission.AttachFiles; } } resolvedPermissions &= mask; //Ensure we didnt get any permissions this channel doesnt support (from guildPerms, for example) @@ -172,4 +173,4 @@ namespace Discord return resolvedPermissions; } } -} \ No newline at end of file +} diff --git a/src/Discord.Net.Providers.WS4Net/Discord.Net.Providers.WS4Net.csproj b/src/Discord.Net.Providers.WS4Net/Discord.Net.Providers.WS4Net.csproj index 78987e739..bfd0983ce 100644 --- a/src/Discord.Net.Providers.WS4Net/Discord.Net.Providers.WS4Net.csproj +++ b/src/Discord.Net.Providers.WS4Net/Discord.Net.Providers.WS4Net.csproj @@ -10,6 +10,6 @@ - +
diff --git a/src/Discord.Net.Providers.WS4Net/WS4NetClient.cs b/src/Discord.Net.Providers.WS4Net/WS4NetClient.cs index 93d6a83d6..1894a8906 100644 --- a/src/Discord.Net.Providers.WS4Net/WS4NetClient.cs +++ b/src/Discord.Net.Providers.WS4Net/WS4NetClient.cs @@ -66,7 +66,7 @@ namespace Discord.Net.Providers.WS4Net _cancelTokenSource = new CancellationTokenSource(); _cancelToken = CancellationTokenSource.CreateLinkedTokenSource(_parentToken, _cancelTokenSource.Token).Token; - _client = new WS4NetSocket(host, customHeaderItems: _headers.ToList()) + _client = new WS4NetSocket(host, "", customHeaderItems: _headers.ToList()) { EnableAutoSendPing = false, NoDelay = true, @@ -163,4 +163,4 @@ namespace Discord.Net.Providers.WS4Net Closed(ex).GetAwaiter().GetResult(); } } -} \ No newline at end of file +} diff --git a/src/Discord.Net.Rest/API/Common/Game.cs b/src/Discord.Net.Rest/API/Common/Game.cs index bfb861692..4cde8444a 100644 --- a/src/Discord.Net.Rest/API/Common/Game.cs +++ b/src/Discord.Net.Rest/API/Common/Game.cs @@ -1,4 +1,4 @@ -#pragma warning disable CS1591 +#pragma warning disable CS1591 using Newtonsoft.Json; using Newtonsoft.Json.Serialization; using System.Runtime.Serialization; @@ -12,7 +12,7 @@ namespace Discord.API [JsonProperty("url")] public Optional StreamUrl { get; set; } [JsonProperty("type")] - public Optional StreamType { get; set; } + public Optional Type { get; set; } [JsonProperty("details")] public Optional Details { get; set; } [JsonProperty("state")] @@ -29,6 +29,10 @@ namespace Discord.API public Optional Timestamps { get; set; } [JsonProperty("instance")] public Optional Instance { get; set; } + [JsonProperty("sync_id")] + public Optional SyncId { get; set; } + [JsonProperty("session_id")] + public Optional SessionId { get; set; } [OnError] internal void OnError(StreamingContext context, ErrorContext errorContext) diff --git a/src/Discord.Net.Rest/API/Common/GameAssets.cs b/src/Discord.Net.Rest/API/Common/GameAssets.cs index b5928a8ab..94a540769 100644 --- a/src/Discord.Net.Rest/API/Common/GameAssets.cs +++ b/src/Discord.Net.Rest/API/Common/GameAssets.cs @@ -1,4 +1,4 @@ -using Newtonsoft.Json; +using Newtonsoft.Json; namespace Discord.API { @@ -8,9 +8,9 @@ namespace Discord.API public Optional SmallText { get; set; } [JsonProperty("small_image")] public Optional SmallImage { get; set; } - [JsonProperty("large_image")] - public Optional LargeText { get; set; } [JsonProperty("large_text")] + public Optional LargeText { get; set; } + [JsonProperty("large_image")] public Optional LargeImage { get; set; } } -} \ No newline at end of file +} diff --git a/src/Discord.Net.Rest/API/Common/GameParty.cs b/src/Discord.Net.Rest/API/Common/GameParty.cs index e0da4a098..4f8ce2654 100644 --- a/src/Discord.Net.Rest/API/Common/GameParty.cs +++ b/src/Discord.Net.Rest/API/Common/GameParty.cs @@ -1,4 +1,4 @@ -using Newtonsoft.Json; +using Newtonsoft.Json; namespace Discord.API { @@ -7,6 +7,6 @@ namespace Discord.API [JsonProperty("id")] public string Id { get; set; } [JsonProperty("size")] - public int[] Size { get; set; } + public long[] Size { get; set; } } -} \ No newline at end of file +} diff --git a/src/Discord.Net.Rest/API/Rest/UploadFileParams.cs b/src/Discord.Net.Rest/API/Rest/UploadFileParams.cs index 30bfc7f9a..9e909b50c 100644 --- a/src/Discord.Net.Rest/API/Rest/UploadFileParams.cs +++ b/src/Discord.Net.Rest/API/Rest/UploadFileParams.cs @@ -1,18 +1,25 @@ -#pragma warning disable CS1591 +#pragma warning disable CS1591 +using Discord.Net.Converters; using Discord.Net.Rest; +using Newtonsoft.Json; using System.Collections.Generic; +using System.Globalization; using System.IO; +using System.Text; namespace Discord.API.Rest { internal class UploadFileParams { + private static JsonSerializer _serializer = new JsonSerializer { ContractResolver = new DiscordContractResolver() }; + public Stream File { get; } public Optional Filename { get; set; } public Optional Content { get; set; } public Optional Nonce { get; set; } public Optional IsTTS { get; set; } + public Optional Embed { get; set; } public UploadFileParams(Stream file) { @@ -23,12 +30,24 @@ namespace Discord.API.Rest { var d = new Dictionary(); d["file"] = new MultipartFile(File, Filename.GetValueOrDefault("unknown.dat")); + + var payload = new Dictionary(); if (Content.IsSpecified) - d["content"] = Content.Value; + payload["content"] = Content.Value; if (IsTTS.IsSpecified) - d["tts"] = IsTTS.Value.ToString(); + payload["tts"] = IsTTS.Value.ToString(); if (Nonce.IsSpecified) - d["nonce"] = Nonce.Value; + payload["nonce"] = Nonce.Value; + if (Embed.IsSpecified) + payload["embed"] = Embed.Value; + + var json = new StringBuilder(); + using (var text = new StringWriter(json)) + using (var writer = new JsonTextWriter(text)) + _serializer.Serialize(writer, payload); + + d["payload_json"] = json.ToString(); + return d; } } diff --git a/src/Discord.Net.Rest/AssemblyInfo.cs b/src/Discord.Net.Rest/AssemblyInfo.cs index a4f045ab5..9c90919af 100644 --- a/src/Discord.Net.Rest/AssemblyInfo.cs +++ b/src/Discord.Net.Rest/AssemblyInfo.cs @@ -1,7 +1,22 @@ -using System.Runtime.CompilerServices; +using System.Runtime.CompilerServices; [assembly: InternalsVisibleTo("Discord.Net.Rpc")] [assembly: InternalsVisibleTo("Discord.Net.WebSocket")] [assembly: InternalsVisibleTo("Discord.Net.Webhook")] [assembly: InternalsVisibleTo("Discord.Net.Commands")] -[assembly: InternalsVisibleTo("Discord.Net.Tests")] \ No newline at end of file +[assembly: InternalsVisibleTo("Discord.Net.Tests")] + +[assembly: TypeForwardedTo(typeof(Discord.Embed))] +[assembly: TypeForwardedTo(typeof(Discord.EmbedBuilder))] +[assembly: TypeForwardedTo(typeof(Discord.EmbedBuilderExtensions))] +[assembly: TypeForwardedTo(typeof(Discord.EmbedAuthor))] +[assembly: TypeForwardedTo(typeof(Discord.EmbedAuthorBuilder))] +[assembly: TypeForwardedTo(typeof(Discord.EmbedField))] +[assembly: TypeForwardedTo(typeof(Discord.EmbedFieldBuilder))] +[assembly: TypeForwardedTo(typeof(Discord.EmbedFooter))] +[assembly: TypeForwardedTo(typeof(Discord.EmbedFooterBuilder))] +[assembly: TypeForwardedTo(typeof(Discord.EmbedImage))] +[assembly: TypeForwardedTo(typeof(Discord.EmbedProvider))] +[assembly: TypeForwardedTo(typeof(Discord.EmbedThumbnail))] +[assembly: TypeForwardedTo(typeof(Discord.EmbedType))] +[assembly: TypeForwardedTo(typeof(Discord.EmbedVideo))] diff --git a/src/Discord.Net.Rest/BaseDiscordClient.cs b/src/Discord.Net.Rest/BaseDiscordClient.cs index 269dedd71..f8642b96c 100644 --- a/src/Discord.Net.Rest/BaseDiscordClient.cs +++ b/src/Discord.Net.Rest/BaseDiscordClient.cs @@ -1,4 +1,4 @@ -using Discord.Logging; +using Discord.Logging; using System; using System.Collections.Generic; using System.Collections.Immutable; @@ -125,6 +125,10 @@ namespace Discord.Rest /// public void Dispose() => Dispose(true); + /// + public Task GetRecommendedShardCountAsync(RequestOptions options = null) + => ClientHelper.GetRecommendShardCountAsync(this, options); + //IDiscordClient ConnectionState IDiscordClient.ConnectionState => ConnectionState.Disconnected; ISelfUser IDiscordClient.CurrentUser => CurrentUser; diff --git a/src/Discord.Net.Rest/ClientHelper.cs b/src/Discord.Net.Rest/ClientHelper.cs index 5c9e26433..08305f857 100644 --- a/src/Discord.Net.Rest/ClientHelper.cs +++ b/src/Discord.Net.Rest/ClientHelper.cs @@ -1,4 +1,4 @@ -using Discord.API.Rest; +using Discord.API.Rest; using System.Collections.Generic; using System.Collections.Immutable; using System.IO; @@ -163,5 +163,11 @@ namespace Discord.Rest var models = await client.ApiClient.GetVoiceRegionsAsync(options).ConfigureAwait(false); return models.Select(x => RestVoiceRegion.Create(client, x)).FirstOrDefault(x => x.Id == id); } + + public static async Task GetRecommendShardCountAsync(BaseDiscordClient client, RequestOptions options) + { + var response = await client.ApiClient.GetBotGatewayAsync(options).ConfigureAwait(false); + return response.Shards; + } } } diff --git a/src/Discord.Net.Rest/DiscordRestApiClient.cs b/src/Discord.Net.Rest/DiscordRestApiClient.cs index 689cba9c3..f0c4358ad 100644 --- a/src/Discord.Net.Rest/DiscordRestApiClient.cs +++ b/src/Discord.Net.Rest/DiscordRestApiClient.cs @@ -1,5 +1,4 @@ #pragma warning disable CS1591 -#pragma warning disable CS0618 using Discord.API.Rest; using Discord.Net; using Discord.Net.Converters; @@ -70,12 +69,12 @@ namespace Discord.API { switch (tokenType) { + case default(TokenType): + return token; case TokenType.Bot: return $"Bot {token}"; case TokenType.Bearer: return $"Bearer {token}"; - case TokenType.User: - return token; default: throw new ArgumentException("Unknown OAuth token type", nameof(tokenType)); } @@ -113,7 +112,6 @@ namespace Discord.API { _loginCancelToken = new CancellationTokenSource(); - AuthTokenType = TokenType.User; AuthToken = null; await RequestQueue.SetCancelTokenAsync(_loginCancelToken.Token).ConfigureAwait(false); RestClient.SetCancelToken(_loginCancelToken.Token); @@ -172,8 +170,7 @@ namespace Discord.API { options = options ?? new RequestOptions(); options.HeaderOnly = true; - options.BucketId = AuthTokenType == TokenType.User ? ClientBucket.Get(clientBucket).Id : bucketId; - options.IsClientBucket = AuthTokenType == TokenType.User; + options.BucketId = bucketId; var request = new RestRequest(RestClient, method, endpoint, options); await SendInternalAsync(method, endpoint, request).ConfigureAwait(false); @@ -187,8 +184,7 @@ namespace Discord.API { options = options ?? new RequestOptions(); options.HeaderOnly = true; - options.BucketId = AuthTokenType == TokenType.User ? ClientBucket.Get(clientBucket).Id : bucketId; - options.IsClientBucket = AuthTokenType == TokenType.User; + options.BucketId = bucketId; string json = payload != null ? SerializeJson(payload) : null; var request = new JsonRestRequest(RestClient, method, endpoint, json, options); @@ -203,8 +199,7 @@ namespace Discord.API { options = options ?? new RequestOptions(); options.HeaderOnly = true; - options.BucketId = AuthTokenType == TokenType.User ? ClientBucket.Get(clientBucket).Id : bucketId; - options.IsClientBucket = AuthTokenType == TokenType.User; + options.BucketId = bucketId; var request = new MultipartRestRequest(RestClient, method, endpoint, multipartArgs, options); await SendInternalAsync(method, endpoint, request).ConfigureAwait(false); @@ -217,8 +212,7 @@ namespace Discord.API string bucketId = null, ClientBucketType clientBucket = ClientBucketType.Unbucketed, RequestOptions options = null) where TResponse : class { options = options ?? new RequestOptions(); - options.BucketId = AuthTokenType == TokenType.User ? ClientBucket.Get(clientBucket).Id : bucketId; - options.IsClientBucket = AuthTokenType == TokenType.User; + options.BucketId = bucketId; var request = new RestRequest(RestClient, method, endpoint, options); return DeserializeJson(await SendInternalAsync(method, endpoint, request).ConfigureAwait(false)); @@ -231,8 +225,7 @@ namespace Discord.API string bucketId = null, ClientBucketType clientBucket = ClientBucketType.Unbucketed, RequestOptions options = null) where TResponse : class { options = options ?? new RequestOptions(); - options.BucketId = AuthTokenType == TokenType.User ? ClientBucket.Get(clientBucket).Id : bucketId; - options.IsClientBucket = AuthTokenType == TokenType.User; + options.BucketId = bucketId; string json = payload != null ? SerializeJson(payload) : null; var request = new JsonRestRequest(RestClient, method, endpoint, json, options); @@ -246,8 +239,7 @@ namespace Discord.API string bucketId = null, ClientBucketType clientBucket = ClientBucketType.Unbucketed, RequestOptions options = null) { options = options ?? new RequestOptions(); - options.BucketId = AuthTokenType == TokenType.User ? ClientBucket.Get(clientBucket).Id : bucketId; - options.IsClientBucket = AuthTokenType == TokenType.User; + options.BucketId = bucketId; var request = new MultipartRestRequest(RestClient, method, endpoint, multipartArgs, options); return DeserializeJson(await SendInternalAsync(method, endpoint, request).ConfigureAwait(false)); @@ -277,6 +269,18 @@ namespace Discord.API await SendAsync("GET", () => "auth/login", new BucketIds(), options: options).ConfigureAwait(false); } + //Gateway + public async Task GetGatewayAsync(RequestOptions options = null) + { + options = RequestOptions.CreateOrClone(options); + return await SendAsync("GET", () => "gateway", new BucketIds(), options: options).ConfigureAwait(false); + } + public async Task GetBotGatewayAsync(RequestOptions options = null) + { + options = RequestOptions.CreateOrClone(options); + return await SendAsync("GET", () => "gateway/bot", new BucketIds(), options: options).ConfigureAwait(false); + } + //Channels public async Task GetChannelAsync(ulong channelId, RequestOptions options = null) { @@ -466,7 +470,7 @@ namespace Discord.API if (!args.Embed.IsSpecified || args.Embed.Value == null) Preconditions.NotNullOrEmpty(args.Content, nameof(args.Content)); - if (args.Content.Length > DiscordConfig.MaxMessageSize) + if (args.Content?.Length > DiscordConfig.MaxMessageSize) throw new ArgumentException($"Message content is too long, length must be less or equal to {DiscordConfig.MaxMessageSize}.", nameof(args.Content)); options = RequestOptions.CreateOrClone(options); @@ -483,7 +487,7 @@ namespace Discord.API if (!args.Embeds.IsSpecified || args.Embeds.Value == null || args.Embeds.Value.Length == 0) Preconditions.NotNullOrEmpty(args.Content, nameof(args.Content)); - if (args.Content.Length > DiscordConfig.MaxMessageSize) + if (args.Content?.Length > DiscordConfig.MaxMessageSize) throw new ArgumentException($"Message content is too long, length must be less or equal to {DiscordConfig.MaxMessageSize}.", nameof(args.Content)); options = RequestOptions.CreateOrClone(options); @@ -564,7 +568,7 @@ namespace Discord.API { if (!args.Embed.IsSpecified) Preconditions.NotNullOrEmpty(args.Content, nameof(args.Content)); - if (args.Content.Value.Length > DiscordConfig.MaxMessageSize) + if (args.Content.Value?.Length > DiscordConfig.MaxMessageSize) throw new ArgumentOutOfRangeException($"Message content is too long, length must be less or equal to {DiscordConfig.MaxMessageSize}.", nameof(args.Content)); } options = RequestOptions.CreateOrClone(options); diff --git a/src/Discord.Net.Rest/DiscordRestClient.cs b/src/Discord.Net.Rest/DiscordRestClient.cs index 3d90b6c00..8850da3a5 100644 --- a/src/Discord.Net.Rest/DiscordRestClient.cs +++ b/src/Discord.Net.Rest/DiscordRestClient.cs @@ -1,4 +1,4 @@ -using System.Collections.Generic; +using System.Collections.Generic; using System.Collections.Immutable; using System.IO; using System.Threading.Tasks; diff --git a/src/Discord.Net.Rest/Entities/Channels/ChannelHelper.cs b/src/Discord.Net.Rest/Entities/Channels/ChannelHelper.cs index f4b6c7f23..6784f7f6a 100644 --- a/src/Discord.Net.Rest/Entities/Channels/ChannelHelper.cs +++ b/src/Discord.Net.Rest/Entities/Channels/ChannelHelper.cs @@ -1,4 +1,4 @@ -using Discord.API.Rest; +using Discord.API.Rest; using System; using System.Collections.Generic; using System.Collections.Immutable; @@ -170,17 +170,17 @@ namespace Discord.Rest #if FILESYSTEM public static async Task SendFileAsync(IMessageChannel channel, BaseDiscordClient client, - string filePath, string text, bool isTTS, RequestOptions options) + string filePath, string text, bool isTTS, Embed embed, RequestOptions options) { string filename = Path.GetFileName(filePath); using (var file = File.OpenRead(filePath)) - return await SendFileAsync(channel, client, file, filename, text, isTTS, options).ConfigureAwait(false); + return await SendFileAsync(channel, client, file, filename, text, isTTS, embed, options).ConfigureAwait(false); } #endif public static async Task SendFileAsync(IMessageChannel channel, BaseDiscordClient client, - Stream stream, string filename, string text, bool isTTS, RequestOptions options) + Stream stream, string filename, string text, bool isTTS, Embed embed, RequestOptions options) { - var args = new UploadFileParams(stream) { Filename = filename, Content = text, IsTTS = isTTS }; + var args = new UploadFileParams(stream) { Filename = filename, Content = text, IsTTS = isTTS, Embed = embed != null ? embed.ToModel() : Optional.Unspecified }; var model = await client.ApiClient.UploadFileAsync(channel.Id, args, options).ConfigureAwait(false); return RestUserMessage.Create(client, channel, client.CurrentUser, model); } diff --git a/src/Discord.Net.Rest/Entities/Channels/IRestMessageChannel.cs b/src/Discord.Net.Rest/Entities/Channels/IRestMessageChannel.cs index 3a104dd9c..8ab6e9819 100644 --- a/src/Discord.Net.Rest/Entities/Channels/IRestMessageChannel.cs +++ b/src/Discord.Net.Rest/Entities/Channels/IRestMessageChannel.cs @@ -1,4 +1,4 @@ -using System.Collections.Generic; +using System.Collections.Generic; using System.IO; using System.Threading.Tasks; @@ -10,10 +10,10 @@ namespace Discord.Rest new Task SendMessageAsync(string text, bool isTTS = false, Embed embed = null, RequestOptions options = null); #if FILESYSTEM /// Sends a file to this text channel, with an optional caption. - new Task SendFileAsync(string filePath, string text = null, bool isTTS = false, RequestOptions options = null); + new Task SendFileAsync(string filePath, string text = null, bool isTTS = false, Embed embed = null, RequestOptions options = null); #endif /// Sends a file to this text channel, with an optional caption. - new Task SendFileAsync(Stream stream, string filename, string text = null, bool isTTS = false, RequestOptions options = null); + new Task SendFileAsync(Stream stream, string filename, string text = null, bool isTTS = false, Embed embed = null, RequestOptions options = null); /// Gets a message from this message channel with the given id, or null if not found. Task GetMessageAsync(ulong id, RequestOptions options = null); diff --git a/src/Discord.Net.Rest/Entities/Channels/RestDMChannel.cs b/src/Discord.Net.Rest/Entities/Channels/RestDMChannel.cs index d41441967..08acdf32b 100644 --- a/src/Discord.Net.Rest/Entities/Channels/RestDMChannel.cs +++ b/src/Discord.Net.Rest/Entities/Channels/RestDMChannel.cs @@ -1,4 +1,4 @@ -using System; +using System; using System.Collections.Generic; using System.Collections.Immutable; using System.Diagnostics; @@ -37,7 +37,7 @@ namespace Discord.Rest public override async Task UpdateAsync(RequestOptions options = null) { var model = await Discord.ApiClient.GetChannelAsync(Id, options).ConfigureAwait(false); - Update(model); + Update(model); } public Task CloseAsync(RequestOptions options = null) => ChannelHelper.DeleteAsync(this, Discord, options); @@ -66,11 +66,11 @@ namespace Discord.Rest public Task SendMessageAsync(string text, bool isTTS = false, Embed embed = null, RequestOptions options = null) => ChannelHelper.SendMessageAsync(this, Discord, text, isTTS, embed, options); #if FILESYSTEM - public Task SendFileAsync(string filePath, string text, bool isTTS = false, RequestOptions options = null) - => ChannelHelper.SendFileAsync(this, Discord, filePath, text, isTTS, options); + public Task SendFileAsync(string filePath, string text, bool isTTS = false, Embed embed = null, RequestOptions options = null) + => ChannelHelper.SendFileAsync(this, Discord, filePath, text, isTTS, embed, options); #endif - public Task SendFileAsync(Stream stream, string filename, string text, bool isTTS = false, RequestOptions options = null) - => ChannelHelper.SendFileAsync(this, Discord, stream, filename, text, isTTS, options); + public Task SendFileAsync(Stream stream, string filename, string text, bool isTTS = false, Embed embed = null, RequestOptions options = null) + => ChannelHelper.SendFileAsync(this, Discord, stream, filename, text, isTTS, embed, options); public Task TriggerTypingAsync(RequestOptions options = null) => ChannelHelper.TriggerTypingAsync(this, Discord, options); @@ -122,11 +122,11 @@ namespace Discord.Rest => await GetPinnedMessagesAsync(options).ConfigureAwait(false); #if FILESYSTEM - async Task IMessageChannel.SendFileAsync(string filePath, string text, bool isTTS, RequestOptions options) - => await SendFileAsync(filePath, text, isTTS, options).ConfigureAwait(false); + async Task IMessageChannel.SendFileAsync(string filePath, string text, bool isTTS, Embed embed, RequestOptions options) + => await SendFileAsync(filePath, text, isTTS, embed, options).ConfigureAwait(false); #endif - async Task IMessageChannel.SendFileAsync(Stream stream, string filename, string text, bool isTTS, RequestOptions options) - => await SendFileAsync(stream, filename, text, isTTS, options).ConfigureAwait(false); + async Task IMessageChannel.SendFileAsync(Stream stream, string filename, string text, bool isTTS, Embed embed, RequestOptions options) + => await SendFileAsync(stream, filename, text, isTTS, embed, options).ConfigureAwait(false); async Task IMessageChannel.SendMessageAsync(string text, bool isTTS, Embed embed, RequestOptions options) => await SendMessageAsync(text, isTTS, embed, options).ConfigureAwait(false); IDisposable IMessageChannel.EnterTypingState(RequestOptions options) diff --git a/src/Discord.Net.Rest/Entities/Channels/RestGroupChannel.cs b/src/Discord.Net.Rest/Entities/Channels/RestGroupChannel.cs index aa5c0f7dc..a1868573e 100644 --- a/src/Discord.Net.Rest/Entities/Channels/RestGroupChannel.cs +++ b/src/Discord.Net.Rest/Entities/Channels/RestGroupChannel.cs @@ -1,4 +1,4 @@ -using Discord.Audio; +using Discord.Audio; using System; using System.Collections.Generic; using System.Collections.Immutable; @@ -19,7 +19,7 @@ namespace Discord.Rest public string Name { get; private set; } public IReadOnlyCollection Users => _users.ToReadOnlyCollection(); - public IReadOnlyCollection Recipients + public IReadOnlyCollection Recipients => _users.Select(x => x.Value).Where(x => x.Id != Discord.CurrentUser.Id).ToReadOnlyCollection(() => _users.Count - 1); internal RestGroupChannel(BaseDiscordClient discord, ulong id) @@ -79,11 +79,11 @@ namespace Discord.Rest public Task SendMessageAsync(string text, bool isTTS = false, Embed embed = null, RequestOptions options = null) => ChannelHelper.SendMessageAsync(this, Discord, text, isTTS, embed, options); #if FILESYSTEM - public Task SendFileAsync(string filePath, string text, bool isTTS = false, RequestOptions options = null) - => ChannelHelper.SendFileAsync(this, Discord, filePath, text, isTTS, options); + public Task SendFileAsync(string filePath, string text, bool isTTS = false, Embed embed = null, RequestOptions options = null) + => ChannelHelper.SendFileAsync(this, Discord, filePath, text, isTTS, embed, options); #endif - public Task SendFileAsync(Stream stream, string filename, string text, bool isTTS = false, RequestOptions options = null) - => ChannelHelper.SendFileAsync(this, Discord, stream, filename, text, isTTS, options); + public Task SendFileAsync(Stream stream, string filename, string text, bool isTTS = false, Embed embed = null, RequestOptions options = null) + => ChannelHelper.SendFileAsync(this, Discord, stream, filename, text, isTTS, embed, options); public Task TriggerTypingAsync(RequestOptions options = null) => ChannelHelper.TriggerTypingAsync(this, Discord, options); @@ -132,11 +132,11 @@ namespace Discord.Rest => await GetPinnedMessagesAsync(options).ConfigureAwait(false); #if FILESYSTEM - async Task IMessageChannel.SendFileAsync(string filePath, string text, bool isTTS, RequestOptions options) - => await SendFileAsync(filePath, text, isTTS, options).ConfigureAwait(false); + async Task IMessageChannel.SendFileAsync(string filePath, string text, bool isTTS, Embed embed, RequestOptions options) + => await SendFileAsync(filePath, text, isTTS, embed, options).ConfigureAwait(false); #endif - async Task IMessageChannel.SendFileAsync(Stream stream, string filename, string text, bool isTTS, RequestOptions options) - => await SendFileAsync(stream, filename, text, isTTS, options).ConfigureAwait(false); + async Task IMessageChannel.SendFileAsync(Stream stream, string filename, string text, bool isTTS, Embed embed, RequestOptions options) + => await SendFileAsync(stream, filename, text, isTTS, embed, options).ConfigureAwait(false); async Task IMessageChannel.SendMessageAsync(string text, bool isTTS, Embed embed, RequestOptions options) => await SendMessageAsync(text, isTTS, embed, options).ConfigureAwait(false); IDisposable IMessageChannel.EnterTypingState(RequestOptions options) diff --git a/src/Discord.Net.Rest/Entities/Channels/RestTextChannel.cs b/src/Discord.Net.Rest/Entities/Channels/RestTextChannel.cs index 9c29624c1..600b197d6 100644 --- a/src/Discord.Net.Rest/Entities/Channels/RestTextChannel.cs +++ b/src/Discord.Net.Rest/Entities/Channels/RestTextChannel.cs @@ -1,4 +1,4 @@ -using System; +using System; using System.Collections.Generic; using System.Diagnostics; using System.IO; @@ -61,11 +61,11 @@ namespace Discord.Rest public Task SendMessageAsync(string text, bool isTTS = false, Embed embed = null, RequestOptions options = null) => ChannelHelper.SendMessageAsync(this, Discord, text, isTTS, embed, options); #if FILESYSTEM - public Task SendFileAsync(string filePath, string text, bool isTTS = false, RequestOptions options = null) - => ChannelHelper.SendFileAsync(this, Discord, filePath, text, isTTS, options); + public Task SendFileAsync(string filePath, string text, bool isTTS = false, Embed embed = null, RequestOptions options = null) + => ChannelHelper.SendFileAsync(this, Discord, filePath, text, isTTS, embed, options); #endif - public Task SendFileAsync(Stream stream, string filename, string text, bool isTTS = false, RequestOptions options = null) - => ChannelHelper.SendFileAsync(this, Discord, stream, filename, text, isTTS, options); + public Task SendFileAsync(Stream stream, string filename, string text, bool isTTS = false, Embed embed = null, RequestOptions options = null) + => ChannelHelper.SendFileAsync(this, Discord, stream, filename, text, isTTS, embed, options); public Task DeleteMessagesAsync(IEnumerable messages, RequestOptions options = null) => ChannelHelper.DeleteMessagesAsync(this, Discord, messages.Select(x => x.Id), options); @@ -123,18 +123,18 @@ namespace Discord.Rest else return AsyncEnumerable.Empty>(); } - async Task> IMessageChannel.GetPinnedMessagesAsync(RequestOptions options) + async Task> IMessageChannel.GetPinnedMessagesAsync(RequestOptions options) => await GetPinnedMessagesAsync(options).ConfigureAwait(false); #if FILESYSTEM - async Task IMessageChannel.SendFileAsync(string filePath, string text, bool isTTS, RequestOptions options) - => await SendFileAsync(filePath, text, isTTS, options).ConfigureAwait(false); + async Task IMessageChannel.SendFileAsync(string filePath, string text, bool isTTS, Embed embed, RequestOptions options) + => await SendFileAsync(filePath, text, isTTS, embed, options).ConfigureAwait(false); #endif - async Task IMessageChannel.SendFileAsync(Stream stream, string filename, string text, bool isTTS, RequestOptions options) - => await SendFileAsync(stream, filename, text, isTTS, options).ConfigureAwait(false); - async Task IMessageChannel.SendMessageAsync(string text, bool isTTS, Embed embed, RequestOptions options) + async Task IMessageChannel.SendFileAsync(Stream stream, string filename, string text, bool isTTS, Embed embed, RequestOptions options) + => await SendFileAsync(stream, filename, text, isTTS, embed, options).ConfigureAwait(false); + async Task IMessageChannel.SendMessageAsync(string text, bool isTTS, Embed embed, RequestOptions options) => await SendMessageAsync(text, isTTS, embed, options).ConfigureAwait(false); - IDisposable IMessageChannel.EnterTypingState(RequestOptions options) + IDisposable IMessageChannel.EnterTypingState(RequestOptions options) => EnterTypingState(options); //IGuildChannel diff --git a/src/Discord.Net.Rest/Entities/Channels/RpcVirtualMessageChannel.cs b/src/Discord.Net.Rest/Entities/Channels/RpcVirtualMessageChannel.cs index eb807423f..3b6a68193 100644 --- a/src/Discord.Net.Rest/Entities/Channels/RpcVirtualMessageChannel.cs +++ b/src/Discord.Net.Rest/Entities/Channels/RpcVirtualMessageChannel.cs @@ -1,4 +1,4 @@ -using System; +using System; using System.Collections.Generic; using System.Diagnostics; using System.IO; @@ -21,7 +21,7 @@ namespace Discord.Rest { return new RestVirtualMessageChannel(discord, id); } - + public Task GetMessageAsync(ulong id, RequestOptions options = null) => ChannelHelper.GetMessageAsync(this, Discord, id, options); public IAsyncEnumerable> GetMessagesAsync(int limit = DiscordConfig.MaxMessagesPerBatch, RequestOptions options = null) @@ -36,11 +36,11 @@ namespace Discord.Rest public Task SendMessageAsync(string text, bool isTTS, Embed embed = null, RequestOptions options = null) => ChannelHelper.SendMessageAsync(this, Discord, text, isTTS, embed, options); #if FILESYSTEM - public Task SendFileAsync(string filePath, string text, bool isTTS, RequestOptions options = null) - => ChannelHelper.SendFileAsync(this, Discord, filePath, text, isTTS, options); + public Task SendFileAsync(string filePath, string text, bool isTTS, Embed embed = null, RequestOptions options = null) + => ChannelHelper.SendFileAsync(this, Discord, filePath, text, isTTS, embed, options); #endif - public Task SendFileAsync(Stream stream, string filename, string text, bool isTTS, RequestOptions options = null) - => ChannelHelper.SendFileAsync(this, Discord, stream, filename, text, isTTS, options); + public Task SendFileAsync(Stream stream, string filename, string text, bool isTTS, Embed embed = null, RequestOptions options = null) + => ChannelHelper.SendFileAsync(this, Discord, stream, filename, text, isTTS, embed, options); public Task TriggerTypingAsync(RequestOptions options = null) => ChannelHelper.TriggerTypingAsync(this, Discord, options); @@ -82,11 +82,11 @@ namespace Discord.Rest => await GetPinnedMessagesAsync(options); #if FILESYSTEM - async Task IMessageChannel.SendFileAsync(string filePath, string text, bool isTTS, RequestOptions options) - => await SendFileAsync(filePath, text, isTTS, options); + async Task IMessageChannel.SendFileAsync(string filePath, string text, bool isTTS, Embed embed, RequestOptions options) + => await SendFileAsync(filePath, text, isTTS, embed, options); #endif - async Task IMessageChannel.SendFileAsync(Stream stream, string filename, string text, bool isTTS, RequestOptions options) - => await SendFileAsync(stream, filename, text, isTTS, options); + async Task IMessageChannel.SendFileAsync(Stream stream, string filename, string text, bool isTTS, Embed embed, RequestOptions options) + => await SendFileAsync(stream, filename, text, isTTS, embed, options); async Task IMessageChannel.SendMessageAsync(string text, bool isTTS, Embed embed, RequestOptions options) => await SendMessageAsync(text, isTTS, embed, options); IDisposable IMessageChannel.EnterTypingState(RequestOptions options) diff --git a/src/Discord.Net.Rest/Entities/Messages/MessageHelper.cs b/src/Discord.Net.Rest/Entities/Messages/MessageHelper.cs index 47bb6f926..8ae41cc37 100644 --- a/src/Discord.Net.Rest/Entities/Messages/MessageHelper.cs +++ b/src/Discord.Net.Rest/Entities/Messages/MessageHelper.cs @@ -1,4 +1,4 @@ -using Discord.API.Rest; +using Discord.API.Rest; using System; using System.Collections.Generic; using System.Collections.Immutable; @@ -13,6 +13,9 @@ namespace Discord.Rest public static async Task ModifyAsync(IMessage msg, BaseDiscordClient client, Action func, RequestOptions options) { + if (msg.Author.Id != client.CurrentUser.Id) + throw new InvalidOperationException("Only the author of a message may change it."); + var args = new MessageProperties(); func(args); var apiArgs = new API.Rest.ModifyMessageParams diff --git a/src/Discord.Net.Rest/Entities/Messages/RestUserMessage.cs b/src/Discord.Net.Rest/Entities/Messages/RestUserMessage.cs index e5eed874e..0d1f3be2b 100644 --- a/src/Discord.Net.Rest/Entities/Messages/RestUserMessage.cs +++ b/src/Discord.Net.Rest/Entities/Messages/RestUserMessage.cs @@ -1,4 +1,4 @@ -using System; +using System; using System.Collections.Generic; using System.Collections.Immutable; using System.Diagnostics; diff --git a/src/Discord.Net.Rest/Entities/Users/RestUser.cs b/src/Discord.Net.Rest/Entities/Users/RestUser.cs index c6cf6103a..c484986b1 100644 --- a/src/Discord.Net.Rest/Entities/Users/RestUser.cs +++ b/src/Discord.Net.Rest/Entities/Users/RestUser.cs @@ -1,4 +1,4 @@ -using System; +using System; using System.Diagnostics; using System.Threading.Tasks; using Model = Discord.API.User; @@ -60,6 +60,9 @@ namespace Discord.Rest public string GetAvatarUrl(ImageFormat format = ImageFormat.Auto, ushort size = 128) => CDN.GetUserAvatarUrl(Id, AvatarId, size, format); + public string GetDefaultAvatarUrl() + => CDN.GetDefaultUserAvatarUrl(DiscriminatorValue); + public override string ToString() => $"{Username}#{Discriminator}"; private string DebuggerDisplay => $"{Username}#{Discriminator} ({Id}{(IsBot ? ", Bot" : "")})"; diff --git a/src/Discord.Net.Rest/Net/Converters/OptionalConverter.cs b/src/Discord.Net.Rest/Net/Converters/OptionalConverter.cs index 18b2a9e1c..d3d6191c0 100644 --- a/src/Discord.Net.Rest/Net/Converters/OptionalConverter.cs +++ b/src/Discord.Net.Rest/Net/Converters/OptionalConverter.cs @@ -1,4 +1,4 @@ -using Newtonsoft.Json; +using Newtonsoft.Json; using System; namespace Discord.Net.Converters @@ -19,10 +19,18 @@ namespace Discord.Net.Converters public override object ReadJson(JsonReader reader, Type objectType, object existingValue, JsonSerializer serializer) { T obj; + // custom converters need to be able to safely fail; move this check in here to prevent wasteful casting when parsing primitives if (_innerConverter != null) - obj = (T)_innerConverter.ReadJson(reader, typeof(T), null, serializer); + { + object o = _innerConverter.ReadJson(reader, typeof(T), null, serializer); + if (o is Optional) + return o; + + obj = (T)o; + } else obj = serializer.Deserialize(reader); + return new Optional(obj); } diff --git a/src/Discord.Net.Rest/Net/Converters/UnixTimestampConverter.cs b/src/Discord.Net.Rest/Net/Converters/UnixTimestampConverter.cs index d4660dc44..0b50cb166 100644 --- a/src/Discord.Net.Rest/Net/Converters/UnixTimestampConverter.cs +++ b/src/Discord.Net.Rest/Net/Converters/UnixTimestampConverter.cs @@ -1,4 +1,4 @@ -using System; +using System; using Newtonsoft.Json; namespace Discord.Net.Converters @@ -11,13 +11,18 @@ namespace Discord.Net.Converters public override bool CanRead => true; public override bool CanWrite => true; + // 1e13 unix ms = year 2286 + // necessary to prevent discord.js from sending values in the e15 and overflowing a DTO + private const long MaxSaneMs = 1_000_000_000_000_0; + public override object ReadJson(JsonReader reader, Type objectType, object existingValue, JsonSerializer serializer) { - // Discord doesn't validate if timestamps contain decimals or not - if (reader.Value is double d) + // Discord doesn't validate if timestamps contain decimals or not, and they also don't validate if timestamps are reasonably sized + if (reader.Value is double d && d < MaxSaneMs) return new DateTimeOffset(1970, 1, 1, 0, 0, 0, 0, TimeSpan.Zero).AddMilliseconds(d); - long offset = (long)reader.Value; - return new DateTimeOffset(1970, 1, 1, 0, 0, 0, 0, TimeSpan.Zero).AddMilliseconds(offset); + else if (reader.Value is long l && l < MaxSaneMs) + return new DateTimeOffset(1970, 1, 1, 0, 0, 0, 0, TimeSpan.Zero).AddMilliseconds(l); + return Optional.Unspecified; } public override void WriteJson(JsonWriter writer, object value, JsonSerializer serializer) @@ -25,4 +30,4 @@ namespace Discord.Net.Converters throw new NotImplementedException(); } } -} \ No newline at end of file +} diff --git a/src/Discord.Net.Rest/Net/DefaultRestClient.cs b/src/Discord.Net.Rest/Net/DefaultRestClient.cs index 637099fd6..ec789be59 100644 --- a/src/Discord.Net.Rest/Net/DefaultRestClient.cs +++ b/src/Discord.Net.Rest/Net/DefaultRestClient.cs @@ -1,4 +1,5 @@ -using Newtonsoft.Json; +using Discord.Net.Converters; +using Newtonsoft.Json; using System; using System.Collections.Generic; using System.Globalization; diff --git a/src/Discord.Net.Rest/Net/Queue/RequestQueueBucket.cs b/src/Discord.Net.Rest/Net/Queue/RequestQueueBucket.cs index 2cc4b8a10..2d96ca796 100644 --- a/src/Discord.Net.Rest/Net/Queue/RequestQueueBucket.cs +++ b/src/Discord.Net.Rest/Net/Queue/RequestQueueBucket.cs @@ -1,4 +1,4 @@ -using Newtonsoft.Json; +using Newtonsoft.Json; using Newtonsoft.Json.Linq; using System; #if DEBUG_LIMITS @@ -86,7 +86,7 @@ namespace Discord.Net.Queue Debug.WriteLine($"[{id}] (!) 502"); #endif if ((request.Options.RetryMode & RetryMode.Retry502) == 0) - throw new HttpException(HttpStatusCode.BadGateway, null); + throw new HttpException(HttpStatusCode.BadGateway, request, null); continue; //Retry default: @@ -106,7 +106,7 @@ namespace Discord.Net.Queue } catch { } } - throw new HttpException(response.StatusCode, code, reason); + throw new HttpException(response.StatusCode, request, code, reason); } } else @@ -163,7 +163,7 @@ namespace Discord.Net.Queue if (!isRateLimited) throw new TimeoutException(); else - throw new RateLimitedException(); + throw new RateLimitedException(request); } lock (_lock) @@ -182,12 +182,12 @@ namespace Discord.Net.Queue } if ((request.Options.RetryMode & RetryMode.RetryRatelimit) == 0) - throw new RateLimitedException(); + throw new RateLimitedException(request); if (resetAt.HasValue) { if (resetAt > timeoutAt) - throw new RateLimitedException(); + throw new RateLimitedException(request); int millis = (int)Math.Ceiling((resetAt.Value - DateTimeOffset.UtcNow).TotalMilliseconds); #if DEBUG_LIMITS Debug.WriteLine($"[{id}] Sleeping {millis} ms (Pre-emptive)"); @@ -198,7 +198,7 @@ namespace Discord.Net.Queue else { if ((timeoutAt.Value - DateTimeOffset.UtcNow).TotalMilliseconds < 500.0) - throw new RateLimitedException(); + throw new RateLimitedException(request); #if DEBUG_LIMITS Debug.WriteLine($"[{id}] Sleeping 500* ms (Pre-emptive)"); #endif diff --git a/src/Discord.Net.Rest/Net/Queue/Requests/RestRequest.cs b/src/Discord.Net.Rest/Net/Queue/Requests/RestRequest.cs index 8f160273a..bb5840ce2 100644 --- a/src/Discord.Net.Rest/Net/Queue/Requests/RestRequest.cs +++ b/src/Discord.Net.Rest/Net/Queue/Requests/RestRequest.cs @@ -1,11 +1,11 @@ -using Discord.Net.Rest; +using Discord.Net.Rest; using System; using System.IO; using System.Threading.Tasks; namespace Discord.Net.Queue { - public class RestRequest + public class RestRequest : IRequest { public IRestClient Client { get; } public string Method { get; } diff --git a/src/Discord.Net.Rest/Net/Queue/Requests/WebSocketRequest.cs b/src/Discord.Net.Rest/Net/Queue/Requests/WebSocketRequest.cs index 478289b59..81eb40b31 100644 --- a/src/Discord.Net.Rest/Net/Queue/Requests/WebSocketRequest.cs +++ b/src/Discord.Net.Rest/Net/Queue/Requests/WebSocketRequest.cs @@ -1,4 +1,4 @@ -using Discord.Net.WebSockets; +using Discord.Net.WebSockets; using System; using System.IO; using System.Threading; @@ -6,7 +6,7 @@ using System.Threading.Tasks; namespace Discord.Net.Queue { - public class WebSocketRequest + public class WebSocketRequest : IRequest { public IWebSocketClient Client { get; } public string BucketId { get; } diff --git a/src/Discord.Net.WebSocket/BaseSocketClient.cs b/src/Discord.Net.WebSocket/BaseSocketClient.cs index 2ab244aeb..923b2c23b 100644 --- a/src/Discord.Net.WebSocket/BaseSocketClient.cs +++ b/src/Discord.Net.WebSocket/BaseSocketClient.cs @@ -1,4 +1,4 @@ -using System.Collections.Generic; +using System.Collections.Generic; using System.IO; using System.Threading.Tasks; using Discord.API; @@ -44,7 +44,7 @@ namespace Discord.WebSocket /// public abstract Task StopAsync(); public abstract Task SetStatusAsync(UserStatus status); - public abstract Task SetGameAsync(string name, string streamUrl = null, StreamType streamType = StreamType.NotStreaming); + public abstract Task SetGameAsync(string name, string streamUrl = null, ActivityType type = ActivityType.Playing); public abstract Task SetActivityAsync(IActivity activity); public abstract Task DownloadUsersAsync(IEnumerable guilds); diff --git a/src/Discord.Net.WebSocket/DiscordShardedClient.cs b/src/Discord.Net.WebSocket/DiscordShardedClient.cs index 4e99ae28d..ad89067df 100644 --- a/src/Discord.Net.WebSocket/DiscordShardedClient.cs +++ b/src/Discord.Net.WebSocket/DiscordShardedClient.cs @@ -75,8 +75,8 @@ namespace Discord.WebSocket { if (_automaticShards) { - var response = await ApiClient.GetBotGatewayAsync().ConfigureAwait(false); - _shardIds = Enumerable.Range(0, response.Shards).ToArray(); + var shardCount = await GetRecommendedShardCountAsync().ConfigureAwait(false); + _shardIds = Enumerable.Range(0, shardCount).ToArray(); _totalShards = _shardIds.Length; _shards = new DiscordSocketClient[_shardIds.Length]; for (int i = 0; i < _shardIds.Length; i++) @@ -238,13 +238,13 @@ namespace Discord.WebSocket for (int i = 0; i < _shards.Length; i++) await _shards[i].SetStatusAsync(status).ConfigureAwait(false); } - public override async Task SetGameAsync(string name, string streamUrl = null, StreamType streamType = StreamType.NotStreaming) + public override async Task SetGameAsync(string name, string streamUrl = null, ActivityType type = ActivityType.Playing) { IActivity activity = null; - if (streamUrl != null) - activity = new StreamingGame(name, streamUrl, streamType); - else if (name != null) - activity = new Game(name); + if (!string.IsNullOrEmpty(streamUrl)) + activity = new StreamingGame(name, streamUrl); + else if (!string.IsNullOrEmpty(name)) + activity = new Game(name, type); await SetActivityAsync(activity).ConfigureAwait(false); } public override async Task SetActivityAsync(IActivity activity) diff --git a/src/Discord.Net.WebSocket/DiscordSocketApiClient.cs b/src/Discord.Net.WebSocket/DiscordSocketApiClient.cs index 72781204c..8ae41cc59 100644 --- a/src/Discord.Net.WebSocket/DiscordSocketApiClient.cs +++ b/src/Discord.Net.WebSocket/DiscordSocketApiClient.cs @@ -1,4 +1,4 @@ -#pragma warning disable CS1591 +#pragma warning disable CS1591 using Discord.API.Gateway; using Discord.API.Rest; using Discord.Net.Queue; @@ -207,18 +207,7 @@ namespace Discord.API await RequestQueue.SendAsync(new WebSocketRequest(WebSocketClient, null, bytes, true, options)).ConfigureAwait(false); await _sentGatewayMessageEvent.InvokeAsync(opCode).ConfigureAwait(false); } - - //Gateway - public async Task GetGatewayAsync(RequestOptions options = null) - { - options = RequestOptions.CreateOrClone(options); - return await SendAsync("GET", () => "gateway", new BucketIds(), options: options).ConfigureAwait(false); - } - public async Task GetBotGatewayAsync(RequestOptions options = null) - { - options = RequestOptions.CreateOrClone(options); - return await SendAsync("GET", () => "gateway/bot", new BucketIds(), options: options).ConfigureAwait(false); - } + public async Task SendIdentifyAsync(int largeThreshold = 100, int shardID = 0, int totalShards = 1, RequestOptions options = null) { options = RequestOptions.CreateOrClone(options); diff --git a/src/Discord.Net.WebSocket/DiscordSocketClient.cs b/src/Discord.Net.WebSocket/DiscordSocketClient.cs index cb3f23c57..593f796c2 100644 --- a/src/Discord.Net.WebSocket/DiscordSocketClient.cs +++ b/src/Discord.Net.WebSocket/DiscordSocketClient.cs @@ -1,4 +1,3 @@ -#pragma warning disable CS0618 using Discord.API; using Discord.API.Gateway; using Discord.Logging; @@ -326,12 +325,12 @@ namespace Discord.WebSocket _statusSince = null; await SendStatusAsync().ConfigureAwait(false); } - public override async Task SetGameAsync(string name, string streamUrl = null, StreamType streamType = StreamType.NotStreaming) + public override async Task SetGameAsync(string name, string streamUrl = null, ActivityType type = ActivityType.Playing) { if (!string.IsNullOrEmpty(streamUrl)) - Activity = new StreamingGame(name, streamUrl, streamType); + Activity = new StreamingGame(name, streamUrl); else if (!string.IsNullOrEmpty(name)) - Activity = new Game(name); + Activity = new Game(name, type); else Activity = null; await SendStatusAsync().ConfigureAwait(false); @@ -354,15 +353,13 @@ namespace Discord.WebSocket // Discord only accepts rich presence over RPC, don't even bother building a payload if (Activity is RichGame game) throw new NotSupportedException("Outgoing Rich Presences are not supported"); - else if (Activity is StreamingGame stream) - { - gameModel.StreamUrl = stream.Url; - gameModel.StreamType = stream.StreamType; - } - else if (Activity != null) + + if (Activity != null) { gameModel.Name = Activity.Name; - gameModel.StreamType = StreamType.NotStreaming; + gameModel.Type = Activity.Type; + if (Activity is StreamingGame streamGame) + gameModel.StreamUrl = streamGame.Url; } await ApiClient.SendStatusUpdateAsync( @@ -418,11 +415,8 @@ namespace Discord.WebSocket _sessionId = null; _lastSeq = 0; - bool retry = (bool)payload; - if (retry) - _connection.Reconnect(); //TODO: Untested - else - await ApiClient.SendIdentifyAsync(shardID: ShardId, totalShards: TotalShards).ConfigureAwait(false); + + await ApiClient.SendIdentifyAsync(shardID: ShardId, totalShards: TotalShards).ConfigureAwait(false); } break; case GatewayOpCode.Reconnect: @@ -451,7 +445,7 @@ namespace Discord.WebSocket { var model = data.Guilds[i]; var guild = AddGuild(model, state); - if (!guild.IsAvailable || ApiClient.AuthTokenType == TokenType.User) + if (!guild.IsAvailable) unavailableGuilds++; else await GuildAvailableAsync(guild).ConfigureAwait(false); @@ -470,9 +464,6 @@ namespace Discord.WebSocket return; } - if (ApiClient.AuthTokenType == TokenType.User) - await SyncGuildsAsync().ConfigureAwait(false); - _lastGuildAvailableTime = Environment.TickCount; _guildDownloadTask = WaitForGuildsAsync(_connection.CancelToken, _gatewayLogger) .ContinueWith(async x => @@ -547,8 +538,6 @@ namespace Discord.WebSocket var guild = AddGuild(data, State); if (guild != null) { - if (ApiClient.AuthTokenType == TokenType.User) - await SyncGuildsAsync().ConfigureAwait(false); await TimedInvokeAsync(_joinedGuildEvent, nameof(JoinedGuild), guild).ConfigureAwait(false); } else diff --git a/src/Discord.Net.WebSocket/Entities/Channels/ISocketMessageChannel.cs b/src/Discord.Net.WebSocket/Entities/Channels/ISocketMessageChannel.cs index e2119e7a2..026bd8378 100644 --- a/src/Discord.Net.WebSocket/Entities/Channels/ISocketMessageChannel.cs +++ b/src/Discord.Net.WebSocket/Entities/Channels/ISocketMessageChannel.cs @@ -1,4 +1,4 @@ -using Discord.Rest; +using Discord.Rest; using System.Collections.Generic; using System.IO; using System.Threading.Tasks; @@ -14,10 +14,10 @@ namespace Discord.WebSocket new Task SendMessageAsync(string text, bool isTTS = false, Embed embed = null, RequestOptions options = null); #if FILESYSTEM /// Sends a file to this text channel, with an optional caption. - new Task SendFileAsync(string filePath, string text = null, bool isTTS = false, RequestOptions options = null); + new Task SendFileAsync(string filePath, string text = null, bool isTTS = false, Embed embed = null, RequestOptions options = null); #endif /// Sends a file to this text channel, with an optional caption. - new Task SendFileAsync(Stream stream, string filename, string text = null, bool isTTS = false, RequestOptions options = null); + new Task SendFileAsync(Stream stream, string filename, string text = null, bool isTTS = false, Embed embed = null, RequestOptions options = null); SocketMessage GetCachedMessage(ulong id); /// Gets the last N messages from this message channel. diff --git a/src/Discord.Net.WebSocket/Entities/Channels/SocketCategoryChannel.cs b/src/Discord.Net.WebSocket/Entities/Channels/SocketCategoryChannel.cs index d5a183b1e..e7a165c2f 100644 --- a/src/Discord.Net.WebSocket/Entities/Channels/SocketCategoryChannel.cs +++ b/src/Discord.Net.WebSocket/Entities/Channels/SocketCategoryChannel.cs @@ -1,4 +1,4 @@ -using System; +using System; using System.Collections.Generic; using System.Collections.Immutable; using System.Diagnostics; @@ -15,10 +15,12 @@ namespace Discord.WebSocket public class SocketCategoryChannel : SocketGuildChannel, ICategoryChannel { public override IReadOnlyCollection Users - => Guild.Users.Where(x => x.VoiceChannel?.Id == Id).ToImmutableArray(); + => Guild.Users.Where(x => Permissions.GetValue( + Permissions.ResolveChannel(Guild, x, this, Permissions.ResolveGuild(Guild, x)), + ChannelPermission.ViewChannel)).ToImmutableArray(); public IReadOnlyCollection Channels - => Guild.Channels.Where(x => x.CategoryId == CategoryId).ToImmutableArray(); + => Guild.Channels.Where(x => x.CategoryId == Id).ToImmutableArray(); internal SocketCategoryChannel(DiscordSocketClient discord, ulong id, SocketGuild guild) : base(discord, id, guild) @@ -31,14 +33,28 @@ namespace Discord.WebSocket return entity; } + //Users + public override SocketGuildUser GetUser(ulong id) + { + var user = Guild.GetUser(id); + if (user != null) + { + var guildPerms = Permissions.ResolveGuild(Guild, user); + var channelPerms = Permissions.ResolveChannel(Guild, user, this, guildPerms); + if (Permissions.GetValue(channelPerms, ChannelPermission.ViewChannel)) + return user; + } + return null; + } + private string DebuggerDisplay => $"{Name} ({Id}, Category)"; internal new SocketCategoryChannel Clone() => MemberwiseClone() as SocketCategoryChannel; // IGuildChannel IAsyncEnumerable> IGuildChannel.GetUsersAsync(CacheMode mode, RequestOptions options) - => throw new NotSupportedException(); + => ImmutableArray.Create>(Users).ToAsyncEnumerable(); Task IGuildChannel.GetUserAsync(ulong id, CacheMode mode, RequestOptions options) - => throw new NotSupportedException(); + => Task.FromResult(GetUser(id)); Task IGuildChannel.CreateInviteAsync(int? maxAge, int? maxUses, bool isTemporary, bool isUnique, RequestOptions options) => throw new NotSupportedException(); Task> IGuildChannel.GetInvitesAsync(RequestOptions options) @@ -46,8 +62,8 @@ namespace Discord.WebSocket //IChannel IAsyncEnumerable> IChannel.GetUsersAsync(CacheMode mode, RequestOptions options) - => throw new NotSupportedException(); + => ImmutableArray.Create>(Users).ToAsyncEnumerable(); Task IChannel.GetUserAsync(ulong id, CacheMode mode, RequestOptions options) - => throw new NotSupportedException(); + => Task.FromResult(GetUser(id)); } } diff --git a/src/Discord.Net.WebSocket/Entities/Channels/SocketDMChannel.cs b/src/Discord.Net.WebSocket/Entities/Channels/SocketDMChannel.cs index 00cef60f8..451240e66 100644 --- a/src/Discord.Net.WebSocket/Entities/Channels/SocketDMChannel.cs +++ b/src/Discord.Net.WebSocket/Entities/Channels/SocketDMChannel.cs @@ -1,4 +1,4 @@ -using Discord.Rest; +using Discord.Rest; using System; using System.Collections.Generic; using System.Collections.Immutable; @@ -70,11 +70,11 @@ namespace Discord.WebSocket public Task SendMessageAsync(string text, bool isTTS = false, Embed embed = null, RequestOptions options = null) => ChannelHelper.SendMessageAsync(this, Discord, text, isTTS, embed, options); #if FILESYSTEM - public Task SendFileAsync(string filePath, string text, bool isTTS = false, RequestOptions options = null) - => ChannelHelper.SendFileAsync(this, Discord, filePath, text, isTTS, options); + public Task SendFileAsync(string filePath, string text, bool isTTS = false, Embed embed = null, RequestOptions options = null) + => ChannelHelper.SendFileAsync(this, Discord, filePath, text, isTTS, embed, options); #endif - public Task SendFileAsync(Stream stream, string filename, string text, bool isTTS = false, RequestOptions options = null) - => ChannelHelper.SendFileAsync(this, Discord, stream, filename, text, isTTS, options); + public Task SendFileAsync(Stream stream, string filename, string text, bool isTTS = false, Embed embed = null, RequestOptions options = null) + => ChannelHelper.SendFileAsync(this, Discord, stream, filename, text, isTTS, embed, options); public Task TriggerTypingAsync(RequestOptions options = null) => ChannelHelper.TriggerTypingAsync(this, Discord, options); @@ -113,7 +113,7 @@ namespace Discord.WebSocket //IPrivateChannel IReadOnlyCollection IPrivateChannel.Recipients => ImmutableArray.Create(Recipient); - + //IMessageChannel async Task IMessageChannel.GetMessageAsync(ulong id, CacheMode mode, RequestOptions options) { @@ -131,11 +131,11 @@ namespace Discord.WebSocket async Task> IMessageChannel.GetPinnedMessagesAsync(RequestOptions options) => await GetPinnedMessagesAsync(options).ConfigureAwait(false); #if FILESYSTEM - async Task IMessageChannel.SendFileAsync(string filePath, string text, bool isTTS, RequestOptions options) - => await SendFileAsync(filePath, text, isTTS, options).ConfigureAwait(false); + async Task IMessageChannel.SendFileAsync(string filePath, string text, bool isTTS, Embed embed, RequestOptions options) + => await SendFileAsync(filePath, text, isTTS, embed, options).ConfigureAwait(false); #endif - async Task IMessageChannel.SendFileAsync(Stream stream, string filename, string text, bool isTTS, RequestOptions options) - => await SendFileAsync(stream, filename, text, isTTS, options).ConfigureAwait(false); + async Task IMessageChannel.SendFileAsync(Stream stream, string filename, string text, bool isTTS, Embed embed, RequestOptions options) + => await SendFileAsync(stream, filename, text, isTTS, embed, options).ConfigureAwait(false); async Task IMessageChannel.SendMessageAsync(string text, bool isTTS, Embed embed, RequestOptions options) => await SendMessageAsync(text, isTTS, embed, options).ConfigureAwait(false); IDisposable IMessageChannel.EnterTypingState(RequestOptions options) diff --git a/src/Discord.Net.WebSocket/Entities/Channels/SocketGroupChannel.cs b/src/Discord.Net.WebSocket/Entities/Channels/SocketGroupChannel.cs index 92a93a903..d46c5fc59 100644 --- a/src/Discord.Net.WebSocket/Entities/Channels/SocketGroupChannel.cs +++ b/src/Discord.Net.WebSocket/Entities/Channels/SocketGroupChannel.cs @@ -1,4 +1,4 @@ -using Discord.Audio; +using Discord.Audio; using Discord.Rest; using System; using System.Collections.Concurrent; @@ -61,7 +61,7 @@ namespace Discord.WebSocket users[models[i].Id] = SocketGroupUser.Create(this, state, models[i]); _users = users; } - + public Task LeaveAsync(RequestOptions options = null) => ChannelHelper.DeleteAsync(this, Discord, options); @@ -98,11 +98,11 @@ namespace Discord.WebSocket public Task SendMessageAsync(string text, bool isTTS = false, Embed embed = null, RequestOptions options = null) => ChannelHelper.SendMessageAsync(this, Discord, text, isTTS, embed, options); #if FILESYSTEM - public Task SendFileAsync(string filePath, string text, bool isTTS = false, RequestOptions options = null) - => ChannelHelper.SendFileAsync(this, Discord, filePath, text, isTTS, options); + public Task SendFileAsync(string filePath, string text, bool isTTS = false, Embed embed = null, RequestOptions options = null) + => ChannelHelper.SendFileAsync(this, Discord, filePath, text, isTTS, embed, options); #endif - public Task SendFileAsync(Stream stream, string filename, string text, bool isTTS = false, RequestOptions options = null) - => ChannelHelper.SendFileAsync(this, Discord, stream, filename, text, isTTS, options); + public Task SendFileAsync(Stream stream, string filename, string text, bool isTTS = false, Embed embed = null, RequestOptions options = null) + => ChannelHelper.SendFileAsync(this, Discord, stream, filename, text, isTTS, embed, options); public Task TriggerTypingAsync(RequestOptions options = null) => ChannelHelper.TriggerTypingAsync(this, Discord, options); @@ -195,11 +195,11 @@ namespace Discord.WebSocket async Task> IMessageChannel.GetPinnedMessagesAsync(RequestOptions options) => await GetPinnedMessagesAsync(options).ConfigureAwait(false); #if FILESYSTEM - async Task IMessageChannel.SendFileAsync(string filePath, string text, bool isTTS, RequestOptions options) - => await SendFileAsync(filePath, text, isTTS, options).ConfigureAwait(false); + async Task IMessageChannel.SendFileAsync(string filePath, string text, bool isTTS, Embed embed, RequestOptions options) + => await SendFileAsync(filePath, text, isTTS, embed, options).ConfigureAwait(false); #endif - async Task IMessageChannel.SendFileAsync(Stream stream, string filename, string text, bool isTTS, RequestOptions options) - => await SendFileAsync(stream, filename, text, isTTS, options).ConfigureAwait(false); + async Task IMessageChannel.SendFileAsync(Stream stream, string filename, string text, bool isTTS, Embed embed, RequestOptions options) + => await SendFileAsync(stream, filename, text, isTTS, embed, options).ConfigureAwait(false); async Task IMessageChannel.SendMessageAsync(string text, bool isTTS, Embed embed, RequestOptions options) => await SendMessageAsync(text, isTTS, embed, options).ConfigureAwait(false); IDisposable IMessageChannel.EnterTypingState(RequestOptions options) diff --git a/src/Discord.Net.WebSocket/Entities/Channels/SocketTextChannel.cs b/src/Discord.Net.WebSocket/Entities/Channels/SocketTextChannel.cs index 7b8f572d2..ec7615b55 100644 --- a/src/Discord.Net.WebSocket/Entities/Channels/SocketTextChannel.cs +++ b/src/Discord.Net.WebSocket/Entities/Channels/SocketTextChannel.cs @@ -1,4 +1,4 @@ -using Discord.Rest; +using Discord.Rest; using System; using System.Collections.Generic; using System.Collections.Immutable; @@ -16,7 +16,7 @@ namespace Discord.WebSocket private readonly MessageCache _messages; public string Topic { get; private set; } - + private bool _nsfw; public bool IsNsfw => _nsfw || ChannelHelper.IsNsfw(this); @@ -24,9 +24,9 @@ namespace Discord.WebSocket public IReadOnlyCollection CachedMessages => _messages?.Messages ?? ImmutableArray.Create(); public override IReadOnlyCollection Users => Guild.Users.Where(x => Permissions.GetValue( - Permissions.ResolveChannel(Guild, x, this, Permissions.ResolveGuild(Guild, x)), + Permissions.ResolveChannel(Guild, x, this, Permissions.ResolveGuild(Guild, x)), ChannelPermission.ViewChannel)).ToImmutableArray(); - + internal SocketTextChannel(DiscordSocketClient discord, ulong id, SocketGuild guild) : base(discord, id, guild) { @@ -78,11 +78,11 @@ namespace Discord.WebSocket public Task SendMessageAsync(string text, bool isTTS = false, Embed embed = null, RequestOptions options = null) => ChannelHelper.SendMessageAsync(this, Discord, text, isTTS, embed, options); #if FILESYSTEM - public Task SendFileAsync(string filePath, string text, bool isTTS = false, RequestOptions options = null) - => ChannelHelper.SendFileAsync(this, Discord, filePath, text, isTTS, options); + public Task SendFileAsync(string filePath, string text, bool isTTS = false, Embed embed = null, RequestOptions options = null) + => ChannelHelper.SendFileAsync(this, Discord, filePath, text, isTTS, embed, options); #endif - public Task SendFileAsync(Stream stream, string filename, string text, bool isTTS = false, RequestOptions options = null) - => ChannelHelper.SendFileAsync(this, Discord, stream, filename, text, isTTS, options); + public Task SendFileAsync(Stream stream, string filename, string text, bool isTTS = false, Embed embed = null, RequestOptions options = null) + => ChannelHelper.SendFileAsync(this, Discord, stream, filename, text, isTTS, embed, options); public Task DeleteMessagesAsync(IEnumerable messages, RequestOptions options = null) => ChannelHelper.DeleteMessagesAsync(this, Discord, messages.Select(x => x.Id), options); @@ -155,14 +155,14 @@ namespace Discord.WebSocket async Task> IMessageChannel.GetPinnedMessagesAsync(RequestOptions options) => await GetPinnedMessagesAsync(options).ConfigureAwait(false); #if FILESYSTEM - async Task IMessageChannel.SendFileAsync(string filePath, string text, bool isTTS, RequestOptions options) - => await SendFileAsync(filePath, text, isTTS, options).ConfigureAwait(false); + async Task IMessageChannel.SendFileAsync(string filePath, string text, bool isTTS, Embed embed, RequestOptions options) + => await SendFileAsync(filePath, text, isTTS, embed, options).ConfigureAwait(false); #endif - async Task IMessageChannel.SendFileAsync(Stream stream, string filename, string text, bool isTTS, RequestOptions options) - => await SendFileAsync(stream, filename, text, isTTS, options).ConfigureAwait(false); + async Task IMessageChannel.SendFileAsync(Stream stream, string filename, string text, bool isTTS, Embed embed, RequestOptions options) + => await SendFileAsync(stream, filename, text, isTTS, embed, options).ConfigureAwait(false); async Task IMessageChannel.SendMessageAsync(string text, bool isTTS, Embed embed, RequestOptions options) => await SendMessageAsync(text, isTTS, embed, options).ConfigureAwait(false); IDisposable IMessageChannel.EnterTypingState(RequestOptions options) => EnterTypingState(options); } -} \ No newline at end of file +} diff --git a/src/Discord.Net.WebSocket/Entities/Guilds/SocketGuild.cs b/src/Discord.Net.WebSocket/Entities/Guilds/SocketGuild.cs index ea68a8f54..259dae5a8 100644 --- a/src/Discord.Net.WebSocket/Entities/Guilds/SocketGuild.cs +++ b/src/Discord.Net.WebSocket/Entities/Guilds/SocketGuild.cs @@ -1,4 +1,3 @@ -#pragma warning disable CS0618 using Discord.Audio; using Discord.Rest; using System; @@ -64,7 +63,7 @@ namespace Discord.WebSocket public Task DownloaderPromise => _downloaderPromise.Task; public IAudioClient AudioClient => _audioClient; public SocketTextChannel DefaultChannel => TextChannels - .Where(c => CurrentUser.GetPermissions(c).ReadMessages) + .Where(c => CurrentUser.GetPermissions(c).ViewChannel) .OrderBy(c => c.Position) .FirstOrDefault(); public SocketVoiceChannel AFKChannel @@ -192,12 +191,9 @@ namespace Discord.WebSocket _syncPromise = new TaskCompletionSource(); _downloaderPromise = new TaskCompletionSource(); - if (Discord.ApiClient.AuthTokenType != TokenType.User) - { - var _ = _syncPromise.TrySetResultAsync(true); - /*if (!model.Large) - _ = _downloaderPromise.TrySetResultAsync(true);*/ - } + var _ = _syncPromise.TrySetResultAsync(true); + /*if (!model.Large) + _ = _downloaderPromise.TrySetResultAsync(true);*/ } internal void Update(ClientState state, Model model) { @@ -696,7 +692,6 @@ namespace Discord.WebSocket => Task.FromResult(CurrentUser); Task IGuild.GetOwnerAsync(CacheMode mode, RequestOptions options) => Task.FromResult(Owner); - Task IGuild.DownloadUsersAsync() { throw new NotSupportedException(); } async Task IGuild.GetWebhookAsync(ulong id, RequestOptions options) => await GetWebhookAsync(id, options); diff --git a/src/Discord.Net.WebSocket/Entities/Messages/SocketUserMessage.cs b/src/Discord.Net.WebSocket/Entities/Messages/SocketUserMessage.cs index b240645e5..5489ad2bb 100644 --- a/src/Discord.Net.WebSocket/Entities/Messages/SocketUserMessage.cs +++ b/src/Discord.Net.WebSocket/Entities/Messages/SocketUserMessage.cs @@ -1,4 +1,4 @@ -using Discord.Rest; +using Discord.Rest; using System; using System.Collections.Generic; using System.Collections.Immutable; diff --git a/src/Discord.Net.WebSocket/Entities/Users/SocketUser.cs b/src/Discord.Net.WebSocket/Entities/Users/SocketUser.cs index 58d5c62a1..00899d47e 100644 --- a/src/Discord.Net.WebSocket/Entities/Users/SocketUser.cs +++ b/src/Discord.Net.WebSocket/Entities/Users/SocketUser.cs @@ -1,4 +1,4 @@ -using Discord.Rest; +using Discord.Rest; using System; using System.Threading.Tasks; using Model = Discord.API.User; @@ -37,23 +37,23 @@ namespace Discord.WebSocket { var newVal = ushort.Parse(model.Discriminator.Value); if (newVal != DiscriminatorValue) - { + { DiscriminatorValue = ushort.Parse(model.Discriminator.Value); hasChanges = true; } } if (model.Bot.IsSpecified && model.Bot.Value != IsBot) - { + { IsBot = model.Bot.Value; hasChanges = true; } if (model.Username.IsSpecified && model.Username.Value != Username) - { + { Username = model.Username.Value; hasChanges = true; } return hasChanges; - } + } public async Task GetOrCreateDMChannelAsync(RequestOptions options = null) => GlobalUser.DMChannel ?? await UserHelper.CreateDMChannelAsync(this, Discord, options) as IDMChannel; @@ -61,6 +61,9 @@ namespace Discord.WebSocket public string GetAvatarUrl(ImageFormat format = ImageFormat.Auto, ushort size = 128) => CDN.GetUserAvatarUrl(Id, AvatarId, size, format); + public string GetDefaultAvatarUrl() + => CDN.GetDefaultUserAvatarUrl(DiscriminatorValue); + public override string ToString() => $"{Username}#{Discriminator}"; private string DebuggerDisplay => $"{Username}#{Discriminator} ({Id}{(IsBot ? ", Bot" : "")})"; internal SocketUser Clone() => MemberwiseClone() as SocketUser; diff --git a/src/Discord.Net.WebSocket/Extensions/EntityExtensions.cs b/src/Discord.Net.WebSocket/Extensions/EntityExtensions.cs index c66163610..ff395a932 100644 --- a/src/Discord.Net.WebSocket/Extensions/EntityExtensions.cs +++ b/src/Discord.Net.WebSocket/Extensions/EntityExtensions.cs @@ -1,9 +1,30 @@ -namespace Discord.WebSocket +namespace Discord.WebSocket { internal static class EntityExtensions { public static IActivity ToEntity(this API.Game model) { + // Spotify Game + if (model.SyncId.IsSpecified) + { + var assets = model.Assets.GetValueOrDefault()?.ToEntity(); + string albumText = assets?[1]?.Text; + string albumArtId = assets?[1]?.ImageId?.Replace("spotify:",""); + var timestamps = model.Timestamps.IsSpecified ? model.Timestamps.Value.ToEntity() : null; + return new SpotifyGame + { + Name = model.Name, + SessionId = model.SessionId.GetValueOrDefault(), + SyncId = model.SyncId.Value, + AlbumTitle = albumText, + TrackTitle = model.Details.GetValueOrDefault(), + Artists = model.State.GetValueOrDefault()?.Split(';'), + Duration = timestamps?.End - timestamps?.Start, + AlbumArt = albumArtId != null ? CDN.GetSpotifyAlbumArtUrl(albumArtId) : null, + Type = ActivityType.Listening + }; + } + // Rich Game if (model.ApplicationId.IsSpecified) { @@ -27,15 +48,14 @@ { return new StreamingGame( model.Name, - model.StreamUrl.Value, - model.StreamType.Value.GetValueOrDefault()); + model.StreamUrl.Value); } // Normal Game - return new Game(model.Name); + return new Game(model.Name, model.Type.GetValueOrDefault() ?? ActivityType.Playing); } // (Small, Large) - public static GameAsset[] ToEntity(this API.GameAssets model, ulong appId) + public static GameAsset[] ToEntity(this API.GameAssets model, ulong? appId = null) { return new GameAsset[] { @@ -57,7 +77,7 @@ public static GameParty ToEntity(this API.GameParty model) { // Discord will probably send bad data since they don't validate anything - int current = 0, cap = 0; + long current = 0, cap = 0; if (model.Size?.Length == 2) { current = model.Size[0]; diff --git a/src/Discord.Net.Webhook/WebhookClientHelper.cs b/src/Discord.Net.Webhook/WebhookClientHelper.cs index f3a3984cf..1116662a6 100644 --- a/src/Discord.Net.Webhook/WebhookClientHelper.cs +++ b/src/Discord.Net.Webhook/WebhookClientHelper.cs @@ -49,7 +49,7 @@ namespace Discord.Webhook if (username != null) args.Username = username; if (avatarUrl != null) - args.AvatarUrl = username; + args.AvatarUrl = avatarUrl; if (embeds != null) args.Embeds = embeds.Select(x => x.ToModel()).ToArray(); var msg = await client.ApiClient.UploadWebhookFileAsync(client.Webhook.Id, args, options).ConfigureAwait(false); diff --git a/src/Discord.Net/Discord.Net.nuspec b/src/Discord.Net/Discord.Net.nuspec index cd57d2fcf..2bf531cb1 100644 --- a/src/Discord.Net/Discord.Net.nuspec +++ b/src/Discord.Net/Discord.Net.nuspec @@ -2,7 +2,7 @@ Discord.Net - 2.0.0-beta$suffix$ + 2.0.0-beta2$suffix$ Discord.Net Discord.Net Contributors RogueException @@ -13,25 +13,25 @@ false - - - - - + + + + + - - - - - + + + + + - - - - - + + + + + diff --git a/test/Discord.Net.Tests/AnalyzerTests/Extensions/AppDomainPolyfill.cs b/test/Discord.Net.Tests/AnalyzerTests/Extensions/AppDomainPolyfill.cs new file mode 100644 index 000000000..729bc385c --- /dev/null +++ b/test/Discord.Net.Tests/AnalyzerTests/Extensions/AppDomainPolyfill.cs @@ -0,0 +1,30 @@ +using System.Linq; +using System.Reflection; +using Microsoft.DotNet.PlatformAbstractions; +using Microsoft.Extensions.DependencyModel; + +namespace System +{ + /// Polyfill of the AppDomain class from full framework. + internal class AppDomain + { + public static AppDomain CurrentDomain { get; private set; } + + private AppDomain() + { + } + + static AppDomain() + { + CurrentDomain = new AppDomain(); + } + + public Assembly[] GetAssemblies() + { + var rid = RuntimeEnvironment.GetRuntimeIdentifier(); + var ass = DependencyContext.Default.GetRuntimeAssemblyNames(rid); + + return ass.Select(xan => Assembly.Load(xan)).ToArray(); + } + } +} \ No newline at end of file diff --git a/test/Discord.Net.Tests/AnalyzerTests/GuildAccessTests.cs b/test/Discord.Net.Tests/AnalyzerTests/GuildAccessTests.cs new file mode 100644 index 000000000..073cc1de7 --- /dev/null +++ b/test/Discord.Net.Tests/AnalyzerTests/GuildAccessTests.cs @@ -0,0 +1,111 @@ +using System; +using System.Collections.Generic; +using System.Text; +using System.Linq; +using System.Threading.Tasks; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.Diagnostics; +using Discord.Analyzers; +using TestHelper; +using Xunit; + +namespace Discord +{ + public partial class AnalyserTests + { + public class GuildAccessTests : DiagnosticVerifier + { + [Fact] + public void VerifyDiagnosticWhenLackingRequireContext() + { + string source = @"using System; +using System.Threading.Tasks; +using Discord.Commands; + +namespace Test +{ + public class TestModule : ModuleBase + { + [Command(""test"")] + public Task TestCmd() => ReplyAsync(Context.Guild.Name); + } +}"; + var expected = new DiagnosticResult() + { + Id = "DNET0001", + Locations = new[] { new DiagnosticResultLocation("Test0.cs", line: 10, column: 45) }, + Message = "Command method 'TestCmd' is accessing 'Context.Guild' but is not restricted to Guild contexts.", + Severity = DiagnosticSeverity.Warning + }; + VerifyCSharpDiagnostic(source, expected); + } + + [Fact] + public void VerifyDiagnosticWhenWrongRequireContext() + { + string source = @"using System; +using System.Threading.Tasks; +using Discord.Commands; + +namespace Test +{ + public class TestModule : ModuleBase + { + [Command(""test""), RequireContext(ContextType.Group)] + public Task TestCmd() => ReplyAsync(Context.Guild.Name); + } +}"; + var expected = new DiagnosticResult() + { + Id = "DNET0001", + Locations = new[] { new DiagnosticResultLocation("Test0.cs", line: 10, column: 45) }, + Message = "Command method 'TestCmd' is accessing 'Context.Guild' but is not restricted to Guild contexts.", + Severity = DiagnosticSeverity.Warning + }; + VerifyCSharpDiagnostic(source, expected); + } + + [Fact] + public void VerifyNoDiagnosticWhenRequireContextOnMethod() + { + string source = @"using System; +using System.Threading.Tasks; +using Discord.Commands; + +namespace Test +{ + public class TestModule : ModuleBase + { + [Command(""test""), RequireContext(ContextType.Guild)] + public Task TestCmd() => ReplyAsync(Context.Guild.Name); + } +}"; + + VerifyCSharpDiagnostic(source, Array.Empty()); + } + + [Fact] + public void VerifyNoDiagnosticWhenRequireContextOnClass() + { + string source = @"using System; +using System.Threading.Tasks; +using Discord.Commands; + +namespace Test +{ + [RequireContext(ContextType.Guild)] + public class TestModule : ModuleBase + { + [Command(""test"")] + public Task TestCmd() => ReplyAsync(Context.Guild.Name); + } +}"; + + VerifyCSharpDiagnostic(source, Array.Empty()); + } + + protected override DiagnosticAnalyzer GetCSharpDiagnosticAnalyzer() + => new GuildAccessAnalyzer(); + } + } +} diff --git a/test/Discord.Net.Tests/AnalyzerTests/Helpers/CodeFixVerifier.Helper.cs b/test/Discord.Net.Tests/AnalyzerTests/Helpers/CodeFixVerifier.Helper.cs new file mode 100644 index 000000000..0f73d0643 --- /dev/null +++ b/test/Discord.Net.Tests/AnalyzerTests/Helpers/CodeFixVerifier.Helper.cs @@ -0,0 +1,85 @@ +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CodeActions; +using Microsoft.CodeAnalysis.Formatting; +using Microsoft.CodeAnalysis.Simplification; +using System.Collections.Generic; +using System.Linq; +using System.Threading; + +namespace TestHelper +{ + /// + /// Diagnostic Producer class with extra methods dealing with applying codefixes + /// All methods are static + /// + public abstract partial class CodeFixVerifier : DiagnosticVerifier + { + /// + /// Apply the inputted CodeAction to the inputted document. + /// Meant to be used to apply codefixes. + /// + /// The Document to apply the fix on + /// A CodeAction that will be applied to the Document. + /// A Document with the changes from the CodeAction + private static Document ApplyFix(Document document, CodeAction codeAction) + { + var operations = codeAction.GetOperationsAsync(CancellationToken.None).Result; + var solution = operations.OfType().Single().ChangedSolution; + return solution.GetDocument(document.Id); + } + + /// + /// Compare two collections of Diagnostics,and return a list of any new diagnostics that appear only in the second collection. + /// Note: Considers Diagnostics to be the same if they have the same Ids. In the case of multiple diagnostics with the same Id in a row, + /// this method may not necessarily return the new one. + /// + /// The Diagnostics that existed in the code before the CodeFix was applied + /// The Diagnostics that exist in the code after the CodeFix was applied + /// A list of Diagnostics that only surfaced in the code after the CodeFix was applied + private static IEnumerable GetNewDiagnostics(IEnumerable diagnostics, IEnumerable newDiagnostics) + { + var oldArray = diagnostics.OrderBy(d => d.Location.SourceSpan.Start).ToArray(); + var newArray = newDiagnostics.OrderBy(d => d.Location.SourceSpan.Start).ToArray(); + + int oldIndex = 0; + int newIndex = 0; + + while (newIndex < newArray.Length) + { + if (oldIndex < oldArray.Length && oldArray[oldIndex].Id == newArray[newIndex].Id) + { + ++oldIndex; + ++newIndex; + } + else + { + yield return newArray[newIndex++]; + } + } + } + + /// + /// Get the existing compiler diagnostics on the inputted document. + /// + /// The Document to run the compiler diagnostic analyzers on + /// The compiler diagnostics that were found in the code + private static IEnumerable GetCompilerDiagnostics(Document document) + { + return document.GetSemanticModelAsync().Result.GetDiagnostics(); + } + + /// + /// Given a document, turn it into a string based on the syntax root + /// + /// The Document to be converted to a string + /// A string containing the syntax of the Document after formatting + private static string GetStringFromDocument(Document document) + { + var simplifiedDoc = Simplifier.ReduceAsync(document, Simplifier.Annotation).Result; + var root = simplifiedDoc.GetSyntaxRootAsync().Result; + root = Formatter.Format(root, Formatter.Annotation, simplifiedDoc.Project.Solution.Workspace); + return root.GetText().ToString(); + } + } +} + diff --git a/test/Discord.Net.Tests/AnalyzerTests/Helpers/DiagnosticResult.cs b/test/Discord.Net.Tests/AnalyzerTests/Helpers/DiagnosticResult.cs new file mode 100644 index 000000000..5ae6f528e --- /dev/null +++ b/test/Discord.Net.Tests/AnalyzerTests/Helpers/DiagnosticResult.cs @@ -0,0 +1,87 @@ +using Microsoft.CodeAnalysis; +using System; + +namespace TestHelper +{ + /// + /// Location where the diagnostic appears, as determined by path, line number, and column number. + /// + public struct DiagnosticResultLocation + { + public DiagnosticResultLocation(string path, int line, int column) + { + if (line < -1) + { + throw new ArgumentOutOfRangeException(nameof(line), "line must be >= -1"); + } + + if (column < -1) + { + throw new ArgumentOutOfRangeException(nameof(column), "column must be >= -1"); + } + + this.Path = path; + this.Line = line; + this.Column = column; + } + + public string Path { get; } + public int Line { get; } + public int Column { get; } + } + + /// + /// Struct that stores information about a Diagnostic appearing in a source + /// + public struct DiagnosticResult + { + private DiagnosticResultLocation[] locations; + + public DiagnosticResultLocation[] Locations + { + get + { + if (this.locations == null) + { + this.locations = new DiagnosticResultLocation[] { }; + } + return this.locations; + } + + set + { + this.locations = value; + } + } + + public DiagnosticSeverity Severity { get; set; } + + public string Id { get; set; } + + public string Message { get; set; } + + public string Path + { + get + { + return this.Locations.Length > 0 ? this.Locations[0].Path : ""; + } + } + + public int Line + { + get + { + return this.Locations.Length > 0 ? this.Locations[0].Line : -1; + } + } + + public int Column + { + get + { + return this.Locations.Length > 0 ? this.Locations[0].Column : -1; + } + } + } +} \ No newline at end of file diff --git a/test/Discord.Net.Tests/AnalyzerTests/Helpers/DiagnosticVerifier.Helper.cs b/test/Discord.Net.Tests/AnalyzerTests/Helpers/DiagnosticVerifier.Helper.cs new file mode 100644 index 000000000..7a8eb2e9c --- /dev/null +++ b/test/Discord.Net.Tests/AnalyzerTests/Helpers/DiagnosticVerifier.Helper.cs @@ -0,0 +1,207 @@ +using System; +using System.Collections.Generic; +using System.Collections.Immutable; +using System.Linq; +using System.Reflection; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; +using Microsoft.CodeAnalysis.Diagnostics; +using Microsoft.CodeAnalysis.Text; +using Discord; +using Discord.Commands; + +namespace TestHelper +{ + /// + /// Class for turning strings into documents and getting the diagnostics on them + /// All methods are static + /// + public abstract partial class DiagnosticVerifier + { + private static readonly MetadataReference CorlibReference = MetadataReference.CreateFromFile(typeof(object).GetTypeInfo().Assembly.Location); + private static readonly MetadataReference SystemCoreReference = MetadataReference.CreateFromFile(typeof(Enumerable).GetTypeInfo().Assembly.Location); + private static readonly MetadataReference CSharpSymbolsReference = MetadataReference.CreateFromFile(typeof(CSharpCompilation).GetTypeInfo().Assembly.Location); + private static readonly MetadataReference CodeAnalysisReference = MetadataReference.CreateFromFile(typeof(Compilation).GetTypeInfo().Assembly.Location); + //private static readonly MetadataReference DiscordNetReference = MetadataReference.CreateFromFile(typeof(IDiscordClient).GetTypeInfo().Assembly.Location); + //private static readonly MetadataReference DiscordCommandsReference = MetadataReference.CreateFromFile(typeof(CommandAttribute).GetTypeInfo().Assembly.Location); + private static readonly Assembly DiscordCommandsAssembly = typeof(CommandAttribute).GetTypeInfo().Assembly; + + internal static string DefaultFilePathPrefix = "Test"; + internal static string CSharpDefaultFileExt = "cs"; + internal static string VisualBasicDefaultExt = "vb"; + internal static string TestProjectName = "TestProject"; + + #region Get Diagnostics + + /// + /// Given classes in the form of strings, their language, and an IDiagnosticAnlayzer to apply to it, return the diagnostics found in the string after converting it to a document. + /// + /// Classes in the form of strings + /// The language the source classes are in + /// The analyzer to be run on the sources + /// An IEnumerable of Diagnostics that surfaced in the source code, sorted by Location + private static Diagnostic[] GetSortedDiagnostics(string[] sources, string language, DiagnosticAnalyzer analyzer) + { + return GetSortedDiagnosticsFromDocuments(analyzer, GetDocuments(sources, language)); + } + + /// + /// Given an analyzer and a document to apply it to, run the analyzer and gather an array of diagnostics found in it. + /// The returned diagnostics are then ordered by location in the source document. + /// + /// The analyzer to run on the documents + /// The Documents that the analyzer will be run on + /// An IEnumerable of Diagnostics that surfaced in the source code, sorted by Location + protected static Diagnostic[] GetSortedDiagnosticsFromDocuments(DiagnosticAnalyzer analyzer, Document[] documents) + { + var projects = new HashSet(); + foreach (var document in documents) + { + projects.Add(document.Project); + } + + var diagnostics = new List(); + foreach (var project in projects) + { + var compilationWithAnalyzers = project.GetCompilationAsync().Result.WithAnalyzers(ImmutableArray.Create(analyzer)); + var diags = compilationWithAnalyzers.GetAnalyzerDiagnosticsAsync().Result; + foreach (var diag in diags) + { + if (diag.Location == Location.None || diag.Location.IsInMetadata) + { + diagnostics.Add(diag); + } + else + { + for (int i = 0; i < documents.Length; i++) + { + var document = documents[i]; + var tree = document.GetSyntaxTreeAsync().Result; + if (tree == diag.Location.SourceTree) + { + diagnostics.Add(diag); + } + } + } + } + } + + var results = SortDiagnostics(diagnostics); + diagnostics.Clear(); + return results; + } + + /// + /// Sort diagnostics by location in source document + /// + /// The list of Diagnostics to be sorted + /// An IEnumerable containing the Diagnostics in order of Location + private static Diagnostic[] SortDiagnostics(IEnumerable diagnostics) + { + return diagnostics.OrderBy(d => d.Location.SourceSpan.Start).ToArray(); + } + + #endregion + + #region Set up compilation and documents + /// + /// Given an array of strings as sources and a language, turn them into a project and return the documents and spans of it. + /// + /// Classes in the form of strings + /// The language the source code is in + /// A Tuple containing the Documents produced from the sources and their TextSpans if relevant + private static Document[] GetDocuments(string[] sources, string language) + { + if (language != LanguageNames.CSharp && language != LanguageNames.VisualBasic) + { + throw new ArgumentException("Unsupported Language"); + } + + var project = CreateProject(sources, language); + var documents = project.Documents.ToArray(); + + if (sources.Length != documents.Length) + { + throw new Exception("Amount of sources did not match amount of Documents created"); + } + + return documents; + } + + /// + /// Create a Document from a string through creating a project that contains it. + /// + /// Classes in the form of a string + /// The language the source code is in + /// A Document created from the source string + protected static Document CreateDocument(string source, string language = LanguageNames.CSharp) + { + return CreateProject(new[] { source }, language).Documents.First(); + } + + /// + /// Create a project using the inputted strings as sources. + /// + /// Classes in the form of strings + /// The language the source code is in + /// A Project created out of the Documents created from the source strings + private static Project CreateProject(string[] sources, string language = LanguageNames.CSharp) + { + string fileNamePrefix = DefaultFilePathPrefix; + string fileExt = language == LanguageNames.CSharp ? CSharpDefaultFileExt : VisualBasicDefaultExt; + + var projectId = ProjectId.CreateNewId(debugName: TestProjectName); + + var solution = new AdhocWorkspace() + .CurrentSolution + .AddProject(projectId, TestProjectName, TestProjectName, language) + .AddMetadataReference(projectId, CorlibReference) + .AddMetadataReference(projectId, SystemCoreReference) + .AddMetadataReference(projectId, CSharpSymbolsReference) + .AddMetadataReference(projectId, CodeAnalysisReference) + .AddMetadataReferences(projectId, Transitive(DiscordCommandsAssembly)); + + int count = 0; + foreach (var source in sources) + { + var newFileName = fileNamePrefix + count + "." + fileExt; + var documentId = DocumentId.CreateNewId(projectId, debugName: newFileName); + solution = solution.AddDocument(documentId, newFileName, SourceText.From(source)); + count++; + } + return solution.GetProject(projectId); + } + #endregion + + /// + /// Get the for and all assemblies referenced by + /// + /// The assembly. + /// s. + private static IEnumerable Transitive(Assembly assembly) + { + foreach (var a in RecursiveReferencedAssemblies(assembly)) + { + yield return MetadataReference.CreateFromFile(a.Location); + } + } + + private static HashSet RecursiveReferencedAssemblies(Assembly a, HashSet assemblies = null) + { + assemblies = assemblies ?? new HashSet(); + if (assemblies.Add(a)) + { + foreach (var referencedAssemblyName in a.GetReferencedAssemblies()) + { + var referencedAssembly = AppDomain.CurrentDomain.GetAssemblies() + .SingleOrDefault(x => x.GetName() == referencedAssemblyName) ?? + Assembly.Load(referencedAssemblyName); + RecursiveReferencedAssemblies(referencedAssembly, assemblies); + } + } + + return assemblies; + } + } +} + diff --git a/test/Discord.Net.Tests/AnalyzerTests/Verifiers/CodeFixVerifier.cs b/test/Discord.Net.Tests/AnalyzerTests/Verifiers/CodeFixVerifier.cs new file mode 100644 index 000000000..5d057b610 --- /dev/null +++ b/test/Discord.Net.Tests/AnalyzerTests/Verifiers/CodeFixVerifier.cs @@ -0,0 +1,129 @@ +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CodeActions; +using Microsoft.CodeAnalysis.CodeFixes; +using Microsoft.CodeAnalysis.Diagnostics; +using Microsoft.CodeAnalysis.Formatting; +//using Microsoft.VisualStudio.TestTools.UnitTesting; +using System.Collections.Generic; +using System.Linq; +using System.Threading; +using Xunit; + +namespace TestHelper +{ + /// + /// Superclass of all Unit tests made for diagnostics with codefixes. + /// Contains methods used to verify correctness of codefixes + /// + public abstract partial class CodeFixVerifier : DiagnosticVerifier + { + /// + /// Returns the codefix being tested (C#) - to be implemented in non-abstract class + /// + /// The CodeFixProvider to be used for CSharp code + protected virtual CodeFixProvider GetCSharpCodeFixProvider() + { + return null; + } + + /// + /// Returns the codefix being tested (VB) - to be implemented in non-abstract class + /// + /// The CodeFixProvider to be used for VisualBasic code + protected virtual CodeFixProvider GetBasicCodeFixProvider() + { + return null; + } + + /// + /// Called to test a C# codefix when applied on the inputted string as a source + /// + /// A class in the form of a string before the CodeFix was applied to it + /// A class in the form of a string after the CodeFix was applied to it + /// Index determining which codefix to apply if there are multiple + /// A bool controlling whether or not the test will fail if the CodeFix introduces other warnings after being applied + protected void VerifyCSharpFix(string oldSource, string newSource, int? codeFixIndex = null, bool allowNewCompilerDiagnostics = false) + { + VerifyFix(LanguageNames.CSharp, GetCSharpDiagnosticAnalyzer(), GetCSharpCodeFixProvider(), oldSource, newSource, codeFixIndex, allowNewCompilerDiagnostics); + } + + /// + /// Called to test a VB codefix when applied on the inputted string as a source + /// + /// A class in the form of a string before the CodeFix was applied to it + /// A class in the form of a string after the CodeFix was applied to it + /// Index determining which codefix to apply if there are multiple + /// A bool controlling whether or not the test will fail if the CodeFix introduces other warnings after being applied + protected void VerifyBasicFix(string oldSource, string newSource, int? codeFixIndex = null, bool allowNewCompilerDiagnostics = false) + { + VerifyFix(LanguageNames.VisualBasic, GetBasicDiagnosticAnalyzer(), GetBasicCodeFixProvider(), oldSource, newSource, codeFixIndex, allowNewCompilerDiagnostics); + } + + /// + /// General verifier for codefixes. + /// Creates a Document from the source string, then gets diagnostics on it and applies the relevant codefixes. + /// Then gets the string after the codefix is applied and compares it with the expected result. + /// Note: If any codefix causes new diagnostics to show up, the test fails unless allowNewCompilerDiagnostics is set to true. + /// + /// The language the source code is in + /// The analyzer to be applied to the source code + /// The codefix to be applied to the code wherever the relevant Diagnostic is found + /// A class in the form of a string before the CodeFix was applied to it + /// A class in the form of a string after the CodeFix was applied to it + /// Index determining which codefix to apply if there are multiple + /// A bool controlling whether or not the test will fail if the CodeFix introduces other warnings after being applied + private void VerifyFix(string language, DiagnosticAnalyzer analyzer, CodeFixProvider codeFixProvider, string oldSource, string newSource, int? codeFixIndex, bool allowNewCompilerDiagnostics) + { + var document = CreateDocument(oldSource, language); + var analyzerDiagnostics = GetSortedDiagnosticsFromDocuments(analyzer, new[] { document }); + var compilerDiagnostics = GetCompilerDiagnostics(document); + var attempts = analyzerDiagnostics.Length; + + for (int i = 0; i < attempts; ++i) + { + var actions = new List(); + var context = new CodeFixContext(document, analyzerDiagnostics[0], (a, d) => actions.Add(a), CancellationToken.None); + codeFixProvider.RegisterCodeFixesAsync(context).Wait(); + + if (!actions.Any()) + { + break; + } + + if (codeFixIndex != null) + { + document = ApplyFix(document, actions.ElementAt((int)codeFixIndex)); + break; + } + + document = ApplyFix(document, actions.ElementAt(0)); + analyzerDiagnostics = GetSortedDiagnosticsFromDocuments(analyzer, new[] { document }); + + var newCompilerDiagnostics = GetNewDiagnostics(compilerDiagnostics, GetCompilerDiagnostics(document)); + + //check if applying the code fix introduced any new compiler diagnostics + if (!allowNewCompilerDiagnostics && newCompilerDiagnostics.Any()) + { + // Format and get the compiler diagnostics again so that the locations make sense in the output + document = document.WithSyntaxRoot(Formatter.Format(document.GetSyntaxRootAsync().Result, Formatter.Annotation, document.Project.Solution.Workspace)); + newCompilerDiagnostics = GetNewDiagnostics(compilerDiagnostics, GetCompilerDiagnostics(document)); + + Assert.True(false, + string.Format("Fix introduced new compiler diagnostics:\r\n{0}\r\n\r\nNew document:\r\n{1}\r\n", + string.Join("\r\n", newCompilerDiagnostics.Select(d => d.ToString())), + document.GetSyntaxRootAsync().Result.ToFullString())); + } + + //check if there are analyzer diagnostics left after the code fix + if (!analyzerDiagnostics.Any()) + { + break; + } + } + + //after applying all of the code fixes, compare the resulting string to the inputted one + var actual = GetStringFromDocument(document); + Assert.Equal(newSource, actual); + } + } +} \ No newline at end of file diff --git a/test/Discord.Net.Tests/AnalyzerTests/Verifiers/DiagnosticVerifier.cs b/test/Discord.Net.Tests/AnalyzerTests/Verifiers/DiagnosticVerifier.cs new file mode 100644 index 000000000..498e5ef27 --- /dev/null +++ b/test/Discord.Net.Tests/AnalyzerTests/Verifiers/DiagnosticVerifier.cs @@ -0,0 +1,271 @@ +using System.Collections.Generic; +using System.Linq; +using System.Text; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; +using Microsoft.CodeAnalysis.Diagnostics; +//using Microsoft.VisualStudio.TestTools.UnitTesting; +using Xunit; + +namespace TestHelper +{ + /// + /// Superclass of all Unit Tests for DiagnosticAnalyzers + /// + public abstract partial class DiagnosticVerifier + { + #region To be implemented by Test classes + /// + /// Get the CSharp analyzer being tested - to be implemented in non-abstract class + /// + protected virtual DiagnosticAnalyzer GetCSharpDiagnosticAnalyzer() + { + return null; + } + + /// + /// Get the Visual Basic analyzer being tested (C#) - to be implemented in non-abstract class + /// + protected virtual DiagnosticAnalyzer GetBasicDiagnosticAnalyzer() + { + return null; + } + #endregion + + #region Verifier wrappers + + /// + /// Called to test a C# DiagnosticAnalyzer when applied on the single inputted string as a source + /// Note: input a DiagnosticResult for each Diagnostic expected + /// + /// A class in the form of a string to run the analyzer on + /// DiagnosticResults that should appear after the analyzer is run on the source + protected void VerifyCSharpDiagnostic(string source, params DiagnosticResult[] expected) + { + VerifyDiagnostics(new[] { source }, LanguageNames.CSharp, GetCSharpDiagnosticAnalyzer(), expected); + } + + /// + /// Called to test a VB DiagnosticAnalyzer when applied on the single inputted string as a source + /// Note: input a DiagnosticResult for each Diagnostic expected + /// + /// A class in the form of a string to run the analyzer on + /// DiagnosticResults that should appear after the analyzer is run on the source + protected void VerifyBasicDiagnostic(string source, params DiagnosticResult[] expected) + { + VerifyDiagnostics(new[] { source }, LanguageNames.VisualBasic, GetBasicDiagnosticAnalyzer(), expected); + } + + /// + /// Called to test a C# DiagnosticAnalyzer when applied on the inputted strings as a source + /// Note: input a DiagnosticResult for each Diagnostic expected + /// + /// An array of strings to create source documents from to run the analyzers on + /// DiagnosticResults that should appear after the analyzer is run on the sources + protected void VerifyCSharpDiagnostic(string[] sources, params DiagnosticResult[] expected) + { + VerifyDiagnostics(sources, LanguageNames.CSharp, GetCSharpDiagnosticAnalyzer(), expected); + } + + /// + /// Called to test a VB DiagnosticAnalyzer when applied on the inputted strings as a source + /// Note: input a DiagnosticResult for each Diagnostic expected + /// + /// An array of strings to create source documents from to run the analyzers on + /// DiagnosticResults that should appear after the analyzer is run on the sources + protected void VerifyBasicDiagnostic(string[] sources, params DiagnosticResult[] expected) + { + VerifyDiagnostics(sources, LanguageNames.VisualBasic, GetBasicDiagnosticAnalyzer(), expected); + } + + /// + /// General method that gets a collection of actual diagnostics found in the source after the analyzer is run, + /// then verifies each of them. + /// + /// An array of strings to create source documents from to run the analyzers on + /// The language of the classes represented by the source strings + /// The analyzer to be run on the source code + /// DiagnosticResults that should appear after the analyzer is run on the sources + private void VerifyDiagnostics(string[] sources, string language, DiagnosticAnalyzer analyzer, params DiagnosticResult[] expected) + { + var diagnostics = GetSortedDiagnostics(sources, language, analyzer); + VerifyDiagnosticResults(diagnostics, analyzer, expected); + } + + #endregion + + #region Actual comparisons and verifications + /// + /// Checks each of the actual Diagnostics found and compares them with the corresponding DiagnosticResult in the array of expected results. + /// Diagnostics are considered equal only if the DiagnosticResultLocation, Id, Severity, and Message of the DiagnosticResult match the actual diagnostic. + /// + /// The Diagnostics found by the compiler after running the analyzer on the source code + /// The analyzer that was being run on the sources + /// Diagnostic Results that should have appeared in the code + private static void VerifyDiagnosticResults(IEnumerable actualResults, DiagnosticAnalyzer analyzer, params DiagnosticResult[] expectedResults) + { + int expectedCount = expectedResults.Count(); + int actualCount = actualResults.Count(); + + if (expectedCount != actualCount) + { + string diagnosticsOutput = actualResults.Any() ? FormatDiagnostics(analyzer, actualResults.ToArray()) : " NONE."; + + Assert.True(false, + string.Format("Mismatch between number of diagnostics returned, expected \"{0}\" actual \"{1}\"\r\n\r\nDiagnostics:\r\n{2}\r\n", expectedCount, actualCount, diagnosticsOutput)); + } + + for (int i = 0; i < expectedResults.Length; i++) + { + var actual = actualResults.ElementAt(i); + var expected = expectedResults[i]; + + if (expected.Line == -1 && expected.Column == -1) + { + if (actual.Location != Location.None) + { + Assert.True(false, + string.Format("Expected:\nA project diagnostic with No location\nActual:\n{0}", + FormatDiagnostics(analyzer, actual))); + } + } + else + { + VerifyDiagnosticLocation(analyzer, actual, actual.Location, expected.Locations.First()); + var additionalLocations = actual.AdditionalLocations.ToArray(); + + if (additionalLocations.Length != expected.Locations.Length - 1) + { + Assert.True(false, + string.Format("Expected {0} additional locations but got {1} for Diagnostic:\r\n {2}\r\n", + expected.Locations.Length - 1, additionalLocations.Length, + FormatDiagnostics(analyzer, actual))); + } + + for (int j = 0; j < additionalLocations.Length; ++j) + { + VerifyDiagnosticLocation(analyzer, actual, additionalLocations[j], expected.Locations[j + 1]); + } + } + + if (actual.Id != expected.Id) + { + Assert.True(false, + string.Format("Expected diagnostic id to be \"{0}\" was \"{1}\"\r\n\r\nDiagnostic:\r\n {2}\r\n", + expected.Id, actual.Id, FormatDiagnostics(analyzer, actual))); + } + + if (actual.Severity != expected.Severity) + { + Assert.True(false, + string.Format("Expected diagnostic severity to be \"{0}\" was \"{1}\"\r\n\r\nDiagnostic:\r\n {2}\r\n", + expected.Severity, actual.Severity, FormatDiagnostics(analyzer, actual))); + } + + if (actual.GetMessage() != expected.Message) + { + Assert.True(false, + string.Format("Expected diagnostic message to be \"{0}\" was \"{1}\"\r\n\r\nDiagnostic:\r\n {2}\r\n", + expected.Message, actual.GetMessage(), FormatDiagnostics(analyzer, actual))); + } + } + } + + /// + /// Helper method to VerifyDiagnosticResult that checks the location of a diagnostic and compares it with the location in the expected DiagnosticResult. + /// + /// The analyzer that was being run on the sources + /// The diagnostic that was found in the code + /// The Location of the Diagnostic found in the code + /// The DiagnosticResultLocation that should have been found + private static void VerifyDiagnosticLocation(DiagnosticAnalyzer analyzer, Diagnostic diagnostic, Location actual, DiagnosticResultLocation expected) + { + var actualSpan = actual.GetLineSpan(); + + Assert.True(actualSpan.Path == expected.Path || (actualSpan.Path != null && actualSpan.Path.Contains("Test0.") && expected.Path.Contains("Test.")), + string.Format("Expected diagnostic to be in file \"{0}\" was actually in file \"{1}\"\r\n\r\nDiagnostic:\r\n {2}\r\n", + expected.Path, actualSpan.Path, FormatDiagnostics(analyzer, diagnostic))); + + var actualLinePosition = actualSpan.StartLinePosition; + + // Only check line position if there is an actual line in the real diagnostic + if (actualLinePosition.Line > 0) + { + if (actualLinePosition.Line + 1 != expected.Line) + { + Assert.True(false, + string.Format("Expected diagnostic to be on line \"{0}\" was actually on line \"{1}\"\r\n\r\nDiagnostic:\r\n {2}\r\n", + expected.Line, actualLinePosition.Line + 1, FormatDiagnostics(analyzer, diagnostic))); + } + } + + // Only check column position if there is an actual column position in the real diagnostic + if (actualLinePosition.Character > 0) + { + if (actualLinePosition.Character + 1 != expected.Column) + { + Assert.True(false, + string.Format("Expected diagnostic to start at column \"{0}\" was actually at column \"{1}\"\r\n\r\nDiagnostic:\r\n {2}\r\n", + expected.Column, actualLinePosition.Character + 1, FormatDiagnostics(analyzer, diagnostic))); + } + } + } + #endregion + + #region Formatting Diagnostics + /// + /// Helper method to format a Diagnostic into an easily readable string + /// + /// The analyzer that this verifier tests + /// The Diagnostics to be formatted + /// The Diagnostics formatted as a string + private static string FormatDiagnostics(DiagnosticAnalyzer analyzer, params Diagnostic[] diagnostics) + { + var builder = new StringBuilder(); + for (int i = 0; i < diagnostics.Length; ++i) + { + builder.AppendLine("// " + diagnostics[i].ToString()); + + var analyzerType = analyzer.GetType(); + var rules = analyzer.SupportedDiagnostics; + + foreach (var rule in rules) + { + if (rule != null && rule.Id == diagnostics[i].Id) + { + var location = diagnostics[i].Location; + if (location == Location.None) + { + builder.AppendFormat("GetGlobalResult({0}.{1})", analyzerType.Name, rule.Id); + } + else + { + Assert.True(location.IsInSource, + $"Test base does not currently handle diagnostics in metadata locations. Diagnostic in metadata: {diagnostics[i]}\r\n"); + + string resultMethodName = diagnostics[i].Location.SourceTree.FilePath.EndsWith(".cs") ? "GetCSharpResultAt" : "GetBasicResultAt"; + var linePosition = diagnostics[i].Location.GetLineSpan().StartLinePosition; + + builder.AppendFormat("{0}({1}, {2}, {3}.{4})", + resultMethodName, + linePosition.Line + 1, + linePosition.Character + 1, + analyzerType.Name, + rule.Id); + } + + if (i != diagnostics.Length - 1) + { + builder.Append(','); + } + + builder.AppendLine(); + break; + } + } + } + return builder.ToString(); + } + #endregion + } +} diff --git a/test/Discord.Net.Tests/Discord.Net.Tests.csproj b/test/Discord.Net.Tests/Discord.Net.Tests.csproj index 9e734641c..204dca5c4 100644 --- a/test/Discord.Net.Tests/Discord.Net.Tests.csproj +++ b/test/Discord.Net.Tests/Discord.Net.Tests.csproj @@ -10,10 +10,14 @@ PreserveNewest + + + + diff --git a/test/Discord.Net.Tests/Tests.ChannelPermissions.cs b/test/Discord.Net.Tests/Tests.ChannelPermissions.cs index ac8ede4e4..b37a1195e 100644 --- a/test/Discord.Net.Tests/Tests.ChannelPermissions.cs +++ b/test/Discord.Net.Tests/Tests.ChannelPermissions.cs @@ -1,10 +1,10 @@ -using System; +using System; using System.Threading.Tasks; using Xunit; namespace Discord { - public partial class Tests + public class ChannelPermissionsTests { [Fact] public Task TestChannelPermission() diff --git a/test/Discord.Net.Tests/Tests.GuildPermissions.cs b/test/Discord.Net.Tests/Tests.GuildPermissions.cs index bb113d221..a562f4afb 100644 --- a/test/Discord.Net.Tests/Tests.GuildPermissions.cs +++ b/test/Discord.Net.Tests/Tests.GuildPermissions.cs @@ -1,10 +1,10 @@ -using System; +using System; using System.Threading.Tasks; using Xunit; namespace Discord { - public partial class Tests + public class GuidPermissionsTests { [Fact] public Task TestGuildPermission() diff --git a/test/Discord.Net.Tests/Tests.Permissions.cs b/test/Discord.Net.Tests/Tests.Permissions.cs new file mode 100644 index 000000000..e22659d15 --- /dev/null +++ b/test/Discord.Net.Tests/Tests.Permissions.cs @@ -0,0 +1,706 @@ +using System; +using System.Threading.Tasks; +using Xunit; + +namespace Discord +{ + public class PermissionsTests + { + private void TestHelper(ChannelPermissions value, ChannelPermission permission, bool expected = false) + => TestHelper(value.RawValue, (ulong)permission, expected); + + private void TestHelper(GuildPermissions value, GuildPermission permission, bool expected = false) + => TestHelper(value.RawValue, (ulong)permission, expected); + + /// + /// Tests the flag of the given permissions value to the expected output + /// and then tries to toggle the flag on and off + /// + /// + /// + /// + private void TestHelper(ulong rawValue, ulong flagValue, bool expected) + { + Assert.Equal(expected, Permissions.GetValue(rawValue, flagValue)); + + // check that toggling the bit works + Permissions.UnsetFlag(ref rawValue, flagValue); + Assert.Equal(false, Permissions.GetValue(rawValue, flagValue)); + Permissions.SetFlag(ref rawValue, flagValue); + Assert.Equal(true, Permissions.GetValue(rawValue, flagValue)); + + // do the same, but with the SetValue method + Permissions.SetValue(ref rawValue, true, flagValue); + Assert.Equal(true, Permissions.GetValue(rawValue, flagValue)); + Permissions.SetValue(ref rawValue, false, flagValue); + Assert.Equal(false, Permissions.GetValue(rawValue, flagValue)); + } + + /// + /// Tests that flag of the given permissions value to be the expected output + /// and then tries cycling through the states of the allow and deny values + /// for that flag + /// + /// + /// + /// + private void TestHelper(OverwritePermissions value, ChannelPermission flag, PermValue expected) + { + // check that the value matches + Assert.Equal(expected, Permissions.GetValue(value.AllowValue, value.DenyValue, flag)); + + // check toggling bits for both allow and deny + // have to make copies to get around read only property + ulong allow = value.AllowValue; + ulong deny = value.DenyValue; + + // both unset should be inherit + Permissions.UnsetFlag(ref allow, (ulong)flag); + Permissions.UnsetFlag(ref deny, (ulong)flag); + Assert.Equal(PermValue.Inherit, Permissions.GetValue(allow, deny, flag)); + + // allow set should be allow + Permissions.SetFlag(ref allow, (ulong)flag); + Permissions.UnsetFlag(ref deny, (ulong)flag); + Assert.Equal(PermValue.Allow, Permissions.GetValue(allow, deny, flag)); + + // deny should be deny + Permissions.UnsetFlag(ref allow, (ulong)flag); + Permissions.SetFlag(ref deny, (ulong)flag); + Assert.Equal(PermValue.Deny, Permissions.GetValue(allow, deny, flag)); + + // allow takes precedence + Permissions.SetFlag(ref allow, (ulong)flag); + Permissions.SetFlag(ref deny, (ulong)flag); + Assert.Equal(PermValue.Allow, Permissions.GetValue(allow, deny, flag)); + } + + /// + /// Tests for the class. + /// + /// Tests that text channel permissions get the right value + /// from the Has method. + /// + /// + [Fact] + public Task TestPermissionsHasChannelPermissionText() + { + var value = ChannelPermissions.Text; + // check that the result of GetValue matches for all properties of text channel + TestHelper(value, ChannelPermission.CreateInstantInvite, true); + TestHelper(value, ChannelPermission.ManageChannels, true); + TestHelper(value, ChannelPermission.AddReactions, true); + TestHelper(value, ChannelPermission.ViewChannel, true); + TestHelper(value, ChannelPermission.SendMessages, true); + TestHelper(value, ChannelPermission.SendTTSMessages, true); + TestHelper(value, ChannelPermission.ManageMessages, true); + TestHelper(value, ChannelPermission.EmbedLinks, true); + TestHelper(value, ChannelPermission.AttachFiles, true); + TestHelper(value, ChannelPermission.ReadMessageHistory, true); + TestHelper(value, ChannelPermission.MentionEveryone, true); + TestHelper(value, ChannelPermission.UseExternalEmojis, true); + TestHelper(value, ChannelPermission.ManageRoles, true); + TestHelper(value, ChannelPermission.ManageWebhooks, true); + + TestHelper(value, ChannelPermission.Connect, false); + TestHelper(value, ChannelPermission.Speak, false); + TestHelper(value, ChannelPermission.MuteMembers, false); + TestHelper(value, ChannelPermission.DeafenMembers, false); + TestHelper(value, ChannelPermission.MoveMembers, false); + TestHelper(value, ChannelPermission.UseVAD, false); + + return Task.CompletedTask; + } + + /// + /// Tests for the class. + /// + /// Tests that no channel permissions get the right value + /// from the Has method. + /// + /// + [Fact] + public Task TestPermissionsHasChannelPermissionNone() + { + // check that none will fail all + var value = ChannelPermissions.None; + + TestHelper(value, ChannelPermission.CreateInstantInvite, false); + TestHelper(value, ChannelPermission.ManageChannels, false); + TestHelper(value, ChannelPermission.AddReactions, false); + TestHelper(value, ChannelPermission.ViewChannel, false); + TestHelper(value, ChannelPermission.SendMessages, false); + TestHelper(value, ChannelPermission.SendTTSMessages, false); + TestHelper(value, ChannelPermission.ManageMessages, false); + TestHelper(value, ChannelPermission.EmbedLinks, false); + TestHelper(value, ChannelPermission.AttachFiles, false); + TestHelper(value, ChannelPermission.ReadMessageHistory, false); + TestHelper(value, ChannelPermission.MentionEveryone, false); + TestHelper(value, ChannelPermission.UseExternalEmojis, false); + TestHelper(value, ChannelPermission.ManageRoles, false); + TestHelper(value, ChannelPermission.ManageWebhooks, false); + TestHelper(value, ChannelPermission.Connect, false); + TestHelper(value, ChannelPermission.Speak, false); + TestHelper(value, ChannelPermission.MuteMembers, false); + TestHelper(value, ChannelPermission.DeafenMembers, false); + TestHelper(value, ChannelPermission.MoveMembers, false); + TestHelper(value, ChannelPermission.UseVAD, false); + + return Task.CompletedTask; + } + + /// + /// Tests for the class. + /// + /// Tests that the dm channel permissions get the right value + /// from the Has method. + /// + /// + [Fact] + public Task TestPermissionsHasChannelPermissionDM() + { + // check that none will fail all + var value = ChannelPermissions.DM; + + TestHelper(value, ChannelPermission.CreateInstantInvite, false); + TestHelper(value, ChannelPermission.ManageChannels, false); + TestHelper(value, ChannelPermission.AddReactions, false); + TestHelper(value, ChannelPermission.ViewChannel, true); + TestHelper(value, ChannelPermission.SendMessages, true); + TestHelper(value, ChannelPermission.SendTTSMessages, false); + TestHelper(value, ChannelPermission.ManageMessages, false); + TestHelper(value, ChannelPermission.EmbedLinks, true); + TestHelper(value, ChannelPermission.AttachFiles, true); + TestHelper(value, ChannelPermission.ReadMessageHistory, true); + TestHelper(value, ChannelPermission.MentionEveryone, false); + TestHelper(value, ChannelPermission.UseExternalEmojis, true); + TestHelper(value, ChannelPermission.ManageRoles, false); + TestHelper(value, ChannelPermission.ManageWebhooks, false); + TestHelper(value, ChannelPermission.Connect, true); + TestHelper(value, ChannelPermission.Speak, true); + TestHelper(value, ChannelPermission.MuteMembers, false); + TestHelper(value, ChannelPermission.DeafenMembers, false); + TestHelper(value, ChannelPermission.MoveMembers, false); + TestHelper(value, ChannelPermission.UseVAD, true); + + return Task.CompletedTask; + } + + /// + /// Tests for the class. + /// + /// Tests that the group channel permissions get the right value + /// from the Has method. + /// + /// + [Fact] + public Task TestPermissionsHasChannelPermissionGroup() + { + var value = ChannelPermissions.Group; + + TestHelper(value, ChannelPermission.CreateInstantInvite, false); + TestHelper(value, ChannelPermission.ManageChannels, false); + TestHelper(value, ChannelPermission.AddReactions, false); + TestHelper(value, ChannelPermission.ViewChannel, false); + TestHelper(value, ChannelPermission.SendMessages, true); + TestHelper(value, ChannelPermission.SendTTSMessages, true); + TestHelper(value, ChannelPermission.ManageMessages, false); + TestHelper(value, ChannelPermission.EmbedLinks, true); + TestHelper(value, ChannelPermission.AttachFiles, true); + TestHelper(value, ChannelPermission.ReadMessageHistory, false); + TestHelper(value, ChannelPermission.MentionEveryone, false); + TestHelper(value, ChannelPermission.UseExternalEmojis, false); + TestHelper(value, ChannelPermission.ManageRoles, false); + TestHelper(value, ChannelPermission.ManageWebhooks, false); + TestHelper(value, ChannelPermission.Connect, true); + TestHelper(value, ChannelPermission.Speak, true); + TestHelper(value, ChannelPermission.MuteMembers, false); + TestHelper(value, ChannelPermission.DeafenMembers, false); + TestHelper(value, ChannelPermission.MoveMembers, false); + TestHelper(value, ChannelPermission.UseVAD, true); + + return Task.CompletedTask; + } + + + /// + /// Tests for the class. + /// + /// Tests that the voice channel permissions get the right value + /// from the Has method. + /// + /// + [Fact] + public Task TestPermissionsHasChannelPermissionVoice() + { + // make a flag with all possible values for Voice channel permissions + var value = ChannelPermissions.Voice; + + TestHelper(value, ChannelPermission.CreateInstantInvite, true); + TestHelper(value, ChannelPermission.ManageChannels, true); + TestHelper(value, ChannelPermission.AddReactions, false); + TestHelper(value, ChannelPermission.ViewChannel, false); + TestHelper(value, ChannelPermission.SendMessages, false); + TestHelper(value, ChannelPermission.SendTTSMessages, false); + TestHelper(value, ChannelPermission.ManageMessages, false); + TestHelper(value, ChannelPermission.EmbedLinks, false); + TestHelper(value, ChannelPermission.AttachFiles, false); + TestHelper(value, ChannelPermission.ReadMessageHistory, false); + TestHelper(value, ChannelPermission.MentionEveryone, false); + TestHelper(value, ChannelPermission.UseExternalEmojis, false); + TestHelper(value, ChannelPermission.ManageRoles, true); + TestHelper(value, ChannelPermission.ManageWebhooks, false); + TestHelper(value, ChannelPermission.Connect, true); + TestHelper(value, ChannelPermission.Speak, true); + TestHelper(value, ChannelPermission.MuteMembers, true); + TestHelper(value, ChannelPermission.DeafenMembers, true); + TestHelper(value, ChannelPermission.MoveMembers, true); + TestHelper(value, ChannelPermission.UseVAD, true); + + return Task.CompletedTask; + } + + /// + /// Tests for the class. + /// + /// Test that that the Has method of + /// returns the correct value when no permissions are set. + /// + /// + [Fact] + public Task TestPermissionsHasGuildPermissionNone() + { + var value = GuildPermissions.None; + + TestHelper(value, GuildPermission.CreateInstantInvite, false); + TestHelper(value, GuildPermission.KickMembers, false); + TestHelper(value, GuildPermission.BanMembers, false); + TestHelper(value, GuildPermission.Administrator, false); + TestHelper(value, GuildPermission.ManageChannels, false); + TestHelper(value, GuildPermission.ManageGuild, false); + TestHelper(value, GuildPermission.AddReactions, false); + TestHelper(value, GuildPermission.ViewAuditLog, false); + TestHelper(value, GuildPermission.ReadMessages, false); + TestHelper(value, GuildPermission.SendMessages, false); + TestHelper(value, GuildPermission.SendTTSMessages, false); + TestHelper(value, GuildPermission.ManageMessages, false); + TestHelper(value, GuildPermission.EmbedLinks, false); + TestHelper(value, GuildPermission.AttachFiles, false); + TestHelper(value, GuildPermission.ReadMessageHistory, false); + TestHelper(value, GuildPermission.MentionEveryone, false); + TestHelper(value, GuildPermission.UseExternalEmojis, false); + TestHelper(value, GuildPermission.Connect, false); + TestHelper(value, GuildPermission.Speak, false); + TestHelper(value, GuildPermission.MuteMembers, false); + TestHelper(value, GuildPermission.MoveMembers, false); + TestHelper(value, GuildPermission.UseVAD, false); + TestHelper(value, GuildPermission.ChangeNickname, false); + TestHelper(value, GuildPermission.ManageNicknames, false); + TestHelper(value, GuildPermission.ManageRoles, false); + TestHelper(value, GuildPermission.ManageWebhooks, false); + TestHelper(value, GuildPermission.ManageEmojis, false); + + return Task.CompletedTask; + } + + /// + /// Tests for the class. + /// + /// Test that that the Has method of + /// returns the correct value when all permissions are set. + /// + /// + [Fact] + public Task TestPermissionsHasGuildPermissionAll() + { + var value = GuildPermissions.All; + + TestHelper(value, GuildPermission.CreateInstantInvite, true); + TestHelper(value, GuildPermission.KickMembers, true); + TestHelper(value, GuildPermission.BanMembers, true); + TestHelper(value, GuildPermission.Administrator, true); + TestHelper(value, GuildPermission.ManageChannels, true); + TestHelper(value, GuildPermission.ManageGuild, true); + TestHelper(value, GuildPermission.AddReactions, true); + TestHelper(value, GuildPermission.ViewAuditLog, true); + TestHelper(value, GuildPermission.ReadMessages, true); + TestHelper(value, GuildPermission.SendMessages, true); + TestHelper(value, GuildPermission.SendTTSMessages, true); + TestHelper(value, GuildPermission.ManageMessages, true); + TestHelper(value, GuildPermission.EmbedLinks, true); + TestHelper(value, GuildPermission.AttachFiles, true); + TestHelper(value, GuildPermission.ReadMessageHistory, true); + TestHelper(value, GuildPermission.MentionEveryone, true); + TestHelper(value, GuildPermission.UseExternalEmojis, true); + TestHelper(value, GuildPermission.Connect, true); + TestHelper(value, GuildPermission.Speak, true); + TestHelper(value, GuildPermission.MuteMembers, true); + TestHelper(value, GuildPermission.MoveMembers, true); + TestHelper(value, GuildPermission.UseVAD, true); + TestHelper(value, GuildPermission.ChangeNickname, true); + TestHelper(value, GuildPermission.ManageNicknames, true); + TestHelper(value, GuildPermission.ManageRoles, true); + TestHelper(value, GuildPermission.ManageWebhooks, true); + TestHelper(value, GuildPermission.ManageEmojis, true); + + + return Task.CompletedTask; + } + + /// + /// Tests for the class. + /// + /// Test that that the Has method of + /// returns the correct value when webhook permissions are set. + /// + /// + [Fact] + public Task TestPermissionsHasGuildPermissionWebhook() + { + var value = GuildPermissions.Webhook; + + TestHelper(value, GuildPermission.CreateInstantInvite, false); + TestHelper(value, GuildPermission.KickMembers, false); + TestHelper(value, GuildPermission.BanMembers, false); + TestHelper(value, GuildPermission.Administrator, false); + TestHelper(value, GuildPermission.ManageChannels, false); + TestHelper(value, GuildPermission.ManageGuild, false); + TestHelper(value, GuildPermission.AddReactions, false); + TestHelper(value, GuildPermission.ViewAuditLog, false); + TestHelper(value, GuildPermission.ReadMessages, false); + TestHelper(value, GuildPermission.SendMessages, true); + TestHelper(value, GuildPermission.SendTTSMessages, true); + TestHelper(value, GuildPermission.ManageMessages, false); + TestHelper(value, GuildPermission.EmbedLinks, true); + TestHelper(value, GuildPermission.AttachFiles, true); + TestHelper(value, GuildPermission.ReadMessageHistory, false); + TestHelper(value, GuildPermission.MentionEveryone, false); + TestHelper(value, GuildPermission.UseExternalEmojis, false); + TestHelper(value, GuildPermission.Connect, false); + TestHelper(value, GuildPermission.Speak, false); + TestHelper(value, GuildPermission.MuteMembers, false); + TestHelper(value, GuildPermission.MoveMembers, false); + TestHelper(value, GuildPermission.UseVAD, false); + TestHelper(value, GuildPermission.ChangeNickname, false); + TestHelper(value, GuildPermission.ManageNicknames, false); + TestHelper(value, GuildPermission.ManageRoles, false); + TestHelper(value, GuildPermission.ManageWebhooks, false); + TestHelper(value, GuildPermission.ManageEmojis, false); + + return Task.CompletedTask; + } + + /// + /// Test + /// for when all text permissions are allowed and denied + /// + /// + [Fact] + public Task TestOverwritePermissionsText() + { + // allow all for text channel + var value = new OverwritePermissions(ChannelPermissions.Text.RawValue, ChannelPermissions.None.RawValue); + + TestHelper(value, ChannelPermission.CreateInstantInvite, PermValue.Allow); + TestHelper(value, ChannelPermission.ManageChannels, PermValue.Allow); + TestHelper(value, ChannelPermission.AddReactions, PermValue.Allow); + TestHelper(value, ChannelPermission.ViewChannel, PermValue.Allow); + TestHelper(value, ChannelPermission.SendMessages, PermValue.Allow); + TestHelper(value, ChannelPermission.SendTTSMessages, PermValue.Allow); + TestHelper(value, ChannelPermission.ManageMessages, PermValue.Allow); + TestHelper(value, ChannelPermission.EmbedLinks, PermValue.Allow); + TestHelper(value, ChannelPermission.AttachFiles, PermValue.Allow); + TestHelper(value, ChannelPermission.ReadMessageHistory, PermValue.Allow); + TestHelper(value, ChannelPermission.MentionEveryone, PermValue.Allow); + TestHelper(value, ChannelPermission.UseExternalEmojis, PermValue.Allow); + TestHelper(value, ChannelPermission.ManageRoles, PermValue.Allow); + TestHelper(value, ChannelPermission.ManageWebhooks, PermValue.Allow); + TestHelper(value, ChannelPermission.Connect, PermValue.Inherit); + TestHelper(value, ChannelPermission.Speak, PermValue.Inherit); + TestHelper(value, ChannelPermission.MuteMembers, PermValue.Inherit); + TestHelper(value, ChannelPermission.DeafenMembers, PermValue.Inherit); + TestHelper(value, ChannelPermission.MoveMembers, PermValue.Inherit); + TestHelper(value, ChannelPermission.UseVAD, PermValue.Inherit); + + value = new OverwritePermissions(ChannelPermissions.None.RawValue, ChannelPermissions.Text.RawValue); + + TestHelper(value, ChannelPermission.CreateInstantInvite, PermValue.Deny); + TestHelper(value, ChannelPermission.ManageChannels, PermValue.Deny); + TestHelper(value, ChannelPermission.AddReactions, PermValue.Deny); + TestHelper(value, ChannelPermission.ViewChannel, PermValue.Deny); + TestHelper(value, ChannelPermission.SendMessages, PermValue.Deny); + TestHelper(value, ChannelPermission.SendTTSMessages, PermValue.Deny); + TestHelper(value, ChannelPermission.ManageMessages, PermValue.Deny); + TestHelper(value, ChannelPermission.EmbedLinks, PermValue.Deny); + TestHelper(value, ChannelPermission.AttachFiles, PermValue.Deny); + TestHelper(value, ChannelPermission.ReadMessageHistory, PermValue.Deny); + TestHelper(value, ChannelPermission.MentionEveryone, PermValue.Deny); + TestHelper(value, ChannelPermission.UseExternalEmojis, PermValue.Deny); + TestHelper(value, ChannelPermission.ManageRoles, PermValue.Deny); + TestHelper(value, ChannelPermission.ManageWebhooks, PermValue.Deny); + TestHelper(value, ChannelPermission.Connect, PermValue.Inherit); + TestHelper(value, ChannelPermission.Speak, PermValue.Inherit); + TestHelper(value, ChannelPermission.MuteMembers, PermValue.Inherit); + TestHelper(value, ChannelPermission.DeafenMembers, PermValue.Inherit); + TestHelper(value, ChannelPermission.MoveMembers, PermValue.Inherit); + TestHelper(value, ChannelPermission.UseVAD, PermValue.Inherit); + + return Task.CompletedTask; + } + + /// + /// Test + /// for when none of the permissions are set. + /// + /// + [Fact] + public Task TestOverwritePermissionsNone() + { + // allow all for text channel + var value = new OverwritePermissions(ChannelPermissions.None.RawValue, ChannelPermissions.None.RawValue); + + TestHelper(value, ChannelPermission.CreateInstantInvite, PermValue.Inherit); + TestHelper(value, ChannelPermission.ManageChannels, PermValue.Inherit); + TestHelper(value, ChannelPermission.AddReactions, PermValue.Inherit); + TestHelper(value, ChannelPermission.ViewChannel, PermValue.Inherit); + TestHelper(value, ChannelPermission.SendMessages, PermValue.Inherit); + TestHelper(value, ChannelPermission.SendTTSMessages, PermValue.Inherit); + TestHelper(value, ChannelPermission.ManageMessages, PermValue.Inherit); + TestHelper(value, ChannelPermission.EmbedLinks, PermValue.Inherit); + TestHelper(value, ChannelPermission.AttachFiles, PermValue.Inherit); + TestHelper(value, ChannelPermission.ReadMessageHistory, PermValue.Inherit); + TestHelper(value, ChannelPermission.MentionEveryone, PermValue.Inherit); + TestHelper(value, ChannelPermission.UseExternalEmojis, PermValue.Inherit); + TestHelper(value, ChannelPermission.ManageRoles, PermValue.Inherit); + TestHelper(value, ChannelPermission.ManageWebhooks, PermValue.Inherit); + TestHelper(value, ChannelPermission.Connect, PermValue.Inherit); + TestHelper(value, ChannelPermission.Speak, PermValue.Inherit); + TestHelper(value, ChannelPermission.MuteMembers, PermValue.Inherit); + TestHelper(value, ChannelPermission.DeafenMembers, PermValue.Inherit); + TestHelper(value, ChannelPermission.MoveMembers, PermValue.Inherit); + TestHelper(value, ChannelPermission.UseVAD, PermValue.Inherit); + + value = new OverwritePermissions(); + + TestHelper(value, ChannelPermission.CreateInstantInvite, PermValue.Inherit); + TestHelper(value, ChannelPermission.ManageChannels, PermValue.Inherit); + TestHelper(value, ChannelPermission.AddReactions, PermValue.Inherit); + TestHelper(value, ChannelPermission.ViewChannel, PermValue.Inherit); + TestHelper(value, ChannelPermission.SendMessages, PermValue.Inherit); + TestHelper(value, ChannelPermission.SendTTSMessages, PermValue.Inherit); + TestHelper(value, ChannelPermission.ManageMessages, PermValue.Inherit); + TestHelper(value, ChannelPermission.EmbedLinks, PermValue.Inherit); + TestHelper(value, ChannelPermission.AttachFiles, PermValue.Inherit); + TestHelper(value, ChannelPermission.ReadMessageHistory, PermValue.Inherit); + TestHelper(value, ChannelPermission.MentionEveryone, PermValue.Inherit); + TestHelper(value, ChannelPermission.UseExternalEmojis, PermValue.Inherit); + TestHelper(value, ChannelPermission.ManageRoles, PermValue.Inherit); + TestHelper(value, ChannelPermission.ManageWebhooks, PermValue.Inherit); + TestHelper(value, ChannelPermission.Connect, PermValue.Inherit); + TestHelper(value, ChannelPermission.Speak, PermValue.Inherit); + TestHelper(value, ChannelPermission.MuteMembers, PermValue.Inherit); + TestHelper(value, ChannelPermission.DeafenMembers, PermValue.Inherit); + TestHelper(value, ChannelPermission.MoveMembers, PermValue.Inherit); + TestHelper(value, ChannelPermission.UseVAD, PermValue.Inherit); + + value = OverwritePermissions.InheritAll; + + TestHelper(value, ChannelPermission.CreateInstantInvite, PermValue.Inherit); + TestHelper(value, ChannelPermission.ManageChannels, PermValue.Inherit); + TestHelper(value, ChannelPermission.AddReactions, PermValue.Inherit); + TestHelper(value, ChannelPermission.ViewChannel, PermValue.Inherit); + TestHelper(value, ChannelPermission.SendMessages, PermValue.Inherit); + TestHelper(value, ChannelPermission.SendTTSMessages, PermValue.Inherit); + TestHelper(value, ChannelPermission.ManageMessages, PermValue.Inherit); + TestHelper(value, ChannelPermission.EmbedLinks, PermValue.Inherit); + TestHelper(value, ChannelPermission.AttachFiles, PermValue.Inherit); + TestHelper(value, ChannelPermission.ReadMessageHistory, PermValue.Inherit); + TestHelper(value, ChannelPermission.MentionEveryone, PermValue.Inherit); + TestHelper(value, ChannelPermission.UseExternalEmojis, PermValue.Inherit); + TestHelper(value, ChannelPermission.ManageRoles, PermValue.Inherit); + TestHelper(value, ChannelPermission.ManageWebhooks, PermValue.Inherit); + TestHelper(value, ChannelPermission.Connect, PermValue.Inherit); + TestHelper(value, ChannelPermission.Speak, PermValue.Inherit); + TestHelper(value, ChannelPermission.MuteMembers, PermValue.Inherit); + TestHelper(value, ChannelPermission.DeafenMembers, PermValue.Inherit); + TestHelper(value, ChannelPermission.MoveMembers, PermValue.Inherit); + TestHelper(value, ChannelPermission.UseVAD, PermValue.Inherit); + + return Task.CompletedTask; + } + + /// + /// Test + /// for when all dm permissions are allowed and denied + /// + /// + [Fact] + public Task TestOverwritePermissionsDM() + { + // allow all for text channel + var value = new OverwritePermissions(ChannelPermissions.DM.RawValue, ChannelPermissions.None.RawValue); + + TestHelper(value, ChannelPermission.CreateInstantInvite, PermValue.Inherit); + TestHelper(value, ChannelPermission.ManageChannels, PermValue.Inherit); + TestHelper(value, ChannelPermission.AddReactions, PermValue.Inherit); + TestHelper(value, ChannelPermission.ViewChannel, PermValue.Allow); + TestHelper(value, ChannelPermission.SendMessages, PermValue.Allow); + TestHelper(value, ChannelPermission.SendTTSMessages, PermValue.Inherit); + TestHelper(value, ChannelPermission.ManageMessages, PermValue.Inherit); + TestHelper(value, ChannelPermission.EmbedLinks, PermValue.Allow); + TestHelper(value, ChannelPermission.AttachFiles, PermValue.Allow); + TestHelper(value, ChannelPermission.ReadMessageHistory, PermValue.Allow); + TestHelper(value, ChannelPermission.MentionEveryone, PermValue.Inherit); + TestHelper(value, ChannelPermission.UseExternalEmojis, PermValue.Allow); + TestHelper(value, ChannelPermission.ManageRoles, PermValue.Inherit); + TestHelper(value, ChannelPermission.ManageWebhooks, PermValue.Inherit); + TestHelper(value, ChannelPermission.Connect, PermValue.Allow); + TestHelper(value, ChannelPermission.Speak, PermValue.Allow); + TestHelper(value, ChannelPermission.MuteMembers, PermValue.Inherit); + TestHelper(value, ChannelPermission.DeafenMembers, PermValue.Inherit); + TestHelper(value, ChannelPermission.MoveMembers, PermValue.Inherit); + TestHelper(value, ChannelPermission.UseVAD, PermValue.Allow); + + value = new OverwritePermissions(ChannelPermissions.None.RawValue, ChannelPermissions.DM.RawValue); + + TestHelper(value, ChannelPermission.CreateInstantInvite, PermValue.Inherit); + TestHelper(value, ChannelPermission.ManageChannels, PermValue.Inherit); + TestHelper(value, ChannelPermission.AddReactions, PermValue.Inherit); + TestHelper(value, ChannelPermission.ViewChannel, PermValue.Deny); + TestHelper(value, ChannelPermission.SendMessages, PermValue.Deny); + TestHelper(value, ChannelPermission.SendTTSMessages, PermValue.Inherit); + TestHelper(value, ChannelPermission.ManageMessages, PermValue.Inherit); + TestHelper(value, ChannelPermission.EmbedLinks, PermValue.Deny); + TestHelper(value, ChannelPermission.AttachFiles, PermValue.Deny); + TestHelper(value, ChannelPermission.ReadMessageHistory, PermValue.Deny); + TestHelper(value, ChannelPermission.MentionEveryone, PermValue.Inherit); + TestHelper(value, ChannelPermission.UseExternalEmojis, PermValue.Deny); + TestHelper(value, ChannelPermission.ManageRoles, PermValue.Inherit); + TestHelper(value, ChannelPermission.ManageWebhooks, PermValue.Inherit); + TestHelper(value, ChannelPermission.Connect, PermValue.Deny); + TestHelper(value, ChannelPermission.Speak, PermValue.Deny); + TestHelper(value, ChannelPermission.MuteMembers, PermValue.Inherit); + TestHelper(value, ChannelPermission.DeafenMembers, PermValue.Inherit); + TestHelper(value, ChannelPermission.MoveMembers, PermValue.Inherit); + TestHelper(value, ChannelPermission.UseVAD, PermValue.Deny); + + return Task.CompletedTask; + } + + /// + /// Test + /// for when all group permissions are allowed and denied + /// + /// + [Fact] + public Task TestOverwritePermissionsGroup() + { + // allow all for group channels + var value = new OverwritePermissions(ChannelPermissions.Group.RawValue, ChannelPermissions.None.RawValue); + + TestHelper(value, ChannelPermission.CreateInstantInvite, PermValue.Inherit); + TestHelper(value, ChannelPermission.ManageChannels, PermValue.Inherit); + TestHelper(value, ChannelPermission.AddReactions, PermValue.Inherit); + TestHelper(value, ChannelPermission.ViewChannel, PermValue.Inherit); + TestHelper(value, ChannelPermission.SendMessages, PermValue.Allow); + TestHelper(value, ChannelPermission.SendTTSMessages, PermValue.Allow); + TestHelper(value, ChannelPermission.ManageMessages, PermValue.Inherit); + TestHelper(value, ChannelPermission.EmbedLinks, PermValue.Allow); + TestHelper(value, ChannelPermission.AttachFiles, PermValue.Allow); + TestHelper(value, ChannelPermission.ReadMessageHistory, PermValue.Inherit); + TestHelper(value, ChannelPermission.MentionEveryone, PermValue.Inherit); + TestHelper(value, ChannelPermission.UseExternalEmojis, PermValue.Inherit); + TestHelper(value, ChannelPermission.ManageRoles, PermValue.Inherit); + TestHelper(value, ChannelPermission.ManageWebhooks, PermValue.Inherit); + TestHelper(value, ChannelPermission.Connect, PermValue.Allow); + TestHelper(value, ChannelPermission.Speak, PermValue.Allow); + TestHelper(value, ChannelPermission.MuteMembers, PermValue.Inherit); + TestHelper(value, ChannelPermission.DeafenMembers, PermValue.Inherit); + TestHelper(value, ChannelPermission.MoveMembers, PermValue.Inherit); + TestHelper(value, ChannelPermission.UseVAD, PermValue.Allow); + + value = new OverwritePermissions(ChannelPermissions.None.RawValue, ChannelPermissions.Group.RawValue); + + TestHelper(value, ChannelPermission.CreateInstantInvite, PermValue.Inherit); + TestHelper(value, ChannelPermission.ManageChannels, PermValue.Inherit); + TestHelper(value, ChannelPermission.AddReactions, PermValue.Inherit); + TestHelper(value, ChannelPermission.ViewChannel, PermValue.Inherit); + TestHelper(value, ChannelPermission.SendMessages, PermValue.Deny); + TestHelper(value, ChannelPermission.SendTTSMessages, PermValue.Deny); + TestHelper(value, ChannelPermission.ManageMessages, PermValue.Inherit); + TestHelper(value, ChannelPermission.EmbedLinks, PermValue.Deny); + TestHelper(value, ChannelPermission.AttachFiles, PermValue.Deny); + TestHelper(value, ChannelPermission.ReadMessageHistory, PermValue.Inherit); + TestHelper(value, ChannelPermission.MentionEveryone, PermValue.Inherit); + TestHelper(value, ChannelPermission.UseExternalEmojis, PermValue.Inherit); + TestHelper(value, ChannelPermission.ManageRoles, PermValue.Inherit); + TestHelper(value, ChannelPermission.ManageWebhooks, PermValue.Inherit); + TestHelper(value, ChannelPermission.Connect, PermValue.Deny); + TestHelper(value, ChannelPermission.Speak, PermValue.Deny); + TestHelper(value, ChannelPermission.MuteMembers, PermValue.Inherit); + TestHelper(value, ChannelPermission.DeafenMembers, PermValue.Inherit); + TestHelper(value, ChannelPermission.MoveMembers, PermValue.Inherit); + TestHelper(value, ChannelPermission.UseVAD, PermValue.Deny); + + return Task.CompletedTask; + } + + /// + /// Test + /// for when all group permissions are allowed and denied + /// + /// + [Fact] + public Task TestOverwritePermissionsVoice() + { + // allow all for group channels + var value = new OverwritePermissions(ChannelPermissions.Voice.RawValue, ChannelPermissions.None.RawValue); + + TestHelper(value, ChannelPermission.CreateInstantInvite, PermValue.Allow); + TestHelper(value, ChannelPermission.ManageChannels, PermValue.Allow); + TestHelper(value, ChannelPermission.AddReactions, PermValue.Inherit); + TestHelper(value, ChannelPermission.ViewChannel, PermValue.Inherit); + TestHelper(value, ChannelPermission.SendMessages, PermValue.Inherit); + TestHelper(value, ChannelPermission.SendTTSMessages, PermValue.Inherit); + TestHelper(value, ChannelPermission.ManageMessages, PermValue.Inherit); + TestHelper(value, ChannelPermission.EmbedLinks, PermValue.Inherit); + TestHelper(value, ChannelPermission.AttachFiles, PermValue.Inherit); + TestHelper(value, ChannelPermission.ReadMessageHistory, PermValue.Inherit); + TestHelper(value, ChannelPermission.MentionEveryone, PermValue.Inherit); + TestHelper(value, ChannelPermission.UseExternalEmojis, PermValue.Inherit); + TestHelper(value, ChannelPermission.ManageRoles, PermValue.Allow); + TestHelper(value, ChannelPermission.ManageWebhooks, PermValue.Inherit); + TestHelper(value, ChannelPermission.Connect, PermValue.Allow); + TestHelper(value, ChannelPermission.Speak, PermValue.Allow); + TestHelper(value, ChannelPermission.MuteMembers, PermValue.Allow); + TestHelper(value, ChannelPermission.DeafenMembers, PermValue.Allow); + TestHelper(value, ChannelPermission.MoveMembers, PermValue.Allow); + TestHelper(value, ChannelPermission.UseVAD, PermValue.Allow); + + value = new OverwritePermissions(ChannelPermissions.None.RawValue, ChannelPermissions.Voice.RawValue); + + TestHelper(value, ChannelPermission.CreateInstantInvite, PermValue.Deny); + TestHelper(value, ChannelPermission.ManageChannels, PermValue.Deny); + TestHelper(value, ChannelPermission.AddReactions, PermValue.Inherit); + TestHelper(value, ChannelPermission.ViewChannel, PermValue.Inherit); + TestHelper(value, ChannelPermission.SendMessages, PermValue.Inherit); + TestHelper(value, ChannelPermission.SendTTSMessages, PermValue.Inherit); + TestHelper(value, ChannelPermission.ManageMessages, PermValue.Inherit); + TestHelper(value, ChannelPermission.EmbedLinks, PermValue.Inherit); + TestHelper(value, ChannelPermission.AttachFiles, PermValue.Inherit); + TestHelper(value, ChannelPermission.ReadMessageHistory, PermValue.Inherit); + TestHelper(value, ChannelPermission.MentionEveryone, PermValue.Inherit); + TestHelper(value, ChannelPermission.UseExternalEmojis, PermValue.Inherit); + TestHelper(value, ChannelPermission.ManageRoles, PermValue.Deny); + TestHelper(value, ChannelPermission.ManageWebhooks, PermValue.Inherit); + TestHelper(value, ChannelPermission.Connect, PermValue.Deny); + TestHelper(value, ChannelPermission.Speak, PermValue.Deny); + TestHelper(value, ChannelPermission.MuteMembers, PermValue.Deny); + TestHelper(value, ChannelPermission.DeafenMembers, PermValue.Deny); + TestHelper(value, ChannelPermission.MoveMembers, PermValue.Deny); + TestHelper(value, ChannelPermission.UseVAD, PermValue.Deny); + + return Task.CompletedTask; + } + } +}