Browse Source

Added new parameter scoring, support multiple matches

tags/1.0-rc
RogueException 8 years ago
parent
commit
324664917d
14 changed files with 317 additions and 178 deletions
  1. +69
    -5
      src/Discord.Net.Commands/Command.cs
  2. +2
    -2
      src/Discord.Net.Commands/CommandError.cs
  3. +1
    -1
      src/Discord.Net.Commands/CommandParameter.cs
  4. +24
    -41
      src/Discord.Net.Commands/CommandParser.cs
  5. +26
    -24
      src/Discord.Net.Commands/CommandService.cs
  6. +8
    -0
      src/Discord.Net.Commands/MultiMatchHandling.cs
  7. +27
    -28
      src/Discord.Net.Commands/Readers/ChannelTypeReader.cs
  8. +2
    -2
      src/Discord.Net.Commands/Readers/EnumTypeReader.cs
  9. +4
    -5
      src/Discord.Net.Commands/Readers/MessageTypeReader.cs
  10. +29
    -19
      src/Discord.Net.Commands/Readers/RoleTypeReader.cs
  11. +1
    -2
      src/Discord.Net.Commands/Readers/SimpleTypeReader.cs
  12. +61
    -35
      src/Discord.Net.Commands/Readers/UserTypeReader.cs
  13. +33
    -8
      src/Discord.Net.Commands/Results/ParseResult.cs
  14. +30
    -6
      src/Discord.Net.Commands/Results/TypeReaderResult.cs

+ 69
- 5
src/Discord.Net.Commands/Command.cs View File

@@ -1,7 +1,9 @@
using System; using System;
using System.Collections.Concurrent;
using System.Collections.Generic; using System.Collections.Generic;
using System.Collections.Immutable; using System.Collections.Immutable;
using System.Diagnostics; using System.Diagnostics;
using System.Linq;
using System.Reflection; using System.Reflection;
using System.Threading.Tasks; using System.Threading.Tasks;


@@ -10,6 +12,9 @@ namespace Discord.Commands
[DebuggerDisplay(@"{DebuggerDisplay,nq}")] [DebuggerDisplay(@"{DebuggerDisplay,nq}")]
public class Command public class Command
{ {
private static readonly MethodInfo _convertParamsMethod = typeof(Command).GetTypeInfo().GetDeclaredMethod(nameof(ConvertParamsList));
private static readonly ConcurrentDictionary<Type, Func<IEnumerable<object>, object>> _arrayConverters = new ConcurrentDictionary<Type, Func<IEnumerable<object>, object>>();

private readonly object _instance; private readonly object _instance;
private readonly Func<IMessage, IReadOnlyList<object>, Task> _action; private readonly Func<IMessage, IReadOnlyList<object>, Task> _action;


@@ -19,6 +24,7 @@ namespace Discord.Commands
public string Description { get; } public string Description { get; }
public string Summary { get; } public string Summary { get; }
public string Text { get; } public string Text { get; }
public bool HasVarArgs { get; }
public IReadOnlyList<CommandParameter> Parameters { get; } public IReadOnlyList<CommandParameter> Parameters { get; }
public IReadOnlyList<PreconditionAttribute> Preconditions { get; } public IReadOnlyList<PreconditionAttribute> Preconditions { get; }


@@ -42,8 +48,9 @@ namespace Discord.Commands
var summary = source.GetCustomAttribute<SummaryAttribute>(); var summary = source.GetCustomAttribute<SummaryAttribute>();
if (summary != null) if (summary != null)
Summary = summary.Text; Summary = summary.Text;
Parameters = BuildParameters(source); Parameters = BuildParameters(source);
HasVarArgs = Parameters.Count > 0 ? Parameters[Parameters.Count - 1].IsMultiple : false;
Preconditions = BuildPreconditions(source); Preconditions = BuildPreconditions(source);
_action = BuildAction(source); _action = BuildAction(source);
} }
@@ -76,14 +83,38 @@ namespace Discord.Commands


return await CommandParser.ParseArgs(this, msg, searchResult.Text.Substring(Text.Length), 0).ConfigureAwait(false); return await CommandParser.ParseArgs(this, msg, searchResult.Text.Substring(Text.Length), 0).ConfigureAwait(false);
} }
public async Task<ExecuteResult> Execute(IMessage msg, ParseResult parseResult)
public Task<ExecuteResult> Execute(IMessage msg, ParseResult parseResult)
{ {
if (!parseResult.IsSuccess) if (!parseResult.IsSuccess)
return ExecuteResult.FromError(parseResult);
return Task.FromResult(ExecuteResult.FromError(parseResult));


var argList = new object[parseResult.ArgValues.Count];
for (int i = 0; i < parseResult.ArgValues.Count; i++)
{
if (!parseResult.ArgValues[i].IsSuccess)
return Task.FromResult(ExecuteResult.FromError(parseResult.ArgValues[i]));
argList[i] = parseResult.ArgValues[i].Values.First().Value;
}

object[] paramList = null;
if (parseResult.ParamValues != null)
{
paramList = new object[parseResult.ParamValues.Count];
for (int i = 0; i < parseResult.ParamValues.Count; i++)
{
if (!parseResult.ParamValues[i].IsSuccess)
return Task.FromResult(ExecuteResult.FromError(parseResult.ParamValues[i]));
paramList[i] = parseResult.ParamValues[i].Values.First().Value;
}
}

return Execute(msg, argList, paramList);
}
public async Task<ExecuteResult> Execute(IMessage msg, IEnumerable<object> argList, IEnumerable<object> paramList)
{
try try
{ {
await _action.Invoke(msg, parseResult.Values);//Note: This code may need context
await _action.Invoke(msg, GenerateArgs(argList, paramList)).ConfigureAwait(false);//Note: This code may need context
return ExecuteResult.FromSuccess(); return ExecuteResult.FromSuccess();
} }
catch (Exception ex) catch (Exception ex)
@@ -108,7 +139,7 @@ namespace Discord.Commands
{ {
var parameter = parameters[i]; var parameter = parameters[i];
var type = parameter.ParameterType; var type = parameter.ParameterType;
//Detect 'params' //Detect 'params'
bool isMultiple = parameter.GetCustomAttribute<ParamArrayAttribute>() != null; bool isMultiple = parameter.GetCustomAttribute<ParamArrayAttribute>() != null;
if (isMultiple) if (isMultiple)
@@ -156,6 +187,39 @@ namespace Discord.Commands
}; };
} }


