You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

CommandService.cs 14 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335
  1. using Discord.Commands.Builders;
  2. using Discord.Logging;
  3. using System;
  4. using System.Collections.Concurrent;
  5. using System.Collections.Generic;
  6. using System.Collections.Immutable;
  7. using System.Linq;
  8. using System.Reflection;
  9. using System.Threading;
  10. using System.Threading.Tasks;
  11. using Microsoft.Extensions.DependencyInjection;
  12. namespace Discord.Commands
  13. {
  14. public class CommandService
  15. {
  16. public event Func<LogMessage, Task> Log { add { _logEvent.Add(value); } remove { _logEvent.Remove(value); } }
  17. internal readonly AsyncEvent<Func<LogMessage, Task>> _logEvent = new AsyncEvent<Func<LogMessage, Task>>();
  18. private readonly SemaphoreSlim _moduleLock;
  19. private readonly ConcurrentDictionary<Type, ModuleInfo> _typedModuleDefs;
  20. private readonly ConcurrentDictionary<Type, ConcurrentDictionary<Type, TypeReader>> _typeReaders;
  21. private readonly ConcurrentDictionary<Type, TypeReader> _defaultTypeReaders;
  22. private readonly ImmutableList<Tuple<Type, Type>> _entityTypeReaders; //TODO: Candidate for C#7 Tuple
  23. private readonly HashSet<ModuleInfo> _moduleDefs;
  24. private readonly CommandMap _map;
  25. internal readonly bool _caseSensitive, _throwOnError;
  26. internal readonly char _separatorChar;
  27. internal readonly RunMode _defaultRunMode;
  28. internal readonly Logger _cmdLogger;
  29. internal readonly LogManager _logManager;
  30. public IEnumerable<ModuleInfo> Modules => _moduleDefs.Select(x => x);
  31. public IEnumerable<CommandInfo> Commands => _moduleDefs.SelectMany(x => x.Commands);
  32. public ILookup<Type, TypeReader> TypeReaders => _typeReaders.SelectMany(x => x.Value.Select(y => new { y.Key, y.Value })).ToLookup(x => x.Key, x => x.Value);
  33. public CommandService() : this(new CommandServiceConfig()) { }
  34. public CommandService(CommandServiceConfig config)
  35. {
  36. _caseSensitive = config.CaseSensitiveCommands;
  37. _throwOnError = config.ThrowOnError;
  38. _separatorChar = config.SeparatorChar;
  39. _defaultRunMode = config.DefaultRunMode;
  40. if (_defaultRunMode == RunMode.Default)
  41. throw new InvalidOperationException("The default run mode cannot be set to Default.");
  42. _logManager = new LogManager(config.LogLevel);
  43. _logManager.Message += async msg => await _logEvent.InvokeAsync(msg).ConfigureAwait(false);
  44. _cmdLogger = _logManager.CreateLogger("Command");
  45. _moduleLock = new SemaphoreSlim(1, 1);
  46. _typedModuleDefs = new ConcurrentDictionary<Type, ModuleInfo>();
  47. _moduleDefs = new HashSet<ModuleInfo>();
  48. _map = new CommandMap(this);
  49. _typeReaders = new ConcurrentDictionary<Type, ConcurrentDictionary<Type, TypeReader>>();
  50. _defaultTypeReaders = new ConcurrentDictionary<Type, TypeReader>();
  51. foreach (var type in PrimitiveParsers.SupportedTypes)
  52. _defaultTypeReaders[type] = PrimitiveTypeReader.Create(type);
  53. _defaultTypeReaders[typeof(string)] = new PrimitiveTypeReader<string>(0,
  54. (string x, out string y) =>
  55. {
  56. y = x;
  57. return true;
  58. });
  59. var entityTypeReaders = ImmutableList.CreateBuilder<Tuple<Type, Type>>();
  60. entityTypeReaders.Add(new Tuple<Type, Type>(typeof(IMessage), typeof(MessageTypeReader<>)));
  61. entityTypeReaders.Add(new Tuple<Type, Type>(typeof(IChannel), typeof(ChannelTypeReader<>)));
  62. entityTypeReaders.Add(new Tuple<Type, Type>(typeof(IRole), typeof(RoleTypeReader<>)));
  63. entityTypeReaders.Add(new Tuple<Type, Type>(typeof(IUser), typeof(UserTypeReader<>)));
  64. _entityTypeReaders = entityTypeReaders.ToImmutable();
  65. }
  66. //Modules
  67. public async Task<ModuleInfo> CreateModuleAsync(string primaryAlias, Action<ModuleBuilder> buildFunc)
  68. {
  69. await _moduleLock.WaitAsync().ConfigureAwait(false);
  70. try
  71. {
  72. var builder = new ModuleBuilder(this, null, primaryAlias);
  73. buildFunc(builder);
  74. var module = builder.Build(this);
  75. return LoadModuleInternal(module);
  76. }
  77. finally
  78. {
  79. _moduleLock.Release();
  80. }
  81. }
  82. public Task<ModuleInfo> AddModuleAsync<T>() => AddModuleAsync(typeof(T));
  83. public async Task<ModuleInfo> AddModuleAsync(Type type)
  84. {
  85. await _moduleLock.WaitAsync().ConfigureAwait(false);
  86. try
  87. {
  88. var typeInfo = type.GetTypeInfo();
  89. if (_typedModuleDefs.ContainsKey(type))
  90. throw new ArgumentException($"This module has already been added.");
  91. var module = (await ModuleClassBuilder.Build(this, typeInfo).ConfigureAwait(false))
  92. .FirstOrDefault();
  93. if (module.Value == default(ModuleInfo))
  94. throw new InvalidOperationException($"Could not build the module {type.FullName}, did you pass an invalid type?");
  95. _typedModuleDefs[module.Key] = module.Value;
  96. return LoadModuleInternal(module.Value);
  97. }
  98. finally
  99. {
  100. _moduleLock.Release();
  101. }
  102. }
  103. public async Task<IEnumerable<ModuleInfo>> AddModulesAsync(Assembly assembly)
  104. {
  105. await _moduleLock.WaitAsync().ConfigureAwait(false);
  106. try
  107. {
  108. var types = await ModuleClassBuilder.Search(assembly, this).ConfigureAwait(false);
  109. var moduleDefs = await ModuleClassBuilder.Build(types, this).ConfigureAwait(false);
  110. foreach (var info in moduleDefs)
  111. {
  112. _typedModuleDefs[info.Key] = info.Value;
  113. LoadModuleInternal(info.Value);
  114. }
  115. return moduleDefs.Select(x => x.Value).ToImmutableArray();
  116. }
  117. finally
  118. {
  119. _moduleLock.Release();
  120. }
  121. }
  122. private ModuleInfo LoadModuleInternal(ModuleInfo module)
  123. {
  124. _moduleDefs.Add(module);
  125. foreach (var command in module.Commands)
  126. _map.AddCommand(command);
  127. foreach (var submodule in module.Submodules)
  128. LoadModuleInternal(submodule);
  129. return module;
  130. }
  131. public async Task<bool> RemoveModuleAsync(ModuleInfo module)
  132. {
  133. await _moduleLock.WaitAsync().ConfigureAwait(false);
  134. try
  135. {
  136. return RemoveModuleInternal(module);
  137. }
  138. finally
  139. {
  140. _moduleLock.Release();
  141. }
  142. }
  143. public Task<bool> RemoveModuleAsync<T>() => RemoveModuleAsync(typeof(T));
  144. public async Task<bool> RemoveModuleAsync(Type type)
  145. {
  146. await _moduleLock.WaitAsync().ConfigureAwait(false);
  147. try
  148. {
  149. if (!_typedModuleDefs.TryRemove(type, out var module))
  150. return false;
  151. return RemoveModuleInternal(module);
  152. }
  153. finally
  154. {
  155. _moduleLock.Release();
  156. }
  157. }
  158. private bool RemoveModuleInternal(ModuleInfo module)
  159. {
  160. if (!_moduleDefs.Remove(module))
  161. return false;
  162. foreach (var cmd in module.Commands)
  163. _map.RemoveCommand(cmd);
  164. foreach (var submodule in module.Submodules)
  165. {
  166. RemoveModuleInternal(submodule);
  167. }
  168. return true;
  169. }
  170. //Type Readers
  171. public void AddTypeReader<T>(TypeReader reader)
  172. {
  173. var readers = _typeReaders.GetOrAdd(typeof(T), x => new ConcurrentDictionary<Type, TypeReader>());
  174. readers[reader.GetType()] = reader;
  175. }
  176. public void AddTypeReader(Type type, TypeReader reader)
  177. {
  178. var readers = _typeReaders.GetOrAdd(type, x => new ConcurrentDictionary<Type, TypeReader>());
  179. readers[reader.GetType()] = reader;
  180. }
  181. internal IDictionary<Type, TypeReader> GetTypeReaders(Type type)
  182. {
  183. if (_typeReaders.TryGetValue(type, out var definedTypeReaders))
  184. return definedTypeReaders;
  185. return null;
  186. }
  187. internal TypeReader GetDefaultTypeReader(Type type)
  188. {
  189. if (_defaultTypeReaders.TryGetValue(type, out var reader))
  190. return reader;
  191. var typeInfo = type.GetTypeInfo();
  192. //Is this an enum?
  193. if (typeInfo.IsEnum)
  194. {
  195. reader = EnumTypeReader.GetReader(type);
  196. _defaultTypeReaders[type] = reader;
  197. return reader;
  198. }
  199. //Is this an entity?
  200. for (int i = 0; i < _entityTypeReaders.Count; i++)
  201. {
  202. if (type == _entityTypeReaders[i].Item1 || typeInfo.ImplementedInterfaces.Contains(_entityTypeReaders[i].Item1))
  203. {
  204. reader = Activator.CreateInstance(_entityTypeReaders[i].Item2.MakeGenericType(type)) as TypeReader;
  205. _defaultTypeReaders[type] = reader;
  206. return reader;
  207. }
  208. }
  209. return null;
  210. }
  211. //Execution
  212. public SearchResult Search(ICommandContext context, int argPos)
  213. => Search(context, context.Message.Content.Substring(argPos));
  214. public SearchResult Search(ICommandContext context, string input)
  215. {
  216. string searchInput = _caseSensitive ? input : input.ToLowerInvariant();
  217. var matches = _map.GetCommands(searchInput).OrderByDescending(x => x.Command.Overloads.Average(y => y.Priority)).ToImmutableArray();
  218. if (matches.Length > 0)
  219. return SearchResult.FromSuccess(input, matches);
  220. else
  221. return SearchResult.FromError(CommandError.UnknownCommand, "Unknown command.");
  222. }
  223. public Task<IResult> ExecuteAsync(ICommandContext context, int argPos, IServiceProvider services = null, MultiMatchHandling multiMatchHandling = MultiMatchHandling.Exception)
  224. => ExecuteAsync(context, context.Message.Content.Substring(argPos), services, multiMatchHandling);
  225. public async Task<IResult> ExecuteAsync(ICommandContext context, string input, IServiceProvider services = null, MultiMatchHandling multiMatchHandling = MultiMatchHandling.Exception)
  226. {
  227. services = services ?? EmptyServiceProvider.Instance;
  228. var searchResult = Search(context, input);
  229. if (!searchResult.IsSuccess)
  230. return searchResult;
  231. var commands = searchResult.Commands;
  232. for (int i = 0; i < commands.Count; i++)
  233. {
  234. var command = commands[i].Command;
  235. var overloads = command.Overloads.OrderByDescending(x => x.Priority).ToImmutableArray();
  236. var preconditionResult = PreconditionResult.FromSuccess();
  237. for (int j = 0; j < overloads.Length; j++)
  238. {
  239. preconditionResult = await overloads[j].CheckPreconditionsAsync(context, services).ConfigureAwait(false);
  240. if (!preconditionResult.IsSuccess)
  241. {
  242. if (i == commands.Count && j == overloads.Length)
  243. return preconditionResult;
  244. else
  245. continue;
  246. }
  247. }
  248. var rawParseResults = new List<ParseResult>();
  249. foreach (var overload in overloads)
  250. {
  251. rawParseResults.Add(await overload.ParseAsync(context, services, searchResult, preconditionResult).ConfigureAwait(false));
  252. }
  253. //order by average score
  254. var orderedParseResults = rawParseResults.OrderByDescending(
  255. x => !x.IsSuccess ? 0 :
  256. (x.ArgValues.Count > 0 ? x.ArgValues.Average(y => y.Values.Max(z => z.Score)) : 0) +
  257. (x.ParamValues.Count > 0 ? x.ParamValues.Average(y => y.Values.Max(z => z.Score)) : 0));
  258. var parseResults = orderedParseResults.ToImmutableArray();
  259. for (int j = 0; j < parseResults.Length; j++)
  260. {
  261. var parseResult = parseResults[j];
  262. var overload = parseResult.Overload;
  263. if (!parseResult.IsSuccess)
  264. {
  265. if (parseResult.Error == CommandError.MultipleMatches)
  266. {
  267. IReadOnlyList<TypeReaderValue> argList, paramList;
  268. switch (multiMatchHandling)
  269. {
  270. case MultiMatchHandling.Best:
  271. argList = parseResult.ArgValues.Select(x => x.Values.OrderByDescending(y => y.Score).First()).ToImmutableArray();
  272. paramList = parseResult.ParamValues.Select(x => x.Values.OrderByDescending(y => y.Score).First()).ToImmutableArray();
  273. parseResult = ParseResult.FromSuccess(overload, argList, paramList);
  274. break;
  275. }
  276. }
  277. if (!parseResult.IsSuccess)
  278. {
  279. if (i == commands.Count && j == parseResults.Length)
  280. return parseResult;
  281. else
  282. continue;
  283. }
  284. }
  285. return await overload.ExecuteAsync(context, parseResult, services).ConfigureAwait(false);
  286. }
  287. }
  288. return SearchResult.FromError(CommandError.UnknownCommand, "This input does not match any overload.");
  289. }
  290. }
  291. }