diff --git a/src/Discord.Net.Commands/Attributes/OverrideTypeReaderAttribute.cs b/src/Discord.Net.Commands/Attributes/OverrideTypeReaderAttribute.cs new file mode 100644 index 000000000..37f685c95 --- /dev/null +++ b/src/Discord.Net.Commands/Attributes/OverrideTypeReaderAttribute.cs @@ -0,0 +1,22 @@ +using System; + +using System.Reflection; + +namespace Discord.Commands +{ + [AttributeUsage(AttributeTargets.Parameter)] + public class OverrideTypeReaderAttribute : Attribute + { + private static readonly TypeInfo _typeReaderTypeInfo = typeof(TypeReader).GetTypeInfo(); + + public Type TypeReader { get; } + + public OverrideTypeReaderAttribute(Type overridenTypeReader) + { + if (!_typeReaderTypeInfo.IsAssignableFrom(overridenTypeReader.GetTypeInfo())) + throw new ArgumentException($"{nameof(overridenTypeReader)} must inherit from {nameof(TypeReader)}"); + + TypeReader = overridenTypeReader; + } + } +} \ No newline at end of file diff --git a/src/Discord.Net.Commands/Attributes/ParameterPreconditionAttribute.cs b/src/Discord.Net.Commands/Attributes/ParameterPreconditionAttribute.cs new file mode 100644 index 000000000..3bf8d177a --- /dev/null +++ b/src/Discord.Net.Commands/Attributes/ParameterPreconditionAttribute.cs @@ -0,0 +1,11 @@ +using System; +using System.Threading.Tasks; + +namespace Discord.Commands +{ + [AttributeUsage(AttributeTargets.Parameter, AllowMultiple = true, Inherited = true)] + public abstract class ParameterPreconditionAttribute : Attribute + { + public abstract Task CheckPermissions(CommandContext context, ParameterInfo parameter, object value, IDependencyMap map); + } +} \ No newline at end of file diff --git a/src/Discord.Net.Commands/Builders/ModuleClassBuilder.cs b/src/Discord.Net.Commands/Builders/ModuleClassBuilder.cs index aaec43161..e8dc60de8 100644 --- a/src/Discord.Net.Commands/Builders/ModuleClassBuilder.cs +++ b/src/Discord.Net.Commands/Builders/ModuleClassBuilder.cs @@ -182,6 +182,10 @@ namespace Discord.Commands // TODO: C#7 type switch if (attribute is SummaryAttribute) builder.Summary = (attribute as SummaryAttribute).Text; + else if (attribute is OverrideTypeReaderAttribute) + builder.TypeReader = GetTypeReader(service, paramType, (attribute as OverrideTypeReaderAttribute).TypeReader); + else if (attribute is ParameterPreconditionAttribute) + builder.AddPrecondition(attribute as ParameterPreconditionAttribute); else if (attribute is ParamArrayAttribute) { builder.IsMultiple = true; @@ -196,23 +200,42 @@ namespace Discord.Commands } } - var reader = service.GetTypeReader(paramType); - if (reader == null) + if (builder.TypeReader == null) { - var paramTypeInfo = paramType.GetTypeInfo(); - if (paramTypeInfo.IsEnum) - { - reader = EnumTypeReader.GetReader(paramType); - service.AddTypeReader(paramType, reader); - } - else + var reader = service.GetDefaultTypeReader(paramType); + + if (reader == null) { - throw new InvalidOperationException($"{paramType.FullName} is not supported as a command parameter, are you missing a TypeReader?"); + var paramTypeInfo = paramType.GetTypeInfo(); + if (paramTypeInfo.IsEnum) + { + reader = EnumTypeReader.GetReader(paramType); + service.AddTypeReader(paramType, reader); + } + else + { + throw new InvalidOperationException($"{paramType.FullName} is not supported as a command parameter, are you missing a TypeReader?"); + } } + + builder.ParameterType = paramType; + builder.TypeReader = reader; } + } + + private static TypeReader GetTypeReader(CommandService service, Type paramType, Type typeReaderType) + { + var readers = service.GetTypeReaders(paramType); + TypeReader reader = null; + if (readers != null) + if (readers.TryGetValue(typeReaderType, out reader)) + return reader; + + //could not find any registered type reader: try to create one + reader = ReflectionUtils.CreateObject(typeReaderType.GetTypeInfo(), service, DependencyMap.Empty); + service.AddTypeReader(paramType, reader); - builder.ParameterType = paramType; - builder.TypeReader = reader; + return reader; } private static bool IsValidModuleDefinition(TypeInfo typeInfo) diff --git a/src/Discord.Net.Commands/Builders/ParameterBuilder.cs b/src/Discord.Net.Commands/Builders/ParameterBuilder.cs index 801a10080..d4cf598ec 100644 --- a/src/Discord.Net.Commands/Builders/ParameterBuilder.cs +++ b/src/Discord.Net.Commands/Builders/ParameterBuilder.cs @@ -1,10 +1,15 @@ using System; +using System.Linq; using System.Reflection; +using System.Collections.Generic; + namespace Discord.Commands.Builders { public class ParameterBuilder { + private readonly List _preconditions; + public CommandBuilder Command { get; } public string Name { get; internal set; } public Type ParameterType { get; internal set; } @@ -16,16 +21,20 @@ namespace Discord.Commands.Builders public object DefaultValue { get; set; } public string Summary { get; set; } + public IReadOnlyList Preconditions => _preconditions; + //Automatic internal ParameterBuilder(CommandBuilder command) { + _preconditions = new List(); + Command = command; } //User-defined internal ParameterBuilder(CommandBuilder command, string name, Type type) : this(command) { - Preconditions.NotNull(name, nameof(name)); + Discord.Preconditions.NotNull(name, nameof(name)); Name = name; SetType(type); @@ -33,7 +42,11 @@ namespace Discord.Commands.Builders internal void SetType(Type type) { - TypeReader = Command.Module.Service.GetTypeReader(type); + var readers = Command.Module.Service.GetTypeReaders(type); + if (readers == null) + throw new InvalidOperationException($"{type} does not have a TypeReader registered for it"); + + TypeReader = readers.FirstOrDefault().Value; if (type.GetTypeInfo().IsValueType) DefaultValue = Activator.CreateInstance(type); @@ -49,7 +62,7 @@ namespace Discord.Commands.Builders } public ParameterBuilder WithDefault(object defaultValue) { - DefaultValue = defaultValue; + DefaultValue = defaultValue; return this; } public ParameterBuilder WithIsOptional(bool isOptional) @@ -68,6 +81,12 @@ namespace Discord.Commands.Builders return this; } + public ParameterBuilder AddPrecondition(ParameterPreconditionAttribute precondition) + { + _preconditions.Add(precondition); + return this; + } + internal ParameterInfo Build(CommandInfo info) { if (TypeReader == null) diff --git a/src/Discord.Net.Commands/CommandService.cs b/src/Discord.Net.Commands/CommandService.cs index d5f4bee7e..9c2440b7c 100644 --- a/src/Discord.Net.Commands/CommandService.cs +++ b/src/Discord.Net.Commands/CommandService.cs @@ -15,7 +15,8 @@ namespace Discord.Commands { private readonly SemaphoreSlim _moduleLock; private readonly ConcurrentDictionary _typedModuleDefs; - private readonly ConcurrentDictionary _typeReaders; + private readonly ConcurrentDictionary> _typeReaders; + private readonly ConcurrentDictionary _defaultTypeReaders; private readonly ConcurrentBag _moduleDefs; private readonly CommandMap _map; @@ -24,6 +25,7 @@ namespace Discord.Commands public IEnumerable Modules => _moduleDefs.Select(x => x); public IEnumerable Commands => _moduleDefs.SelectMany(x => x.Commands); + public ILookup TypeReaders => _typeReaders.SelectMany(x => x.Value.Select(y => new {y.Key, y.Value})).ToLookup(x => x.Key, x => x.Value); public CommandService() : this(new CommandServiceConfig()) { } public CommandService(CommandServiceConfig config) @@ -32,7 +34,9 @@ namespace Discord.Commands _typedModuleDefs = new ConcurrentDictionary(); _moduleDefs = new ConcurrentBag(); _map = new CommandMap(); - _typeReaders = new ConcurrentDictionary + _typeReaders = new ConcurrentDictionary>(); + + _defaultTypeReaders = new ConcurrentDictionary { [typeof(bool)] = new SimpleTypeReader(), [typeof(char)] = new SimpleTypeReader(), @@ -50,7 +54,7 @@ namespace Discord.Commands [typeof(decimal)] = new SimpleTypeReader(), [typeof(DateTime)] = new SimpleTypeReader(), [typeof(DateTimeOffset)] = new SimpleTypeReader(), - [typeof(TimeSpan)] = new SimpleTypeReader(), + [typeof(IMessage)] = new MessageTypeReader(), [typeof(IUserMessage)] = new MessageTypeReader(), [typeof(IChannel)] = new ChannelTypeReader(), @@ -196,16 +200,25 @@ namespace Discord.Commands //Type Readers public void AddTypeReader(TypeReader reader) { - _typeReaders[typeof(T)] = reader; + var readers = _typeReaders.GetOrAdd(typeof(T), x => new ConcurrentDictionary()); + readers[reader.GetType()] = reader; } public void AddTypeReader(Type type, TypeReader reader) { - _typeReaders[type] = reader; + var readers = _typeReaders.GetOrAdd(type, x=> new ConcurrentDictionary()); + readers[reader.GetType()] = reader; + } + internal IDictionary GetTypeReaders(Type type) + { + ConcurrentDictionary definedTypeReaders; + if (_typeReaders.TryGetValue(type, out definedTypeReaders)) + return definedTypeReaders; + return null; } - internal TypeReader GetTypeReader(Type type) + internal TypeReader GetDefaultTypeReader(Type type) { TypeReader reader; - if (_typeReaders.TryGetValue(type, out reader)) + if (_defaultTypeReaders.TryGetValue(type, out reader)) return reader; return null; } diff --git a/src/Discord.Net.Commands/Info/CommandInfo.cs b/src/Discord.Net.Commands/Info/CommandInfo.cs index 571c47e13..a6ac50005 100644 --- a/src/Discord.Net.Commands/Info/CommandInfo.cs +++ b/src/Discord.Net.Commands/Info/CommandInfo.cs @@ -137,7 +137,15 @@ namespace Discord.Commands try { - var args = GenerateArgs(argList, paramList); + object[] args = GenerateArgs(argList, paramList); + + foreach (var parameter in Parameters) + { + var result = await parameter.CheckPreconditionsAsync(context, args, map).ConfigureAwait(false); + if (!result.IsSuccess) + return ExecuteResult.FromError(result); + } + switch (RunMode) { case RunMode.Sync: //Always sync diff --git a/src/Discord.Net.Commands/Info/ParameterInfo.cs b/src/Discord.Net.Commands/Info/ParameterInfo.cs index 18c5e653c..f8a97647a 100644 --- a/src/Discord.Net.Commands/Info/ParameterInfo.cs +++ b/src/Discord.Net.Commands/Info/ParameterInfo.cs @@ -1,5 +1,7 @@ using System; using System.Linq; +using System.Collections.Generic; +using System.Collections.Immutable; using System.Threading.Tasks; using Discord.Commands.Builders; @@ -10,6 +12,17 @@ namespace Discord.Commands { private readonly TypeReader _reader; + public CommandInfo Command { get; } + public string Name { get; } + public string Summary { get; } + public bool IsOptional { get; } + public bool IsRemainder { get; } + public bool IsMultiple { get; } + public Type Type { get; } + public object DefaultValue { get; } + + public IReadOnlyList Preconditions { get; } + internal ParameterInfo(ParameterBuilder builder, CommandInfo command, CommandService service) { Command = command; @@ -23,17 +36,30 @@ namespace Discord.Commands Type = builder.ParameterType; DefaultValue = builder.DefaultValue; + Preconditions = builder.Preconditions.ToImmutableArray(); + _reader = builder.TypeReader; } - public CommandInfo Command { get; } - public string Name { get; } - public string Summary { get; } - public bool IsOptional { get; } - public bool IsRemainder { get; } - public bool IsMultiple { get; } - public Type Type { get; } - public object DefaultValue { get; } + public async Task CheckPreconditionsAsync(CommandContext context, object[] args, IDependencyMap map = null) + { + if (map == null) + map = DependencyMap.Empty; + + int position = 0; + for(position = 0; position < Command.Parameters.Count; position++) + if (Command.Parameters[position] == this) + break; + + foreach (var precondition in Preconditions) + { + var result = await precondition.CheckPermissions(context, this, args[position], map).ConfigureAwait(false); + if (!result.IsSuccess) + return result; + } + + return PreconditionResult.FromSuccess(); + } public async Task Parse(CommandContext context, string input) {