private object[] GenerateArgs(IEnumerable<object> argList, IEnumerable<object> paramsList)
{
int argCount = Parameters.Count;
var array = new object[Parameters.Count];
if (HasVarArgs)
argCount--;

int i = 0;
foreach (var arg in argList)
{
if (i == argCount)
throw new InvalidOperationException("Command was invoked with too many parameters");
array[i++] = arg;
}
if (i < argCount)
throw new InvalidOperationException("Command was invoked with too few parameters");

if (HasVarArgs)
{
var func = _arrayConverters.GetOrAdd(Parameters[Parameters.Count - 1].ElementType, t =>
{
var method = _convertParamsMethod.MakeGenericMethod(t);
return (Func<IEnumerable<object>, object>)method.CreateDelegate(typeof(Func<IEnumerable<object>, object>));
});
array[i] = func(paramsList);
}

return array;
}

private static T[] ConvertParamsList<T>(IEnumerable<object> paramsList)
=> paramsList.Cast<T>().ToArray();

public override string ToString() => Name; public override string ToString() => Name;
private string DebuggerDisplay => $"{Module.Name}.{Name} ({Text})"; private string DebuggerDisplay => $"{Module.Name}.{Name} ({Text})";
} }


+ 2
- 2
src/Discord.Net.Commands/CommandError.cs View File

