diff --git a/src/Discord.Net.Commands/Builders/ModuleClassBuilder.cs b/src/Discord.Net.Commands/Builders/ModuleClassBuilder.cs index 2df1d805f..be3449b6f 100644 --- a/src/Discord.Net.Commands/Builders/ModuleClassBuilder.cs +++ b/src/Discord.Net.Commands/Builders/ModuleClassBuilder.cs @@ -197,22 +197,34 @@ namespace Discord.Commands var createInstance = ReflectionUtils.CreateBuilder(typeInfo, service); - builder.Callback = async (ctx, args, map, cmd) => + async Task ExecuteCallback(ICommandContext context, object[] args, IServiceProvider services, CommandInfo cmd) { - var instance = createInstance(map); - instance.SetContext(ctx); + var instance = createInstance(services); + instance.SetContext(context); + try { instance.BeforeExecute(cmd); + var task = method.Invoke(instance, args) as Task ?? Task.Delay(0); - await task.ConfigureAwait(false); + if (task is Task resultTask) + { + return await resultTask.ConfigureAwait(false); + } + else + { + await task.ConfigureAwait(false); + return ExecuteResult.FromSuccess(); + } } finally { instance.AfterExecute(cmd); (instance as IDisposable)?.Dispose(); } - }; + } + + builder.Callback = ExecuteCallback; } private static void BuildParameter(ParameterBuilder builder, System.Reflection.ParameterInfo paramInfo, int position, int count, CommandService service) @@ -293,7 +305,7 @@ namespace Discord.Commands private static bool IsValidCommandDefinition(MethodInfo methodInfo) { return methodInfo.IsDefined(typeof(CommandAttribute)) && - (methodInfo.ReturnType == typeof(Task) || methodInfo.ReturnType == typeof(void)) && + (methodInfo.ReturnType == typeof(Task) || methodInfo.ReturnType == typeof(Task)) && !methodInfo.IsStatic && !methodInfo.IsGenericMethod; } diff --git a/src/Discord.Net.Commands/CommandError.cs b/src/Discord.Net.Commands/CommandError.cs index 41b4822ad..abfc14e1d 100644 --- a/src/Discord.Net.Commands/CommandError.cs +++ b/src/Discord.Net.Commands/CommandError.cs @@ -18,6 +18,9 @@ UnmetPrecondition, //Execute - Exception + Exception, + + //Runtime + Unsuccessful } } diff --git a/src/Discord.Net.Commands/CommandMatch.cs b/src/Discord.Net.Commands/CommandMatch.cs index d2bd9ef03..d922a2229 100644 --- a/src/Discord.Net.Commands/CommandMatch.cs +++ b/src/Discord.Net.Commands/CommandMatch.cs @@ -20,9 +20,9 @@ namespace Discord.Commands => Command.CheckPreconditionsAsync(context, services); public Task ParseAsync(ICommandContext context, SearchResult searchResult, PreconditionResult preconditionResult = null, IServiceProvider services = null) => Command.ParseAsync(context, Alias.Length, searchResult, preconditionResult, services); - public Task ExecuteAsync(ICommandContext context, IEnumerable argList, IEnumerable paramList, IServiceProvider services) + public Task ExecuteAsync(ICommandContext context, IEnumerable argList, IEnumerable paramList, IServiceProvider services) => Command.ExecuteAsync(context, argList, paramList, services); - public Task ExecuteAsync(ICommandContext context, ParseResult parseResult, IServiceProvider services) + public Task ExecuteAsync(ICommandContext context, ParseResult parseResult, IServiceProvider services) => Command.ExecuteAsync(context, parseResult, services); } } diff --git a/src/Discord.Net.Commands/Info/CommandInfo.cs b/src/Discord.Net.Commands/Info/CommandInfo.cs index 00041f22d..60df2a6a9 100644 --- a/src/Discord.Net.Commands/Info/CommandInfo.cs +++ b/src/Discord.Net.Commands/Info/CommandInfo.cs @@ -36,14 +36,14 @@ namespace Discord.Commands internal CommandInfo(CommandBuilder builder, ModuleInfo module, CommandService service) { Module = module; - + Name = builder.Name; Summary = builder.Summary; Remarks = builder.Remarks; RunMode = (builder.RunMode == RunMode.Default ? service._defaultRunMode : builder.RunMode); Priority = builder.Priority; - + Aliases = module.Aliases .Permutate(builder.Aliases, (first, second) => { @@ -106,7 +106,7 @@ namespace Discord.Commands return PreconditionResult.FromSuccess(); } - + public async Task ParseAsync(ICommandContext context, int startIndex, SearchResult searchResult, PreconditionResult preconditionResult = null, IServiceProvider services = null) { services = services ?? EmptyServiceProvider.Instance; @@ -115,35 +115,35 @@ namespace Discord.Commands return ParseResult.FromError(searchResult); if (preconditionResult != null && !preconditionResult.IsSuccess) return ParseResult.FromError(preconditionResult); - + string input = searchResult.Text.Substring(startIndex); return await CommandParser.ParseArgs(this, context, services, input, 0).ConfigureAwait(false); } - public Task ExecuteAsync(ICommandContext context, ParseResult parseResult, IServiceProvider services) + public Task ExecuteAsync(ICommandContext context, ParseResult parseResult, IServiceProvider services) { if (!parseResult.IsSuccess) - return Task.FromResult(ExecuteResult.FromError(parseResult)); + return Task.FromResult((IResult)ExecuteResult.FromError(parseResult)); var argList = new object[parseResult.ArgValues.Count]; for (int i = 0; i < parseResult.ArgValues.Count; i++) { if (!parseResult.ArgValues[i].IsSuccess) - return Task.FromResult(ExecuteResult.FromError(parseResult.ArgValues[i])); + return Task.FromResult((IResult)ExecuteResult.FromError(parseResult.ArgValues[i])); argList[i] = parseResult.ArgValues[i].Values.First().Value; } - + var paramList = new object[parseResult.ParamValues.Count]; for (int i = 0; i < parseResult.ParamValues.Count; i++) { if (!parseResult.ParamValues[i].IsSuccess) - return Task.FromResult(ExecuteResult.FromError(parseResult.ParamValues[i])); + return Task.FromResult((IResult)ExecuteResult.FromError(parseResult.ParamValues[i])); paramList[i] = parseResult.ParamValues[i].Values.First().Value; } return ExecuteAsync(context, argList, paramList, services); } - public async Task ExecuteAsync(ICommandContext context, IEnumerable argList, IEnumerable paramList, IServiceProvider services) + public async Task ExecuteAsync(ICommandContext context, IEnumerable argList, IEnumerable paramList, IServiceProvider services) { services = services ?? EmptyServiceProvider.Instance; @@ -163,10 +163,9 @@ namespace Discord.Commands switch (RunMode) { case RunMode.Sync: //Always sync - await ExecuteAsyncInternal(context, args, services).ConfigureAwait(false); - break; + return await ExecuteAsyncInternal(context, args, services).ConfigureAwait(false); case RunMode.Async: //Always async - var t2 = Task.Run(async () => + var t2 = Task.Run(async () => { await ExecuteAsyncInternal(context, args, services).ConfigureAwait(false); }); @@ -180,12 +179,26 @@ namespace Discord.Commands } } - private async Task ExecuteAsyncInternal(ICommandContext context, object[] args, IServiceProvider services) + private async Task ExecuteAsyncInternal(ICommandContext context, object[] args, IServiceProvider services) { await Module.Service._cmdLogger.DebugAsync($"Executing {GetLogText(context)}").ConfigureAwait(false); try { - await _action(context, args, services, this).ConfigureAwait(false); + var task = _action(context, args, services, this); + if (task is Task resultTask) + { + var result = await resultTask.ConfigureAwait(false); + if (result is RuntimeResult execResult) + return execResult; + } + else if (task is Task execTask) + { + return await execTask.ConfigureAwait(false); + } + else + await task.ConfigureAwait(false); + + return ExecuteResult.FromSuccess(); } catch (Exception ex) { @@ -202,8 +215,13 @@ namespace Discord.Commands else ExceptionDispatchInfo.Capture(ex).Throw(); } + + return ExecuteResult.FromError(CommandError.Exception, ex.Message); + } + finally + { + await Module.Service._cmdLogger.VerboseAsync($"Executed {GetLogText(context)}").ConfigureAwait(false); } - await Module.Service._cmdLogger.VerboseAsync($"Executed {GetLogText(context)}").ConfigureAwait(false); } private object[] GenerateArgs(IEnumerable argList, IEnumerable paramsList) @@ -240,7 +258,7 @@ namespace Discord.Commands => paramsList.Cast().ToArray(); internal string GetLogText(ICommandContext context) - { + { if (context.Guild != null) return $"\"{Name}\" for {context.User} in {context.Guild}/{context.Channel}"; else diff --git a/src/Discord.Net.Commands/Results/RuntimeResult.cs b/src/Discord.Net.Commands/Results/RuntimeResult.cs new file mode 100644 index 000000000..2a326a7a3 --- /dev/null +++ b/src/Discord.Net.Commands/Results/RuntimeResult.cs @@ -0,0 +1,27 @@ +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Text; + +namespace Discord.Commands +{ + [DebuggerDisplay(@"{DebuggerDisplay,nq}")] + public abstract class RuntimeResult : IResult + { + protected RuntimeResult(CommandError? error, string reason) + { + Error = error; + Reason = reason; + } + + public CommandError? Error { get; } + public string Reason { get; } + + public bool IsSuccess => !Error.HasValue; + + string IResult.ErrorReason => Reason; + + public override string ToString() => Reason ?? (IsSuccess ? "Successful" : "Unsuccessful"); + private string DebuggerDisplay => IsSuccess ? $"Success: {Reason ?? "No Reason"}" : $"{Error}: {Reason}"; + } +}