diff --git a/src/Discord.Net.Commands/CommandMatch.cs b/src/Discord.Net.Commands/CommandMatch.cs new file mode 100644 index 000000000..2e13e0f46 --- /dev/null +++ b/src/Discord.Net.Commands/CommandMatch.cs @@ -0,0 +1,26 @@ +using System.Collections.Generic; +using System.Threading.Tasks; + +namespace Discord.Commands +{ + public struct CommandMatch + { + public CommandInfo Command { get; } + public string Alias { get; } + + public CommandMatch(CommandInfo command, string alias) + { + Command = command; + Alias = alias; + } + + public Task CheckPreconditionsAsync(CommandContext context, IDependencyMap map = null) + => Command.CheckPreconditionsAsync(context, map); + public Task ParseAsync(CommandContext context, SearchResult searchResult, PreconditionResult? preconditionResult = null) + => Command.ParseAsync(context, Alias.Length, searchResult, preconditionResult); + public Task ExecuteAsync(CommandContext context, IEnumerable argList, IEnumerable paramList, IDependencyMap map) + => Command.ExecuteAsync(context, argList, paramList, map); + public Task ExecuteAsync(CommandContext context, ParseResult parseResult, IDependencyMap map) + => Command.ExecuteAsync(context, parseResult, map); + } +} diff --git a/src/Discord.Net.Commands/CommandService.cs b/src/Discord.Net.Commands/CommandService.cs index 010a0ee8a..7152b87da 100644 --- a/src/Discord.Net.Commands/CommandService.cs +++ b/src/Discord.Net.Commands/CommandService.cs @@ -21,6 +21,7 @@ namespace Discord.Commands private readonly CommandMap _map; internal readonly bool _caseSensitive; + internal readonly char _separatorChar; internal readonly RunMode _defaultRunMode; public IEnumerable Modules => _moduleDefs.Select(x => x); @@ -30,10 +31,14 @@ namespace Discord.Commands public CommandService() : this(new CommandServiceConfig()) { } public CommandService(CommandServiceConfig config) { + _caseSensitive = config.CaseSensitiveCommands; + _separatorChar = config.SeparatorChar; + _defaultRunMode = config.DefaultRunMode; + _moduleLock = new SemaphoreSlim(1, 1); _typedModuleDefs = new ConcurrentDictionary(); _moduleDefs = new ConcurrentBag(); - _map = new CommandMap(); + _map = new CommandMap(this); _typeReaders = new ConcurrentDictionary>(); _defaultTypeReaders = new ConcurrentDictionary @@ -57,9 +62,6 @@ namespace Discord.Commands }; foreach (var type in PrimitiveParsers.SupportedTypes) _defaultTypeReaders[type] = PrimitiveTypeReader.Create(type); - - _caseSensitive = config.CaseSensitiveCommands; - _defaultRunMode = config.DefaultRunMode; } //Modules @@ -214,7 +216,7 @@ namespace Discord.Commands public SearchResult Search(CommandContext context, string input) { string searchInput = _caseSensitive ? input : input.ToLowerInvariant(); - var matches = _map.GetCommands(searchInput).OrderByDescending(x => x.Priority).ToImmutableArray(); + var matches = _map.GetCommands(searchInput).OrderByDescending(x => x.Command.Priority).ToImmutableArray(); if (matches.Length > 0) return SearchResult.FromSuccess(input, matches); @@ -269,7 +271,7 @@ namespace Discord.Commands } } - return await commands[i].Execute(context, parseResult, dependencyMap).ConfigureAwait(false); + return await commands[i].ExecuteAsync(context, parseResult, dependencyMap).ConfigureAwait(false); } return SearchResult.FromError(CommandError.UnknownCommand, "This input does not match any overload."); diff --git a/src/Discord.Net.Commands/CommandServiceConfig.cs b/src/Discord.Net.Commands/CommandServiceConfig.cs index 8377d4e60..037e315c7 100644 --- a/src/Discord.Net.Commands/CommandServiceConfig.cs +++ b/src/Discord.Net.Commands/CommandServiceConfig.cs @@ -4,6 +4,8 @@ { /// The default RunMode commands should have, if one is not specified on the Command attribute or builder. public RunMode DefaultRunMode { get; set; } = RunMode.Sync; + + public char SeparatorChar { get; set; } = ' '; /// Should commands be case-sensitive? public bool CaseSensitiveCommands { get; set; } = false; } diff --git a/src/Discord.Net.Commands/Info/CommandInfo.cs b/src/Discord.Net.Commands/Info/CommandInfo.cs index c24d35f6a..8920b25da 100644 --- a/src/Discord.Net.Commands/Info/CommandInfo.cs +++ b/src/Discord.Net.Commands/Info/CommandInfo.cs @@ -44,7 +44,12 @@ namespace Discord.Commands // both command and module provide aliases if (module.Aliases.Count > 0 && builder.Aliases.Count > 0) - Aliases = module.Aliases.Permutate(builder.Aliases, (first, second) => second != null ? first + " " + second : first).Select(x => service._caseSensitive ? x : x.ToLowerInvariant()).ToImmutableArray(); + { + Aliases = module.Aliases + .Permutate(builder.Aliases, (first, second) => second != null ? first + service._separatorChar + second : first) + .Select(x => service._caseSensitive ? x : x.ToLowerInvariant()) + .ToImmutableArray(); + } // only module provides aliases else if (module.Aliases.Count > 0) Aliases = module.Aliases.Select(x => service._caseSensitive ? x : x.ToLowerInvariant()).ToImmutableArray(); @@ -84,33 +89,19 @@ namespace Discord.Commands return PreconditionResult.FromSuccess(); } - - public async Task ParseAsync(CommandContext context, SearchResult searchResult, PreconditionResult? preconditionResult = null) + + public async Task ParseAsync(CommandContext context, int startIndex, SearchResult searchResult, PreconditionResult? preconditionResult = null) { if (!searchResult.IsSuccess) return ParseResult.FromError(searchResult); if (preconditionResult != null && !preconditionResult.Value.IsSuccess) return ParseResult.FromError(preconditionResult.Value); - string input = searchResult.Text; - var matchingAliases = Aliases.Where(alias => input.StartsWith(alias)).ToArray(); - - string matchingAlias = null; - foreach (string alias in matchingAliases) - { - if (alias.Length > matchingAlias.Length) - matchingAlias = alias; - } - - if (matchingAlias == null) - return ParseResult.FromError(CommandError.ParseFailed, "Unable to find matching alias"); - - input = input.Substring(matchingAlias.Length); - + string input = searchResult.Text.Substring(startIndex); return await CommandParser.ParseArgs(this, context, input, 0).ConfigureAwait(false); } - public Task Execute(CommandContext context, ParseResult parseResult, IDependencyMap map) + public Task ExecuteAsync(CommandContext context, ParseResult parseResult, IDependencyMap map) { if (!parseResult.IsSuccess) return Task.FromResult(ExecuteResult.FromError(parseResult)); diff --git a/src/Discord.Net.Commands/Map/CommandMap.cs b/src/Discord.Net.Commands/Map/CommandMap.cs index 3a5239878..bcff800d3 100644 --- a/src/Discord.Net.Commands/Map/CommandMap.cs +++ b/src/Discord.Net.Commands/Map/CommandMap.cs @@ -4,36 +4,30 @@ namespace Discord.Commands { internal class CommandMap { + private readonly CommandService _service; private readonly CommandMapNode _root; private static readonly string[] _blankAliases = new[] { "" }; - public CommandMap() + public CommandMap(CommandService service) { + _service = service; _root = new CommandMapNode(""); } public void AddCommand(CommandInfo command) { - foreach (string text in GetAliases(command)) - _root.AddCommand(text, 0, command); + foreach (string text in command.Aliases) + _root.AddCommand(_service, text, 0, command); } public void RemoveCommand(CommandInfo command) { - foreach (string text in GetAliases(command)) - _root.RemoveCommand(text, 0, command); + foreach (string text in command.Aliases) + _root.RemoveCommand(_service, text, 0, command); } - public IEnumerable GetCommands(string text) + public IEnumerable GetCommands(string text) { - return _root.GetCommands(text, 0); - } - - private IReadOnlyList GetAliases(CommandInfo command) - { - var aliases = command.Aliases; - if (aliases.Count == 0) - return _blankAliases; - return aliases; + return _root.GetCommands(_service, text, 0, text != ""); } } } diff --git a/src/Discord.Net.Commands/Map/CommandMapNode.cs b/src/Discord.Net.Commands/Map/CommandMapNode.cs index a86c0643d..863409207 100644 --- a/src/Discord.Net.Commands/Map/CommandMapNode.cs +++ b/src/Discord.Net.Commands/Map/CommandMapNode.cs @@ -7,7 +7,7 @@ namespace Discord.Commands { internal class CommandMapNode { - private static readonly char[] _whitespaceChars = new char[] { ' ', '\r', '\n' }; + private static readonly char[] _whitespaceChars = new[] { ' ', '\r', '\n' }; private readonly ConcurrentDictionary _nodes; private readonly string _name; @@ -23,9 +23,9 @@ namespace Discord.Commands _commands = ImmutableArray.Create(); } - public void AddCommand(string text, int index, CommandInfo command) + public void AddCommand(CommandService service, string text, int index, CommandInfo command) { - int nextSpace = NextWhitespace(text, index); + int nextSegment = NextSegment(text, index, service._separatorChar); string name; lock (_lockObj) @@ -38,19 +38,20 @@ namespace Discord.Commands } else { - if (nextSpace == -1) + if (nextSegment == -1) name = text.Substring(index); else - name = text.Substring(index, nextSpace - index); + name = text.Substring(index, nextSegment - index); - var nextNode = _nodes.GetOrAdd(name, x => new CommandMapNode(x)); - nextNode.AddCommand(nextSpace == -1 ? "" : text, nextSpace + 1, command); + string fullName = _name == "" ? name : _name + service._separatorChar + name; + var nextNode = _nodes.GetOrAdd(name, x => new CommandMapNode(fullName)); + nextNode.AddCommand(service, nextSegment == -1 ? "" : text, nextSegment + 1, command); } } } - public void RemoveCommand(string text, int index, CommandInfo command) + public void RemoveCommand(CommandService service, string text, int index, CommandInfo command) { - int nextSpace = NextWhitespace(text, index); + int nextSegment = NextSegment(text, index, service._separatorChar); string name; lock (_lockObj) @@ -59,15 +60,15 @@ namespace Discord.Commands _commands = _commands.Remove(command); else { - if (nextSpace == -1) + if (nextSegment == -1) name = text.Substring(index); else - name = text.Substring(index, nextSpace - index); + name = text.Substring(index, nextSegment - index); CommandMapNode nextNode; if (_nodes.TryGetValue(name, out nextNode)) { - nextNode.RemoveCommand(nextSpace == -1 ? "" : text, nextSpace + 1, command); + nextNode.RemoveCommand(service, nextSegment == -1 ? "" : text, nextSegment + 1, command); if (nextNode.IsEmpty) _nodes.TryRemove(name, out nextNode); } @@ -75,39 +76,58 @@ namespace Discord.Commands } } - public IEnumerable GetCommands(string text, int index) + public IEnumerable GetCommands(CommandService service, string text, int index, bool visitChildren = true) { - int nextSpace = NextWhitespace(text, index); - string name; - var commands = _commands; for (int i = 0; i < commands.Length; i++) - yield return _commands[i]; + yield return new CommandMatch(_commands[i], _name); - if (text != "") + if (visitChildren) { - if (nextSpace == -1) + string name; + CommandMapNode nextNode; + + //Search for next segment + int nextSegment = NextSegment(text, index, service._separatorChar); + if (nextSegment == -1) name = text.Substring(index); else - name = text.Substring(index, nextSpace - index); - - CommandMapNode nextNode; + name = text.Substring(index, nextSegment - index); if (_nodes.TryGetValue(name, out nextNode)) { - foreach (var cmd in nextNode.GetCommands(nextSpace == -1 ? "" : text, nextSpace + 1)) + foreach (var cmd in nextNode.GetCommands(service, nextSegment == -1 ? "" : text, nextSegment + 1, true)) yield return cmd; } + + //Check if this is the last command segment before args + nextSegment = NextSegment(text, index, _whitespaceChars, service._separatorChar); + if (nextSegment != -1) + { + name = text.Substring(index, nextSegment - index); + if (_nodes.TryGetValue(name, out nextNode)) + { + foreach (var cmd in nextNode.GetCommands(service, nextSegment == -1 ? "" : text, nextSegment + 1, false)) + yield return cmd; + } + } } } - private static int NextWhitespace(string text, int startIndex) + private static int NextSegment(string text, int startIndex, char separator) + { + return text.IndexOf(separator, startIndex); + } + private static int NextSegment(string text, int startIndex, char[] separators, char except) { int lowest = int.MaxValue; - for (int i = 0; i < _whitespaceChars.Length; i++) + for (int i = 0; i < separators.Length; i++) { - int index = text.IndexOf(_whitespaceChars[i], startIndex); - if (index != -1 && index < lowest) - lowest = index; + if (separators[i] != except) + { + int index = text.IndexOf(separators[i], startIndex); + if (index != -1 && index < lowest) + lowest = index; + } } return (lowest != int.MaxValue) ? lowest : -1; } diff --git a/src/Discord.Net.Commands/Results/SearchResult.cs b/src/Discord.Net.Commands/Results/SearchResult.cs index 17942b61a..87d900d4d 100644 --- a/src/Discord.Net.Commands/Results/SearchResult.cs +++ b/src/Discord.Net.Commands/Results/SearchResult.cs @@ -7,14 +7,14 @@ namespace Discord.Commands public struct SearchResult : IResult { public string Text { get; } - public IReadOnlyList Commands { get; } + public IReadOnlyList Commands { get; } public CommandError? Error { get; } public string ErrorReason { get; } public bool IsSuccess => !Error.HasValue; - private SearchResult(string text, IReadOnlyList commands, CommandError? error, string errorReason) + private SearchResult(string text, IReadOnlyList commands, CommandError? error, string errorReason) { Text = text; Commands = commands; @@ -22,7 +22,7 @@ namespace Discord.Commands ErrorReason = errorReason; } - public static SearchResult FromSuccess(string text, IReadOnlyList commands) + public static SearchResult FromSuccess(string text, IReadOnlyList commands) => new SearchResult(text, commands, null, null); public static SearchResult FromError(CommandError error, string reason) => new SearchResult(null, null, error, reason);