@@ -3,14 +3,14 @@
public enum CommandError public enum CommandError
{ {
//Search //Search
UnknownCommand,
UnknownCommand = 1,


//Parse //Parse
ParseFailed, ParseFailed,
BadArgCount, BadArgCount,


//Parse (Type Reader) //Parse (Type Reader)
CastFailed,
//CastFailed,
ObjectNotFound, ObjectNotFound,
MultipleMatches, MultipleMatches,




+ 1
- 1
src/Discord.Net.Commands/CommandParameter.cs View File

@@ -17,7 +17,7 @@ namespace Discord.Commands
public bool IsRemainder { get; } public bool IsRemainder { get; }
public bool IsMultiple { get; } public bool IsMultiple { get; }
public Type ElementType { get; } public Type ElementType { get; }
internal object DefaultValue { get; }
public object DefaultValue { get; }


public CommandParameter(ParameterInfo source, string name, string summary, Type type, TypeReader reader, bool isOptional, bool isRemainder, bool isMultiple, object defaultValue) public CommandParameter(ParameterInfo source, string name, string summary, Type type, TypeReader reader, bool isOptional, bool isRemainder, bool isMultiple, object defaultValue)
{ {


+ 24
- 41
src/Discord.Net.Commands/CommandParser.cs View File

@@ -1,8 +1,5 @@
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;

using System.Collections.Immutable; using System.Collections.Immutable;
using System.Reflection;
using System.Text; using System.Text;
using System.Threading.Tasks; using System.Threading.Tasks;


@@ -16,9 +13,6 @@ namespace Discord.Commands
Parameter, Parameter,
QuotedParameter QuotedParameter
} }

private static readonly MethodInfo _convertArrayMethod = typeof(CommandParser).GetTypeInfo().GetDeclaredMethod(nameof(ConvertParamsList));
private static readonly ConcurrentDictionary<Type, Func<List<object>, object>> _arrayConverters = new ConcurrentDictionary<Type, Func<List<object>, object>>();
public static async Task<ParseResult> ParseArgs(Command command, IMessage context, string input, int startPos) public static async Task<ParseResult> ParseArgs(Command command, IMessage context, string input, int startPos)
{ {
@@ -27,9 +21,10 @@ namespace Discord.Commands
int endPos = input.Length; int endPos = input.Length;
var curPart = ParserPart.None; var curPart = ParserPart.None;
int lastArgEndPos = int.MinValue; int lastArgEndPos = int.MinValue;
var argList = ImmutableArray.CreateBuilder<object>();
List<object> paramsList = null; // TODO: could we use a better type?
var argList = ImmutableArray.CreateBuilder<TypeReaderResult>();
ImmutableArray<TypeReaderResult>.Builder paramList = null;
bool isEscaping = false; bool isEscaping = false;
bool hasMultipleMatches = false;
char c; char c;


for (int curPos = startPos; curPos <= endPos; curPos++) for (int curPos = startPos; curPos <= endPos; curPos++)
@@ -117,30 +112,28 @@ namespace Discord.Commands


var typeReaderResult = await curParam.Parse(context, argString).ConfigureAwait(false); var typeReaderResult = await curParam.Parse(context, argString).ConfigureAwait(false);
if (!typeReaderResult.IsSuccess) if (!typeReaderResult.IsSuccess)
return ParseResult.FromError(typeReaderResult);
{
if (typeReaderResult.Error == CommandError.MultipleMatches)
hasMultipleMatches = true;
else
return ParseResult.FromError(typeReaderResult);
}


if (curParam.IsMultiple) if (curParam.IsMultiple)
{ {
if (paramsList == null)
paramsList = new List<object>();
paramsList.Add(typeReaderResult.Value);
if (paramList == null)
paramList = ImmutableArray.CreateBuilder<TypeReaderResult>();
paramList.Add(typeReaderResult);


if (curPos == endPos) if (curPos == endPos)
{ {
var func = _arrayConverters.GetOrAdd(curParam.ElementType, t =>
{
var method = _convertArrayMethod.MakeGenericMethod(t);
return (Func<List<object>, object>)method.CreateDelegate(typeof(Func<List<object>, object>));
});
argList.Add(func.Invoke(paramsList));

curParam = null; curParam = null;
curPart = ParserPart.None; curPart = ParserPart.None;
} }
} }
else else
{ {
argList.Add(typeReaderResult.Value);
argList.Add(typeReaderResult);


curParam = null; curParam = null;
curPart = ParserPart.None; curPart = ParserPart.None;
@@ -154,34 +147,24 @@ namespace Discord.Commands
var typeReaderResult = await curParam.Parse(context, argBuilder.ToString()).ConfigureAwait(false); var typeReaderResult = await curParam.Parse(context, argBuilder.ToString()).ConfigureAwait(false);
if (!typeReaderResult.IsSuccess) if (!typeReaderResult.IsSuccess)
return ParseResult.FromError(typeReaderResult); return ParseResult.FromError(typeReaderResult);
argList.Add(typeReaderResult.Value);
argList.Add(typeReaderResult);
} }


if (isEscaping) if (isEscaping)
return ParseResult.FromError(CommandError.ParseFailed, "Input text may not end on an incomplete escape."); return ParseResult.FromError(CommandError.ParseFailed, "Input text may not end on an incomplete escape.");
if (curPart == ParserPart.QuotedParameter) if (curPart == ParserPart.QuotedParameter)
return ParseResult.FromError(CommandError.ParseFailed, "A quoted parameter is incomplete"); return ParseResult.FromError(CommandError.ParseFailed, "A quoted parameter is incomplete");

if (argList.Count < command.Parameters.Count)
//Add missing optionals
for (int i = paramList.Count; i < command.Parameters.Count; i++)
{ {
for (int i = argList.Count; i < command.Parameters.Count; i++)
{
var param = command.Parameters[i];
if (!param.IsOptional)
return ParseResult.FromError(CommandError.BadArgCount, "The input text has too few parameters.");
argList.Add(param.DefaultValue);
}
var param = command.Parameters[i];
if (!param.IsOptional)
return ParseResult.FromError(CommandError.BadArgCount, "The input text has too few parameters.");
argList.Add(TypeReaderResult.FromSuccess(param.DefaultValue));
} }

return ParseResult.FromSuccess(argList.ToImmutable());
}

private static T[] ConvertParamsList<T>(List<object> paramsList)
{
var array = new T[paramsList.Count];
for (int i = 0; i < array.Length; i++)
array[i] = (T)paramsList[i];
return array;
return ParseResult.FromSuccess(argList.ToImmutable(), paramList?.ToImmutable());
} }
} }
} }

+ 26
- 24
src/Discord.Net.Commands/CommandService.cs View File

@@ -40,16 +40,8 @@ namespace Discord.Commands
[typeof(decimal)] = new SimpleTypeReader<decimal>(), [typeof(decimal)] = new SimpleTypeReader<decimal>(),
[typeof(DateTime)] = new SimpleTypeReader<DateTime>(), [typeof(DateTime)] = new SimpleTypeReader<DateTime>(),
[typeof(DateTimeOffset)] = new SimpleTypeReader<DateTimeOffset>(), [typeof(DateTimeOffset)] = new SimpleTypeReader<DateTimeOffset>(),

//TODO: Do we want to support any other interfaces?

//[typeof(IMentionable)] = new GeneralTypeReader(),
//[typeof(ISnowflakeEntity)] = new GeneralTypeReader(),
//[typeof(IEntity<ulong>)] = new GeneralTypeReader(),

[typeof(IMessage)] = new MessageTypeReader(), [typeof(IMessage)] = new MessageTypeReader(),
//[typeof(IAttachment)] = new xxx(),
//[typeof(IEmbed)] = new xxx(),


