diff --git a/src/Discord.Net.Commands/Builders/ModuleClassBuilder.cs b/src/Discord.Net.Commands/Builders/ModuleClassBuilder.cs index 2d909f978..c5eb1d4d3 100644 --- a/src/Discord.Net.Commands/Builders/ModuleClassBuilder.cs +++ b/src/Discord.Net.Commands/Builders/ModuleClassBuilder.cs @@ -275,28 +275,23 @@ namespace Discord.Commands if (builder.TypeReader == null) { - builder.TypeReader = service.GetTypeReaders(paramType)?.FirstOrDefault().Value + builder.TypeReader = service.GetTypeReaders(paramType, false)?.FirstOrDefault().Value ?? service.GetDefaultTypeReader(paramType); } } internal static TypeReader GetTypeReader(CommandService service, Type paramType, Type typeReaderType, IServiceProvider services) { - var readers = service.GetTypeReaders(paramType); - TypeReader reader = null; + var readers = service.GetTypeReaders(paramType, true); if (readers != null) - { - if (readers.TryGetValue(typeReaderType, out reader)) - return reader; - } - - var overrideTypeReader = service.GetOverrideTypeReader(paramType); - if (overrideTypeReader != null) - return overrideTypeReader; + foreach (var kvp in readers) + if (kvp.Key == typeReaderType) + return kvp.Value; //We dont have a cached type reader, create one - reader = ReflectionUtils.CreateObject(typeReaderType.GetTypeInfo(), service, services); - service.AddOverrideTypeReader(paramType, reader); + TypeReader reader = ReflectionUtils.CreateObject(typeReaderType.GetTypeInfo(), service, services); + reader.IsOverride = true; + service.AddTypeReader(paramType, reader); return reader; } diff --git a/src/Discord.Net.Commands/Builders/ParameterBuilder.cs b/src/Discord.Net.Commands/Builders/ParameterBuilder.cs index 4ad5bfac0..f2a3ee70c 100644 --- a/src/Discord.Net.Commands/Builders/ParameterBuilder.cs +++ b/src/Discord.Net.Commands/Builders/ParameterBuilder.cs @@ -60,7 +60,7 @@ namespace Discord.Commands.Builders if (type.GetTypeInfo().GetCustomAttribute() != null) { IsRemainder = true; - var reader = commands.GetTypeReaders(type)?.FirstOrDefault().Value; + var reader = commands.GetTypeReaders(type, false)?.FirstOrDefault().Value; if (reader == null) { Type readerType; @@ -80,8 +80,7 @@ namespace Discord.Commands.Builders return reader; } - - var readers = commands.GetTypeReaders(type); + var readers = commands.GetTypeReaders(type, false); if (readers != null) return readers.FirstOrDefault().Value; else diff --git a/src/Discord.Net.Commands/CommandService.cs b/src/Discord.Net.Commands/CommandService.cs index 80f9e5ce3..e61f14a5b 100644 --- a/src/Discord.Net.Commands/CommandService.cs +++ b/src/Discord.Net.Commands/CommandService.cs @@ -50,7 +50,6 @@ namespace Discord.Commands private readonly ConcurrentDictionary> _typeReaders; private readonly ConcurrentDictionary> _userEntityTypeReaders; private readonly ConcurrentDictionary _defaultTypeReaders; - private readonly ConcurrentDictionary _overrideTypeReaders; private readonly ImmutableList<(Type EntityType, Type TypeReaderType)> _entityTypeReaders; private readonly HashSet _moduleDefs; private readonly CommandMap _map; @@ -121,7 +120,6 @@ namespace Discord.Commands _map = new CommandMap(this); _typeReaders = new ConcurrentDictionary>(); _userEntityTypeReaders = new ConcurrentDictionary>(); - _overrideTypeReaders = new ConcurrentDictionary(); _defaultTypeReaders = new ConcurrentDictionary(); foreach (var type in PrimitiveParsers.SupportedTypes) @@ -449,20 +447,10 @@ namespace Discord.Commands var nullableReader = NullableTypeReader.Create(valueType, valueTypeReader); readers[nullableReader.GetType()] = nullableReader; } - internal void AddOverrideTypeReader(Type valueType, TypeReader valueTypeReader) - { - _overrideTypeReaders[valueType] = valueTypeReader; - } - internal TypeReader GetOverrideTypeReader(Type type) - { - if (_overrideTypeReaders.TryGetValue(type, out var definedTypeReader)) - return definedTypeReader; - return null; - } - internal IDictionary GetTypeReaders(Type type) + internal IEnumerable> GetTypeReaders(Type type, bool includeOverride) { if (_typeReaders.TryGetValue(type, out var definedTypeReaders)) - return definedTypeReaders; + return includeOverride ? definedTypeReaders : definedTypeReaders.Where(x => !x.Value.IsOverride); var assignableEntityReaders = _userEntityTypeReaders.Where(x => x.Key.IsAssignableFrom(type)); @@ -490,7 +478,7 @@ namespace Discord.Commands var entityTypeReaderType = entityReaders.Value.Value.First(); TypeReader reader = Activator.CreateInstance(entityTypeReaderType.MakeGenericType(type)) as TypeReader; AddTypeReader(type, reader); - return GetTypeReaders(type); + return GetTypeReaders(type, false); } return null; } diff --git a/src/Discord.Net.Commands/Readers/NamedArgumentTypeReader.cs b/src/Discord.Net.Commands/Readers/NamedArgumentTypeReader.cs index 0adf61046..b584d29d5 100644 --- a/src/Discord.Net.Commands/Readers/NamedArgumentTypeReader.cs +++ b/src/Discord.Net.Commands/Readers/NamedArgumentTypeReader.cs @@ -136,8 +136,8 @@ namespace Discord.Commands var overridden = prop.GetCustomAttribute(); var reader = (overridden != null) ? ModuleClassBuilder.GetTypeReader(_commands, elemType, overridden.TypeReader, services) - : (_commands.GetDefaultTypeReader(elemType) - ?? _commands.GetTypeReaders(elemType).FirstOrDefault().Value); + : (_commands.GetTypeReaders(elemType, false)?.FirstOrDefault().Value + ?? _commands.GetDefaultTypeReader(elemType)); if (reader != null) { diff --git a/src/Discord.Net.Commands/Readers/TypeReader.cs b/src/Discord.Net.Commands/Readers/TypeReader.cs index af780993d..a071d9b53 100644 --- a/src/Discord.Net.Commands/Readers/TypeReader.cs +++ b/src/Discord.Net.Commands/Readers/TypeReader.cs @@ -8,6 +8,7 @@ namespace Discord.Commands /// public abstract class TypeReader { + internal bool IsOverride { get; set; } = false; /// /// Attempts to parse the into the desired type. ///