From 3343b1b5d07e119840f1649fd9c85f16bb2d5182 Mon Sep 17 00:00:00 2001 From: FiniteReality Date: Sat, 24 Dec 2016 18:59:04 +0000 Subject: [PATCH] Add command suggestions for incorrect commands Example Usage: ```cs var result = await service.ExecuteAsync(context, 0, maxDifferences: 3); if (result is SearchResult) { var sResult = (SearchResult)result; var commands = sResult.Commands.Select(x => x.Alias); await message.Channel.SendMessageAsync( $"Invalid command - Did you mean:\n{string.Join("\n", commands)}"); } ``` --- src/Discord.Net.Commands/CommandService.cs | 18 ++-- src/Discord.Net.Commands/Map/CommandMap.cs | 7 +- .../Map/CommandMapNode.cs | 93 +++++++++++++++++++ .../Results/SearchResult.cs | 4 +- 4 files changed, 111 insertions(+), 11 deletions(-) diff --git a/src/Discord.Net.Commands/CommandService.cs b/src/Discord.Net.Commands/CommandService.cs index 0d27bd178..3202c6b60 100644 --- a/src/Discord.Net.Commands/CommandService.cs +++ b/src/Discord.Net.Commands/CommandService.cs @@ -222,26 +222,28 @@ namespace Discord.Commands } //Execution - public SearchResult Search(ICommandContext context, int argPos) - => Search(context, context.Message.Content.Substring(argPos)); - public SearchResult Search(ICommandContext context, string input) + public SearchResult Search(ICommandContext context, int argPos, int maxDifferences = 5) + => Search(context, context.Message.Content.Substring(argPos), maxDifferences); + public SearchResult Search(ICommandContext context, string input, int maxDifferences = 5) { string searchInput = _caseSensitive ? input : input.ToLowerInvariant(); var matches = _map.GetCommands(searchInput).OrderByDescending(x => x.Command.Priority).ToImmutableArray(); - + if (matches.Length > 0) return SearchResult.FromSuccess(input, matches); + else if (maxDifferences > 0) + return SearchResult.FromError(CommandError.UnknownCommand, "Unknown command.", _map.GetPartialMatches(searchInput, maxDifferences).ToImmutableArray()); else return SearchResult.FromError(CommandError.UnknownCommand, "Unknown command."); } - public Task ExecuteAsync(ICommandContext context, int argPos, IDependencyMap dependencyMap = null, MultiMatchHandling multiMatchHandling = MultiMatchHandling.Exception) - => ExecuteAsync(context, context.Message.Content.Substring(argPos), dependencyMap, multiMatchHandling); - public async Task ExecuteAsync(ICommandContext context, string input, IDependencyMap dependencyMap = null, MultiMatchHandling multiMatchHandling = MultiMatchHandling.Exception) + public Task ExecuteAsync(ICommandContext context, int argPos, IDependencyMap dependencyMap = null, MultiMatchHandling multiMatchHandling = MultiMatchHandling.Exception, int maxDifferences = 5) + => ExecuteAsync(context, context.Message.Content.Substring(argPos), dependencyMap, multiMatchHandling, maxDifferences); + public async Task ExecuteAsync(ICommandContext context, string input, IDependencyMap dependencyMap = null, MultiMatchHandling multiMatchHandling = MultiMatchHandling.Exception, int maxDifferences = 5) { dependencyMap = dependencyMap ?? DependencyMap.Empty; - var searchResult = Search(context, input); + var searchResult = Search(context, input, maxDifferences); if (!searchResult.IsSuccess) return searchResult; diff --git a/src/Discord.Net.Commands/Map/CommandMap.cs b/src/Discord.Net.Commands/Map/CommandMap.cs index bcff800d3..ac8a23b17 100644 --- a/src/Discord.Net.Commands/Map/CommandMap.cs +++ b/src/Discord.Net.Commands/Map/CommandMap.cs @@ -1,4 +1,4 @@ -using System.Collections.Generic; +using System.Collections.Generic; namespace Discord.Commands { @@ -29,5 +29,10 @@ namespace Discord.Commands { return _root.GetCommands(_service, text, 0, text != ""); } + + public IEnumerable GetPartialMatches(string text, int maxDifferences) + { + return _root.GetPartialMatches(_service, text, maxDifferences, 0, text != ""); + } } } diff --git a/src/Discord.Net.Commands/Map/CommandMapNode.cs b/src/Discord.Net.Commands/Map/CommandMapNode.cs index 863409207..467fc9b4c 100644 --- a/src/Discord.Net.Commands/Map/CommandMapNode.cs +++ b/src/Discord.Net.Commands/Map/CommandMapNode.cs @@ -49,6 +49,7 @@ namespace Discord.Commands } } } + public void RemoveCommand(CommandService service, string text, int index, CommandInfo command) { int nextSegment = NextSegment(text, index, service._separatorChar); @@ -113,6 +114,52 @@ namespace Discord.Commands } } + internal IEnumerable GetPartialMatches(CommandService service, string text, int maxDifference, int index, bool visitChildren = true) + { + var commands = _commands; + for (int i = 0; i < commands.Length; i++) + yield return new CommandMatch(_commands[i], _name); + + if (visitChildren) + { + string name; + CommandMapNode nextNode; + + //Search for next segment + int nextSegment = NextSegment(text, index, service._separatorChar); + if (nextSegment == -1) + name = text.Substring(index); + else + name = text.Substring(index, nextSegment - index); + + foreach (var key in _nodes.Keys) + { + if (LevenshteinDistance(name, key) < maxDifference) + { + if (_nodes.TryGetValue(key, out nextNode)) + foreach (var cmd in nextNode.GetPartialMatches(service, nextSegment == -1 ? "" : text, maxDifference, nextSegment + 1, true)) + yield return cmd; + } + } + + //Check if this is the last command segment before args + nextSegment = NextSegment(text, index, _whitespaceChars, service._separatorChar); + if (nextSegment != -1) + { + name = text.Substring(index, nextSegment - index); + foreach (var key in _nodes.Keys) + { + if (LevenshteinDistance(name, key) < maxDifference) + { + if (_nodes.TryGetValue(key, out nextNode)) + foreach (var cmd in nextNode.GetPartialMatches(service, nextSegment == -1 ? "" : text, maxDifference, nextSegment + 1, false)) + yield return cmd; + } + } + } + } + } + private static int NextSegment(string text, int startIndex, char separator) { return text.IndexOf(separator, startIndex); @@ -131,5 +178,51 @@ namespace Discord.Commands } return (lowest != int.MaxValue) ? lowest : -1; } + + private static int LevenshteinDistance(string source, string target) + { + var sourceLength = source.Length; + var targetLength = target.Length; + + if (sourceLength == 0) + return targetLength; + if (targetLength == 0) + return sourceLength; + + var matrix = new int[sourceLength + 1, targetLength + 1]; + for (int row = 0; row <= sourceLength; matrix[row, 0] = row++) + { } + for (int col = 0; col <= targetLength; matrix[0, col] = col++) + { } + + for (int i = 1; i <= sourceLength; i++) + { + char sourceChr = source[i - 1]; + for (int j = 1; j <= targetLength; j++) + { + char targetChr = target[j - 1]; + + int cost = sourceChr == targetChr ? 0 : 1; + + int above = matrix[i - 1, j] + 1; + int left = matrix[i, j - 1] + 1; + int diagonal = matrix[i - 1, j - 1] + cost; + + int minimum = int.MaxValue; + + if (above < left) + minimum = above; + else + minimum = left; + + if (diagonal < minimum) + minimum = diagonal; + + matrix[i, j] = minimum; + } + } + + return matrix[sourceLength, targetLength]; + } } } diff --git a/src/Discord.Net.Commands/Results/SearchResult.cs b/src/Discord.Net.Commands/Results/SearchResult.cs index 87d900d4d..aabb52b69 100644 --- a/src/Discord.Net.Commands/Results/SearchResult.cs +++ b/src/Discord.Net.Commands/Results/SearchResult.cs @@ -24,8 +24,8 @@ namespace Discord.Commands public static SearchResult FromSuccess(string text, IReadOnlyList commands) => new SearchResult(text, commands, null, null); - public static SearchResult FromError(CommandError error, string reason) - => new SearchResult(null, null, error, reason); + public static SearchResult FromError(CommandError error, string reason, IReadOnlyList suggestions = null) + => new SearchResult(null, suggestions, error, reason); public static SearchResult FromError(IResult result) => new SearchResult(null, null, result.Error, result.ErrorReason);