diff --git a/src/Discord.Net.Commands/Command.cs b/src/Discord.Net.Commands/Command.cs index f46fafb27..cf35afd32 100644 --- a/src/Discord.Net.Commands/Command.cs +++ b/src/Discord.Net.Commands/Command.cs @@ -1,7 +1,9 @@ using System; +using System.Collections.Concurrent; using System.Collections.Generic; using System.Collections.Immutable; using System.Diagnostics; +using System.Linq; using System.Reflection; using System.Threading.Tasks; @@ -10,6 +12,9 @@ namespace Discord.Commands [DebuggerDisplay(@"{DebuggerDisplay,nq}")] public class Command { + private static readonly MethodInfo _convertParamsMethod = typeof(Command).GetTypeInfo().GetDeclaredMethod(nameof(ConvertParamsList)); + private static readonly ConcurrentDictionary, object>> _arrayConverters = new ConcurrentDictionary, object>>(); + private readonly object _instance; private readonly Func, Task> _action; @@ -19,6 +24,7 @@ namespace Discord.Commands public string Description { get; } public string Summary { get; } public string Text { get; } + public bool HasVarArgs { get; } public IReadOnlyList Parameters { get; } public IReadOnlyList Preconditions { get; } @@ -42,8 +48,9 @@ namespace Discord.Commands var summary = source.GetCustomAttribute(); if (summary != null) Summary = summary.Text; - + Parameters = BuildParameters(source); + HasVarArgs = Parameters.Count > 0 ? Parameters[Parameters.Count - 1].IsMultiple : false; Preconditions = BuildPreconditions(source); _action = BuildAction(source); } @@ -76,14 +83,38 @@ namespace Discord.Commands return await CommandParser.ParseArgs(this, msg, searchResult.Text.Substring(Text.Length), 0).ConfigureAwait(false); } - public async Task Execute(IMessage msg, ParseResult parseResult) + public Task Execute(IMessage msg, ParseResult parseResult) { if (!parseResult.IsSuccess) - return ExecuteResult.FromError(parseResult); + return Task.FromResult(ExecuteResult.FromError(parseResult)); + var argList = new object[parseResult.ArgValues.Count]; + for (int i = 0; i < parseResult.ArgValues.Count; i++) + { + if (!parseResult.ArgValues[i].IsSuccess) + return Task.FromResult(ExecuteResult.FromError(parseResult.ArgValues[i])); + argList[i] = parseResult.ArgValues[i].Values.First().Value; + } + + object[] paramList = null; + if (parseResult.ParamValues != null) + { + paramList = new object[parseResult.ParamValues.Count]; + for (int i = 0; i < parseResult.ParamValues.Count; i++) + { + if (!parseResult.ParamValues[i].IsSuccess) + return Task.FromResult(ExecuteResult.FromError(parseResult.ParamValues[i])); + paramList[i] = parseResult.ParamValues[i].Values.First().Value; + } + } + + return Execute(msg, argList, paramList); + } + public async Task Execute(IMessage msg, IEnumerable argList, IEnumerable paramList) + { try { - await _action.Invoke(msg, parseResult.Values);//Note: This code may need context + await _action.Invoke(msg, GenerateArgs(argList, paramList)).ConfigureAwait(false);//Note: This code may need context return ExecuteResult.FromSuccess(); } catch (Exception ex) @@ -108,7 +139,7 @@ namespace Discord.Commands { var parameter = parameters[i]; var type = parameter.ParameterType; - + //Detect 'params' bool isMultiple = parameter.GetCustomAttribute() != null; if (isMultiple) @@ -156,6 +187,39 @@ namespace Discord.Commands }; } + private object[] GenerateArgs(IEnumerable argList, IEnumerable paramsList) + { + int argCount = Parameters.Count; + var array = new object[Parameters.Count]; + if (HasVarArgs) + argCount--; + + int i = 0; + foreach (var arg in argList) + { + if (i == argCount) + throw new InvalidOperationException("Command was invoked with too many parameters"); + array[i++] = arg; + } + if (i < argCount) + throw new InvalidOperationException("Command was invoked with too few parameters"); + + if (HasVarArgs) + { + var func = _arrayConverters.GetOrAdd(Parameters[Parameters.Count - 1].ElementType, t => + { + var method = _convertParamsMethod.MakeGenericMethod(t); + return (Func, object>)method.CreateDelegate(typeof(Func, object>)); + }); + array[i] = func(paramsList); + } + + return array; + } + + private static T[] ConvertParamsList(IEnumerable paramsList) + => paramsList.Cast().ToArray(); + public override string ToString() => Name; private string DebuggerDisplay => $"{Module.Name}.{Name} ({Text})"; } diff --git a/src/Discord.Net.Commands/CommandError.cs b/src/Discord.Net.Commands/CommandError.cs index 1046e7f4f..41b4822ad 100644 --- a/src/Discord.Net.Commands/CommandError.cs +++ b/src/Discord.Net.Commands/CommandError.cs @@ -3,14 +3,14 @@ public enum CommandError { //Search - UnknownCommand, + UnknownCommand = 1, //Parse ParseFailed, BadArgCount, //Parse (Type Reader) - CastFailed, + //CastFailed, ObjectNotFound, MultipleMatches, diff --git a/src/Discord.Net.Commands/CommandParameter.cs b/src/Discord.Net.Commands/CommandParameter.cs index 860ab8190..1e358e8bf 100644 --- a/src/Discord.Net.Commands/CommandParameter.cs +++ b/src/Discord.Net.Commands/CommandParameter.cs @@ -17,7 +17,7 @@ namespace Discord.Commands public bool IsRemainder { get; } public bool IsMultiple { get; } public Type ElementType { get; } - internal object DefaultValue { get; } + public object DefaultValue { get; } public CommandParameter(ParameterInfo source, string name, string summary, Type type, TypeReader reader, bool isOptional, bool isRemainder, bool isMultiple, object defaultValue) { diff --git a/src/Discord.Net.Commands/CommandParser.cs b/src/Discord.Net.Commands/CommandParser.cs index 1a4368a31..30b72647e 100644 --- a/src/Discord.Net.Commands/CommandParser.cs +++ b/src/Discord.Net.Commands/CommandParser.cs @@ -1,8 +1,5 @@ -using System; -using System.Collections.Concurrent; -using System.Collections.Generic; + using System.Collections.Immutable; -using System.Reflection; using System.Text; using System.Threading.Tasks; @@ -16,9 +13,6 @@ namespace Discord.Commands Parameter, QuotedParameter } - - private static readonly MethodInfo _convertArrayMethod = typeof(CommandParser).GetTypeInfo().GetDeclaredMethod(nameof(ConvertParamsList)); - private static readonly ConcurrentDictionary, object>> _arrayConverters = new ConcurrentDictionary, object>>(); public static async Task ParseArgs(Command command, IMessage context, string input, int startPos) { @@ -27,9 +21,10 @@ namespace Discord.Commands int endPos = input.Length; var curPart = ParserPart.None; int lastArgEndPos = int.MinValue; - var argList = ImmutableArray.CreateBuilder(); - List paramsList = null; // TODO: could we use a better type? + var argList = ImmutableArray.CreateBuilder(); + ImmutableArray.Builder paramList = null; bool isEscaping = false; + bool hasMultipleMatches = false; char c; for (int curPos = startPos; curPos <= endPos; curPos++) @@ -117,30 +112,28 @@ namespace Discord.Commands var typeReaderResult = await curParam.Parse(context, argString).ConfigureAwait(false); if (!typeReaderResult.IsSuccess) - return ParseResult.FromError(typeReaderResult); + { + if (typeReaderResult.Error == CommandError.MultipleMatches) + hasMultipleMatches = true; + else + return ParseResult.FromError(typeReaderResult); + } if (curParam.IsMultiple) { - if (paramsList == null) - paramsList = new List(); - paramsList.Add(typeReaderResult.Value); + if (paramList == null) + paramList = ImmutableArray.CreateBuilder(); + paramList.Add(typeReaderResult); if (curPos == endPos) { - var func = _arrayConverters.GetOrAdd(curParam.ElementType, t => - { - var method = _convertArrayMethod.MakeGenericMethod(t); - return (Func, object>)method.CreateDelegate(typeof(Func, object>)); - }); - argList.Add(func.Invoke(paramsList)); - curParam = null; curPart = ParserPart.None; } } else { - argList.Add(typeReaderResult.Value); + argList.Add(typeReaderResult); curParam = null; curPart = ParserPart.None; @@ -154,34 +147,24 @@ namespace Discord.Commands var typeReaderResult = await curParam.Parse(context, argBuilder.ToString()).ConfigureAwait(false); if (!typeReaderResult.IsSuccess) return ParseResult.FromError(typeReaderResult); - argList.Add(typeReaderResult.Value); + argList.Add(typeReaderResult); } if (isEscaping) return ParseResult.FromError(CommandError.ParseFailed, "Input text may not end on an incomplete escape."); if (curPart == ParserPart.QuotedParameter) return ParseResult.FromError(CommandError.ParseFailed, "A quoted parameter is incomplete"); - - if (argList.Count < command.Parameters.Count) + + //Add missing optionals + for (int i = paramList.Count; i < command.Parameters.Count; i++) { - for (int i = argList.Count; i < command.Parameters.Count; i++) - { - var param = command.Parameters[i]; - if (!param.IsOptional) - return ParseResult.FromError(CommandError.BadArgCount, "The input text has too few parameters."); - argList.Add(param.DefaultValue); - } + var param = command.Parameters[i]; + if (!param.IsOptional) + return ParseResult.FromError(CommandError.BadArgCount, "The input text has too few parameters."); + argList.Add(TypeReaderResult.FromSuccess(param.DefaultValue)); } - - return ParseResult.FromSuccess(argList.ToImmutable()); - } - - private static T[] ConvertParamsList(List paramsList) - { - var array = new T[paramsList.Count]; - for (int i = 0; i < array.Length; i++) - array[i] = (T)paramsList[i]; - return array; + + return ParseResult.FromSuccess(argList.ToImmutable(), paramList?.ToImmutable()); } } } diff --git a/src/Discord.Net.Commands/CommandService.cs b/src/Discord.Net.Commands/CommandService.cs index c2d99f21a..56b7386d8 100644 --- a/src/Discord.Net.Commands/CommandService.cs +++ b/src/Discord.Net.Commands/CommandService.cs @@ -40,16 +40,8 @@ namespace Discord.Commands [typeof(decimal)] = new SimpleTypeReader(), [typeof(DateTime)] = new SimpleTypeReader(), [typeof(DateTimeOffset)] = new SimpleTypeReader(), - - //TODO: Do we want to support any other interfaces? - - //[typeof(IMentionable)] = new GeneralTypeReader(), - //[typeof(ISnowflakeEntity)] = new GeneralTypeReader(), - //[typeof(IEntity)] = new GeneralTypeReader(), - + [typeof(IMessage)] = new MessageTypeReader(), - //[typeof(IAttachment)] = new xxx(), - //[typeof(IEmbed)] = new xxx(), [typeof(IChannel)] = new ChannelTypeReader(), [typeof(IDMChannel)] = new ChannelTypeReader(), @@ -61,10 +53,8 @@ namespace Discord.Commands [typeof(IVoiceChannel)] = new ChannelTypeReader(), //[typeof(IGuild)] = new GuildTypeReader(), - //[typeof(IUserGuild)] = new GuildTypeReader(), - //[typeof(IGuildIntegration)] = new xxx(), - [typeof(IRole)] = new RoleTypeReader(), + [typeof(IRole)] = new RoleTypeReader(), //[typeof(IInvite)] = new InviteTypeReader(), //[typeof(IInviteMetadata)] = new InviteTypeReader(), @@ -72,10 +62,6 @@ namespace Discord.Commands [typeof(IUser)] = new UserTypeReader(), [typeof(IGroupUser)] = new UserTypeReader(), [typeof(IGuildUser)] = new UserTypeReader(), - //[typeof(ISelfUser)] = new UserTypeReader(), - //[typeof(IPresence)] = new UserTypeReader(), - //[typeof(IVoiceState)] = new UserTypeReader(), - //[typeof(IConnection)] = new xxx(), }; } @@ -201,8 +187,9 @@ namespace Discord.Commands return SearchResult.FromError(CommandError.UnknownCommand, "Unknown command."); } - public Task Execute(IMessage message, int argPos) => Execute(message, message.Content.Substring(argPos)); - public async Task Execute(IMessage message, string input) + public Task Execute(IMessage message, int argPos, MultiMatchHandling multiMatchHandling = MultiMatchHandling.Exception) + => Execute(message, message.Content.Substring(argPos), multiMatchHandling); + public async Task Execute(IMessage message, string input, MultiMatchHandling multiMatchHandling = MultiMatchHandling.Exception) { var searchResult = Search(message, input); if (!searchResult.IsSuccess) @@ -223,14 +210,29 @@ namespace Discord.Commands var parseResult = await commands[i].Parse(message, searchResult, preconditionResult); if (!parseResult.IsSuccess) { - if (commands.Count == 1) - return parseResult; - else - continue; + if (parseResult.Error == CommandError.MultipleMatches) + { + TypeReaderValue[] argList, paramList; + switch (multiMatchHandling) + { + case MultiMatchHandling.Best: + argList = parseResult.ArgValues.Select(x => x.Values.OrderByDescending(y => y.Score).First()).ToArray(); + paramList = parseResult.ParamValues?.Select(x => x.Values.OrderByDescending(y => y.Score).First()).ToArray(); + parseResult = ParseResult.FromSuccess(argList, paramList); + break; + } + } + + if (!parseResult.IsSuccess) + { + if (commands.Count == 1) + return parseResult; + else + continue; + } } - var executeResult = await commands[i].Execute(message, parseResult); - return executeResult; + return await commands[i].Execute(message, parseResult); } return SearchResult.FromError(CommandError.UnknownCommand, "This input does not match any overload."); diff --git a/src/Discord.Net.Commands/MultiMatchHandling.cs b/src/Discord.Net.Commands/MultiMatchHandling.cs new file mode 100644 index 000000000..89dcf1c06 --- /dev/null +++ b/src/Discord.Net.Commands/MultiMatchHandling.cs @@ -0,0 +1,8 @@ +namespace Discord.Commands +{ + public enum MultiMatchHandling + { + Exception, + Best + } +} diff --git a/src/Discord.Net.Commands/Readers/ChannelTypeReader.cs b/src/Discord.Net.Commands/Readers/ChannelTypeReader.cs index 4a1350fee..4b9c32a61 100644 --- a/src/Discord.Net.Commands/Readers/ChannelTypeReader.cs +++ b/src/Discord.Net.Commands/Readers/ChannelTypeReader.cs @@ -1,4 +1,6 @@ using System; +using System.Collections.Generic; +using System.Globalization; using System.Linq; using System.Threading.Tasks; @@ -9,40 +11,37 @@ namespace Discord.Commands { public override async Task Read(IMessage context, string input) { - IGuildChannel guildChannel = context.Channel as IGuildChannel; - IChannel result = null; + var guild = (context.Channel as IGuildChannel)?.Guild; - if (guildChannel != null) + if (guild != null) { - //By Id + var results = new Dictionary(); + var channels = await guild.GetChannelsAsync().ConfigureAwait(false); ulong id; - if (MentionUtils.TryParseChannel(input, out id) || ulong.TryParse(input, out id)) - { - var channel = await guildChannel.Guild.GetChannelAsync(id).ConfigureAwait(false); - if (channel != null) - result = channel; - } - - //By Name - if (result == null) - { - var channels = await guildChannel.Guild.GetChannelsAsync().ConfigureAwait(false); - var filteredChannels = channels.Where(x => string.Equals(input, x.Name, StringComparison.OrdinalIgnoreCase)).ToArray(); - if (filteredChannels.Length > 1) - return TypeReaderResult.FromError(CommandError.MultipleMatches, "Multiple channels found."); - else if (filteredChannels.Length == 1) - result = filteredChannels[0]; - } + + //By Mention (1.0) + if (MentionUtils.TryParseChannel(input, out id)) + AddResult(results, await guild.GetUserAsync(id).ConfigureAwait(false) as T, 1.00f); + + //By Id (0.9) + if (ulong.TryParse(input, NumberStyles.None, CultureInfo.InvariantCulture, out id)) + AddResult(results, await guild.GetChannelAsync(id).ConfigureAwait(false) as T, 0.90f); + + //By Name (0.7-0.8) + foreach (var channel in channels.Where(x => string.Equals(input, x.Name, StringComparison.OrdinalIgnoreCase))) + AddResult(results, channel as T, channel.Name == input ? 0.80f : 0.70f); + + if (results.Count > 0) + return TypeReaderResult.FromSuccess(results.Values); } - if (result == null) - return TypeReaderResult.FromError(CommandError.ObjectNotFound, "Channel not found."); + return TypeReaderResult.FromError(CommandError.ObjectNotFound, "Channel not found."); + } - T castResult = result as T; - if (castResult == null) - return TypeReaderResult.FromError(CommandError.CastFailed, $"Channel is not a {typeof(T).Name}."); - else - return TypeReaderResult.FromSuccess(castResult); + private void AddResult(Dictionary results, T channel, float score) + { + if (channel != null && !results.ContainsKey(channel.Id)) + results.Add(channel.Id, new TypeReaderValue(channel, score)); } } } diff --git a/src/Discord.Net.Commands/Readers/EnumTypeReader.cs b/src/Discord.Net.Commands/Readers/EnumTypeReader.cs index 8e5313118..031b007e0 100644 --- a/src/Discord.Net.Commands/Readers/EnumTypeReader.cs +++ b/src/Discord.Net.Commands/Readers/EnumTypeReader.cs @@ -52,14 +52,14 @@ namespace Discord.Commands if (_enumsByValue.TryGetValue(baseValue, out enumValue)) return Task.FromResult(TypeReaderResult.FromSuccess(enumValue)); else - return Task.FromResult(TypeReaderResult.FromError(CommandError.CastFailed, $"Value is not a {_enumType.Name}")); + return Task.FromResult(TypeReaderResult.FromError(CommandError.ParseFailed, $"Value is not a {_enumType.Name}")); } else { if (_enumsByName.TryGetValue(input.ToLower(), out enumValue)) return Task.FromResult(TypeReaderResult.FromSuccess(enumValue)); else - return Task.FromResult(TypeReaderResult.FromError(CommandError.CastFailed, $"Value is not a {_enumType.Name}")); + return Task.FromResult(TypeReaderResult.FromError(CommandError.ParseFailed, $"Value is not a {_enumType.Name}")); } } } diff --git a/src/Discord.Net.Commands/Readers/MessageTypeReader.cs b/src/Discord.Net.Commands/Readers/MessageTypeReader.cs index 50ec7000a..4ec25fe56 100644 --- a/src/Discord.Net.Commands/Readers/MessageTypeReader.cs +++ b/src/Discord.Net.Commands/Readers/MessageTypeReader.cs @@ -7,18 +7,17 @@ namespace Discord.Commands { public override Task Read(IMessage context, string input) { - //By Id ulong id; + + //By Id (1.0) if (ulong.TryParse(input, NumberStyles.None, CultureInfo.InvariantCulture, out id)) { var msg = context.Channel.GetCachedMessage(id); - if (msg == null) - return Task.FromResult(TypeReaderResult.FromError(CommandError.ObjectNotFound, "Message not found.")); - else + if (msg != null) return Task.FromResult(TypeReaderResult.FromSuccess(msg)); } - return Task.FromResult(TypeReaderResult.FromError(CommandError.ParseFailed, "Failed to parse Message Id.")); + return Task.FromResult(TypeReaderResult.FromError(CommandError.ObjectNotFound, "Message not found.")); } } } diff --git a/src/Discord.Net.Commands/Readers/RoleTypeReader.cs b/src/Discord.Net.Commands/Readers/RoleTypeReader.cs index 10aee6b1c..5be7ddee0 100644 --- a/src/Discord.Net.Commands/Readers/RoleTypeReader.cs +++ b/src/Discord.Net.Commands/Readers/RoleTypeReader.cs @@ -1,36 +1,46 @@ using System; +using System.Collections.Generic; +using System.Globalization; using System.Linq; using System.Threading.Tasks; namespace Discord.Commands { - internal class RoleTypeReader : TypeReader + internal class RoleTypeReader : TypeReader + where T : class, IRole { public override Task Read(IMessage context, string input) { - IGuildChannel guildChannel = context.Channel as IGuildChannel; + var guild = (context.Channel as IGuildChannel)?.Guild; + ulong id; - if (guildChannel != null) + if (guild != null) { - //By Id - ulong id; - if (MentionUtils.TryParseRole(input, out id) || ulong.TryParse(input, out id)) - { - var channel = guildChannel.Guild.GetRole(id); - if (channel != null) - return Task.FromResult(TypeReaderResult.FromSuccess(channel)); - } + var results = new Dictionary(); + var roles = guild.Roles; - //By Name - var roles = guildChannel.Guild.Roles; - var filteredRoles = roles.Where(x => string.Equals(input, x.Name, StringComparison.OrdinalIgnoreCase)).ToArray(); - if (filteredRoles.Length > 1) - return Task.FromResult(TypeReaderResult.FromError(CommandError.MultipleMatches, "Multiple roles found.")); - else if (filteredRoles.Length == 1) - return Task.FromResult(TypeReaderResult.FromSuccess(filteredRoles[0])); + //By Mention (1.0) + if (MentionUtils.TryParseRole(input, out id)) + AddResult(results, guild.GetRole(id) as T, 1.00f); + + //By Id (0.9) + if (ulong.TryParse(input, NumberStyles.None, CultureInfo.InvariantCulture, out id)) + AddResult(results, guild.GetRole(id) as T, 0.90f); + + //By Name (0.7-0.8) + foreach (var role in roles.Where(x => string.Equals(input, x.Name, StringComparison.OrdinalIgnoreCase))) + AddResult(results, role as T, role.Name == input ? 0.80f : 0.70f); + + if (results.Count > 0) + return Task.FromResult(TypeReaderResult.FromSuccess(results)); } - return Task.FromResult(TypeReaderResult.FromError(CommandError.ObjectNotFound, "Role not found.")); } + + private void AddResult(Dictionary results, T role, float score) + { + if (role != null && !results.ContainsKey(role.Id)) + results.Add(role.Id, new TypeReaderValue(role, score)); + } } } diff --git a/src/Discord.Net.Commands/Readers/SimpleTypeReader.cs b/src/Discord.Net.Commands/Readers/SimpleTypeReader.cs index a3822084b..615ec5014 100644 --- a/src/Discord.Net.Commands/Readers/SimpleTypeReader.cs +++ b/src/Discord.Net.Commands/Readers/SimpleTypeReader.cs @@ -16,8 +16,7 @@ namespace Discord.Commands T value; if (_tryParse(input, out value)) return Task.FromResult(TypeReaderResult.FromSuccess(value)); - else - return Task.FromResult(TypeReaderResult.FromError(CommandError.ParseFailed, $"Failed to parse {typeof(T).Name}")); + return Task.FromResult(TypeReaderResult.FromError(CommandError.ParseFailed, $"Failed to parse {typeof(T).Name}")); } } } diff --git a/src/Discord.Net.Commands/Readers/UserTypeReader.cs b/src/Discord.Net.Commands/Readers/UserTypeReader.cs index e4bd9ffd1..0e881667d 100644 --- a/src/Discord.Net.Commands/Readers/UserTypeReader.cs +++ b/src/Discord.Net.Commands/Readers/UserTypeReader.cs @@ -1,4 +1,6 @@ using System; +using System.Collections.Generic; +using System.Globalization; using System.Linq; using System.Threading.Tasks; @@ -9,54 +11,78 @@ namespace Discord.Commands { public override async Task Read(IMessage context, string input) { - IUser result = null; - - //By Id + var results = new Dictionary(); + var guild = (context.Channel as IGuildChannel)?.Guild; + IReadOnlyCollection channelUsers = await context.Channel.GetUsersAsync().ConfigureAwait(false); + IReadOnlyCollection guildUsers = null; ulong id; - if (MentionUtils.TryParseUser(input, out id) || ulong.TryParse(input, out id)) + + if (guild != null) + guildUsers = await guild.GetUsersAsync().ConfigureAwait(false); + + //By Mention (1.0) + if (MentionUtils.TryParseUser(input, out id)) { - var user = await context.Channel.GetUserAsync(id).ConfigureAwait(false); - if (user != null) - result = user; + if (guild != null) + AddResult(results, await guild.GetUserAsync(id).ConfigureAwait(false) as T, 1.00f); + else + AddResult(results, await context.Channel.GetUserAsync(id).ConfigureAwait(false) as T, 1.00f); } - //By Username + Discriminator - if (result == null) + //By Id (0.9) + if (ulong.TryParse(input, NumberStyles.None, CultureInfo.InvariantCulture, out id)) { - int index = input.LastIndexOf('#'); - if (index >= 0) + if (guild != null) + AddResult(results, await guild.GetUserAsync(id).ConfigureAwait(false) as T, 0.90f); + else + AddResult(results, await context.Channel.GetUserAsync(id).ConfigureAwait(false) as T, 0.90f); + } + + //By Username + Discriminator (0.7-0.85) + int index = input.LastIndexOf('#'); + if (index >= 0) + { + string username = input.Substring(0, index); + ushort discriminator; + if (ushort.TryParse(input.Substring(index + 1), out discriminator)) { - string username = input.Substring(0, index); - ushort discriminator; - if (ushort.TryParse(input.Substring(index + 1), out discriminator)) - { - var users = await context.Channel.GetUsersAsync().ConfigureAwait(false); - result = users.Where(x => - x.DiscriminatorValue == discriminator && - string.Equals(username, x.Username, StringComparison.OrdinalIgnoreCase)).FirstOrDefault(); - } + var channelUser = channelUsers.Where(x => x.DiscriminatorValue == discriminator && + string.Equals(username, x.Username, StringComparison.OrdinalIgnoreCase)).FirstOrDefault(); + AddResult(results, channelUser as T, channelUser.Username == username ? 0.85f : 0.75f); + + var guildUser = channelUsers.Where(x => x.DiscriminatorValue == discriminator && + string.Equals(username, x.Username, StringComparison.OrdinalIgnoreCase)).FirstOrDefault(); + AddResult(results, guildUser as T, guildUser.Username == username ? 0.80f : 0.70f); } } - //By Username - if (result == null) + //By Username (0.5-0.6) { - var users = await context.Channel.GetUsersAsync().ConfigureAwait(false); - var filteredUsers = users.Where(x => string.Equals(input, x.Username, StringComparison.OrdinalIgnoreCase)).ToArray(); - if (filteredUsers.Length > 1) - return TypeReaderResult.FromError(CommandError.MultipleMatches, "Multiple users found."); - else if (filteredUsers.Length == 1) - result = filteredUsers[0]; + foreach (var channelUser in channelUsers.Where(x => string.Equals(input, x.Username, StringComparison.OrdinalIgnoreCase))) + AddResult(results, channelUser as T, channelUser.Username == input ? 0.65f : 0.55f); + + foreach (var guildUser in guildUsers.Where(x => string.Equals(input, x.Username, StringComparison.OrdinalIgnoreCase))) + AddResult(results, guildUser as T, guildUser.Username == input ? 0.60f : 0.50f); } - if (result == null) - return TypeReaderResult.FromError(CommandError.ObjectNotFound, "User not found."); + //By Nickname (0.5-0.6) + { + foreach (var channelUser in channelUsers.Where(x => string.Equals(input, (x as IGuildUser).Nickname, StringComparison.OrdinalIgnoreCase))) + AddResult(results, channelUser as T, (channelUser as IGuildUser).Nickname == input ? 0.65f : 0.55f); + + foreach (var guildUser in guildUsers.Where(x => string.Equals(input, (x as IGuildUser).Nickname, StringComparison.OrdinalIgnoreCase))) + AddResult(results, guildUser as T, (guildUser as IGuildUser).Nickname == input ? 0.60f : 0.50f); + } - T castResult = result as T; - if (castResult == null) - return TypeReaderResult.FromError(CommandError.CastFailed, $"User is not a {typeof(T).Name}."); - else - return TypeReaderResult.FromSuccess(castResult); + if (results.Count > 0) + return TypeReaderResult.FromSuccess(results.Values.ToArray()); + return TypeReaderResult.FromError(CommandError.ObjectNotFound, "User not found."); + } + + private void AddResult(Dictionary results, T user, float score) + { + if (user != null && !results.ContainsKey(user.Id)) + results.Add(user.Id, new TypeReaderValue(user, score)); } } } diff --git a/src/Discord.Net.Commands/Results/ParseResult.cs b/src/Discord.Net.Commands/Results/ParseResult.cs index 5c19083be..22fdc4259 100644 --- a/src/Discord.Net.Commands/Results/ParseResult.cs +++ b/src/Discord.Net.Commands/Results/ParseResult.cs @@ -6,28 +6,53 @@ namespace Discord.Commands [DebuggerDisplay(@"{DebuggerDisplay,nq}")] public struct ParseResult : IResult { - public IReadOnlyList Values { get; } + public IReadOnlyList ArgValues { get; } + public IReadOnlyList ParamValues { get; } public CommandError? Error { get; } public string ErrorReason { get; } public bool IsSuccess => !Error.HasValue; - private ParseResult(IReadOnlyList values, CommandError? error, string errorReason) + private ParseResult(IReadOnlyList argValues, IReadOnlyList paramValue, CommandError? error, string errorReason) { - Values = values; + ArgValues = argValues; + ParamValues = paramValue; Error = error; ErrorReason = errorReason; } - public static ParseResult FromSuccess(IReadOnlyList values) - => new ParseResult(values, null, null); + public static ParseResult FromSuccess(IReadOnlyList argValues, IReadOnlyList paramValues) + { + for (int i = 0; i < argValues.Count; i++) + { + if (argValues[i].Values.Count > 1) + return new ParseResult(argValues, paramValues, CommandError.MultipleMatches, "Multiple matches found."); + } + for (int i = 0; i < paramValues.Count; i++) + { + if (paramValues[i].Values.Count > 1) + return new ParseResult(argValues, paramValues, CommandError.MultipleMatches, "Multiple matches found."); + } + return new ParseResult(argValues, paramValues, null, null); + } + public static ParseResult FromSuccess(IReadOnlyList argValues, IReadOnlyList paramValues) + { + var argList = new TypeReaderResult[argValues.Count]; + for (int i = 0; i < argValues.Count; i++) + argList[i] = TypeReaderResult.FromSuccess(argValues[i]); + var paramList = new TypeReaderResult[paramValues.Count]; + for (int i = 0; i < paramValues.Count; i++) + paramList[i] = TypeReaderResult.FromSuccess(paramValues[i]); + return new ParseResult(argList, paramList, null, null); + } + public static ParseResult FromError(CommandError error, string reason) - => new ParseResult(null, error, reason); + => new ParseResult(null, null, error, reason); public static ParseResult FromError(IResult result) - => new ParseResult(null, result.Error, result.ErrorReason); + => new ParseResult(null, null, result.Error, result.ErrorReason); public override string ToString() => IsSuccess ? "Success" : $"{Error}: {ErrorReason}"; - private string DebuggerDisplay => IsSuccess ? $"Success ({Values.Count} Values)" : $"{Error}: {ErrorReason}"; + private string DebuggerDisplay => IsSuccess ? $"Success ({ArgValues.Count}{(ParamValues != null ? $" +{ParamValues.Count} Values" : "")})" : $"{Error}: {ErrorReason}"; } } diff --git a/src/Discord.Net.Commands/Results/TypeReaderResult.cs b/src/Discord.Net.Commands/Results/TypeReaderResult.cs index 13ed3cb08..20a9c4a22 100644 --- a/src/Discord.Net.Commands/Results/TypeReaderResult.cs +++ b/src/Discord.Net.Commands/Results/TypeReaderResult.cs @@ -1,32 +1,56 @@ -using System.Diagnostics; +using System; +using System.Collections.Generic; +using System.Collections.Immutable; +using System.Diagnostics; +using System.Linq; namespace Discord.Commands { [DebuggerDisplay(@"{DebuggerDisplay,nq}")] - public struct TypeReaderResult : IResult + public struct TypeReaderValue { public object Value { get; } + public float Score { get; } + + public TypeReaderValue(object value, float score) + { + Value = value; + Score = score; + } + + public override string ToString() => Value?.ToString(); + private string DebuggerDisplay => $"[{Value}, {Math.Round(Score, 2)}]"; + } + + [DebuggerDisplay(@"{DebuggerDisplay,nq}")] + public struct TypeReaderResult : IResult + { + public IReadOnlyCollection Values { get; } public CommandError? Error { get; } public string ErrorReason { get; } public bool IsSuccess => !Error.HasValue; - private TypeReaderResult(object value, CommandError? error, string errorReason) + private TypeReaderResult(IReadOnlyCollection values, CommandError? error, string errorReason) { - Value = value; + Values = values; Error = error; ErrorReason = errorReason; } public static TypeReaderResult FromSuccess(object value) - => new TypeReaderResult(value, null, null); + => new TypeReaderResult(ImmutableArray.Create(new TypeReaderValue(value, 1.0f)), null, null); + public static TypeReaderResult FromSuccess(TypeReaderValue value) + => new TypeReaderResult(ImmutableArray.Create(value), null, null); + public static TypeReaderResult FromSuccess(IReadOnlyCollection values) + => new TypeReaderResult(values, null, null); public static TypeReaderResult FromError(CommandError error, string reason) => new TypeReaderResult(null, error, reason); public static TypeReaderResult FromError(IResult result) => new TypeReaderResult(null, result.Error, result.ErrorReason); public override string ToString() => IsSuccess ? "Success" : $"{Error}: {ErrorReason}"; - private string DebuggerDisplay => IsSuccess ? $"Success ({Value})" : $"{Error}: {ErrorReason}"; + private string DebuggerDisplay => IsSuccess ? $"Success ({string.Join(", ", Values)})" : $"{Error}: {ErrorReason}"; } }