diff --git a/src/Discord.Net/DiscordClient.Users.cs b/src/Discord.Net/DiscordClient.Users.cs index 67fe28a07..ba379acc8 100644 --- a/src/Discord.Net/DiscordClient.Users.cs +++ b/src/Discord.Net/DiscordClient.Users.cs @@ -218,16 +218,22 @@ namespace Discord return query; } - public Task EditUser(User user, bool? mute = null, bool? deaf = null, IEnumerable roles = null) + public Task EditUser(User user, bool? mute = null, bool? deaf = null, IEnumerable roles = null, EditMode rolesMode = EditMode.Set) { if (user == null) throw new ArgumentNullException(nameof(user)); if (user.IsPrivate) throw new InvalidOperationException("Unable to edit users in a private channel"); CheckReady(); + //Modify the roles collection and filter out the everyone role + IEnumerable roleIds = roles == null ? null : user.Roles + .Modify(roles, rolesMode) + .Where(x => !x.IsEveryone) + .Select(x => x.Id); + var serverId = user.Server.Id; return _api.EditUser(serverId, user.Id, mute: mute, deaf: deaf, - roleIds: roles.Select(x => x.Id).Where(x => x != serverId)); + roleIds: roleIds); } public Task KickUser(User user) diff --git a/src/Discord.Net/Helpers/Extensions.cs b/src/Discord.Net/Helpers/Extensions.cs index 45aad0c5b..d11ae6f65 100644 --- a/src/Discord.Net/Helpers/Extensions.cs +++ b/src/Discord.Net/Helpers/Extensions.cs @@ -1,35 +1,44 @@ using System; +using System.Collections.Generic; +using System.Linq; using System.Threading; using System.Threading.Tasks; namespace Discord { - internal static class Extensions + public enum EditMode : byte { - public static async Task Timeout(this Task self, int milliseconds) + Set, + Add, + Remove + } + + internal static class Extensions + { + public static async Task Timeout(this Task task, int milliseconds) { Task timeoutTask = Task.Delay(milliseconds); - Task finishedTask = await Task.WhenAny(self, timeoutTask).ConfigureAwait(false); + Task finishedTask = await Task.WhenAny(task, timeoutTask).ConfigureAwait(false); if (finishedTask == timeoutTask) throw new TimeoutException(); else - await self.ConfigureAwait(false); + await task.ConfigureAwait(false); } - public static async Task Timeout(this Task self, int milliseconds) + public static async Task Timeout(this Task task, int milliseconds) { Task timeoutTask = Task.Delay(milliseconds); - Task finishedTask = await Task.WhenAny(self, timeoutTask).ConfigureAwait(false); + Task finishedTask = await Task.WhenAny(task, timeoutTask).ConfigureAwait(false); if (finishedTask == timeoutTask) throw new TimeoutException(); else - return await self.ConfigureAwait(false); + return await task.ConfigureAwait(false); } - public static async Task Timeout(this Task self, int milliseconds, CancellationTokenSource timeoutToken) + public static async Task Timeout(this Task task, int milliseconds, CancellationTokenSource timeoutToken) { try { timeoutToken.CancelAfter(milliseconds); - await self.ConfigureAwait(false); + await task.ConfigureAwait(false); } catch (OperationCanceledException) { @@ -38,12 +47,12 @@ namespace Discord throw; } } - public static async Task Timeout(this Task self, int milliseconds, CancellationTokenSource timeoutToken) + public static async Task Timeout(this Task task, int milliseconds, CancellationTokenSource timeoutToken) { try { timeoutToken.CancelAfter(milliseconds); - return await self.ConfigureAwait(false); + return await task.ConfigureAwait(false); } catch (OperationCanceledException) { @@ -64,5 +73,20 @@ namespace Discord try { await Task.Delay(-1, token).ConfigureAwait(false); } catch (OperationCanceledException) { } //Expected } - } + + public static IEnumerable Modify(this IEnumerable original, IEnumerable modified, EditMode mode) + { + if (original == null) return null; + switch (mode) + { + case EditMode.Set: + default: + return modified; + case EditMode.Add: + return original.Concat(modified); + case EditMode.Remove: + return original.Except(modified); + } + } + } }