[typeof(IChannel)] = new ChannelTypeReader<IChannel>(), [typeof(IChannel)] = new ChannelTypeReader<IChannel>(),
[typeof(IDMChannel)] = new ChannelTypeReader<IDMChannel>(), [typeof(IDMChannel)] = new ChannelTypeReader<IDMChannel>(),
@@ -61,10 +53,8 @@ namespace Discord.Commands
[typeof(IVoiceChannel)] = new ChannelTypeReader<IVoiceChannel>(), [typeof(IVoiceChannel)] = new ChannelTypeReader<IVoiceChannel>(),


//[typeof(IGuild)] = new GuildTypeReader<IGuild>(), //[typeof(IGuild)] = new GuildTypeReader<IGuild>(),
//[typeof(IUserGuild)] = new GuildTypeReader<IUserGuild>(),
//[typeof(IGuildIntegration)] = new xxx(),


[typeof(IRole)] = new RoleTypeReader(),
[typeof(IRole)] = new RoleTypeReader<IRole>(),


//[typeof(IInvite)] = new InviteTypeReader<IInvite>(), //[typeof(IInvite)] = new InviteTypeReader<IInvite>(),
//[typeof(IInviteMetadata)] = new InviteTypeReader<IInviteMetadata>(), //[typeof(IInviteMetadata)] = new InviteTypeReader<IInviteMetadata>(),
@@ -72,10 +62,6 @@ namespace Discord.Commands
[typeof(IUser)] = new UserTypeReader<IUser>(), [typeof(IUser)] = new UserTypeReader<IUser>(),
[typeof(IGroupUser)] = new UserTypeReader<IGroupUser>(), [typeof(IGroupUser)] = new UserTypeReader<IGroupUser>(),
[typeof(IGuildUser)] = new UserTypeReader<IGuildUser>(), [typeof(IGuildUser)] = new UserTypeReader<IGuildUser>(),
//[typeof(ISelfUser)] = new UserTypeReader<ISelfUser>(),
//[typeof(IPresence)] = new UserTypeReader<IPresence>(),
//[typeof(IVoiceState)] = new UserTypeReader<IVoiceState>(),
//[typeof(IConnection)] = new xxx(),
}; };
} }


@@ -201,8 +187,9 @@ namespace Discord.Commands
return SearchResult.FromError(CommandError.UnknownCommand, "Unknown command."); return SearchResult.FromError(CommandError.UnknownCommand, "Unknown command.");
} }


public Task<IResult> Execute(IMessage message, int argPos) => Execute(message, message.Content.Substring(argPos));
public async Task<IResult> Execute(IMessage message, string input)
public Task<IResult> Execute(IMessage message, int argPos, MultiMatchHandling multiMatchHandling = MultiMatchHandling.Exception)
=> Execute(message, message.Content.Substring(argPos), multiMatchHandling);
public async Task<IResult> Execute(IMessage message, string input, MultiMatchHandling multiMatchHandling = MultiMatchHandling.Exception)
{ {
var searchResult = Search(message, input); var searchResult = Search(message, input);
if (!searchResult.IsSuccess) if (!searchResult.IsSuccess)
@@ -223,14 +210,29 @@ namespace Discord.Commands
var parseResult = await commands[i].Parse(message, searchResult, preconditionResult); var parseResult = await commands[i].Parse(message, searchResult, preconditionResult);
if (!parseResult.IsSuccess) if (!parseResult.IsSuccess)
{ {
if (commands.Count == 1)
return parseResult;
else
continue;
if (parseResult.Error == CommandError.MultipleMatches)
{
TypeReaderValue[] argList, paramList;
switch (multiMatchHandling)
{
case MultiMatchHandling.Best:
argList = parseResult.ArgValues.Select(x => x.Values.OrderByDescending(y => y.Score).First()).ToArray();
paramList = parseResult.ParamValues?.Select(x => x.Values.OrderByDescending(y => y.Score).First()).ToArray();
parseResult = ParseResult.FromSuccess(argList, paramList);
break;
}
}

if (!parseResult.IsSuccess)
{
if (commands.Count == 1)
return parseResult;
else
continue;
}
} }


var executeResult = await commands[i].Execute(message, parseResult);
return executeResult;
return await commands[i].Execute(message, parseResult);
} }
return SearchResult.FromError(CommandError.UnknownCommand, "This input does not match any overload."); return SearchResult.FromError(CommandError.UnknownCommand, "This input does not match any overload.");


+ 8
- 0
src/Discord.Net.Commands/MultiMatchHandling.cs View File

@@ -0,0 +1,8 @@
namespace Discord.Commands
{
public enum MultiMatchHandling
{
Exception,
Best
}
}

+ 27
- 28
src/Discord.Net.Commands/Readers/ChannelTypeReader.cs View File

@@ -1,4 +1,6 @@
using System; using System;
using System.Collections.Generic;
using System.Globalization;
using System.Linq; using System.Linq;
using System.Threading.Tasks; using System.Threading.Tasks;


