using System; using System.Collections.Concurrent; using System.Collections.Generic; using System.Collections.Immutable; using System.Globalization; using System.Linq; using System.Reflection; using System.Threading; using System.Threading.Tasks; namespace Discord.Commands { public class CommandService { private readonly SemaphoreSlim _moduleLock; private readonly ConcurrentDictionary _modules; private readonly ConcurrentDictionary _typeReaders; private readonly CommandMap _map; public IEnumerable Modules => _modules.Select(x => x.Value); public IEnumerable Commands => _modules.SelectMany(x => x.Value.Commands); public CommandService() { _moduleLock = new SemaphoreSlim(1, 1); _modules = new ConcurrentDictionary(); _map = new CommandMap(); _typeReaders = new ConcurrentDictionary { [typeof(string)] = new GenericTypeReader((m, s) => Task.FromResult(TypeReaderResult.FromSuccess(s))), [typeof(byte)] = new GenericTypeReader((m, s) => { byte value; if (byte.TryParse(s, out value)) return Task.FromResult(TypeReaderResult.FromSuccess(value)); return Task.FromResult(TypeReaderResult.FromError(CommandError.ParseFailed, "Failed to parse Byte")); }), [typeof(sbyte)] = new GenericTypeReader((m, s) => { sbyte value; if (sbyte.TryParse(s, out value)) return Task.FromResult(TypeReaderResult.FromSuccess(value)); return Task.FromResult(TypeReaderResult.FromError(CommandError.ParseFailed, "Failed to parse SByte")); }), [typeof(ushort)] = new GenericTypeReader((m, s) => { ushort value; if (ushort.TryParse(s, out value)) return Task.FromResult(TypeReaderResult.FromSuccess(value)); return Task.FromResult(TypeReaderResult.FromError(CommandError.ParseFailed, "Failed to parse UInt16")); }), [typeof(short)] = new GenericTypeReader((m, s) => { short value; if (short.TryParse(s, out value)) return Task.FromResult(TypeReaderResult.FromSuccess(value)); return Task.FromResult(TypeReaderResult.FromError(CommandError.ParseFailed, "Failed to parse Int16")); }), [typeof(uint)] = new GenericTypeReader((m, s) => { uint value; if (uint.TryParse(s, out value)) return Task.FromResult(TypeReaderResult.FromSuccess(value)); return Task.FromResult(TypeReaderResult.FromError(CommandError.ParseFailed, "Failed to parse UInt32")); }), [typeof(int)] = new GenericTypeReader((m, s) => { int value; if (int.TryParse(s, out value)) return Task.FromResult(TypeReaderResult.FromSuccess(value)); return Task.FromResult(TypeReaderResult.FromError(CommandError.ParseFailed, "Failed to parse Int32")); }), [typeof(ulong)] = new GenericTypeReader((m, s) => { ulong value; if (ulong.TryParse(s, out value)) return Task.FromResult(TypeReaderResult.FromSuccess(value)); return Task.FromResult(TypeReaderResult.FromError(CommandError.ParseFailed, "Failed to parse UInt64")); }), [typeof(long)] = new GenericTypeReader((m, s) => { long value; if (long.TryParse(s, out value)) return Task.FromResult(TypeReaderResult.FromSuccess(value)); return Task.FromResult(TypeReaderResult.FromError(CommandError.ParseFailed, "Failed to parse Int64")); }), [typeof(float)] = new GenericTypeReader((m, s) => { float value; if (float.TryParse(s, out value)) return Task.FromResult(TypeReaderResult.FromSuccess(value)); return Task.FromResult(TypeReaderResult.FromError(CommandError.ParseFailed, "Failed to parse Single")); }), [typeof(double)] = new GenericTypeReader((m, s) => { double value; if (double.TryParse(s, out value)) return Task.FromResult(TypeReaderResult.FromSuccess(value)); return Task.FromResult(TypeReaderResult.FromError(CommandError.ParseFailed, "Failed to parse Double")); }), [typeof(decimal)] = new GenericTypeReader((m, s) => { decimal value; if (decimal.TryParse(s, out value)) return Task.FromResult(TypeReaderResult.FromSuccess(value)); return Task.FromResult(TypeReaderResult.FromError(CommandError.ParseFailed, "Failed to parse Decimal")); }), [typeof(DateTime)] = new GenericTypeReader((m, s) => { DateTime value; if (DateTime.TryParse(s, out value)) return Task.FromResult(TypeReaderResult.FromSuccess(value)); return Task.FromResult(TypeReaderResult.FromError(CommandError.ParseFailed, "Failed to parse DateTime")); }), [typeof(DateTimeOffset)] = new GenericTypeReader((m, s) => { DateTimeOffset value; if (DateTimeOffset.TryParse(s, out value)) return Task.FromResult(TypeReaderResult.FromSuccess(value)); return Task.FromResult(TypeReaderResult.FromError(CommandError.ParseFailed, "Failed to parse DateTimeOffset")); }), [typeof(IMessage)] = new MessageTypeReader(), [typeof(IChannel)] = new ChannelTypeReader(), [typeof(IGuildChannel)] = new ChannelTypeReader(), [typeof(ITextChannel)] = new ChannelTypeReader(), [typeof(IVoiceChannel)] = new ChannelTypeReader(), [typeof(IRole)] = new RoleTypeReader(), [typeof(IUser)] = new UserTypeReader(), [typeof(IGuildUser)] = new UserTypeReader() }; } public void AddTypeReader(TypeReader reader) { _typeReaders[typeof(T)] = reader; } public void AddTypeReader(Type type, TypeReader reader) { _typeReaders[type] = reader; } internal TypeReader GetTypeReader(Type type) { TypeReader reader; if (_typeReaders.TryGetValue(type, out reader)) return reader; return null; } public async Task Load(object moduleInstance) { await _moduleLock.WaitAsync().ConfigureAwait(false); try { if (_modules.ContainsKey(moduleInstance)) throw new ArgumentException($"This module has already been loaded."); var typeInfo = moduleInstance.GetType().GetTypeInfo(); var moduleAttr = typeInfo.GetCustomAttribute(); if (moduleAttr != null) throw new ArgumentException($"Modules must be marked with ModuleAttribute."); return LoadInternal(moduleInstance, moduleAttr, typeInfo); } finally { _moduleLock.Release(); } } private Module LoadInternal(object moduleInstance, ModuleAttribute moduleAttr, TypeInfo typeInfo) { var loadedModule = new Module(this, moduleInstance, moduleAttr, typeInfo); _modules[moduleInstance] = loadedModule; foreach (var cmd in loadedModule.Commands) _map.AddCommand(cmd); return loadedModule; } public async Task> LoadAssembly(Assembly assembly) { var modules = ImmutableArray.CreateBuilder(); await _moduleLock.WaitAsync().ConfigureAwait(false); try { foreach (var type in assembly.ExportedTypes) { var typeInfo = type.GetTypeInfo(); var moduleAttr = typeInfo.GetCustomAttribute(); if (moduleAttr != null) { var moduleInstance = ReflectionUtils.CreateObject(typeInfo); modules.Add(LoadInternal(moduleInstance, moduleAttr, typeInfo)); } } return modules.ToImmutable(); } finally { _moduleLock.Release(); } } public async Task Unload(Module module) { await _moduleLock.WaitAsync().ConfigureAwait(false); try { return UnloadInternal(module.Instance); } finally { _moduleLock.Release(); } } public async Task Unload(object moduleInstance) { await _moduleLock.WaitAsync().ConfigureAwait(false); try { return UnloadInternal(moduleInstance); } finally { _moduleLock.Release(); } } private bool UnloadInternal(object module) { Module unloadedModule; if (_modules.TryRemove(module, out unloadedModule)) { foreach (var cmd in unloadedModule.Commands) _map.RemoveCommand(cmd); return true; } else return false; } public SearchResult Search(IMessage message, int argPos) => Search(message, message.RawText.Substring(argPos)); public SearchResult Search(IMessage message, string input) { string lowerInput = input.ToLowerInvariant(); var matches = _map.GetCommands(input).ToImmutableArray(); if (matches.Length > 0) return SearchResult.FromSuccess(input, matches); else return SearchResult.FromError(CommandError.UnknownCommand, "Unknown command."); } public Task Execute(IMessage message, int argPos) => Execute(message, message.RawText.Substring(argPos)); public async Task Execute(IMessage message, string input) { var searchResult = Search(message, input); if (!searchResult.IsSuccess) return searchResult; var commands = searchResult.Commands; for (int i = commands.Count - 1; i >= 0; i--) { var parseResult = await commands[i].Parse(message, searchResult); if (!parseResult.IsSuccess) continue; var executeResult = await commands[i].Execute(message, parseResult); return executeResult; } return ParseResult.FromError(CommandError.ParseFailed, "This input does not match any overload."); } } }