diff --git a/src/Discord.Net.Commands/Attributes/TypeReaderAttribute.cs b/src/Discord.Net.Commands/Attributes/TypeReaderAttribute.cs new file mode 100644 index 000000000..f4a69e653 --- /dev/null +++ b/src/Discord.Net.Commands/Attributes/TypeReaderAttribute.cs @@ -0,0 +1,21 @@ +using System; +using System.Reflection; + +namespace Discord.Commands.Attributes +{ + [AttributeUsage(AttributeTargets.Class | AttributeTargets.Parameter, AllowMultiple = true)] + public class TypeReaderAttribute : Attribute + { + public Type Type { get; } + public TypeInfo OverridingTypeReader { get; } + + public TypeReaderAttribute(Type forType, Type typeReader) + { + if (!typeof(TypeReader).GetTypeInfo().IsAssignableFrom(typeReader.GetTypeInfo())) + throw new ArgumentException($"Type of argument {nameof(typeReader)} must derive from {nameof(TypeReader)}", nameof(typeReader)); + + Type = forType; + OverridingTypeReader = typeReader.GetTypeInfo(); + } + } +} diff --git a/src/Discord.Net.Commands/CommandInfo.cs b/src/Discord.Net.Commands/CommandInfo.cs index 47aae1ae2..b36758891 100644 --- a/src/Discord.Net.Commands/CommandInfo.cs +++ b/src/Discord.Net.Commands/CommandInfo.cs @@ -6,6 +6,7 @@ using System.Diagnostics; using System.Linq; using System.Reflection; using System.Threading.Tasks; +using Discord.Commands.Attributes; namespace Discord.Commands { @@ -30,7 +31,7 @@ namespace Discord.Commands public IReadOnlyList Parameters { get; } public IReadOnlyList Preconditions { get; } - internal CommandInfo(MethodInfo source, ModuleInfo module, CommandAttribute attribute, string groupPrefix) + internal CommandInfo(MethodInfo source, ModuleInfo module, CommandAttribute attribute, string groupPrefix, IDependencyMap dependencyMap = null) { try { @@ -74,7 +75,7 @@ namespace Discord.Commands var priorityAttr = source.GetCustomAttribute(); Priority = priorityAttr?.Priority ?? 0; - Parameters = BuildParameters(source); + Parameters = BuildParameters(source, dependencyMap); HasVarArgs = Parameters.Count > 0 ? Parameters[Parameters.Count - 1].IsMultiple : false; Preconditions = BuildPreconditions(source); _action = BuildAction(source); @@ -184,7 +185,7 @@ namespace Discord.Commands return methodInfo.GetCustomAttributes().ToImmutableArray(); } - private IReadOnlyList BuildParameters(MethodInfo methodInfo) + private IReadOnlyList BuildParameters(MethodInfo methodInfo, IDependencyMap dependencyMap = null) { var parameters = methodInfo.GetParameters(); @@ -199,7 +200,13 @@ namespace Discord.Commands if (isMultiple) type = type.GetElementType(); - var reader = Module.Service.GetTypeReader(type); + var trAttr = parameter.GetCustomAttribute(); + + var reader = (trAttr != null && trAttr.Type == type) ? + ReflectionUtils.CreateObject(trAttr.OverridingTypeReader, Module.Service, dependencyMap) : + (Module.OverridingTypeReaders.ContainsKey(type) ? + Module.OverridingTypeReaders[type] : + Module.Service.GetTypeReader(type)); var typeInfo = type.GetTypeInfo(); //Detect enums diff --git a/src/Discord.Net.Commands/CommandService.cs b/src/Discord.Net.Commands/CommandService.cs index ef0dba7e7..90da96cc6 100644 --- a/src/Discord.Net.Commands/CommandService.cs +++ b/src/Discord.Net.Commands/CommandService.cs @@ -65,7 +65,7 @@ namespace Discord.Commands } //Modules - public async Task AddModule() + public async Task AddModule(IDependencyMap dependencyMap = null) { await _moduleLock.WaitAsync().ConfigureAwait(false); try @@ -80,14 +80,14 @@ namespace Discord.Commands if (_moduleDefs.ContainsKey(typeof(T))) throw new ArgumentException($"This module has already been added."); - return AddModuleInternal(typeInfo); + return AddModuleInternal(typeInfo, dependencyMap); } finally { _moduleLock.Release(); } } - public async Task> AddModules(Assembly assembly) + public async Task> AddModules(Assembly assembly, IDependencyMap dependencyMap = null) { var moduleDefs = ImmutableArray.CreateBuilder(); await _moduleLock.WaitAsync().ConfigureAwait(false); @@ -102,7 +102,7 @@ namespace Discord.Commands { var dontAutoLoad = typeInfo.GetCustomAttribute(); if (dontAutoLoad == null && !typeInfo.IsAbstract) - moduleDefs.Add(AddModuleInternal(typeInfo)); + moduleDefs.Add(AddModuleInternal(typeInfo, dependencyMap)); } } } @@ -113,9 +113,9 @@ namespace Discord.Commands _moduleLock.Release(); } } - private ModuleInfo AddModuleInternal(TypeInfo typeInfo) + private ModuleInfo AddModuleInternal(TypeInfo typeInfo, IDependencyMap dependencyMap = null) { - var moduleDef = new ModuleInfo(typeInfo, this); + var moduleDef = new ModuleInfo(typeInfo, this, dependencyMap); _moduleDefs[typeInfo.AsType()] = moduleDef; foreach (var cmd in moduleDef.Commands) diff --git a/src/Discord.Net.Commands/ModuleInfo.cs b/src/Discord.Net.Commands/ModuleInfo.cs index b7471edb5..bf8e20924 100644 --- a/src/Discord.Net.Commands/ModuleInfo.cs +++ b/src/Discord.Net.Commands/ModuleInfo.cs @@ -3,6 +3,7 @@ using System.Collections.Generic; using System.Collections.Immutable; using System.Diagnostics; using System.Reflection; +using Discord.Commands.Attributes; namespace Discord.Commands { @@ -19,8 +20,9 @@ namespace Discord.Commands public string Remarks { get; } public IEnumerable Commands { get; } public IReadOnlyList Preconditions { get; } + public ImmutableDictionary OverridingTypeReaders { get; } - internal ModuleInfo(TypeInfo source, CommandService service) + internal ModuleInfo(TypeInfo source, CommandService service, IDependencyMap dependencyMap = null) { Source = source; Service = service; @@ -45,19 +47,31 @@ namespace Discord.Commands if (remarksAttr != null) Remarks = remarksAttr.Text; + var typeReaders = new Dictionary(); + + var trAttrs = source.GetCustomAttributes(); + foreach (var trAttr in trAttrs) + typeReaders[trAttr.Type] = GetOverridingTypeReader(trAttr, dependencyMap); + + OverridingTypeReaders = typeReaders.ToImmutableDictionary(); + List commands = new List(); - SearchClass(source, commands, Prefix); + SearchClass(source, commands, Prefix, dependencyMap); Commands = commands; Preconditions = Source.GetCustomAttributes().ToImmutableArray(); } - private void SearchClass(TypeInfo parentType, List commands, string groupPrefix) + + private TypeReader GetOverridingTypeReader(TypeReaderAttribute trAttr, IDependencyMap dependencyMap = null) + => ReflectionUtils.CreateObject(trAttr.OverridingTypeReader, Service, dependencyMap); + + private void SearchClass(TypeInfo parentType, List commands, string groupPrefix, IDependencyMap dependencyMap = null) { foreach (var method in parentType.DeclaredMethods) { var cmdAttr = method.GetCustomAttribute(); if (cmdAttr != null) - commands.Add(new CommandInfo(method, this, cmdAttr, groupPrefix)); + commands.Add(new CommandInfo(method, this, cmdAttr, groupPrefix, dependencyMap)); } foreach (var type in parentType.DeclaredNestedTypes) { @@ -71,7 +85,7 @@ namespace Discord.Commands else nextGroupPrefix = groupAttrib.Prefix ?? type.Name.ToLowerInvariant(); - SearchClass(type, commands, nextGroupPrefix); + SearchClass(type, commands, nextGroupPrefix, dependencyMap); } } }