@@ -9,40 +11,37 @@ namespace Discord.Commands
{ {
public override async Task<TypeReaderResult> Read(IMessage context, string input) public override async Task<TypeReaderResult> Read(IMessage context, string input)
{ {
IGuildChannel guildChannel = context.Channel as IGuildChannel;
IChannel result = null;
var guild = (context.Channel as IGuildChannel)?.Guild;


if (guildChannel != null)
if (guild != null)
{ {
//By Id
var results = new Dictionary<ulong, TypeReaderValue>();
var channels = await guild.GetChannelsAsync().ConfigureAwait(false);
ulong id; ulong id;
if (MentionUtils.TryParseChannel(input, out id) || ulong.TryParse(input, out id))
{
var channel = await guildChannel.Guild.GetChannelAsync(id).ConfigureAwait(false);
if (channel != null)
result = channel;
}

//By Name
if (result == null)
{
var channels = await guildChannel.Guild.GetChannelsAsync().ConfigureAwait(false);
var filteredChannels = channels.Where(x => string.Equals(input, x.Name, StringComparison.OrdinalIgnoreCase)).ToArray();
if (filteredChannels.Length > 1)
return TypeReaderResult.FromError(CommandError.MultipleMatches, "Multiple channels found.");
else if (filteredChannels.Length == 1)
result = filteredChannels[0];
}

//By Mention (1.0)
if (MentionUtils.TryParseChannel(input, out id))
AddResult(results, await guild.GetUserAsync(id).ConfigureAwait(false) as T, 1.00f);

//By Id (0.9)
if (ulong.TryParse(input, NumberStyles.None, CultureInfo.InvariantCulture, out id))
AddResult(results, await guild.GetChannelAsync(id).ConfigureAwait(false) as T, 0.90f);

//By Name (0.7-0.8)
foreach (var channel in channels.Where(x => string.Equals(input, x.Name, StringComparison.OrdinalIgnoreCase)))
AddResult(results, channel as T, channel.Name == input ? 0.80f : 0.70f);

if (results.Count > 0)
return TypeReaderResult.FromSuccess(results.Values);
} }


if (result == null)
return TypeReaderResult.FromError(CommandError.ObjectNotFound, "Channel not found.");
return TypeReaderResult.FromError(CommandError.ObjectNotFound, "Channel not found.");
}


T castResult = result as T;
if (castResult == null)
return TypeReaderResult.FromError(CommandError.CastFailed, $"Channel is not a {typeof(T).Name}.");
else
return TypeReaderResult.FromSuccess(castResult);
private void AddResult(Dictionary<ulong, TypeReaderValue> results, T channel, float score)
{
if (channel != null && !results.ContainsKey(channel.Id))
results.Add(channel.Id, new TypeReaderValue(channel, score));
} }
} }
} }

+ 2
- 2
src/Discord.Net.Commands/Readers/EnumTypeReader.cs View File

@@ -52,14 +52,14 @@ namespace Discord.Commands
if (_enumsByValue.TryGetValue(baseValue, out enumValue)) if (_enumsByValue.TryGetValue(baseValue, out enumValue))
return Task.FromResult(TypeReaderResult.FromSuccess(enumValue)); return Task.FromResult(TypeReaderResult.FromSuccess(enumValue));
else else
return Task.FromResult(TypeReaderResult.FromError(CommandError.CastFailed, $"Value is not a {_enumType.Name}"));
return Task.FromResult(TypeReaderResult.FromError(CommandError.ParseFailed, $"Value is not a {_enumType.Name}"));
} }
else else
{ {
if (_enumsByName.TryGetValue(input.ToLower(), out enumValue)) if (_enumsByName.TryGetValue(input.ToLower(), out enumValue))
return Task.FromResult(TypeReaderResult.FromSuccess(enumValue)); return Task.FromResult(TypeReaderResult.FromSuccess(enumValue));
else else
return Task.FromResult(TypeReaderResult.FromError(CommandError.CastFailed, $"Value is not a {_enumType.Name}"));
return Task.FromResult(TypeReaderResult.FromError(CommandError.ParseFailed, $"Value is not a {_enumType.Name}"));
} }
} }
} }


+ 4
- 5
src/Discord.Net.Commands/Readers/MessageTypeReader.cs View File

@@ -7,18 +7,17 @@ namespace Discord.Commands
{ {
public override Task<TypeReaderResult> Read(IMessage context, string input) public override Task<TypeReaderResult> Read(IMessage context, string input)
{ {
//By Id
ulong id; ulong id;

//By Id (1.0)
if (ulong.TryParse(input, NumberStyles.None, CultureInfo.InvariantCulture, out id)) if (ulong.TryParse(input, NumberStyles.None, CultureInfo.InvariantCulture, out id))
{ {
var msg = context.Channel.GetCachedMessage(id); var msg = context.Channel.GetCachedMessage(id);
if (msg == null)
return Task.FromResult(TypeReaderResult.FromError(CommandError.ObjectNotFound, "Message not found."));
else
if (msg != null)
return Task.FromResult(TypeReaderResult.FromSuccess(msg)); return Task.FromResult(TypeReaderResult.FromSuccess(msg));
} }


return Task.FromResult(TypeReaderResult.FromError(CommandError.ParseFailed, "Failed to parse Message Id."));
return Task.FromResult(TypeReaderResult.FromError(CommandError.ObjectNotFound, "Message not found."));
} }
} }
} }

+ 29
- 19
src/Discord.Net.Commands/Readers/RoleTypeReader.cs View File

