diff --git a/src/Discord.Net.Commands/Attributes/ModuleAttribute.cs b/src/Discord.Net.Commands/Attributes/ModuleAttribute.cs index 59e6a6aca..57f525389 100644 --- a/src/Discord.Net.Commands/Attributes/ModuleAttribute.cs +++ b/src/Discord.Net.Commands/Attributes/ModuleAttribute.cs @@ -6,13 +6,16 @@ namespace Discord.Commands public class ModuleAttribute : Attribute { public string Prefix { get; } - public ModuleAttribute() + public bool Autoload { get; } + public ModuleAttribute(bool autoload = true) { Prefix = null; + Autoload = autoload; } - public ModuleAttribute(string prefix) + public ModuleAttribute(string prefix, bool autoload = true) { Prefix = prefix; + Autoload = autoload; } } } diff --git a/src/Discord.Net.Commands/CommandService.cs b/src/Discord.Net.Commands/CommandService.cs index f762ae366..d0dfeaeb9 100644 --- a/src/Discord.Net.Commands/CommandService.cs +++ b/src/Discord.Net.Commands/CommandService.cs @@ -164,7 +164,7 @@ namespace Discord.Commands return loadedModule; } - public async Task> LoadAssembly(Assembly assembly) + public async Task> LoadAssembly(Assembly assembly, IDependencyMap dependencyMap = null) { var modules = ImmutableArray.CreateBuilder(); await _moduleLock.WaitAsync().ConfigureAwait(false); @@ -174,9 +174,9 @@ namespace Discord.Commands { var typeInfo = type.GetTypeInfo(); var moduleAttr = typeInfo.GetCustomAttribute(); - if (moduleAttr != null) + if (moduleAttr != null && moduleAttr.Autoload) { - var moduleInstance = ReflectionUtils.CreateObject(typeInfo); + var moduleInstance = ReflectionUtils.CreateObject(typeInfo, this, dependencyMap); modules.Add(LoadInternal(moduleInstance, moduleAttr, typeInfo)); } } diff --git a/src/Discord.Net.Commands/Dependencies/DependencyMap.cs b/src/Discord.Net.Commands/Dependencies/DependencyMap.cs new file mode 100644 index 000000000..4495a906b --- /dev/null +++ b/src/Discord.Net.Commands/Dependencies/DependencyMap.cs @@ -0,0 +1,32 @@ +using System; +using System.Collections.Generic; +using System.Reflection; + +namespace Discord.Commands +{ + public class DependencyMap : IDependencyMap + { + private Dictionary map; + + public DependencyMap() + { + map = new Dictionary(); + } + + public T Get() where T : class + { + var t = typeof(T); + if (!map.ContainsKey(t)) + throw new KeyNotFoundException($"The dependency map does not contain \"{t.FullName}\""); + return map[t] as T; + } + + public void Add(T obj) + { + var t = typeof(T); + if (map.ContainsKey(t)) + throw new InvalidOperationException($"The dependency map already contains \"{t.FullName}\""); + map.Add(t, obj); + } + } +} diff --git a/src/Discord.Net.Commands/Dependencies/IDependencyMap.cs b/src/Discord.Net.Commands/Dependencies/IDependencyMap.cs new file mode 100644 index 000000000..fb2710795 --- /dev/null +++ b/src/Discord.Net.Commands/Dependencies/IDependencyMap.cs @@ -0,0 +1,13 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Threading.Tasks; + +namespace Discord.Commands +{ + public interface IDependencyMap + { + T Get() where T : class; + void Add(T obj); + } +} diff --git a/src/Discord.Net.Commands/Module.cs b/src/Discord.Net.Commands/Module.cs index ea6e29c28..b884832bc 100644 --- a/src/Discord.Net.Commands/Module.cs +++ b/src/Discord.Net.Commands/Module.cs @@ -43,7 +43,7 @@ namespace Discord.Commands nextGroupPrefix = groupPrefix + groupAttrib.Prefix ?? type.Name; else nextGroupPrefix = groupPrefix; - SearchClass(ReflectionUtils.CreateObject(type), commands, type, nextGroupPrefix); + SearchClass(ReflectionUtils.CreateObject(type, Service), commands, type, nextGroupPrefix); } } } diff --git a/src/Discord.Net.Commands/ReflectionUtils.cs b/src/Discord.Net.Commands/ReflectionUtils.cs index 28672a06f..8562092f4 100644 --- a/src/Discord.Net.Commands/ReflectionUtils.cs +++ b/src/Discord.Net.Commands/ReflectionUtils.cs @@ -6,18 +6,33 @@ namespace Discord.Commands { internal class ReflectionUtils { - internal static object CreateObject(TypeInfo typeInfo) + internal static object CreateObject(TypeInfo typeInfo, CommandService commands, IDependencyMap map = null) { - var constructor = typeInfo.DeclaredConstructors.Where(x => x.GetParameters().Length == 0).FirstOrDefault(); - if (constructor == null) - throw new InvalidOperationException($"Failed to find a valid constructor for \"{typeInfo.FullName}\""); + if (typeInfo.DeclaredConstructors.Count() > 1) + throw new InvalidOperationException($"Found too many constructors for \"{typeInfo.FullName}\""); + var constructor = typeInfo.DeclaredConstructors.FirstOrDefault(); try { - return constructor.Invoke(null); + if (constructor.GetParameters().Length == 0) + return constructor.Invoke(null); + else if (constructor.GetParameters().Length > 1) + throw new InvalidOperationException($"Could not find a valid constructor for \"{typeInfo.FullName}\" (Found too many parameters)"); + var parameter = constructor.GetParameters().FirstOrDefault(); + if (parameter == null) + throw new InvalidOperationException($"Could not find a valid constructor for \"{typeInfo.FullName}\" (No valid parameters)"); + if (parameter.ParameterType == typeof(CommandService)) + return constructor.Invoke(new object[1] { commands }); + else if (parameter.ParameterType == typeof(IDependencyMap)) + { + if (map == null) throw new InvalidOperationException($"The constructor for \"{typeInfo.FullName}\" requires a Dependency Map."); + return constructor.Invoke(new object[1] { map }); + } + else + throw new InvalidOperationException($"Could not find a valid constructor for \"{typeInfo.FullName}\" (Invalid Parameter Type: \"{parameter.ParameterType.FullName}\")"); } - catch (Exception ex) + catch { - throw new InvalidOperationException($"Failed to create \"{typeInfo.FullName}\"", ex); + throw new InvalidOperationException($"Could not find a valid constructor for \"{typeInfo.FullName}\" (Error invoking constructor)"); } } }