@@ -8,6 +8,7 @@ using System.Linq;
using System.Reflection;
using System.Reflection;
using System.Runtime.ExceptionServices;
using System.Runtime.ExceptionServices;
using System.Threading.Tasks;
using System.Threading.Tasks;
using Microsoft.Extensions.DependencyInjection;
namespace Discord.Commands
namespace Discord.Commands
{
{
@@ -17,7 +18,7 @@ namespace Discord.Commands
private static readonly System.Reflection.MethodInfo _convertParamsMethod = typeof(CommandInfo).GetTypeInfo().GetDeclaredMethod(nameof(ConvertParamsList));
private static readonly System.Reflection.MethodInfo _convertParamsMethod = typeof(CommandInfo).GetTypeInfo().GetDeclaredMethod(nameof(ConvertParamsList));
private static readonly ConcurrentDictionary<Type, Func<IEnumerable<object>, object>> _arrayConverters = new ConcurrentDictionary<Type, Func<IEnumerable<object>, object>>();
private static readonly ConcurrentDictionary<Type, Func<IEnumerable<object>, object>> _arrayConverters = new ConcurrentDictionary<Type, Func<IEnumerable<object>, object>>();
private readonly Func<ICommandContext, object[], IDependencyMap , Task> _action;
private readonly Func<ICommandContext, object[], IServiceProvider , Task> _action;
public ModuleInfo Module { get; }
public ModuleInfo Module { get; }
public string Name { get; }
public string Name { get; }
@@ -63,21 +64,20 @@ namespace Discord.Commands
_action = builder.Callback;
_action = builder.Callback;
}
}
public async Task<PreconditionResult> CheckPreconditionsAsync(ICommandContext context, IDependencyMap map = null)
public async Task<PreconditionResult> CheckPreconditionsAsync(ICommandContext context, IServiceProvider services = null)
{
{
if (map == null)
map = DependencyMap.Empty;
services = services ?? EmptyServiceProvider.Instance;
foreach (PreconditionAttribute precondition in Module.Preconditions)
foreach (PreconditionAttribute precondition in Module.Preconditions)
{
{
var result = await precondition.CheckPermissions(context, this, map ).ConfigureAwait(false);
var result = await precondition.CheckPermissions(context, this, services ).ConfigureAwait(false);
if (!result.IsSuccess)
if (!result.IsSuccess)
return result;
return result;
}
}
foreach (PreconditionAttribute precondition in Preconditions)
foreach (PreconditionAttribute precondition in Preconditions)
{
{
var result = await precondition.CheckPermissions(context, this, map ).ConfigureAwait(false);
var result = await precondition.CheckPermissions(context, this, services ).ConfigureAwait(false);
if (!result.IsSuccess)
if (!result.IsSuccess)
return result;
return result;
}
}
@@ -96,7 +96,7 @@ namespace Discord.Commands
return await CommandParser.ParseArgs(this, context, input, 0).ConfigureAwait(false);
return await CommandParser.ParseArgs(this, context, input, 0).ConfigureAwait(false);
}
}
public Task<ExecuteResult> ExecuteAsync(ICommandContext context, ParseResult parseResult, IDependencyMap map )
public Task<ExecuteResult> ExecuteAsync(ICommandContext context, ParseResult parseResult, IServiceProvider services )
{
{
if (!parseResult.IsSuccess)
if (!parseResult.IsSuccess)
return Task.FromResult(ExecuteResult.FromError(parseResult));
return Task.FromResult(ExecuteResult.FromError(parseResult));
@@ -117,12 +117,11 @@ namespace Discord.Commands
paramList[i] = parseResult.ParamValues[i].Values.First().Value;
paramList[i] = parseResult.ParamValues[i].Values.First().Value;
}
}
return ExecuteAsync(context, argList, paramList, map );
return ExecuteAsync(context, argList, paramList, services );
}
}
public async Task<ExecuteResult> ExecuteAsync(ICommandContext context, IEnumerable<object> argList, IEnumerable<object> paramList, IDependencyMap map )
public async Task<ExecuteResult> ExecuteAsync(ICommandContext context, IEnumerable<object> argList, IEnumerable<object> paramList, IServiceProvider services )
{
{
if (map == null)
map = DependencyMap.Empty;
services = services ?? EmptyServiceProvider.Instance;
try
try
{
{
@@ -132,7 +131,7 @@ namespace Discord.Commands
{
{
var parameter = Parameters[position];
var parameter = Parameters[position];
var argument = args[position];
var argument = args[position];
var result = await parameter.CheckPreconditionsAsync(context, argument, map ).ConfigureAwait(false);
var result = await parameter.CheckPreconditionsAsync(context, argument, services ).ConfigureAwait(false);
if (!result.IsSuccess)
if (!result.IsSuccess)
return ExecuteResult.FromError(result);
return ExecuteResult.FromError(result);
}
}
@@ -140,12 +139,12 @@ namespace Discord.Commands
switch (RunMode)
switch (RunMode)
{
{
case RunMode.Sync: //Always sync
case RunMode.Sync: //Always sync
await ExecuteAsyncInternal(context, args, map ).ConfigureAwait(false);
await ExecuteAsyncInternal(context, args, services ).ConfigureAwait(false);
break;
break;
case RunMode.Async: //Always async
case RunMode.Async: //Always async
var t2 = Task.Run(async () =>
var t2 = Task.Run(async () =>
{
{
await ExecuteAsyncInternal(context, args, map ).ConfigureAwait(false);
await ExecuteAsyncInternal(context, args, services ).ConfigureAwait(false);
});
});
break;
break;
}
}
@@ -157,12 +156,12 @@ namespace Discord.Commands
}
}
}
}
private async Task ExecuteAsyncInternal(ICommandContext context, object[] args, IDependencyMap map )
private async Task ExecuteAsyncInternal(ICommandContext context, object[] args, IServiceProvider services )
{
{
await Module.Service._cmdLogger.DebugAsync($"Executing {GetLogText(context)}").ConfigureAwait(false);
await Module.Service._cmdLogger.DebugAsync($"Executing {GetLogText(context)}").ConfigureAwait(false);
try
try
{
{
await _action(context, args, map ).ConfigureAwait(false);
await _action(context, args, services ).ConfigureAwait(false);
}
}
catch (Exception ex)
catch (Exception ex)
{
{