diff --git a/src/Discord.Net.Commands/Builders/ModuleClassBuilder.cs b/src/Discord.Net.Commands/Builders/ModuleClassBuilder.cs index aec8dcbe3..2d909f978 100644 --- a/src/Discord.Net.Commands/Builders/ModuleClassBuilder.cs +++ b/src/Discord.Net.Commands/Builders/ModuleClassBuilder.cs @@ -275,8 +275,8 @@ namespace Discord.Commands if (builder.TypeReader == null) { - builder.TypeReader = service.GetDefaultTypeReader(paramType) - ?? service.GetTypeReaders(paramType)?.FirstOrDefault().Value; + builder.TypeReader = service.GetTypeReaders(paramType)?.FirstOrDefault().Value + ?? service.GetDefaultTypeReader(paramType); } } @@ -290,9 +290,13 @@ namespace Discord.Commands return reader; } + var overrideTypeReader = service.GetOverrideTypeReader(paramType); + if (overrideTypeReader != null) + return overrideTypeReader; + //We dont have a cached type reader, create one reader = ReflectionUtils.CreateObject(typeReaderType.GetTypeInfo(), service, services); - service.AddTypeReader(paramType, reader, false); + service.AddOverrideTypeReader(paramType, reader); return reader; } diff --git a/src/Discord.Net.Commands/CommandService.cs b/src/Discord.Net.Commands/CommandService.cs index d5c060fe4..b22337559 100644 --- a/src/Discord.Net.Commands/CommandService.cs +++ b/src/Discord.Net.Commands/CommandService.cs @@ -49,6 +49,7 @@ namespace Discord.Commands private readonly ConcurrentDictionary _typedModuleDefs; private readonly ConcurrentDictionary> _typeReaders; private readonly ConcurrentDictionary _defaultTypeReaders; + private readonly ConcurrentDictionary _overrideTypeReaders; private readonly ImmutableList<(Type EntityType, Type TypeReaderType)> _entityTypeReaders; private readonly HashSet _moduleDefs; private readonly CommandMap _map; @@ -109,6 +110,7 @@ namespace Discord.Commands _moduleDefs = new HashSet(); _map = new CommandMap(this); _typeReaders = new ConcurrentDictionary>(); + _overrideTypeReaders = new ConcurrentDictionary(); _defaultTypeReaders = new ConcurrentDictionary(); foreach (var type in PrimitiveParsers.SupportedTypes) @@ -348,10 +350,11 @@ namespace Discord.Commands /// An instance of the to be added. public void AddTypeReader(Type type, TypeReader reader) { - if (_defaultTypeReaders.ContainsKey(type)) - _ = _cmdLogger.WarningAsync($"The default TypeReader for {type.FullName} was replaced by {reader.GetType().FullName}." + - "To suppress this message, use AddTypeReader(reader, true)."); - AddTypeReader(type, reader, true); + var readers = _typeReaders.GetOrAdd(type, x => new ConcurrentDictionary()); + readers[reader.GetType()] = reader; + + if (type.GetTypeInfo().IsValueType) + AddNullableTypeReader(type, reader); } /// /// Adds a custom to this for the supplied object @@ -365,8 +368,9 @@ namespace Discord.Commands /// Defines whether the should replace the default one for /// if it exists. /// + [Obsolete("This method is deprecated. Use the method without the replaceDefault argument.")] public void AddTypeReader(TypeReader reader, bool replaceDefault) - => AddTypeReader(typeof(T), reader, replaceDefault); + => AddTypeReader(typeof(T), reader); /// /// Adds a custom to this for the supplied object /// type. @@ -379,27 +383,10 @@ namespace Discord.Commands /// Defines whether the should replace the default one for if /// it exists. /// + [Obsolete("This method is deprecated. Use the method without the replaceDefault argument.")] public void AddTypeReader(Type type, TypeReader reader, bool replaceDefault) - { - if (replaceDefault && HasDefaultTypeReader(type)) - { - _defaultTypeReaders.AddOrUpdate(type, reader, (k, v) => reader); - if (type.GetTypeInfo().IsValueType) - { - var nullableType = typeof(Nullable<>).MakeGenericType(type); - var nullableReader = NullableTypeReader.Create(type, reader); - _defaultTypeReaders.AddOrUpdate(nullableType, nullableReader, (k, v) => nullableReader); - } - } - else - { - var readers = _typeReaders.GetOrAdd(type, x => new ConcurrentDictionary()); - readers[reader.GetType()] = reader; + => AddTypeReader(type, reader); - if (type.GetTypeInfo().IsValueType) - AddNullableTypeReader(type, reader); - } - } internal bool HasDefaultTypeReader(Type type) { if (_defaultTypeReaders.ContainsKey(type)) @@ -408,7 +395,7 @@ namespace Discord.Commands var typeInfo = type.GetTypeInfo(); if (typeInfo.IsEnum) return true; - return _entityTypeReaders.Any(x => type == x.EntityType || typeInfo.ImplementedInterfaces.Contains(x.TypeReaderType)); + return _entityTypeReaders.Any(x => type == x.EntityType || typeInfo.ImplementedInterfaces.Contains(x.EntityType)); } internal void AddNullableTypeReader(Type valueType, TypeReader valueTypeReader) { @@ -416,11 +403,43 @@ namespace Discord.Commands var nullableReader = NullableTypeReader.Create(valueType, valueTypeReader); readers[nullableReader.GetType()] = nullableReader; } + internal void AddOverrideTypeReader(Type valueType, TypeReader valueTypeReader) + { + _overrideTypeReaders[valueType] = valueTypeReader; + } + internal TypeReader GetOverrideTypeReader(Type type) + { + if (_overrideTypeReaders.TryGetValue(type, out var definedTypeReader)) + return definedTypeReader; + return null; + } internal IDictionary GetTypeReaders(Type type) { if (_typeReaders.TryGetValue(type, out var definedTypeReaders)) return definedTypeReaders; - return null; + + var entityReaders = _typeReaders.Where(x => x.Key.IsAssignableFrom(type)); + + int assignableTo = -1; + KeyValuePair>? typeReader = null; + foreach (var entityReader in entityReaders) + { + int assignables = entityReaders.Sum(x => !x.Equals(entityReader) && x.Key.IsAssignableFrom(entityReader.Key) ? 1 : 0); + if (assignableTo == -1) + { + // First time + assignableTo = assignables; + typeReader = entityReader; + } + // Try to get the "higher" interface. IMessageChannel is assignable to IChannel, but not the inverse + else if (assignables > assignableTo) + { + assignableTo = assignables; + typeReader = entityReader; + } + } + + return typeReader?.Value; } internal TypeReader GetDefaultTypeReader(Type type) { @@ -511,7 +530,7 @@ namespace Discord.Commands await _commandExecutedEvent.InvokeAsync(Optional.Create(), context, searchResult).ConfigureAwait(false); return searchResult; } - + var commands = searchResult.Commands; var preconditionResults = new Dictionary();