You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

CommandService.cs 17 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369
  1. using Discord.Commands.Builders;
  2. using Discord.Logging;
  3. using System;
  4. using System.Collections.Concurrent;
  5. using System.Collections.Generic;
  6. using System.Collections.Immutable;
  7. using System.Linq;
  8. using System.Reflection;
  9. using System.Threading;
  10. using System.Threading.Tasks;
  11. namespace Discord.Commands
  12. {
  13. public class CommandService
  14. {
  15. public event Func<LogMessage, Task> Log { add { _logEvent.Add(value); } remove { _logEvent.Remove(value); } }
  16. internal readonly AsyncEvent<Func<LogMessage, Task>> _logEvent = new AsyncEvent<Func<LogMessage, Task>>();
  17. public event Func<CommandInfo, ICommandContext, IResult, Task> CommandExecuted { add { _commandExecutedEvent.Add(value); } remove { _commandExecutedEvent.Remove(value); } }
  18. internal readonly AsyncEvent<Func<CommandInfo, ICommandContext, IResult, Task>> _commandExecutedEvent = new AsyncEvent<Func<CommandInfo, ICommandContext, IResult, Task>>();
  19. private readonly SemaphoreSlim _moduleLock;
  20. private readonly ConcurrentDictionary<Type, ModuleInfo> _typedModuleDefs;
  21. private readonly ConcurrentDictionary<Type, ConcurrentDictionary<Type, TypeReader>> _typeReaders;
  22. private readonly ConcurrentDictionary<Type, TypeReader> _defaultTypeReaders;
  23. private readonly ImmutableList<Tuple<Type, Type>> _entityTypeReaders; //TODO: Candidate for C#7 Tuple
  24. private readonly HashSet<ModuleInfo> _moduleDefs;
  25. private readonly CommandMap _map;
  26. internal readonly bool _caseSensitive, _throwOnError, _ignoreExtraArgs;
  27. internal readonly char _separatorChar;
  28. internal readonly RunMode _defaultRunMode;
  29. internal readonly Logger _cmdLogger;
  30. internal readonly LogManager _logManager;
  31. internal readonly IReadOnlyDictionary<char, char> _quotationMarkAliasMap;
  32. public IEnumerable<ModuleInfo> Modules => _moduleDefs.Select(x => x);
  33. public IEnumerable<CommandInfo> Commands => _moduleDefs.SelectMany(x => x.Commands);
  34. public ILookup<Type, TypeReader> TypeReaders => _typeReaders.SelectMany(x => x.Value.Select(y => new { y.Key, y.Value })).ToLookup(x => x.Key, x => x.Value);
  35. public CommandService() : this(new CommandServiceConfig()) { }
  36. public CommandService(CommandServiceConfig config)
  37. {
  38. _caseSensitive = config.CaseSensitiveCommands;
  39. _throwOnError = config.ThrowOnError;
  40. _ignoreExtraArgs = config.IgnoreExtraArgs;
  41. _separatorChar = config.SeparatorChar;
  42. _defaultRunMode = config.DefaultRunMode;
  43. _quotationMarkAliasMap = config.QuotationMarkAliasMap?.ToImmutableDictionary();
  44. if (_defaultRunMode == RunMode.Default)
  45. throw new InvalidOperationException("The default run mode cannot be set to Default.");
  46. _logManager = new LogManager(config.LogLevel);
  47. _logManager.Message += async msg => await _logEvent.InvokeAsync(msg).ConfigureAwait(false);
  48. _cmdLogger = _logManager.CreateLogger("Command");
  49. _moduleLock = new SemaphoreSlim(1, 1);
  50. _typedModuleDefs = new ConcurrentDictionary<Type, ModuleInfo>();
  51. _moduleDefs = new HashSet<ModuleInfo>();
  52. _map = new CommandMap(this);
  53. _typeReaders = new ConcurrentDictionary<Type, ConcurrentDictionary<Type, TypeReader>>();
  54. _defaultTypeReaders = new ConcurrentDictionary<Type, TypeReader>();
  55. foreach (var type in PrimitiveParsers.SupportedTypes)
  56. {
  57. _defaultTypeReaders[type] = PrimitiveTypeReader.Create(type);
  58. _defaultTypeReaders[typeof(Nullable<>).MakeGenericType(type)] = NullableTypeReader.Create(type, _defaultTypeReaders[type]);
  59. }
  60. _defaultTypeReaders[typeof(string)] =
  61. new PrimitiveTypeReader<string>((string x, out string y) => { y = x; return true; }, 0);
  62. var entityTypeReaders = ImmutableList.CreateBuilder<Tuple<Type, Type>>();
  63. entityTypeReaders.Add(new Tuple<Type, Type>(typeof(IMessage), typeof(MessageTypeReader<>)));
  64. entityTypeReaders.Add(new Tuple<Type, Type>(typeof(IChannel), typeof(ChannelTypeReader<>)));
  65. entityTypeReaders.Add(new Tuple<Type, Type>(typeof(IRole), typeof(RoleTypeReader<>)));
  66. entityTypeReaders.Add(new Tuple<Type, Type>(typeof(IUser), typeof(UserTypeReader<>)));
  67. _entityTypeReaders = entityTypeReaders.ToImmutable();
  68. }
  69. //Modules
  70. public async Task<ModuleInfo> CreateModuleAsync(string primaryAlias, Action<ModuleBuilder> buildFunc)
  71. {
  72. await _moduleLock.WaitAsync().ConfigureAwait(false);
  73. try
  74. {
  75. var builder = new ModuleBuilder(this, null, primaryAlias);
  76. buildFunc(builder);
  77. var module = builder.Build(this);
  78. return LoadModuleInternal(module);
  79. }
  80. finally
  81. {
  82. _moduleLock.Release();
  83. }
  84. }
  85. public Task<ModuleInfo> AddModuleAsync<T>() => AddModuleAsync(typeof(T));
  86. public async Task<ModuleInfo> AddModuleAsync(Type type)
  87. {
  88. await _moduleLock.WaitAsync().ConfigureAwait(false);
  89. try
  90. {
  91. var typeInfo = type.GetTypeInfo();
  92. if (_typedModuleDefs.ContainsKey(type))
  93. throw new ArgumentException($"This module has already been added.");
  94. var module = (await ModuleClassBuilder.BuildAsync(this, typeInfo).ConfigureAwait(false)).FirstOrDefault();
  95. if (module.Value == default(ModuleInfo))
  96. throw new InvalidOperationException($"Could not build the module {type.FullName}, did you pass an invalid type?");
  97. _typedModuleDefs[module.Key] = module.Value;
  98. return LoadModuleInternal(module.Value);
  99. }
  100. finally
  101. {
  102. _moduleLock.Release();
  103. }
  104. }
  105. public async Task<IEnumerable<ModuleInfo>> AddModulesAsync(Assembly assembly)
  106. {
  107. await _moduleLock.WaitAsync().ConfigureAwait(false);
  108. try
  109. {
  110. var types = await ModuleClassBuilder.SearchAsync(assembly, this).ConfigureAwait(false);
  111. var moduleDefs = await ModuleClassBuilder.BuildAsync(types, this).ConfigureAwait(false);
  112. foreach (var info in moduleDefs)
  113. {
  114. _typedModuleDefs[info.Key] = info.Value;
  115. LoadModuleInternal(info.Value);
  116. }
  117. return moduleDefs.Select(x => x.Value).ToImmutableArray();
  118. }
  119. finally
  120. {
  121. _moduleLock.Release();
  122. }
  123. }
  124. private ModuleInfo LoadModuleInternal(ModuleInfo module)
  125. {
  126. _moduleDefs.Add(module);
  127. foreach (var command in module.Commands)
  128. _map.AddCommand(command);
  129. foreach (var submodule in module.Submodules)
  130. LoadModuleInternal(submodule);
  131. return module;
  132. }
  133. public async Task<bool> RemoveModuleAsync(ModuleInfo module)
  134. {
  135. await _moduleLock.WaitAsync().ConfigureAwait(false);
  136. try
  137. {
  138. return RemoveModuleInternal(module);
  139. }
  140. finally
  141. {
  142. _moduleLock.Release();
  143. }
  144. }
  145. public Task<bool> RemoveModuleAsync<T>() => RemoveModuleAsync(typeof(T));
  146. public async Task<bool> RemoveModuleAsync(Type type)
  147. {
  148. await _moduleLock.WaitAsync().ConfigureAwait(false);
  149. try
  150. {
  151. if (!_typedModuleDefs.TryRemove(type, out var module))
  152. return false;
  153. return RemoveModuleInternal(module);
  154. }
  155. finally
  156. {
  157. _moduleLock.Release();
  158. }
  159. }
  160. private bool RemoveModuleInternal(ModuleInfo module)
  161. {
  162. if (!_moduleDefs.Remove(module))
  163. return false;
  164. foreach (var cmd in module.Commands)
  165. _map.RemoveCommand(cmd);
  166. foreach (var submodule in module.Submodules)
  167. {
  168. RemoveModuleInternal(submodule);
  169. }
  170. return true;
  171. }
  172. //Type Readers
  173. /// <summary>
  174. /// Adds a custom <see cref="TypeReader"/> to this <see cref="CommandService"/> for the supplied object type.
  175. /// If <typeparamref name="T"/> is a <see cref="ValueType"/>, a <see cref="NullableTypeReader{T}"/> will also be added.
  176. /// </summary>
  177. /// <typeparam name="T">The object type to be read by the <see cref="TypeReader"/>.</typeparam>
  178. /// <param name="reader">An instance of the <see cref="TypeReader"/> to be added.</param>
  179. public void AddTypeReader<T>(TypeReader reader)
  180. => AddTypeReader(typeof(T), reader);
  181. /// <summary>
  182. /// Adds a custom <see cref="TypeReader"/> to this <see cref="CommandService"/> for the supplied object type.
  183. /// If <paramref name="type"/> is a <see cref="ValueType"/>, a <see cref="NullableTypeReader{T}"/> for the value type will also be added.
  184. /// </summary>
  185. /// <param name="type">A <see cref="Type"/> instance for the type to be read.</param>
  186. /// <param name="reader">An instance of the <see cref="TypeReader"/> to be added.</param>
  187. public void AddTypeReader(Type type, TypeReader reader)
  188. {
  189. var readers = _typeReaders.GetOrAdd(type, x => new ConcurrentDictionary<Type, TypeReader>());
  190. readers[reader.GetType()] = reader;
  191. if (type.GetTypeInfo().IsValueType)
  192. AddNullableTypeReader(type, reader);
  193. }
  194. internal void AddNullableTypeReader(Type valueType, TypeReader valueTypeReader)
  195. {
  196. var readers = _typeReaders.GetOrAdd(typeof(Nullable<>).MakeGenericType(valueType), x => new ConcurrentDictionary<Type, TypeReader>());
  197. var nullableReader = NullableTypeReader.Create(valueType, valueTypeReader);
  198. readers[nullableReader.GetType()] = nullableReader;
  199. }
  200. internal IDictionary<Type, TypeReader> GetTypeReaders(Type type)
  201. {
  202. if (_typeReaders.TryGetValue(type, out var definedTypeReaders))
  203. return definedTypeReaders;
  204. return null;
  205. }
  206. internal TypeReader GetDefaultTypeReader(Type type)
  207. {
  208. if (_defaultTypeReaders.TryGetValue(type, out var reader))
  209. return reader;
  210. var typeInfo = type.GetTypeInfo();
  211. //Is this an enum?
  212. if (typeInfo.IsEnum)
  213. {
  214. reader = EnumTypeReader.GetReader(type);
  215. _defaultTypeReaders[type] = reader;
  216. return reader;
  217. }
  218. //Is this an entity?
  219. for (int i = 0; i < _entityTypeReaders.Count; i++)
  220. {
  221. if (type == _entityTypeReaders[i].Item1 || typeInfo.ImplementedInterfaces.Contains(_entityTypeReaders[i].Item1))
  222. {
  223. reader = Activator.CreateInstance(_entityTypeReaders[i].Item2.MakeGenericType(type)) as TypeReader;
  224. _defaultTypeReaders[type] = reader;
  225. return reader;
  226. }
  227. }
  228. return null;
  229. }
  230. //Execution
  231. public SearchResult Search(ICommandContext context, int argPos)
  232. => Search(context, context.Message.Content.Substring(argPos));
  233. public SearchResult Search(ICommandContext context, string input)
  234. {
  235. string searchInput = _caseSensitive ? input : input.ToLowerInvariant();
  236. var matches = _map.GetCommands(searchInput).OrderByDescending(x => x.Command.Priority).ToImmutableArray();
  237. if (matches.Length > 0)
  238. return SearchResult.FromSuccess(input, matches);
  239. else
  240. return SearchResult.FromError(CommandError.UnknownCommand, "Unknown command.");
  241. }
  242. public Task<IResult> ExecuteAsync(ICommandContext context, int argPos, IServiceProvider services = null, MultiMatchHandling multiMatchHandling = MultiMatchHandling.Exception)
  243. => ExecuteAsync(context, context.Message.Content.Substring(argPos), services, multiMatchHandling);
  244. public async Task<IResult> ExecuteAsync(ICommandContext context, string input, IServiceProvider services = null, MultiMatchHandling multiMatchHandling = MultiMatchHandling.Exception)
  245. {
  246. services = services ?? EmptyServiceProvider.Instance;
  247. var searchResult = Search(context, input);
  248. if (!searchResult.IsSuccess)
  249. return searchResult;
  250. var commands = searchResult.Commands;
  251. var preconditionResults = new Dictionary<CommandMatch, PreconditionResult>();
  252. foreach (var match in commands)
  253. {
  254. preconditionResults[match] = await match.Command.CheckPreconditionsAsync(context, services).ConfigureAwait(false);
  255. }
  256. var successfulPreconditions = preconditionResults
  257. .Where(x => x.Value.IsSuccess)
  258. .ToArray();
  259. if (successfulPreconditions.Length == 0)
  260. {
  261. //All preconditions failed, return the one from the highest priority command
  262. var bestCandidate = preconditionResults
  263. .OrderByDescending(x => x.Key.Command.Priority)
  264. .FirstOrDefault(x => !x.Value.IsSuccess);
  265. return bestCandidate.Value;
  266. }
  267. //If we get this far, at least one precondition was successful.
  268. var parseResultsDict = new Dictionary<CommandMatch, ParseResult>();
  269. foreach (var pair in successfulPreconditions)
  270. {
  271. var parseResult = await pair.Key.ParseAsync(context, searchResult, pair.Value, services).ConfigureAwait(false);
  272. if (parseResult.Error == CommandError.MultipleMatches)
  273. {
  274. IReadOnlyList<TypeReaderValue> argList, paramList;
  275. switch (multiMatchHandling)
  276. {
  277. case MultiMatchHandling.Best:
  278. argList = parseResult.ArgValues.Select(x => x.Values.OrderByDescending(y => y.Score).First()).ToImmutableArray();
  279. paramList = parseResult.ParamValues.Select(x => x.Values.OrderByDescending(y => y.Score).First()).ToImmutableArray();
  280. parseResult = ParseResult.FromSuccess(argList, paramList);
  281. break;
  282. }
  283. }
  284. parseResultsDict[pair.Key] = parseResult;
  285. }
  286. // Calculates the 'score' of a command given a parse result
  287. float CalculateScore(CommandMatch match, ParseResult parseResult)
  288. {
  289. float argValuesScore = 0, paramValuesScore = 0;
  290. if (match.Command.Parameters.Count > 0)
  291. {
  292. var argValuesSum = parseResult.ArgValues?.Sum(x => x.Values.OrderByDescending(y => y.Score).FirstOrDefault().Score) ?? 0;
  293. var paramValuesSum = parseResult.ParamValues?.Sum(x => x.Values.OrderByDescending(y => y.Score).FirstOrDefault().Score) ?? 0;
  294. argValuesScore = argValuesSum / match.Command.Parameters.Count;
  295. paramValuesScore = paramValuesSum / match.Command.Parameters.Count;
  296. }
  297. var totalArgsScore = (argValuesScore + paramValuesScore) / 2;
  298. return match.Command.Priority + totalArgsScore * 0.99f;
  299. }
  300. //Order the parse results by their score so that we choose the most likely result to execute
  301. var parseResults = parseResultsDict
  302. .OrderByDescending(x => CalculateScore(x.Key, x.Value));
  303. var successfulParses = parseResults
  304. .Where(x => x.Value.IsSuccess)
  305. .ToArray();
  306. if (successfulParses.Length == 0)
  307. {
  308. //All parses failed, return the one from the highest priority command, using score as a tie breaker
  309. var bestMatch = parseResults
  310. .FirstOrDefault(x => !x.Value.IsSuccess);
  311. return bestMatch.Value;
  312. }
  313. //If we get this far, at least one parse was successful. Execute the most likely overload.
  314. var chosenOverload = successfulParses[0];
  315. return await chosenOverload.Key.ExecuteAsync(context, chosenOverload.Value, services).ConfigureAwait(false);
  316. }
  317. }
  318. }