@@ -1,36 +1,46 @@
using System; using System;
using System.Collections.Generic;
using System.Globalization;
using System.Linq; using System.Linq;
using System.Threading.Tasks; using System.Threading.Tasks;


namespace Discord.Commands namespace Discord.Commands
{ {
internal class RoleTypeReader : TypeReader
internal class RoleTypeReader<T> : TypeReader
where T : class, IRole
{ {
public override Task<TypeReaderResult> Read(IMessage context, string input) public override Task<TypeReaderResult> Read(IMessage context, string input)
{ {
IGuildChannel guildChannel = context.Channel as IGuildChannel;
var guild = (context.Channel as IGuildChannel)?.Guild;
ulong id;


if (guildChannel != null)
if (guild != null)
{ {
//By Id
ulong id;
if (MentionUtils.TryParseRole(input, out id) || ulong.TryParse(input, out id))
{
var channel = guildChannel.Guild.GetRole(id);
if (channel != null)
return Task.FromResult(TypeReaderResult.FromSuccess(channel));
}
var results = new Dictionary<ulong, TypeReaderValue>();
var roles = guild.Roles;


//By Name
var roles = guildChannel.Guild.Roles;
var filteredRoles = roles.Where(x => string.Equals(input, x.Name, StringComparison.OrdinalIgnoreCase)).ToArray();
if (filteredRoles.Length > 1)
return Task.FromResult(TypeReaderResult.FromError(CommandError.MultipleMatches, "Multiple roles found."));
else if (filteredRoles.Length == 1)
return Task.FromResult(TypeReaderResult.FromSuccess(filteredRoles[0]));
//By Mention (1.0)
if (MentionUtils.TryParseRole(input, out id))
AddResult(results, guild.GetRole(id) as T, 1.00f);

//By Id (0.9)
if (ulong.TryParse(input, NumberStyles.None, CultureInfo.InvariantCulture, out id))
AddResult(results, guild.GetRole(id) as T, 0.90f);

//By Name (0.7-0.8)
foreach (var role in roles.Where(x => string.Equals(input, x.Name, StringComparison.OrdinalIgnoreCase)))
AddResult(results, role as T, role.Name == input ? 0.80f : 0.70f);

if (results.Count > 0)
return Task.FromResult(TypeReaderResult.FromSuccess(results));
} }
return Task.FromResult(TypeReaderResult.FromError(CommandError.ObjectNotFound, "Role not found.")); return Task.FromResult(TypeReaderResult.FromError(CommandError.ObjectNotFound, "Role not found."));
} }

private void AddResult(Dictionary<ulong, TypeReaderValue> results, T role, float score)
{
if (role != null && !results.ContainsKey(role.Id))
results.Add(role.Id, new TypeReaderValue(role, score));
}
} }
} }

+ 1
- 2
src/Discord.Net.Commands/Readers/SimpleTypeReader.cs View File

@@ -16,8 +16,7 @@ namespace Discord.Commands
T value; T value;
if (_tryParse(input, out value)) if (_tryParse(input, out value))
return Task.FromResult(TypeReaderResult.FromSuccess(value)); return Task.FromResult(TypeReaderResult.FromSuccess(value));
else
return Task.FromResult(TypeReaderResult.FromError(CommandError.ParseFailed, $"Failed to parse {typeof(T).Name}"));
return Task.FromResult(TypeReaderResult.FromError(CommandError.ParseFailed, $"Failed to parse {typeof(T).Name}"));
} }
} }
} }

+ 61
- 35
src/Discord.Net.Commands/Readers/UserTypeReader.cs View File

@@ -1,4 +1,6 @@
using System; using System;
using System.Collections.Generic;
using System.Globalization;
using System.Linq; using System.Linq;
using System.Threading.Tasks; using System.Threading.Tasks;


