diff --git a/src/Discord.Net.Commands/Builders/ModuleClassBuilder.cs b/src/Discord.Net.Commands/Builders/ModuleClassBuilder.cs index 1775cc1fe..aaa96fb8e 100644 --- a/src/Discord.Net.Commands/Builders/ModuleClassBuilder.cs +++ b/src/Discord.Net.Commands/Builders/ModuleClassBuilder.cs @@ -183,7 +183,9 @@ namespace Discord.Commands if (attribute is SummaryAttribute) builder.Summary = (attribute as SummaryAttribute).Text; else if (attribute is OverrideTypeReaderAttribute) - builder.TypeReader = service.GetTypeReader((attribute as OverrideTypeReaderAttribute).TypeReader); + { + builder.TypeReader = GetTypeReader(service, paramType, (attribute as OverrideTypeReaderAttribute).TypeReader); + } else if (attribute is ParameterPreconditionAttribute) builder.AddPrecondition(attribute as ParameterPreconditionAttribute); else if (attribute is ParamArrayAttribute) @@ -200,23 +202,47 @@ namespace Discord.Commands } } - var reader = service.GetTypeReader(paramType); - if (reader == null) + if (builder.TypeReader == null) { - var paramTypeInfo = paramType.GetTypeInfo(); - if (paramTypeInfo.IsEnum) + var readers = service.GetTypeReaders(paramType); + var reader = readers?.FirstOrDefault(); + + if (reader == null) { - reader = EnumTypeReader.GetReader(paramType); - service.AddTypeReader(paramType, reader); + var paramTypeInfo = paramType.GetTypeInfo(); + if (paramTypeInfo.IsEnum) + { + reader = EnumTypeReader.GetReader(paramType); + service.AddTypeReader(paramType, reader); + } + else + { + throw new InvalidOperationException($"{paramType.FullName} is not supported as a command parameter, are you missing a TypeReader?"); + } } - else + + builder.ParameterType = paramType; + builder.TypeReader = reader; + } + } + + private static TypeReader GetTypeReader(CommandService service, Type paramType, Type typeReaderType) + { + var readers = service.GetTypeReaders(paramType); + if (readers != null) + { + var reader = readers.FirstOrDefault(x => x.GetType() == typeReaderType); + if (reader != default(TypeReader)) { - throw new InvalidOperationException($"{paramType.FullName} is not supported as a command parameter, are you missing a TypeReader?"); + return reader; } } - builder.ParameterType = paramType; - builder.TypeReader = reader; + //could not find any registered type reader: try to create one + var typeReader = ReflectionUtils.CreateObject(typeReaderType.GetTypeInfo(), service, DependencyMap.Empty); + service.AddTypeReader(paramType, typeReader); + + return typeReader; } private static bool IsValidModuleDefinition(TypeInfo typeInfo) diff --git a/src/Discord.Net.Commands/Builders/ParameterBuilder.cs b/src/Discord.Net.Commands/Builders/ParameterBuilder.cs index 6b941a1c7..89f89b3cf 100644 --- a/src/Discord.Net.Commands/Builders/ParameterBuilder.cs +++ b/src/Discord.Net.Commands/Builders/ParameterBuilder.cs @@ -1,4 +1,5 @@ using System; +using System.Linq; using System.Reflection; using System.Collections.Generic; @@ -41,7 +42,11 @@ namespace Discord.Commands.Builders internal void SetType(Type type) { - TypeReader = Command.Module.Service.GetTypeReader(type); + var readers = Command.Module.Service.GetTypeReaders(type); + if (readers == null) + throw new InvalidOperationException($"{type} does not have a TypeReader registered for it"); + + TypeReader = readers.FirstOrDefault(); if (type.GetTypeInfo().IsValueType) DefaultValue = Activator.CreateInstance(type); diff --git a/src/Discord.Net.Commands/CommandService.cs b/src/Discord.Net.Commands/CommandService.cs index b6659fea3..285a35432 100644 --- a/src/Discord.Net.Commands/CommandService.cs +++ b/src/Discord.Net.Commands/CommandService.cs @@ -14,8 +14,8 @@ namespace Discord.Commands public class CommandService { private readonly SemaphoreSlim _moduleLock; - private readonly ConcurrentDictionary _typedModuleDefs; - private readonly ConcurrentDictionary _typeReaders; + private readonly ConcurrentDictionary _typedModuleDefs; + private readonly ConcurrentDictionary> _typeReaders; private readonly ConcurrentBag _moduleDefs; private readonly CommandMap _map; @@ -24,6 +24,7 @@ namespace Discord.Commands public IEnumerable Modules => _moduleDefs.Select(x => x); public IEnumerable Commands => _moduleDefs.SelectMany(x => x.Commands); + public ILookup TypeReaders => _typeReaders.SelectMany(x => x.Value, (a, value) => new {a.Key, value}).ToLookup(x => x.Key, x => x.value); public CommandService() : this(new CommandServiceConfig()) { } public CommandService(CommandServiceConfig config) @@ -32,41 +33,41 @@ namespace Discord.Commands _typedModuleDefs = new ConcurrentDictionary(); _moduleDefs = new ConcurrentBag(); _map = new CommandMap(); - _typeReaders = new ConcurrentDictionary + _typeReaders = new ConcurrentDictionary> { - [typeof(bool)] = new SimpleTypeReader(), - [typeof(char)] = new SimpleTypeReader(), - [typeof(string)] = new SimpleTypeReader(), - [typeof(byte)] = new SimpleTypeReader(), - [typeof(sbyte)] = new SimpleTypeReader(), - [typeof(ushort)] = new SimpleTypeReader(), - [typeof(short)] = new SimpleTypeReader(), - [typeof(uint)] = new SimpleTypeReader(), - [typeof(int)] = new SimpleTypeReader(), - [typeof(ulong)] = new SimpleTypeReader(), - [typeof(long)] = new SimpleTypeReader(), - [typeof(float)] = new SimpleTypeReader(), - [typeof(double)] = new SimpleTypeReader(), - [typeof(decimal)] = new SimpleTypeReader(), - [typeof(DateTime)] = new SimpleTypeReader(), - [typeof(DateTimeOffset)] = new SimpleTypeReader(), - [typeof(TimeSpan)] = new SimpleTypeReader(), - [typeof(IMessage)] = new MessageTypeReader(), - [typeof(IUserMessage)] = new MessageTypeReader(), - [typeof(IChannel)] = new ChannelTypeReader(), - [typeof(IDMChannel)] = new ChannelTypeReader(), - [typeof(IGroupChannel)] = new ChannelTypeReader(), - [typeof(IGuildChannel)] = new ChannelTypeReader(), - [typeof(IMessageChannel)] = new ChannelTypeReader(), - [typeof(IPrivateChannel)] = new ChannelTypeReader(), - [typeof(ITextChannel)] = new ChannelTypeReader(), - [typeof(IVoiceChannel)] = new ChannelTypeReader(), - - [typeof(IRole)] = new RoleTypeReader(), - - [typeof(IUser)] = new UserTypeReader(), - [typeof(IGroupUser)] = new UserTypeReader(), - [typeof(IGuildUser)] = new UserTypeReader(), + [typeof(bool)] = new ConcurrentBag{new SimpleTypeReader()}, + [typeof(char)] = new ConcurrentBag{new SimpleTypeReader()}, + [typeof(string)] = new ConcurrentBag{new SimpleTypeReader()}, + [typeof(byte)] = new ConcurrentBag{new SimpleTypeReader()}, + [typeof(sbyte)] = new ConcurrentBag{new SimpleTypeReader()}, + [typeof(ushort)] = new ConcurrentBag{new SimpleTypeReader()}, + [typeof(short)] = new ConcurrentBag{new SimpleTypeReader()}, + [typeof(uint)] = new ConcurrentBag{new SimpleTypeReader()}, + [typeof(int)] = new ConcurrentBag{new SimpleTypeReader()}, + [typeof(ulong)] = new ConcurrentBag{new SimpleTypeReader()}, + [typeof(long)] = new ConcurrentBag{new SimpleTypeReader()}, + [typeof(float)] = new ConcurrentBag{new SimpleTypeReader()}, + [typeof(double)] = new ConcurrentBag{new SimpleTypeReader()}, + [typeof(decimal)] = new ConcurrentBag{new SimpleTypeReader()}, + [typeof(DateTime)] = new ConcurrentBag{new SimpleTypeReader()}, + [typeof(DateTimeOffset)] = new ConcurrentBag{new SimpleTypeReader()}, + + [typeof(IMessage)] = new ConcurrentBag{new MessageTypeReader()}, + [typeof(IUserMessage)] = new ConcurrentBag{new MessageTypeReader()}, + [typeof(IChannel)] = new ConcurrentBag{new ChannelTypeReader()}, + [typeof(IDMChannel)] = new ConcurrentBag{new ChannelTypeReader()}, + [typeof(IGroupChannel)] = new ConcurrentBag{new ChannelTypeReader()}, + [typeof(IGuildChannel)] = new ConcurrentBag{new ChannelTypeReader()}, + [typeof(IMessageChannel)] = new ConcurrentBag{new ChannelTypeReader()}, + [typeof(IPrivateChannel)] = new ConcurrentBag{new ChannelTypeReader()}, + [typeof(ITextChannel)] = new ConcurrentBag{new ChannelTypeReader()}, + [typeof(IVoiceChannel)] = new ConcurrentBag{new ChannelTypeReader()}, + + [typeof(IRole)] = new ConcurrentBag{new RoleTypeReader()}, + + [typeof(IUser)] = new ConcurrentBag{new UserTypeReader()}, + [typeof(IGroupUser)] = new ConcurrentBag{new UserTypeReader()}, + [typeof(IGuildUser)] = new ConcurrentBag{new UserTypeReader()}, }; _caseSensitive = config.CaseSensitiveCommands; _defaultRunMode = config.DefaultRunMode; @@ -196,17 +197,19 @@ namespace Discord.Commands //Type Readers public void AddTypeReader(TypeReader reader) { - _typeReaders[typeof(T)] = reader; + var readers = _typeReaders.GetOrAdd(typeof(T), x => new ConcurrentBag()); + readers.Add(reader); } public void AddTypeReader(Type type, TypeReader reader) { - _typeReaders[type] = reader; + var readers = _typeReaders.GetOrAdd(type, x=> new ConcurrentBag()); + readers.Add(reader); } - internal TypeReader GetTypeReader(Type type) + internal IEnumerable GetTypeReaders(Type type) { - TypeReader reader; - if (_typeReaders.TryGetValue(type, out reader)) - return reader; + ConcurrentBag definedTypeReaders; + if (_typeReaders.TryGetValue(type, out definedTypeReaders)) + return definedTypeReaders; return null; }