From 254e874c999336179dd878dd733cbcb53820cc03 Mon Sep 17 00:00:00 2001 From: FiniteReality Date: Mon, 21 Nov 2016 18:46:21 +0000 Subject: [PATCH] Fix OverrideTypeReader This commit also adds a TypeReaders property to CommandService, so it is possible to see all of the registered TypeReaders. This makes it possible for users to implement their own parsing instead of using the built-in parsing. --- .../Builders/ModuleClassBuilder.cs | 48 +++++++--- .../Builders/ParameterBuilder.cs | 7 +- src/Discord.Net.Commands/CommandService.cs | 87 ++++++++++--------- 3 files changed, 88 insertions(+), 54 deletions(-) diff --git a/src/Discord.Net.Commands/Builders/ModuleClassBuilder.cs b/src/Discord.Net.Commands/Builders/ModuleClassBuilder.cs index 1775cc1fe..aaa96fb8e 100644 --- a/src/Discord.Net.Commands/Builders/ModuleClassBuilder.cs +++ b/src/Discord.Net.Commands/Builders/ModuleClassBuilder.cs @@ -183,7 +183,9 @@ namespace Discord.Commands if (attribute is SummaryAttribute) builder.Summary = (attribute as SummaryAttribute).Text; else if (attribute is OverrideTypeReaderAttribute) - builder.TypeReader = service.GetTypeReader((attribute as OverrideTypeReaderAttribute).TypeReader); + { + builder.TypeReader = GetTypeReader(service, paramType, (attribute as OverrideTypeReaderAttribute).TypeReader); + } else if (attribute is ParameterPreconditionAttribute) builder.AddPrecondition(attribute as ParameterPreconditionAttribute); else if (attribute is ParamArrayAttribute) @@ -200,23 +202,47 @@ namespace Discord.Commands } } - var reader = service.GetTypeReader(paramType); - if (reader == null) + if (builder.TypeReader == null) { - var paramTypeInfo = paramType.GetTypeInfo(); - if (paramTypeInfo.IsEnum) + var readers = service.GetTypeReaders(paramType); + var reader = readers?.FirstOrDefault(); + + if (reader == null) { - reader = EnumTypeReader.GetReader(paramType); - service.AddTypeReader(paramType, reader); + var paramTypeInfo = paramType.GetTypeInfo(); + if (paramTypeInfo.IsEnum) + { + reader = EnumTypeReader.GetReader(paramType); + service.AddTypeReader(paramType, reader); + } + else + { + throw new InvalidOperationException($"{paramType.FullName} is not supported as a command parameter, are you missing a TypeReader?"); + } } - else + + builder.ParameterType = paramType; + builder.TypeReader = reader; + } + } + + private static TypeReader GetTypeReader(CommandService service, Type paramType, Type typeReaderType) + { + var readers = service.GetTypeReaders(paramType); + if (readers != null) + { + var reader = readers.FirstOrDefault(x => x.GetType() == typeReaderType); + if (reader != default(TypeReader)) { - throw new InvalidOperationException($"{paramType.FullName} is not supported as a command parameter, are you missing a TypeReader?"); + return reader; } } - builder.ParameterType = paramType; - builder.TypeReader = reader; + //could not find any registered type reader: try to create one + var typeReader = ReflectionUtils.CreateObject(typeReaderType.GetTypeInfo(), service, DependencyMap.Empty); + service.AddTypeReader(paramType, typeReader); + + return typeReader; } private static bool IsValidModuleDefinition(TypeInfo typeInfo) diff --git a/src/Discord.Net.Commands/Builders/ParameterBuilder.cs b/src/Discord.Net.Commands/Builders/ParameterBuilder.cs index 6b941a1c7..89f89b3cf 100644 --- a/src/Discord.Net.Commands/Builders/ParameterBuilder.cs +++ b/src/Discord.Net.Commands/Builders/ParameterBuilder.cs @@ -1,4 +1,5 @@ using System; +using System.Linq; using System.Reflection; using System.Collections.Generic; @@ -41,7 +42,11 @@ namespace Discord.Commands.Builders internal void SetType(Type type) { - TypeReader = Command.Module.Service.GetTypeReader(type); + var readers = Command.Module.Service.GetTypeReaders(type); + if (readers == null) + throw new InvalidOperationException($"{type} does not have a TypeReader registered for it"); + + TypeReader = readers.FirstOrDefault(); if (type.GetTypeInfo().IsValueType) DefaultValue = Activator.CreateInstance(type); diff --git a/src/Discord.Net.Commands/CommandService.cs b/src/Discord.Net.Commands/CommandService.cs index b6659fea3..285a35432 100644 --- a/src/Discord.Net.Commands/CommandService.cs +++ b/src/Discord.Net.Commands/CommandService.cs @@ -14,8 +14,8 @@ namespace Discord.Commands public class CommandService { private readonly SemaphoreSlim _moduleLock; - private readonly ConcurrentDictionary _typedModuleDefs; - private readonly ConcurrentDictionary _typeReaders; + private readonly ConcurrentDictionary _typedModuleDefs; + private readonly ConcurrentDictionary> _typeReaders; private readonly ConcurrentBag _moduleDefs; private readonly CommandMap _map; @@ -24,6 +24,7 @@ namespace Discord.Commands public IEnumerable Modules => _moduleDefs.Select(x => x); public IEnumerable Commands => _moduleDefs.SelectMany(x => x.Commands); + public ILookup TypeReaders => _typeReaders.SelectMany(x => x.Value, (a, value) => new {a.Key, value}).ToLookup(x => x.Key, x => x.value); public CommandService() : this(new CommandServiceConfig()) { } public CommandService(CommandServiceConfig config) @@ -32,41 +33,41 @@ namespace Discord.Commands _typedModuleDefs = new ConcurrentDictionary(); _moduleDefs = new ConcurrentBag(); _map = new CommandMap(); - _typeReaders = new ConcurrentDictionary + _typeReaders = new ConcurrentDictionary> { - [typeof(bool)] = new SimpleTypeReader(), - [typeof(char)] = new SimpleTypeReader(), - [typeof(string)] = new SimpleTypeReader(), - [typeof(byte)] = new SimpleTypeReader(), - [typeof(sbyte)] = new SimpleTypeReader(), - [typeof(ushort)] = new SimpleTypeReader(), - [typeof(short)] = new SimpleTypeReader(), - [typeof(uint)] = new SimpleTypeReader(), - [typeof(int)] = new SimpleTypeReader(), - [typeof(ulong)] = new SimpleTypeReader(), - [typeof(long)] = new SimpleTypeReader(), - [typeof(float)] = new SimpleTypeReader(), - [typeof(double)] = new SimpleTypeReader(), - [typeof(decimal)] = new SimpleTypeReader(), - [typeof(DateTime)] = new SimpleTypeReader(), - [typeof(DateTimeOffset)] = new SimpleTypeReader(), - [typeof(TimeSpan)] = new SimpleTypeReader(), - [typeof(IMessage)] = new MessageTypeReader(), - [typeof(IUserMessage)] = new MessageTypeReader(), - [typeof(IChannel)] = new ChannelTypeReader(), - [typeof(IDMChannel)] = new ChannelTypeReader(), - [typeof(IGroupChannel)] = new ChannelTypeReader(), - [typeof(IGuildChannel)] = new ChannelTypeReader(), - [typeof(IMessageChannel)] = new ChannelTypeReader(), - [typeof(IPrivateChannel)] = new ChannelTypeReader(), - [typeof(ITextChannel)] = new ChannelTypeReader(), - [typeof(IVoiceChannel)] = new ChannelTypeReader(), - - [typeof(IRole)] = new RoleTypeReader(), - - [typeof(IUser)] = new UserTypeReader(), - [typeof(IGroupUser)] = new UserTypeReader(), - [typeof(IGuildUser)] = new UserTypeReader(), + [typeof(bool)] = new ConcurrentBag{new SimpleTypeReader()}, + [typeof(char)] = new ConcurrentBag{new SimpleTypeReader()}, + [typeof(string)] = new ConcurrentBag{new SimpleTypeReader()}, + [typeof(byte)] = new ConcurrentBag{new SimpleTypeReader()}, + [typeof(sbyte)] = new ConcurrentBag{new SimpleTypeReader()}, + [typeof(ushort)] = new ConcurrentBag{new SimpleTypeReader()}, + [typeof(short)] = new ConcurrentBag{new SimpleTypeReader()}, + [typeof(uint)] = new ConcurrentBag{new SimpleTypeReader()}, + [typeof(int)] = new ConcurrentBag{new SimpleTypeReader()}, + [typeof(ulong)] = new ConcurrentBag{new SimpleTypeReader()}, + [typeof(long)] = new ConcurrentBag{new SimpleTypeReader()}, + [typeof(float)] = new ConcurrentBag{new SimpleTypeReader()}, + [typeof(double)] = new ConcurrentBag{new SimpleTypeReader()}, + [typeof(decimal)] = new ConcurrentBag{new SimpleTypeReader()}, + [typeof(DateTime)] = new ConcurrentBag{new SimpleTypeReader()}, + [typeof(DateTimeOffset)] = new ConcurrentBag{new SimpleTypeReader()}, + + [typeof(IMessage)] = new ConcurrentBag{new MessageTypeReader()}, + [typeof(IUserMessage)] = new ConcurrentBag{new MessageTypeReader()}, + [typeof(IChannel)] = new ConcurrentBag{new ChannelTypeReader()}, + [typeof(IDMChannel)] = new ConcurrentBag{new ChannelTypeReader()}, + [typeof(IGroupChannel)] = new ConcurrentBag{new ChannelTypeReader()}, + [typeof(IGuildChannel)] = new ConcurrentBag{new ChannelTypeReader()}, + [typeof(IMessageChannel)] = new ConcurrentBag{new ChannelTypeReader()}, + [typeof(IPrivateChannel)] = new ConcurrentBag{new ChannelTypeReader()}, + [typeof(ITextChannel)] = new ConcurrentBag{new ChannelTypeReader()}, + [typeof(IVoiceChannel)] = new ConcurrentBag{new ChannelTypeReader()}, + + [typeof(IRole)] = new ConcurrentBag{new RoleTypeReader()}, + + [typeof(IUser)] = new ConcurrentBag{new UserTypeReader()}, + [typeof(IGroupUser)] = new ConcurrentBag{new UserTypeReader()}, + [typeof(IGuildUser)] = new ConcurrentBag{new UserTypeReader()}, }; _caseSensitive = config.CaseSensitiveCommands; _defaultRunMode = config.DefaultRunMode; @@ -196,17 +197,19 @@ namespace Discord.Commands //Type Readers public void AddTypeReader(TypeReader reader) { - _typeReaders[typeof(T)] = reader; + var readers = _typeReaders.GetOrAdd(typeof(T), x => new ConcurrentBag()); + readers.Add(reader); } public void AddTypeReader(Type type, TypeReader reader) { - _typeReaders[type] = reader; + var readers = _typeReaders.GetOrAdd(type, x=> new ConcurrentBag()); + readers.Add(reader); } - internal TypeReader GetTypeReader(Type type) + internal IEnumerable GetTypeReaders(Type type) { - TypeReader reader; - if (_typeReaders.TryGetValue(type, out reader)) - return reader; + ConcurrentBag definedTypeReaders; + if (_typeReaders.TryGetValue(type, out definedTypeReaders)) + return definedTypeReaders; return null; }