@@ -9,54 +11,78 @@ namespace Discord.Commands
{ {
public override async Task<TypeReaderResult> Read(IMessage context, string input) public override async Task<TypeReaderResult> Read(IMessage context, string input)
{ {
IUser result = null;
//By Id
var results = new Dictionary<ulong, TypeReaderValue>();
var guild = (context.Channel as IGuildChannel)?.Guild;
IReadOnlyCollection<IUser> channelUsers = await context.Channel.GetUsersAsync().ConfigureAwait(false);
IReadOnlyCollection<IGuildUser> guildUsers = null;
ulong id; ulong id;
if (MentionUtils.TryParseUser(input, out id) || ulong.TryParse(input, out id))

if (guild != null)
guildUsers = await guild.GetUsersAsync().ConfigureAwait(false);

//By Mention (1.0)
if (MentionUtils.TryParseUser(input, out id))
{ {
var user = await context.Channel.GetUserAsync(id).ConfigureAwait(false);
if (user != null)
result = user;
if (guild != null)
AddResult(results, await guild.GetUserAsync(id).ConfigureAwait(false) as T, 1.00f);
else
AddResult(results, await context.Channel.GetUserAsync(id).ConfigureAwait(false) as T, 1.00f);
} }


//By Username + Discriminator
if (result == null)
//By Id (0.9)
if (ulong.TryParse(input, NumberStyles.None, CultureInfo.InvariantCulture, out id))
{ {
int index = input.LastIndexOf('#');
if (index >= 0)
if (guild != null)
AddResult(results, await guild.GetUserAsync(id).ConfigureAwait(false) as T, 0.90f);
else
AddResult(results, await context.Channel.GetUserAsync(id).ConfigureAwait(false) as T, 0.90f);
}

//By Username + Discriminator (0.7-0.85)
int index = input.LastIndexOf('#');
if (index >= 0)
{
string username = input.Substring(0, index);
ushort discriminator;
if (ushort.TryParse(input.Substring(index + 1), out discriminator))
{ {
string username = input.Substring(0, index);
ushort discriminator;
if (ushort.TryParse(input.Substring(index + 1), out discriminator))
{
var users = await context.Channel.GetUsersAsync().ConfigureAwait(false);
result = users.Where(x =>
x.DiscriminatorValue == discriminator &&
string.Equals(username, x.Username, StringComparison.OrdinalIgnoreCase)).FirstOrDefault();
}
var channelUser = channelUsers.Where(x => x.DiscriminatorValue == discriminator &&
string.Equals(username, x.Username, StringComparison.OrdinalIgnoreCase)).FirstOrDefault();
AddResult(results, channelUser as T, channelUser.Username == username ? 0.85f : 0.75f);

var guildUser = channelUsers.Where(x => x.DiscriminatorValue == discriminator &&
string.Equals(username, x.Username, StringComparison.OrdinalIgnoreCase)).FirstOrDefault();
AddResult(results, guildUser as T, guildUser.Username == username ? 0.80f : 0.70f);
} }
} }


//By Username
if (result == null)
//By Username (0.5-0.6)
{ {
var users = await context.Channel.GetUsersAsync().ConfigureAwait(false);
var filteredUsers = users.Where(x => string.Equals(input, x.Username, StringComparison.OrdinalIgnoreCase)).ToArray();
if (filteredUsers.Length > 1)
return TypeReaderResult.FromError(CommandError.MultipleMatches, "Multiple users found.");
else if (filteredUsers.Length == 1)
result = filteredUsers[0];
foreach (var channelUser in channelUsers.Where(x => string.Equals(input, x.Username, StringComparison.OrdinalIgnoreCase)))
AddResult(results, channelUser as T, channelUser.Username == input ? 0.65f : 0.55f);

foreach (var guildUser in guildUsers.Where(x => string.Equals(input, x.Username, StringComparison.OrdinalIgnoreCase)))
AddResult(results, guildUser as T, guildUser.Username == input ? 0.60f : 0.50f);
} }


if (result == null)
return TypeReaderResult.FromError(CommandError.ObjectNotFound, "User not found.");
//By Nickname (0.5-0.6)
{
foreach (var channelUser in channelUsers.Where(x => string.Equals(input, (x as IGuildUser).Nickname, StringComparison.OrdinalIgnoreCase)))
AddResult(results, channelUser as T, (channelUser as IGuildUser).Nickname == input ? 0.65f : 0.55f);

foreach (var guildUser in guildUsers.Where(x => string.Equals(input, (x as IGuildUser).Nickname, StringComparison.OrdinalIgnoreCase)))
AddResult(results, guildUser as T, (guildUser as IGuildUser).Nickname == input ? 0.60f : 0.50f);
}


T castResult = result as T;
if (castResult == null)
return TypeReaderResult.FromError(CommandError.CastFailed, $"User is not a {typeof(T).Name}.");
else
return TypeReaderResult.FromSuccess(castResult);
if (results.Count > 0)
return TypeReaderResult.FromSuccess(results.Values.ToArray());
return TypeReaderResult.FromError(CommandError.ObjectNotFound, "User not found.");
}

private void AddResult(Dictionary<ulong, TypeReaderValue> results, T user, float score)
{
if (user != null && !results.ContainsKey(user.Id))
results.Add(user.Id, new TypeReaderValue(user, score));
} }
} }
} }

+ 33
- 8
src/Discord.Net.Commands/Results/ParseResult.cs View File

