You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

CommandService.cs 12 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260
  1. using System;
  2. using System.Collections.Concurrent;
  3. using System.Collections.Generic;
  4. using System.Collections.Immutable;
  5. using System.Globalization;
  6. using System.Linq;
  7. using System.Reflection;
  8. using System.Threading;
  9. using System.Threading.Tasks;
  10. namespace Discord.Commands
  11. {
  12. public class CommandService
  13. {
  14. private readonly SemaphoreSlim _moduleLock;
  15. private readonly ConcurrentDictionary<object, Module> _modules;
  16. private readonly ConcurrentDictionary<Type, TypeReader> _typeReaders;
  17. private readonly CommandMap _map;
  18. public IEnumerable<Module> Modules => _modules.Select(x => x.Value);
  19. public IEnumerable<Command> Commands => _modules.SelectMany(x => x.Value.Commands);
  20. public CommandService()
  21. {
  22. _moduleLock = new SemaphoreSlim(1, 1);
  23. _modules = new ConcurrentDictionary<object, Module>();
  24. _map = new CommandMap();
  25. _typeReaders = new ConcurrentDictionary<Type, TypeReader>
  26. {
  27. [typeof(string)] = new GenericTypeReader((m, s) => Task.FromResult(TypeReaderResult.FromSuccess(s))),
  28. [typeof(byte)] = new GenericTypeReader((m, s) =>
  29. {
  30. byte value;
  31. if (byte.TryParse(s, out value)) return Task.FromResult(TypeReaderResult.FromSuccess(value));
  32. return Task.FromResult(TypeReaderResult.FromError(CommandError.ParseFailed, "Failed to parse Byte"));
  33. }),
  34. [typeof(sbyte)] = new GenericTypeReader((m, s) =>
  35. {
  36. sbyte value;
  37. if (sbyte.TryParse(s, out value)) return Task.FromResult(TypeReaderResult.FromSuccess(value));
  38. return Task.FromResult(TypeReaderResult.FromError(CommandError.ParseFailed, "Failed to parse SByte"));
  39. }),
  40. [typeof(ushort)] = new GenericTypeReader((m, s) =>
  41. {
  42. ushort value;
  43. if (ushort.TryParse(s, out value)) return Task.FromResult(TypeReaderResult.FromSuccess(value));
  44. return Task.FromResult(TypeReaderResult.FromError(CommandError.ParseFailed, "Failed to parse UInt16"));
  45. }),
  46. [typeof(short)] = new GenericTypeReader((m, s) =>
  47. {
  48. short value;
  49. if (short.TryParse(s, out value)) return Task.FromResult(TypeReaderResult.FromSuccess(value));
  50. return Task.FromResult(TypeReaderResult.FromError(CommandError.ParseFailed, "Failed to parse Int16"));
  51. }),
  52. [typeof(uint)] = new GenericTypeReader((m, s) =>
  53. {
  54. uint value;
  55. if (uint.TryParse(s, out value)) return Task.FromResult(TypeReaderResult.FromSuccess(value));
  56. return Task.FromResult(TypeReaderResult.FromError(CommandError.ParseFailed, "Failed to parse UInt32"));
  57. }),
  58. [typeof(int)] = new GenericTypeReader((m, s) =>
  59. {
  60. int value;
  61. if (int.TryParse(s, out value)) return Task.FromResult(TypeReaderResult.FromSuccess(value));
  62. return Task.FromResult(TypeReaderResult.FromError(CommandError.ParseFailed, "Failed to parse Int32"));
  63. }),
  64. [typeof(ulong)] = new GenericTypeReader((m, s) =>
  65. {
  66. ulong value;
  67. if (ulong.TryParse(s, out value)) return Task.FromResult(TypeReaderResult.FromSuccess(value));
  68. return Task.FromResult(TypeReaderResult.FromError(CommandError.ParseFailed, "Failed to parse UInt64"));
  69. }),
  70. [typeof(long)] = new GenericTypeReader((m, s) =>
  71. {
  72. long value;
  73. if (long.TryParse(s, out value)) return Task.FromResult(TypeReaderResult.FromSuccess(value));
  74. return Task.FromResult(TypeReaderResult.FromError(CommandError.ParseFailed, "Failed to parse Int64"));
  75. }),
  76. [typeof(float)] = new GenericTypeReader((m, s) =>
  77. {
  78. float value;
  79. if (float.TryParse(s, out value)) return Task.FromResult(TypeReaderResult.FromSuccess(value));
  80. return Task.FromResult(TypeReaderResult.FromError(CommandError.ParseFailed, "Failed to parse Single"));
  81. }),
  82. [typeof(double)] = new GenericTypeReader((m, s) =>
  83. {
  84. double value;
  85. if (double.TryParse(s, out value)) return Task.FromResult(TypeReaderResult.FromSuccess(value));
  86. return Task.FromResult(TypeReaderResult.FromError(CommandError.ParseFailed, "Failed to parse Double"));
  87. }),
  88. [typeof(decimal)] = new GenericTypeReader((m, s) =>
  89. {
  90. decimal value;
  91. if (decimal.TryParse(s, out value)) return Task.FromResult(TypeReaderResult.FromSuccess(value));
  92. return Task.FromResult(TypeReaderResult.FromError(CommandError.ParseFailed, "Failed to parse Decimal"));
  93. }),
  94. [typeof(DateTime)] = new GenericTypeReader((m, s) =>
  95. {
  96. DateTime value;
  97. if (DateTime.TryParse(s, out value)) return Task.FromResult(TypeReaderResult.FromSuccess(value));
  98. return Task.FromResult(TypeReaderResult.FromError(CommandError.ParseFailed, "Failed to parse DateTime"));
  99. }),
  100. [typeof(DateTimeOffset)] = new GenericTypeReader((m, s) =>
  101. {
  102. DateTimeOffset value;
  103. if (DateTimeOffset.TryParse(s, out value)) return Task.FromResult(TypeReaderResult.FromSuccess(value));
  104. return Task.FromResult(TypeReaderResult.FromError(CommandError.ParseFailed, "Failed to parse DateTimeOffset"));
  105. }),
  106. [typeof(IMessage)] = new MessageTypeReader(),
  107. [typeof(IChannel)] = new ChannelTypeReader<IChannel>(),
  108. [typeof(IGuildChannel)] = new ChannelTypeReader<IGuildChannel>(),
  109. [typeof(ITextChannel)] = new ChannelTypeReader<ITextChannel>(),
  110. [typeof(IVoiceChannel)] = new ChannelTypeReader<IVoiceChannel>(),
  111. [typeof(IRole)] = new RoleTypeReader(),
  112. [typeof(IUser)] = new UserTypeReader<IUser>(),
  113. [typeof(IGuildUser)] = new UserTypeReader<IGuildUser>()
  114. };
  115. }
  116. public void AddTypeReader<T>(TypeReader reader)
  117. {
  118. _typeReaders[typeof(T)] = reader;
  119. }
  120. public void AddTypeReader(Type type, TypeReader reader)
  121. {
  122. _typeReaders[type] = reader;
  123. }
  124. internal TypeReader GetTypeReader(Type type)
  125. {
  126. TypeReader reader;
  127. if (_typeReaders.TryGetValue(type, out reader))
  128. return reader;
  129. return null;
  130. }
  131. public async Task<Module> Load(object moduleInstance)
  132. {
  133. await _moduleLock.WaitAsync().ConfigureAwait(false);
  134. try
  135. {
  136. if (_modules.ContainsKey(moduleInstance))
  137. throw new ArgumentException($"This module has already been loaded.");
  138. var typeInfo = moduleInstance.GetType().GetTypeInfo();
  139. var moduleAttr = typeInfo.GetCustomAttribute<ModuleAttribute>();
  140. if (moduleAttr != null)
  141. throw new ArgumentException($"Modules must be marked with ModuleAttribute.");
  142. return LoadInternal(moduleInstance, moduleAttr, typeInfo);
  143. }
  144. finally
  145. {
  146. _moduleLock.Release();
  147. }
  148. }
  149. private Module LoadInternal(object moduleInstance, ModuleAttribute moduleAttr, TypeInfo typeInfo)
  150. {
  151. var loadedModule = new Module(this, moduleInstance, moduleAttr, typeInfo);
  152. _modules[moduleInstance] = loadedModule;
  153. foreach (var cmd in loadedModule.Commands)
  154. _map.AddCommand(cmd);
  155. return loadedModule;
  156. }
  157. public async Task<IEnumerable<Module>> LoadAssembly(Assembly assembly)
  158. {
  159. var modules = ImmutableArray.CreateBuilder<Module>();
  160. await _moduleLock.WaitAsync().ConfigureAwait(false);
  161. try
  162. {
  163. foreach (var type in assembly.ExportedTypes)
  164. {
  165. var typeInfo = type.GetTypeInfo();
  166. var moduleAttr = typeInfo.GetCustomAttribute<ModuleAttribute>();
  167. if (moduleAttr != null)
  168. {
  169. var moduleInstance = ReflectionUtils.CreateObject(typeInfo);
  170. modules.Add(LoadInternal(moduleInstance, moduleAttr, typeInfo));
  171. }
  172. }
  173. return modules.ToImmutable();
  174. }
  175. finally
  176. {
  177. _moduleLock.Release();
  178. }
  179. }
  180. public async Task<bool> Unload(Module module)
  181. {
  182. await _moduleLock.WaitAsync().ConfigureAwait(false);
  183. try
  184. {
  185. return UnloadInternal(module.Instance);
  186. }
  187. finally
  188. {
  189. _moduleLock.Release();
  190. }
  191. }
  192. public async Task<bool> Unload(object moduleInstance)
  193. {
  194. await _moduleLock.WaitAsync().ConfigureAwait(false);
  195. try
  196. {
  197. return UnloadInternal(moduleInstance);
  198. }
  199. finally
  200. {
  201. _moduleLock.Release();
  202. }
  203. }
  204. private bool UnloadInternal(object module)
  205. {
  206. Module unloadedModule;
  207. if (_modules.TryRemove(module, out unloadedModule))
  208. {
  209. foreach (var cmd in unloadedModule.Commands)
  210. _map.RemoveCommand(cmd);
  211. return true;
  212. }
  213. else
  214. return false;
  215. }
  216. public SearchResult Search(IMessage message, int argPos) => Search(message, message.RawText.Substring(argPos));
  217. public SearchResult Search(IMessage message, string input)
  218. {
  219. string lowerInput = input.ToLowerInvariant();
  220. var matches = _map.GetCommands(input).ToImmutableArray();
  221. if (matches.Length > 0)
  222. return SearchResult.FromSuccess(input, matches);
  223. else
  224. return SearchResult.FromError(CommandError.UnknownCommand, "Unknown command.");
  225. }
  226. public Task<IResult> Execute(IMessage message, int argPos) => Execute(message, message.RawText.Substring(argPos));
  227. public async Task<IResult> Execute(IMessage message, string input)
  228. {
  229. var searchResult = Search(message, input);
  230. if (!searchResult.IsSuccess)
  231. return searchResult;
  232. var commands = searchResult.Commands;
  233. for (int i = commands.Count - 1; i >= 0; i--)
  234. {
  235. var parseResult = await commands[i].Parse(message, searchResult);
  236. if (!parseResult.IsSuccess)
  237. continue;
  238. var executeResult = await commands[i].Execute(message, parseResult);
  239. return executeResult;
  240. }
  241. return ParseResult.FromError(CommandError.ParseFailed, "This input does not match any overload.");
  242. }
  243. }
  244. }