diff --git a/src/Discord.Net.Commands/Attributes/ParameterPreconditionAttribute.cs b/src/Discord.Net.Commands/Attributes/ParameterPreconditionAttribute.cs index 168d15e5f..49dae6080 100644 --- a/src/Discord.Net.Commands/Attributes/ParameterPreconditionAttribute.cs +++ b/src/Discord.Net.Commands/Attributes/ParameterPreconditionAttribute.cs @@ -1,11 +1,12 @@ using System; using System.Threading.Tasks; +using Microsoft.Extensions.DependencyInjection; namespace Discord.Commands { [AttributeUsage(AttributeTargets.Parameter, AllowMultiple = true, Inherited = true)] public abstract class ParameterPreconditionAttribute : Attribute { - public abstract Task CheckPermissions(ICommandContext context, ParameterInfo parameter, object value, IDependencyMap map); + public abstract Task CheckPermissions(ICommandContext context, ParameterInfo parameter, object value, IServiceProvider services); } } \ No newline at end of file diff --git a/src/Discord.Net.Commands/Attributes/PreconditionAttribute.cs b/src/Discord.Net.Commands/Attributes/PreconditionAttribute.cs index 7755d459b..e099380f6 100644 --- a/src/Discord.Net.Commands/Attributes/PreconditionAttribute.cs +++ b/src/Discord.Net.Commands/Attributes/PreconditionAttribute.cs @@ -6,6 +6,6 @@ namespace Discord.Commands [AttributeUsage(AttributeTargets.Method | AttributeTargets.Class, AllowMultiple = true, Inherited = true)] public abstract class PreconditionAttribute : Attribute { - public abstract Task CheckPermissions(ICommandContext context, CommandInfo command, IDependencyMap map); + public abstract Task CheckPermissions(ICommandContext context, CommandInfo command, IServiceProvider services); } } diff --git a/src/Discord.Net.Commands/Attributes/Preconditions/RequireBotPermissionAttribute.cs b/src/Discord.Net.Commands/Attributes/Preconditions/RequireBotPermissionAttribute.cs index 520cfa6fd..82975a2f6 100644 --- a/src/Discord.Net.Commands/Attributes/Preconditions/RequireBotPermissionAttribute.cs +++ b/src/Discord.Net.Commands/Attributes/Preconditions/RequireBotPermissionAttribute.cs @@ -1,5 +1,6 @@ using System; using System.Threading.Tasks; +using Microsoft.Extensions.DependencyInjection; namespace Discord.Commands { @@ -41,7 +42,7 @@ namespace Discord.Commands GuildPermission = null; } - public override async Task CheckPermissions(ICommandContext context, CommandInfo command, IDependencyMap map) + public override async Task CheckPermissions(ICommandContext context, CommandInfo command, IServiceProvider services) { var guildUser = await context.Guild.GetCurrentUserAsync(); diff --git a/src/Discord.Net.Commands/Attributes/Preconditions/RequireContextAttribute.cs b/src/Discord.Net.Commands/Attributes/Preconditions/RequireContextAttribute.cs index 42d835c30..a221eb4a9 100644 --- a/src/Discord.Net.Commands/Attributes/Preconditions/RequireContextAttribute.cs +++ b/src/Discord.Net.Commands/Attributes/Preconditions/RequireContextAttribute.cs @@ -1,5 +1,6 @@ using System; using System.Threading.Tasks; +using Microsoft.Extensions.DependencyInjection; namespace Discord.Commands { @@ -37,7 +38,7 @@ namespace Discord.Commands Contexts = contexts; } - public override Task CheckPermissions(ICommandContext context, CommandInfo command, IDependencyMap map) + public override Task CheckPermissions(ICommandContext context, CommandInfo command, IServiceProvider services) { bool isValid = false; diff --git a/src/Discord.Net.Commands/Attributes/Preconditions/RequireOwnerAttribute.cs b/src/Discord.Net.Commands/Attributes/Preconditions/RequireOwnerAttribute.cs index 0f4e8255d..0852ce39c 100644 --- a/src/Discord.Net.Commands/Attributes/Preconditions/RequireOwnerAttribute.cs +++ b/src/Discord.Net.Commands/Attributes/Preconditions/RequireOwnerAttribute.cs @@ -1,5 +1,6 @@ using System; using System.Threading.Tasks; +using Microsoft.Extensions.DependencyInjection; namespace Discord.Commands { @@ -10,7 +11,7 @@ namespace Discord.Commands [AttributeUsage(AttributeTargets.Class | AttributeTargets.Method, AllowMultiple = false, Inherited = true)] public class RequireOwnerAttribute : PreconditionAttribute { - public override async Task CheckPermissions(ICommandContext context, CommandInfo command, IDependencyMap map) + public override async Task CheckPermissions(ICommandContext context, CommandInfo command, IServiceProvider services) { switch (context.Client.TokenType) { diff --git a/src/Discord.Net.Commands/Attributes/Preconditions/RequireUserPermissionAttribute.cs b/src/Discord.Net.Commands/Attributes/Preconditions/RequireUserPermissionAttribute.cs index c5b79c5b9..44c69d76a 100644 --- a/src/Discord.Net.Commands/Attributes/Preconditions/RequireUserPermissionAttribute.cs +++ b/src/Discord.Net.Commands/Attributes/Preconditions/RequireUserPermissionAttribute.cs @@ -1,5 +1,6 @@ using System; using System.Threading.Tasks; +using Microsoft.Extensions.DependencyInjection; namespace Discord.Commands { @@ -42,7 +43,7 @@ namespace Discord.Commands GuildPermission = null; } - public override Task CheckPermissions(ICommandContext context, CommandInfo command, IDependencyMap map) + public override Task CheckPermissions(ICommandContext context, CommandInfo command, IServiceProvider services) { var guildUser = context.User as IGuildUser; diff --git a/src/Discord.Net.Commands/Builders/CommandBuilder.cs b/src/Discord.Net.Commands/Builders/CommandBuilder.cs index c13ca10d4..ff89b7559 100644 --- a/src/Discord.Net.Commands/Builders/CommandBuilder.cs +++ b/src/Discord.Net.Commands/Builders/CommandBuilder.cs @@ -2,6 +2,7 @@ using System.Linq; using System.Threading.Tasks; using System.Collections.Generic; +using Microsoft.Extensions.DependencyInjection; namespace Discord.Commands.Builders { @@ -12,7 +13,7 @@ namespace Discord.Commands.Builders private readonly List _aliases; public ModuleBuilder Module { get; } - internal Func Callback { get; set; } + internal Func Callback { get; set; } public string Name { get; set; } public string Summary { get; set; } @@ -35,7 +36,7 @@ namespace Discord.Commands.Builders _aliases = new List(); } //User-defined - internal CommandBuilder(ModuleBuilder module, string primaryAlias, Func callback) + internal CommandBuilder(ModuleBuilder module, string primaryAlias, Func callback) : this(module) { Discord.Preconditions.NotNull(primaryAlias, nameof(primaryAlias)); diff --git a/src/Discord.Net.Commands/Builders/ModuleBuilder.cs b/src/Discord.Net.Commands/Builders/ModuleBuilder.cs index 45c0034f2..d79239057 100644 --- a/src/Discord.Net.Commands/Builders/ModuleBuilder.cs +++ b/src/Discord.Net.Commands/Builders/ModuleBuilder.cs @@ -1,6 +1,7 @@ using System; using System.Collections.Generic; using System.Threading.Tasks; +using Microsoft.Extensions.DependencyInjection; namespace Discord.Commands.Builders { @@ -73,7 +74,7 @@ namespace Discord.Commands.Builders _preconditions.Add(precondition); return this; } - public ModuleBuilder AddCommand(string primaryAlias, Func callback, Action createFunc) + public ModuleBuilder AddCommand(string primaryAlias, Func callback, Action createFunc) { var builder = new CommandBuilder(this, primaryAlias, callback); createFunc(builder); diff --git a/src/Discord.Net.Commands/Builders/ModuleClassBuilder.cs b/src/Discord.Net.Commands/Builders/ModuleClassBuilder.cs index 25b6e034b..d8464ea72 100644 --- a/src/Discord.Net.Commands/Builders/ModuleClassBuilder.cs +++ b/src/Discord.Net.Commands/Builders/ModuleClassBuilder.cs @@ -243,7 +243,7 @@ namespace Discord.Commands } //We dont have a cached type reader, create one - reader = ReflectionUtils.CreateObject(typeReaderType.GetTypeInfo(), service, DependencyMap.Empty); + reader = ReflectionUtils.CreateObject(typeReaderType.GetTypeInfo(), service, EmptyServiceProvider.Instance); service.AddTypeReader(paramType, reader); return reader; diff --git a/src/Discord.Net.Commands/CommandMatch.cs b/src/Discord.Net.Commands/CommandMatch.cs index 6e78b8509..04a2d040f 100644 --- a/src/Discord.Net.Commands/CommandMatch.cs +++ b/src/Discord.Net.Commands/CommandMatch.cs @@ -1,5 +1,7 @@ -using System.Collections.Generic; +using System; +using System.Collections.Generic; using System.Threading.Tasks; +using Microsoft.Extensions.DependencyInjection; namespace Discord.Commands { @@ -14,13 +16,13 @@ namespace Discord.Commands Alias = alias; } - public Task CheckPreconditionsAsync(ICommandContext context, IDependencyMap map = null) - => Command.CheckPreconditionsAsync(context, map); + public Task CheckPreconditionsAsync(ICommandContext context, IServiceProvider services = null) + => Command.CheckPreconditionsAsync(context, services); public Task ParseAsync(ICommandContext context, SearchResult searchResult, PreconditionResult? preconditionResult = null) => Command.ParseAsync(context, Alias.Length, searchResult, preconditionResult); - public Task ExecuteAsync(ICommandContext context, IEnumerable argList, IEnumerable paramList, IDependencyMap map) - => Command.ExecuteAsync(context, argList, paramList, map); - public Task ExecuteAsync(ICommandContext context, ParseResult parseResult, IDependencyMap map) - => Command.ExecuteAsync(context, parseResult, map); + public Task ExecuteAsync(ICommandContext context, IEnumerable argList, IEnumerable paramList, IServiceProvider services) + => Command.ExecuteAsync(context, argList, paramList, services); + public Task ExecuteAsync(ICommandContext context, ParseResult parseResult, IServiceProvider services) + => Command.ExecuteAsync(context, parseResult, services); } } diff --git a/src/Discord.Net.Commands/CommandService.cs b/src/Discord.Net.Commands/CommandService.cs index 945db33a8..bcfb54d96 100644 --- a/src/Discord.Net.Commands/CommandService.cs +++ b/src/Discord.Net.Commands/CommandService.cs @@ -8,6 +8,7 @@ using System.Linq; using System.Reflection; using System.Threading; using System.Threading.Tasks; +using Microsoft.Extensions.DependencyInjection; namespace Discord.Commands { @@ -247,11 +248,11 @@ namespace Discord.Commands return SearchResult.FromError(CommandError.UnknownCommand, "Unknown command."); } - public Task ExecuteAsync(ICommandContext context, int argPos, IDependencyMap dependencyMap = null, MultiMatchHandling multiMatchHandling = MultiMatchHandling.Exception) - => ExecuteAsync(context, context.Message.Content.Substring(argPos), dependencyMap, multiMatchHandling); - public async Task ExecuteAsync(ICommandContext context, string input, IDependencyMap dependencyMap = null, MultiMatchHandling multiMatchHandling = MultiMatchHandling.Exception) + public Task ExecuteAsync(ICommandContext context, int argPos, IServiceProvider services = null, MultiMatchHandling multiMatchHandling = MultiMatchHandling.Exception) + => ExecuteAsync(context, context.Message.Content.Substring(argPos), services, multiMatchHandling); + public async Task ExecuteAsync(ICommandContext context, string input, IServiceProvider services = null, MultiMatchHandling multiMatchHandling = MultiMatchHandling.Exception) { - dependencyMap = dependencyMap ?? DependencyMap.Empty; + services = services ?? EmptyServiceProvider.Instance; var searchResult = Search(context, input); if (!searchResult.IsSuccess) @@ -260,7 +261,7 @@ namespace Discord.Commands var commands = searchResult.Commands; for (int i = 0; i < commands.Count; i++) { - var preconditionResult = await commands[i].CheckPreconditionsAsync(context, dependencyMap).ConfigureAwait(false); + var preconditionResult = await commands[i].CheckPreconditionsAsync(context, services).ConfigureAwait(false); if (!preconditionResult.IsSuccess) { if (commands.Count == 1) @@ -294,10 +295,19 @@ namespace Discord.Commands } } - return await commands[i].ExecuteAsync(context, parseResult, dependencyMap).ConfigureAwait(false); + return await commands[i].ExecuteAsync(context, parseResult, services).ConfigureAwait(false); } return SearchResult.FromError(CommandError.UnknownCommand, "This input does not match any overload."); } + + public ServiceCollection CreateServiceCollection() + { + var serviceCollection = new ServiceCollection(); + serviceCollection.AddSingleton(this); + serviceCollection.AddSingleton(serviceCollection); + serviceCollection.AddSingleton(serviceCollection); + return serviceCollection; + } } } diff --git a/src/Discord.Net.Commands/Dependencies/DependencyMap.cs b/src/Discord.Net.Commands/Dependencies/DependencyMap.cs deleted file mode 100644 index 55092961a..000000000 --- a/src/Discord.Net.Commands/Dependencies/DependencyMap.cs +++ /dev/null @@ -1,102 +0,0 @@ -using System; -using System.Collections.Generic; -using System.Linq; - -namespace Discord.Commands -{ - public class DependencyMap : IDependencyMap - { - private static readonly Type[] _typeBlacklist = new[] { - typeof(IDependencyMap), - typeof(CommandService) - }; - - private Dictionary> map; - - public static DependencyMap Empty => new DependencyMap(); - - public DependencyMap() - { - map = new Dictionary>(); - } - - /// - public void Add(T obj) where T : class - => AddFactory(() => obj); - /// - public bool TryAdd(T obj) where T : class - => TryAddFactory(() => obj); - /// - public void AddTransient() where T : class, new() - => AddFactory(() => new T()); - /// - public bool TryAddTransient() where T : class, new() - => TryAddFactory(() => new T()); - /// - public void AddTransient() where TKey : class - where TImpl : class, TKey, new() - => AddFactory(() => new TImpl()); - public bool TryAddTransient() where TKey : class - where TImpl : class, TKey, new() - => TryAddFactory(() => new TImpl()); - - /// - public void AddFactory(Func factory) where T : class - { - if (!TryAddFactory(factory)) - throw new InvalidOperationException($"The dependency map already contains \"{typeof(T).FullName}\""); - } - /// - public bool TryAddFactory(Func factory) where T : class - { - var type = typeof(T); - if (_typeBlacklist.Contains(type) || map.ContainsKey(type)) - return false; - map.Add(type, factory); - return true; - } - - /// - public T Get() where T : class - { - return (T)Get(typeof(T)); - } - /// - public object Get(Type t) - { - object result; - if (!TryGet(t, out result)) - throw new KeyNotFoundException($"The dependency map does not contain \"{t.FullName}\""); - else - return result; - } - - /// - public bool TryGet(out T result) where T : class - { - object untypedResult; - if (TryGet(typeof(T), out untypedResult)) - { - result = (T)untypedResult; - return true; - } - else - { - result = default(T); - return false; - } - } - /// - public bool TryGet(Type t, out object result) - { - Func func; - if (map.TryGetValue(t, out func)) - { - result = func(); - return true; - } - result = null; - return false; - } - } -} diff --git a/src/Discord.Net.Commands/Dependencies/IDependencyMap.cs b/src/Discord.Net.Commands/Dependencies/IDependencyMap.cs deleted file mode 100644 index fa76709b6..000000000 --- a/src/Discord.Net.Commands/Dependencies/IDependencyMap.cs +++ /dev/null @@ -1,89 +0,0 @@ -using System; - -namespace Discord.Commands -{ - public interface IDependencyMap - { - /// - /// Add an instance of a service to be injected. - /// - /// The type of service. - /// The instance of a service. - void Add(T obj) where T : class; - /// - /// Tries to add an instance of a service to be injected. - /// - /// The type of service. - /// The instance of a service. - /// A bool, indicating if the service was successfully added to the DependencyMap. - bool TryAdd(T obj) where T : class; - /// - /// Add a service that will be injected by a new instance every time. - /// - /// The type of instance to inject. - void AddTransient() where T : class, new(); - /// - /// Tries to add a service that will be injected by a new instance every time. - /// - /// The type of instance to inject. - /// A bool, indicating if the service was successfully added to the DependencyMap. - bool TryAddTransient() where T : class, new(); - /// - /// Add a service that will be injected by a new instance every time. - /// - /// The type to look for when injecting. - /// The type to inject when injecting. - /// - /// map.AddTransient<IService, Service> - /// - void AddTransient() where TKey: class where TImpl : class, TKey, new(); - /// - /// Tries to add a service that will be injected by a new instance every time. - /// - /// The type to look for when injecting. - /// The type to inject when injecting. - /// A bool, indicating if the service was successfully added to the DependencyMap. - bool TryAddTransient() where TKey : class where TImpl : class, TKey, new(); - /// - /// Add a service that will be injected by a factory. - /// - /// The type to look for when injecting. - /// The factory that returns a type of this service. - void AddFactory(Func factory) where T : class; - /// - /// Tries to add a service that will be injected by a factory. - /// - /// The type to look for when injecting. - /// The factory that returns a type of this service. - /// A bool, indicating if the service was successfully added to the DependencyMap. - bool TryAddFactory(Func factory) where T : class; - - /// - /// Pull an object from the map. - /// - /// The type of service. - /// An instance of this service. - T Get() where T : class; - /// - /// Try to pull an object from the map. - /// - /// The type of service. - /// The instance of this service. - /// Whether or not this object could be found in the map. - bool TryGet(out T result) where T : class; - - /// - /// Pull an object from the map. - /// - /// The type of service. - /// An instance of this service. - object Get(Type t); - /// - /// Try to pull an object from the map. - /// - /// The type of service. - /// An instance of this service. - /// Whether or not this object could be found in the map. - bool TryGet(Type t, out object result); - } -} diff --git a/src/Discord.Net.Commands/Discord.Net.Commands.csproj b/src/Discord.Net.Commands/Discord.Net.Commands.csproj index 40f130d7b..a9cfc8e60 100644 --- a/src/Discord.Net.Commands/Discord.Net.Commands.csproj +++ b/src/Discord.Net.Commands/Discord.Net.Commands.csproj @@ -9,4 +9,7 @@ + + + \ No newline at end of file diff --git a/src/Discord.Net.Commands/EmptyServiceProvider.cs b/src/Discord.Net.Commands/EmptyServiceProvider.cs new file mode 100644 index 000000000..0bef3760e --- /dev/null +++ b/src/Discord.Net.Commands/EmptyServiceProvider.cs @@ -0,0 +1,11 @@ +using System; + +namespace Discord.Commands +{ + internal class EmptyServiceProvider : IServiceProvider + { + public static readonly EmptyServiceProvider Instance = new EmptyServiceProvider(); + + public object GetService(Type serviceType) => null; + } +} diff --git a/src/Discord.Net.Commands/Info/CommandInfo.cs b/src/Discord.Net.Commands/Info/CommandInfo.cs index 9abe6de32..5acd1f648 100644 --- a/src/Discord.Net.Commands/Info/CommandInfo.cs +++ b/src/Discord.Net.Commands/Info/CommandInfo.cs @@ -8,6 +8,7 @@ using System.Linq; using System.Reflection; using System.Runtime.ExceptionServices; using System.Threading.Tasks; +using Microsoft.Extensions.DependencyInjection; namespace Discord.Commands { @@ -17,7 +18,7 @@ namespace Discord.Commands private static readonly System.Reflection.MethodInfo _convertParamsMethod = typeof(CommandInfo).GetTypeInfo().GetDeclaredMethod(nameof(ConvertParamsList)); private static readonly ConcurrentDictionary, object>> _arrayConverters = new ConcurrentDictionary, object>>(); - private readonly Func _action; + private readonly Func _action; public ModuleInfo Module { get; } public string Name { get; } @@ -63,21 +64,20 @@ namespace Discord.Commands _action = builder.Callback; } - public async Task CheckPreconditionsAsync(ICommandContext context, IDependencyMap map = null) + public async Task CheckPreconditionsAsync(ICommandContext context, IServiceProvider services = null) { - if (map == null) - map = DependencyMap.Empty; + services = services ?? EmptyServiceProvider.Instance; foreach (PreconditionAttribute precondition in Module.Preconditions) { - var result = await precondition.CheckPermissions(context, this, map).ConfigureAwait(false); + var result = await precondition.CheckPermissions(context, this, services).ConfigureAwait(false); if (!result.IsSuccess) return result; } foreach (PreconditionAttribute precondition in Preconditions) { - var result = await precondition.CheckPermissions(context, this, map).ConfigureAwait(false); + var result = await precondition.CheckPermissions(context, this, services).ConfigureAwait(false); if (!result.IsSuccess) return result; } @@ -96,7 +96,7 @@ namespace Discord.Commands return await CommandParser.ParseArgs(this, context, input, 0).ConfigureAwait(false); } - public Task ExecuteAsync(ICommandContext context, ParseResult parseResult, IDependencyMap map) + public Task ExecuteAsync(ICommandContext context, ParseResult parseResult, IServiceProvider services) { if (!parseResult.IsSuccess) return Task.FromResult(ExecuteResult.FromError(parseResult)); @@ -117,12 +117,11 @@ namespace Discord.Commands paramList[i] = parseResult.ParamValues[i].Values.First().Value; } - return ExecuteAsync(context, argList, paramList, map); + return ExecuteAsync(context, argList, paramList, services); } - public async Task ExecuteAsync(ICommandContext context, IEnumerable argList, IEnumerable paramList, IDependencyMap map) + public async Task ExecuteAsync(ICommandContext context, IEnumerable argList, IEnumerable paramList, IServiceProvider services) { - if (map == null) - map = DependencyMap.Empty; + services = services ?? EmptyServiceProvider.Instance; try { @@ -132,7 +131,7 @@ namespace Discord.Commands { var parameter = Parameters[position]; var argument = args[position]; - var result = await parameter.CheckPreconditionsAsync(context, argument, map).ConfigureAwait(false); + var result = await parameter.CheckPreconditionsAsync(context, argument, services).ConfigureAwait(false); if (!result.IsSuccess) return ExecuteResult.FromError(result); } @@ -140,12 +139,12 @@ namespace Discord.Commands switch (RunMode) { case RunMode.Sync: //Always sync - await ExecuteAsyncInternal(context, args, map).ConfigureAwait(false); + await ExecuteAsyncInternal(context, args, services).ConfigureAwait(false); break; case RunMode.Async: //Always async var t2 = Task.Run(async () => { - await ExecuteAsyncInternal(context, args, map).ConfigureAwait(false); + await ExecuteAsyncInternal(context, args, services).ConfigureAwait(false); }); break; } @@ -157,12 +156,12 @@ namespace Discord.Commands } } - private async Task ExecuteAsyncInternal(ICommandContext context, object[] args, IDependencyMap map) + private async Task ExecuteAsyncInternal(ICommandContext context, object[] args, IServiceProvider services) { await Module.Service._cmdLogger.DebugAsync($"Executing {GetLogText(context)}").ConfigureAwait(false); try { - await _action(context, args, map).ConfigureAwait(false); + await _action(context, args, services).ConfigureAwait(false); } catch (Exception ex) { diff --git a/src/Discord.Net.Commands/Info/ParameterInfo.cs b/src/Discord.Net.Commands/Info/ParameterInfo.cs index 4ef145b9e..9eea82cb2 100644 --- a/src/Discord.Net.Commands/Info/ParameterInfo.cs +++ b/src/Discord.Net.Commands/Info/ParameterInfo.cs @@ -3,6 +3,7 @@ using System; using System.Collections.Generic; using System.Collections.Immutable; using System.Threading.Tasks; +using Microsoft.Extensions.DependencyInjection; namespace Discord.Commands { @@ -39,14 +40,13 @@ namespace Discord.Commands _reader = builder.TypeReader; } - public async Task CheckPreconditionsAsync(ICommandContext context, object arg, IDependencyMap map = null) + public async Task CheckPreconditionsAsync(ICommandContext context, object arg, IServiceProvider services = null) { - if (map == null) - map = DependencyMap.Empty; + services = EmptyServiceProvider.Instance; foreach (var precondition in Preconditions) { - var result = await precondition.CheckPermissions(context, this, arg, map).ConfigureAwait(false); + var result = await precondition.CheckPermissions(context, this, arg, services).ConfigureAwait(false); if (!result.IsSuccess) return result; } diff --git a/src/Discord.Net.Commands/Utilities/ReflectionUtils.cs b/src/Discord.Net.Commands/Utilities/ReflectionUtils.cs index 5c817183b..4cca0e864 100644 --- a/src/Discord.Net.Commands/Utilities/ReflectionUtils.cs +++ b/src/Discord.Net.Commands/Utilities/ReflectionUtils.cs @@ -2,88 +2,79 @@ using System.Collections.Generic; using System.Linq; using System.Reflection; +using Microsoft.Extensions.DependencyInjection; namespace Discord.Commands { internal static class ReflectionUtils { - private static readonly TypeInfo objectTypeInfo = typeof(object).GetTypeInfo(); + private static readonly TypeInfo _objectTypeInfo = typeof(object).GetTypeInfo(); - internal static T CreateObject(TypeInfo typeInfo, CommandService service, IDependencyMap map = null) - => CreateBuilder(typeInfo, service)(map); + internal static T CreateObject(TypeInfo typeInfo, CommandService commands, IServiceProvider services = null) + => CreateBuilder(typeInfo, commands)(services); + internal static Func CreateBuilder(TypeInfo typeInfo, CommandService commands) + { + var constructor = GetConstructor(typeInfo); + var parameters = constructor.GetParameters(); + var properties = GetProperties(typeInfo); + + return (services) => + { + var args = new object[parameters.Length]; + for (int i = 0; i < parameters.Length; i++) + args[i] = GetMember(commands, services, parameters[i].ParameterType, typeInfo); + var obj = InvokeConstructor(constructor, args, typeInfo); - private static System.Reflection.PropertyInfo[] GetProperties(TypeInfo typeInfo) + foreach(var property in properties) + property.SetValue(obj, GetMember(commands, services, property.PropertyType, typeInfo)); + return obj; + }; + } + private static T InvokeConstructor(ConstructorInfo constructor, object[] args, TypeInfo ownerType) { - var result = new List(); - while (typeInfo != objectTypeInfo) + try { - foreach (var prop in typeInfo.DeclaredProperties) - { - if (prop.SetMethod?.IsPublic == true && prop.GetCustomAttribute() == null) - result.Add(prop); - } - typeInfo = typeInfo.BaseType.GetTypeInfo(); + return (T)constructor.Invoke(args); + } + catch (Exception ex) + { + throw new Exception($"Failed to create \"{ownerType.FullName}\"", ex); } - return result.ToArray(); } - internal static Func CreateBuilder(TypeInfo typeInfo, CommandService service) + private static ConstructorInfo GetConstructor(TypeInfo ownerType) { - var constructors = typeInfo.DeclaredConstructors.Where(x => !x.IsStatic).ToArray(); + var constructors = ownerType.DeclaredConstructors.Where(x => !x.IsStatic).ToArray(); if (constructors.Length == 0) - throw new InvalidOperationException($"No constructor found for \"{typeInfo.FullName}\""); + throw new InvalidOperationException($"No constructor found for \"{ownerType.FullName}\""); else if (constructors.Length > 1) - throw new InvalidOperationException($"Multiple constructors found for \"{typeInfo.FullName}\""); - - var constructor = constructors[0]; - System.Reflection.ParameterInfo[] parameters = constructor.GetParameters(); - System.Reflection.PropertyInfo[] properties = GetProperties(typeInfo) - .Where(p => p.SetMethod?.IsPublic == true && p.GetCustomAttribute() == null) - .ToArray(); - - return (map) => + throw new InvalidOperationException($"Multiple constructors found for \"{ownerType.FullName}\""); + return constructors[0]; + } + private static System.Reflection.PropertyInfo[] GetProperties(TypeInfo ownerType) + { + var result = new List(); + while (ownerType != _objectTypeInfo) { - object[] args = new object[parameters.Length]; - - for (int i = 0; i < parameters.Length; i++) - { - var parameter = parameters[i]; - args[i] = GetMember(parameter.ParameterType, map, service, typeInfo); - } - - T obj; - try - { - obj = (T)constructor.Invoke(args); - } - catch (Exception ex) - { - throw new Exception($"Failed to create \"{typeInfo.FullName}\"", ex); - } - - foreach(var property in properties) + foreach (var prop in ownerType.DeclaredProperties) { - property.SetValue(obj, GetMember(property.PropertyType, map, service, typeInfo)); + if (prop.SetMethod?.IsPublic == true && prop.GetCustomAttribute() == null) + result.Add(prop); } - return obj; - }; + ownerType = ownerType.BaseType.GetTypeInfo(); + } + return result.ToArray(); } - - private static readonly TypeInfo _dependencyTypeInfo = typeof(IDependencyMap).GetTypeInfo(); - - internal static object GetMember(Type targetType, IDependencyMap map, CommandService service, TypeInfo baseType) + private static object GetMember(CommandService commands, IServiceProvider services, Type memberType, TypeInfo ownerType) { - object arg; - if (map == null || !map.TryGet(targetType, out arg)) - { - if (targetType == typeof(CommandService)) - arg = service; - else if (targetType == typeof(IDependencyMap) || targetType == map.GetType()) - arg = map; - else - throw new InvalidOperationException($"Failed to create \"{baseType.FullName}\", dependency \"{targetType.Name}\" was not found."); - } - return arg; + if (memberType == typeof(CommandService)) + return commands; + if (memberType == typeof(IServiceProvider) || memberType == services.GetType()) + return services; + var service = services?.GetService(memberType); + if (service != null) + return service; + throw new InvalidOperationException($"Failed to create \"{ownerType.FullName}\", dependency \"{memberType.Name}\" was not found."); } } }