@@ -6,28 +6,53 @@ namespace Discord.Commands
[DebuggerDisplay(@"{DebuggerDisplay,nq}")] [DebuggerDisplay(@"{DebuggerDisplay,nq}")]
public struct ParseResult : IResult public struct ParseResult : IResult
{ {
public IReadOnlyList<object> Values { get; }
public IReadOnlyList<TypeReaderResult> ArgValues { get; }
public IReadOnlyList<TypeReaderResult> ParamValues { get; }


public CommandError? Error { get; } public CommandError? Error { get; }
public string ErrorReason { get; } public string ErrorReason { get; }


public bool IsSuccess => !Error.HasValue; public bool IsSuccess => !Error.HasValue;


private ParseResult(IReadOnlyList<object> values, CommandError? error, string errorReason)
private ParseResult(IReadOnlyList<TypeReaderResult> argValues, IReadOnlyList<TypeReaderResult> paramValue, CommandError? error, string errorReason)
{ {
Values = values;
ArgValues = argValues;
ParamValues = paramValue;
Error = error; Error = error;
ErrorReason = errorReason; ErrorReason = errorReason;
} }


public static ParseResult FromSuccess(IReadOnlyList<object> values)
=> new ParseResult(values, null, null);
public static ParseResult FromSuccess(IReadOnlyList<TypeReaderResult> argValues, IReadOnlyList<TypeReaderResult> paramValues)
{
for (int i = 0; i < argValues.Count; i++)
{
if (argValues[i].Values.Count > 1)
return new ParseResult(argValues, paramValues, CommandError.MultipleMatches, "Multiple matches found.");
}
for (int i = 0; i < paramValues.Count; i++)
{
if (paramValues[i].Values.Count > 1)
return new ParseResult(argValues, paramValues, CommandError.MultipleMatches, "Multiple matches found.");
}
return new ParseResult(argValues, paramValues, null, null);
}
public static ParseResult FromSuccess(IReadOnlyList<TypeReaderValue> argValues, IReadOnlyList<TypeReaderValue> paramValues)
{
var argList = new TypeReaderResult[argValues.Count];
for (int i = 0; i < argValues.Count; i++)
argList[i] = TypeReaderResult.FromSuccess(argValues[i]);
var paramList = new TypeReaderResult[paramValues.Count];
for (int i = 0; i < paramValues.Count; i++)
paramList[i] = TypeReaderResult.FromSuccess(paramValues[i]);
return new ParseResult(argList, paramList, null, null);
}

public static ParseResult FromError(CommandError error, string reason) public static ParseResult FromError(CommandError error, string reason)
=> new ParseResult(null, error, reason);
=> new ParseResult(null, null, error, reason);
public static ParseResult FromError(IResult result) public static ParseResult FromError(IResult result)
=> new ParseResult(null, result.Error, result.ErrorReason);
=> new ParseResult(null, null, result.Error, result.ErrorReason);


public override string ToString() => IsSuccess ? "Success" : $"{Error}: {ErrorReason}"; public override string ToString() => IsSuccess ? "Success" : $"{Error}: {ErrorReason}";
private string DebuggerDisplay => IsSuccess ? $"Success ({Values.Count} Values)" : $"{Error}: {ErrorReason}";
private string DebuggerDisplay => IsSuccess ? $"Success ({ArgValues.Count}{(ParamValues != null ? $" +{ParamValues.Count} Values" : "")})" : $"{Error}: {ErrorReason}";
} }
} }

+ 30
- 6
src/Discord.Net.Commands/Results/TypeReaderResult.cs View File

@@ -1,32 +1,56 @@
using System.Diagnostics;
using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Diagnostics;
using System.Linq;


namespace Discord.Commands namespace Discord.Commands
{ {
[DebuggerDisplay(@"{DebuggerDisplay,nq}")] [DebuggerDisplay(@"{DebuggerDisplay,nq}")]
public struct TypeReaderResult : IResult
public struct TypeReaderValue
{ {
public object Value { get; } public object Value { get; }
public float Score { get; }

public TypeReaderValue(object value, float score)
{
Value = value;
Score = score;
}

public override string ToString() => Value?.ToString();
private string DebuggerDisplay => $"[{Value}, {Math.Round(Score, 2)}]";
}

[DebuggerDisplay(@"{DebuggerDisplay,nq}")]
public struct TypeReaderResult : IResult
{
public IReadOnlyCollection<TypeReaderValue> Values { get; }


public CommandError? Error { get; } public CommandError? Error { get; }
public string ErrorReason { get; } public string ErrorReason { get; }


public bool IsSuccess => !Error.HasValue; public bool IsSuccess => !Error.HasValue;


private TypeReaderResult(object value, CommandError? error, string errorReason)
private TypeReaderResult(IReadOnlyCollection<TypeReaderValue> values, CommandError? error, string errorReason)
{ {
Value = value;
Values = values;
Error = error; Error = error;
ErrorReason = errorReason; ErrorReason = errorReason;
} }


public static TypeReaderResult FromSuccess(object value) public static TypeReaderResult FromSuccess(object value)
=> new TypeReaderResult(value, null, null);
=> new TypeReaderResult(ImmutableArray.Create(new TypeReaderValue(value, 1.0f)), null, null);
public static TypeReaderResult FromSuccess(TypeReaderValue value)
=> new TypeReaderResult(ImmutableArray.Create(value), null, null);
public static TypeReaderResult FromSuccess(IReadOnlyCollection<TypeReaderValue> values)
=> new TypeReaderResult(values, null, null);
public static TypeReaderResult FromError(CommandError error, string reason) public static TypeReaderResult FromError(CommandError error, string reason)
=> new TypeReaderResult(null, error, reason); => new TypeReaderResult(null, error, reason);
public static TypeReaderResult FromError(IResult result) public static TypeReaderResult FromError(IResult result)
=> new TypeReaderResult(null, result.Error, result.ErrorReason); => new TypeReaderResult(null, result.Error, result.ErrorReason);


public override string ToString() => IsSuccess ? "Success" : $"{Error}: {ErrorReason}"; public override string ToString() => IsSuccess ? "Success" : $"{Error}: {ErrorReason}";
private string DebuggerDisplay => IsSuccess ? $"Success ({Value})" : $"{Error}: {ErrorReason}";
private string DebuggerDisplay => IsSuccess ? $"Success ({string.Join(", ", Values)})" : $"{Error}: {ErrorReason}";
} }
} }

Loading…
Cancel
Save