@@ -11,21 +11,25 @@ namespace Discord.Commands
{
{
public class CommandService
public class CommandService
{
{
private static readonly TypeInfo _moduleTypeInfo = typeof(ModuleBase).GetTypeInfo();
private readonly SemaphoreSlim _moduleLock;
private readonly SemaphoreSlim _moduleLock;
private readonly ConcurrentDictionary<Type, Module> _modules;
private readonly ConcurrentDictionary<Type, ModuleInfo > _moduleDef s;
private readonly ConcurrentDictionary<Type, TypeReader> _typeReaders;
private readonly ConcurrentDictionary<Type, TypeReader> _typeReaders;
private readonly CommandMap _map;
private readonly CommandMap _map;
public IEnumerable<Module> Modules => _modules.Select(x => x.Value);
public IEnumerable<Command> Commands => _modules.SelectMany(x => x.Value.Commands);
public IEnumerable<ModuleInfo > Modules => _moduleDef s.Select(x => x.Value);
public IEnumerable<CommandInfo > Commands => _moduleDef s.SelectMany(x => x.Value.Commands);
public CommandService()
public CommandService()
{
{
_moduleLock = new SemaphoreSlim(1, 1);
_moduleLock = new SemaphoreSlim(1, 1);
_modules = new ConcurrentDictionary<Type, Module>();
_moduleDef s = new ConcurrentDictionary<Type, ModuleInfo >();
_map = new CommandMap();
_map = new CommandMap();
_typeReaders = new ConcurrentDictionary<Type, TypeReader>
_typeReaders = new ConcurrentDictionary<Type, TypeReader>
{
{
[typeof(bool)] = new SimpleTypeReader<bool>(),
[typeof(char)] = new SimpleTypeReader<char>(),
[typeof(string)] = new SimpleTypeReader<string>(),
[typeof(string)] = new SimpleTypeReader<string>(),
[typeof(byte)] = new SimpleTypeReader<byte>(),
[typeof(byte)] = new SimpleTypeReader<byte>(),
[typeof(sbyte)] = new SimpleTypeReader<sbyte>(),
[typeof(sbyte)] = new SimpleTypeReader<sbyte>(),
@@ -43,7 +47,6 @@ namespace Discord.Commands
[typeof(IMessage)] = new MessageTypeReader<IMessage>(),
[typeof(IMessage)] = new MessageTypeReader<IMessage>(),
[typeof(IUserMessage)] = new MessageTypeReader<IUserMessage>(),
[typeof(IUserMessage)] = new MessageTypeReader<IUserMessage>(),
//[typeof(ISystemMessage)] = new MessageTypeReader<ISystemMessage>(),
[typeof(IChannel)] = new ChannelTypeReader<IChannel>(),
[typeof(IChannel)] = new ChannelTypeReader<IChannel>(),
[typeof(IDMChannel)] = new ChannelTypeReader<IDMChannel>(),
[typeof(IDMChannel)] = new ChannelTypeReader<IDMChannel>(),
[typeof(IGroupChannel)] = new ChannelTypeReader<IGroupChannel>(),
[typeof(IGroupChannel)] = new ChannelTypeReader<IGroupChannel>(),
@@ -53,120 +56,99 @@ namespace Discord.Commands
[typeof(ITextChannel)] = new ChannelTypeReader<ITextChannel>(),
[typeof(ITextChannel)] = new ChannelTypeReader<ITextChannel>(),
[typeof(IVoiceChannel)] = new ChannelTypeReader<IVoiceChannel>(),
[typeof(IVoiceChannel)] = new ChannelTypeReader<IVoiceChannel>(),
//[typeof(IGuild)] = new GuildTypeReader<IGuild>(),
[typeof(IRole)] = new RoleTypeReader<IRole>(),
[typeof(IRole)] = new RoleTypeReader<IRole>(),
//[typeof(IInvite)] = new InviteTypeReader<IInvite>(),
//[typeof(IInviteMetadata)] = new InviteTypeReader<IInviteMetadata>(),
[typeof(IUser)] = new UserTypeReader<IUser>(),
[typeof(IUser)] = new UserTypeReader<IUser>(),
[typeof(IGroupUser)] = new UserTypeReader<IGroupUser>(),
[typeof(IGroupUser)] = new UserTypeReader<IGroupUser>(),
[typeof(IGuildUser)] = new UserTypeReader<IGuildUser>(),
[typeof(IGuildUser)] = new UserTypeReader<IGuildUser>(),
};
};
}
}
public void AddTypeReader<T>(TypeReader reader)
{
_typeReaders[typeof(T)] = reader;
}
public void AddTypeReader(Type type, TypeReader reader)
{
_typeReaders[type] = reader;
}
internal TypeReader GetTypeReader(Type type)
{
TypeReader reader;
if (_typeReaders.TryGetValue(type, out reader))
return reader;
return null;
}
public async Task<Module> Load(object moduleInstance)
//Modules
public async Task<ModuleInfo> AddModule<T>(IDependencyMap dependencyMap = null)
{
{
await _moduleLock.WaitAsync().ConfigureAwait(false);
await _moduleLock.WaitAsync().ConfigureAwait(false);
try
try
{
{
if (_modules.ContainsKey(moduleIns tance.GetT ype()))
throw new ArgumentException($"This module has already been lo aded.");
if (_moduleDefs.ContainsKey(typeof(T)))
throw new ArgumentException($"This module has already been added.");
var typeInfo = moduleInstance.GetType().GetTypeInfo();
var moduleAttr = typeInfo.GetCustomAttribute<ModuleAttribute>();
if (moduleAttr == null)
throw new ArgumentException($"Modules must be marked with ModuleAttribute.");
var typeInfo = typeof(T).GetTypeInfo();
if (!_moduleTypeInfo.IsAssignableFrom(typeInfo))
throw new ArgumentException($"Modules must inherit ModuleBase.");
return LoadInternal(moduleInstance, moduleAttr, typeInfo, null );
return AddModuleInternal(typeInfo, dependencyMap);
}
}
finally
finally
{
{
_moduleLock.Release();
_moduleLock.Release();
}
}
}
}
private Module LoadInternal(object moduleInstance, ModuleAttribute moduleAttr, TypeInfo typeInfo, IDependencyMap dependencyMap)
{
if (_modules.ContainsKey(moduleInstance.GetType()))
return _modules[moduleInstance.GetType()];
var loadedModule = new Module(typeInfo, this, moduleInstance, moduleAttr, dependencyMap);
_modules[moduleInstance.GetType()] = loadedModule;
foreach (var cmd in loadedModule.Commands)
_map.AddCommand(cmd);
return loadedModule;
}
public async Task<IEnumerable<Module>> LoadAssembly(Assembly assembly, IDependencyMap dependencyMap = null)
public async Task<IEnumerable<ModuleInfo>> AddModules(Assembly assembly, IDependencyMap dependencyMap = null)
{
{
var modules = ImmutableArray.CreateBuilder<Module>();
var moduleDefs = ImmutableArray.CreateBuilder<ModuleInfo>();
await _moduleLock.WaitAsync().ConfigureAwait(false);
await _moduleLock.WaitAsync().ConfigureAwait(false);
try
try
{
{
foreach (var type in assembly.ExportedTypes)
foreach (var type in assembly.ExportedTypes)
{
{
var typeInfo = type.GetTypeInfo();
var moduleAttr = typeInfo.GetCustomAttribute<ModuleAttribute>();
if (moduleAttr != null && moduleAttr.AutoLoad)
if (!_moduleDefs.ContainsKey(type))
{
{
var moduleInstance = ReflectionUtils.CreateObject(typeInfo, this, dependencyMap);
modules.Add(LoadInternal(moduleInstance, moduleAttr, typeInfo, dependencyMap));
var typeInfo = type.GetTypeInfo();
if (_moduleTypeInfo.IsAssignableFrom(typeInfo))
{
var dontAutoLoad = typeInfo.GetCustomAttribute<DontAutoLoadAttribute>();
if (dontAutoLoad == null)
moduleDefs.Add(AddModuleInternal(typeInfo, dependencyMap));
}
}
}
}
}
return modules.ToImmutable();
return moduleDef s.ToImmutable();
}
}
finally
finally
{
{
_moduleLock.Release();
_moduleLock.Release();
}
}
}
}
private ModuleInfo AddModuleInternal(TypeInfo typeInfo, IDependencyMap dependencyMap)
{
var moduleDef = new ModuleInfo(typeInfo, this, dependencyMap);
_moduleDefs[typeInfo.BaseType] = moduleDef;
foreach (var cmd in moduleDef.Commands)
_map.AddCommand(cmd);
return moduleDef;
}
public async Task<bool> Unload(Module module)
public async Task<bool> RemoveModule(ModuleInfo module)
{
{
await _moduleLock.WaitAsync().ConfigureAwait(false);
await _moduleLock.WaitAsync().ConfigureAwait(false);
try
try
{
{
return UnloadInternal(module.Instance);
return RemoveModuleInternal(module.Source.BaseTyp e);
}
}
finally
finally
{
{
_moduleLock.Release();
_moduleLock.Release();
}
}
}
}
public async Task<bool> Unload(object moduleInstance )
public async Task<bool> RemoveModule<T>( )
{
{
await _moduleLock.WaitAsync().ConfigureAwait(false);
await _moduleLock.WaitAsync().ConfigureAwait(false);
try
try
{
{
return UnloadInternal(moduleInstance );
return RemoveModuleInternal(typeof(T) );
}
}
finally
finally
{
{
_moduleLock.Release();
_moduleLock.Release();
}
}
}
}
private bool UnloadInternal(object modul e)
private bool RemoveModuleInternal(Type typ e)
{
{
Module unloadedModule;
if (_modules.TryRemove(module.Ge tT ype() , out unloadedModule))
ModuleInfo unloadedModule;
if (_moduleDef s.TryRemove(type, out unloadedModule))
{
{
foreach (var cmd in unloadedModule.Commands)
foreach (var cmd in unloadedModule.Commands)
_map.RemoveCommand(cmd);
_map.RemoveCommand(cmd);
@@ -176,8 +158,26 @@ namespace Discord.Commands
return false;
return false;
}
}
public SearchResult Search(IUserMessage message, int argPos) => Search(message, message.Content.Substring(argPos));
public SearchResult Search(IUserMessage message, string input)
//Type Readers
public void AddTypeReader<T>(TypeReader reader)
{
_typeReaders[typeof(T)] = reader;
}
public void AddTypeReader(Type type, TypeReader reader)
{
_typeReaders[type] = reader;
}
internal TypeReader GetTypeReader(Type type)
{
TypeReader reader;
if (_typeReaders.TryGetValue(type, out reader))
return reader;
return null;
}
//Execution
public SearchResult Search(CommandContext context, int argPos) => Search(context, context.Message.Content.Substring(argPos));
public SearchResult Search(CommandContext context, string input)
{
{
string lowerInput = input.ToLowerInvariant();
string lowerInput = input.ToLowerInvariant();
var matches = _map.GetCommands(input).OrderByDescending(x => x.Priority).ToImmutableArray();
var matches = _map.GetCommands(input).OrderByDescending(x => x.Priority).ToImmutableArray();
@@ -188,18 +188,18 @@ namespace Discord.Commands
return SearchResult.FromError(CommandError.UnknownCommand, "Unknown command.");
return SearchResult.FromError(CommandError.UnknownCommand, "Unknown command.");
}
}
public Task<IResult> Execute(IUserMessage message , int argPos, MultiMatchHandling multiMatchHandling = MultiMatchHandling.Exception)
=> Execute(message, m essage.Content.Substring(argPos), multiMatchHandling);
public async Task<IResult> Execute(IUserMessage message , string input, MultiMatchHandling multiMatchHandling = MultiMatchHandling.Exception)
public Task<IResult> Execute(CommandContext context , int argPos, MultiMatchHandling multiMatchHandling = MultiMatchHandling.Exception)
=> Execute(context, context.M essage.Content.Substring(argPos), multiMatchHandling);
public async Task<IResult> Execute(CommandContext context , string input, MultiMatchHandling multiMatchHandling = MultiMatchHandling.Exception)
{
{
var searchResult = Search(message , input);
var searchResult = Search(context , input);
if (!searchResult.IsSuccess)
if (!searchResult.IsSuccess)
return searchResult;
return searchResult;
var commands = searchResult.Commands;
var commands = searchResult.Commands;
for (int i = commands.Count - 1; i >= 0; i--)
for (int i = commands.Count - 1; i >= 0; i--)
{
{
var preconditionResult = await commands[i].CheckPreconditions(messag e);
var preconditionResult = await commands[i].CheckPreconditions(context).ConfigureAwait(fals e);
if (!preconditionResult.IsSuccess)
if (!preconditionResult.IsSuccess)
{
{
if (commands.Count == 1)
if (commands.Count == 1)
@@ -208,17 +208,17 @@ namespace Discord.Commands
continue;
continue;
}
}
var parseResult = await commands[i].Parse(message, searchResult, preconditionResult );
var parseResult = await commands[i].Parse(context, searchResult, preconditionResult).ConfigureAwait(false );
if (!parseResult.IsSuccess)
if (!parseResult.IsSuccess)
{
{
if (parseResult.Error == CommandError.MultipleMatches)
if (parseResult.Error == CommandError.MultipleMatches)
{
{
TypeReaderValue[] argList, paramList;
IReadOnlyList<TypeReaderValue> argList, paramList;
switch (multiMatchHandling)
switch (multiMatchHandling)
{
{
case MultiMatchHandling.Best:
case MultiMatchHandling.Best:
argList = parseResult.ArgValues.Select(x => x.Values.OrderByDescending(y => y.Score).First()).ToArray();
paramList = parseResult.ParamValues.Select(x => x.Values.OrderByDescending(y => y.Score).First()).ToArray();
argList = parseResult.ArgValues.Select(x => x.Values.OrderByDescending(y => y.Score).First()).ToImmutable Array();
paramList = parseResult.ParamValues.Select(x => x.Values.OrderByDescending(y => y.Score).First()).ToImmutable Array();
parseResult = ParseResult.FromSuccess(argList, paramList);
parseResult = ParseResult.FromSuccess(argList, paramList);
break;
break;
}
}
@@ -233,7 +233,7 @@ namespace Discord.Commands
}
}
}
}
return await commands[i].Execute(message, parseResult );
return await commands[i].Execute(context, parseResult).ConfigureAwait(false );
}
}
return SearchResult.FromError(CommandError.UnknownCommand, "This input does not match any overload.");
return SearchResult.FromError(CommandError.UnknownCommand, "This input does not match any overload.");