From 5760e94d81e9644769d3ba4c58132b1976872db9 Mon Sep 17 00:00:00 2001 From: RogueException Date: Fri, 11 Dec 2015 19:07:55 -0400 Subject: [PATCH] Reworked internal task engine for DiscordClient and WebSocket. Several other minor async fixes. --- .../API/Messages/GatewaySocket.cs | 2 +- src/Discord.Net.Audio/AudioService.cs | 16 +- src/Discord.Net.Audio/DiscordAudioClient.cs | 39 +- .../Net/WebSockets/VoiceWebSocket.cs | 82 ++-- src/Discord.Net.Commands/Command.cs | 2 +- src/Discord.Net.Net45/Discord.Net.csproj | 13 +- src/Discord.Net.Shared/TaskHelper.cs | 22 +- ...er.cs => LongStringCollectionConverter.cs} | 4 +- .../{AvatarImageType.cs => ImageType.cs} | 0 src/Discord.Net/API/Messages/Members.cs | 6 +- src/Discord.Net/API/Messages/Messages.cs | 4 +- src/Discord.Net/DiscordClient.Messages.cs | 9 +- src/Discord.Net/DiscordClient.Roles.cs | 2 +- src/Discord.Net/DiscordClient.Servers.cs | 3 +- src/Discord.Net/DiscordClient.Users.cs | 6 +- src/Discord.Net/DiscordClient.cs | 383 ++++++++---------- src/Discord.Net/Helpers/Mention.cs | 2 +- src/Discord.Net/Helpers/TaskManager.cs | 148 +++++++ src/Discord.Net/Net/Rest/SharpRestEngine.cs | 2 +- .../Net/WebSockets/GatewaySocket.cs | 49 +-- .../Net/WebSockets/WS4NetEngine.cs | 10 +- src/Discord.Net/Net/WebSockets/WebSocket.cs | 204 +++------- test/Discord.Net.Tests/Tests.cs | 6 +- 23 files changed, 524 insertions(+), 490 deletions(-) rename src/Discord.Net/API/Converters/{LongCollectionConverter.cs => LongStringCollectionConverter.cs} (94%) rename src/Discord.Net/API/Enums/{AvatarImageType.cs => ImageType.cs} (100%) create mode 100644 src/Discord.Net/Helpers/TaskManager.cs diff --git a/src/Discord.Net.Audio/API/Messages/GatewaySocket.cs b/src/Discord.Net.Audio/API/Messages/GatewaySocket.cs index 5415318c2..072521155 100644 --- a/src/Discord.Net.Audio/API/Messages/GatewaySocket.cs +++ b/src/Discord.Net.Audio/API/Messages/GatewaySocket.cs @@ -5,7 +5,7 @@ using Discord.API.Converters; using Newtonsoft.Json; -namespace Discord.Audio.API +namespace Discord.API { internal sealed class VoiceServerUpdateEvent { diff --git a/src/Discord.Net.Audio/AudioService.cs b/src/Discord.Net.Audio/AudioService.cs index 74bc67c33..a2762f860 100644 --- a/src/Discord.Net.Audio/AudioService.cs +++ b/src/Discord.Net.Audio/AudioService.cs @@ -95,8 +95,7 @@ namespace Discord.Audio else { var logger = Client.Log().CreateLogger("Voice"); - var voiceSocket = new VoiceWebSocket(Client.Config, _config, logger); - _defaultClient = new DiscordAudioClient(this, 0, logger, _client.WebSocket, voiceSocket); + _defaultClient = new DiscordAudioClient(this, 0, logger, _client.WebSocket); } _talkingUsers = new ConcurrentDictionary(); @@ -145,27 +144,26 @@ namespace Discord.Audio return Task.FromResult(_defaultClient); } - var client = _voiceClients.GetOrAdd(server.Id, (Func)(_ => + var client = _voiceClients.GetOrAdd(server.Id, _ => { int id = unchecked(++_nextClientId); var logger = Client.Log().CreateLogger($"Voice #{id}"); - Net.WebSockets.GatewaySocket gatewaySocket = null; - var voiceSocket = new VoiceWebSocket(Client.Config, _config, logger); - var voiceClient = new DiscordAudioClient((AudioService)(this), (int)id, (Logger)logger, (Net.WebSockets.GatewaySocket)gatewaySocket, (VoiceWebSocket)voiceSocket); + GatewaySocket gatewaySocket = null; + var voiceClient = new DiscordAudioClient(this, id, logger, gatewaySocket); voiceClient.SetServerId(server.Id); - voiceSocket.OnPacket += (s, e) => + voiceClient.VoiceSocket.OnPacket += (s, e) => { RaiseOnPacket(e); }; - voiceSocket.IsSpeaking += (s, e) => + voiceClient.VoiceSocket.IsSpeaking += (s, e) => { var user = Client.GetUser(server, e.UserId); RaiseUserIsSpeakingUpdated(user, e.IsSpeaking); }; return voiceClient; - })); + }); //await client.Connect(gatewaySocket.Host, _client.Token).ConfigureAwait(false); return Task.FromResult(client); } diff --git a/src/Discord.Net.Audio/DiscordAudioClient.cs b/src/Discord.Net.Audio/DiscordAudioClient.cs index f8a97a46f..ccb9d6f9c 100644 --- a/src/Discord.Net.Audio/DiscordAudioClient.cs +++ b/src/Discord.Net.Audio/DiscordAudioClient.cs @@ -1,4 +1,5 @@ -using Discord.Net.WebSockets; +using Discord.API; +using Discord.Net.WebSockets; using System; using System.Threading.Tasks; @@ -10,22 +11,29 @@ namespace Discord.Audio public int Id => _id; private readonly AudioService _service; - private readonly GatewaySocket _gatewaySocket; - private readonly VoiceWebSocket _voiceSocket; private readonly Logger _logger; - public long? ServerId => _voiceSocket.ServerId; - public long? ChannelId => _voiceSocket.ChannelId; + public GatewaySocket GatewaySocket => _gatewaySocket; + private readonly GatewaySocket _gatewaySocket; - public DiscordAudioClient(AudioService service, int id, Logger logger, GatewaySocket gatewaySocket, VoiceWebSocket voiceSocket) + public VoiceWebSocket VoiceSocket => _voiceSocket; + private readonly VoiceWebSocket _voiceSocket; + + public string Token => _token; + private string _token; + + public long? ServerId => _voiceSocket.ServerId; + public long? ChannelId => _voiceSocket.ChannelId; + + public DiscordAudioClient(AudioService service, int id, Logger logger, GatewaySocket gatewaySocket) { _service = service; _id = id; _logger = logger; _gatewaySocket = gatewaySocket; - _voiceSocket = voiceSocket; + _voiceSocket = new VoiceWebSocket(service.Client, this, logger); - /*_voiceSocket.Connected += (s, e) => RaiseVoiceConnected(); + /*_voiceSocket.Connected += (s, e) => RaiseVoiceConnected(); _voiceSocket.Disconnected += async (s, e) => { _voiceSocket.CurrentServerId; @@ -37,7 +45,7 @@ namespace Discord.Audio await socket.Reconnect().ConfigureAwait(false); };*/ - /*_voiceSocket.IsSpeaking += (s, e) => + /*_voiceSocket.IsSpeaking += (s, e) => { if (_voiceSocket.State == WebSocketState.Connected) { @@ -54,27 +62,28 @@ namespace Discord.Audio } };*/ - /*this.Connected += (s, e) => + /*this.Connected += (s, e) => { _voiceSocket.ParentCancelToken = _cancelToken; };*/ - _gatewaySocket.ReceivedDispatch += async (s, e) => + _gatewaySocket.ReceivedDispatch += async (s, e) => { try { switch (e.Type) { case "VOICE_SERVER_UPDATE": - { - long serverId = IdConvert.ToLong(e.Payload.Value("guild_id")); + { + var data = e.Payload.ToObject(_gatewaySocket.Serializer); + long serverId = data.ServerId; if (serverId == ServerId) { var client = _service.Client; - string token = e.Payload.Value("token"); + _token = data.Token; _voiceSocket.Host = "wss://" + e.Payload.Value("endpoint").Split(':')[0]; - await _voiceSocket.Connect(client.CurrentUser.Id, _gatewaySocket.SessionId, token/*, client.CancelToken*/).ConfigureAwait(false); + await _voiceSocket.Connect().ConfigureAwait(false); } } break; diff --git a/src/Discord.Net.Audio/Net/WebSockets/VoiceWebSocket.cs b/src/Discord.Net.Audio/Net/WebSockets/VoiceWebSocket.cs index 3e8dba909..da03bff35 100644 --- a/src/Discord.Net.Audio/Net/WebSockets/VoiceWebSocket.cs +++ b/src/Discord.Net.Audio/Net/WebSockets/VoiceWebSocket.cs @@ -26,7 +26,8 @@ namespace Discord.Net.WebSockets //private readonly Random _rand; private readonly int _targetAudioBufferLength; private readonly ConcurrentDictionary _decoders; - private readonly AudioServiceConfig _audioConfig; + private readonly DiscordAudioClient _audioClient; + private readonly AudioServiceConfig _config; private OpusEncoder _encoder; private uint _ssrc; private ConcurrentDictionary _ssrcMapping; @@ -37,8 +38,8 @@ namespace Discord.Net.WebSockets private bool _isEncrypted; private byte[] _secretKey, _encodingBuffer; private ushort _sequence; - private long? _serverId, _channelId, _userId; - private string _sessionId, _token, _encryptionMode; + private long? _serverId, _channelId; + private string _encryptionMode; private int _ping; private Thread _sendThread, _receiveThread; @@ -48,24 +49,21 @@ namespace Discord.Net.WebSockets public int Ping => _ping; internal VoiceBuffer OutputBuffer => _sendBuffer; - public VoiceWebSocket(DiscordConfig config, AudioServiceConfig audioConfig, Logger logger) - : base(config, logger) + public VoiceWebSocket(DiscordClient client, DiscordAudioClient audioClient, Logger logger) + : base(client, logger) { - _audioConfig = audioConfig; - _decoders = new ConcurrentDictionary(); - _targetAudioBufferLength = _audioConfig.BufferLength / 20; //20 ms frames + _audioClient = audioClient; + _config = client.Audio().Config; + _decoders = new ConcurrentDictionary(); + _targetAudioBufferLength = _config.BufferLength / 20; //20 ms frames _encodingBuffer = new byte[MaxOpusSize]; _ssrcMapping = new ConcurrentDictionary(); - _encoder = new OpusEncoder(48000, _audioConfig.Channels, 20, _audioConfig.Bitrate, OpusApplication.Audio); - _sendBuffer = new VoiceBuffer((int)Math.Ceiling(_audioConfig.BufferLength / (double)_encoder.FrameLength), _encoder.FrameSize); + _encoder = new OpusEncoder(48000, _config.Channels, 20, _config.Bitrate, OpusApplication.Audio); + _sendBuffer = new VoiceBuffer((int)Math.Ceiling(_config.BufferLength / (double)_encoder.FrameLength), _encoder.FrameSize); } - public async Task Connect(long userId, string sessionId, string token) - { - _userId = userId; - _sessionId = sessionId; - _token = token; - + public async Task Connect() + { await BeginConnect().ConfigureAwait(false); } public async Task Reconnect() @@ -73,12 +71,12 @@ namespace Discord.Net.WebSockets try { var cancelToken = ParentCancelToken.Value; - await Task.Delay(_config.ReconnectDelay, cancelToken).ConfigureAwait(false); + await Task.Delay(_client.Config.ReconnectDelay, cancelToken).ConfigureAwait(false); while (!cancelToken.IsCancellationRequested) { try { - await Connect(_userId.Value, _sessionId, _token).ConfigureAwait(false); + await Connect().ConfigureAwait(false); break; } catch (OperationCanceledException) { throw; } @@ -86,29 +84,26 @@ namespace Discord.Net.WebSockets { _logger.Error("Reconnect failed", ex); //Net is down? We can keep trying to reconnect until the user runs Disconnect() - await Task.Delay(_config.FailedReconnectDelay, cancelToken).ConfigureAwait(false); + await Task.Delay(_client.Config.FailedReconnectDelay, cancelToken).ConfigureAwait(false); } } } catch (OperationCanceledException) { } } - public Task Disconnect() - { - return SignalDisconnect(wait: true); - } + public Task Disconnect() => _taskManager.Stop(); protected override async Task Run() { _udp = new UdpClient(new IPEndPoint(IPAddress.Any, 0)); List tasks = new List(); - if ((_audioConfig.Mode & AudioMode.Outgoing) != 0) + if ((_config.Mode & AudioMode.Outgoing) != 0) { _sendThread = new Thread(new ThreadStart(() => SendVoiceAsync(_cancelToken))); _sendThread.IsBackground = true; _sendThread.Start(); } - if ((_audioConfig.Mode & AudioMode.Incoming) != 0) + if ((_config.Mode & AudioMode.Incoming) != 0) { _receiveThread = new Thread(new ThreadStart(() => ReceiveVoiceAsync(_cancelToken))); _receiveThread.IsBackground = true; @@ -120,9 +115,9 @@ namespace Discord.Net.WebSockets #if !DOTNET5_4 tasks.Add(WatcherAsync()); #endif - await RunTasks(tasks.ToArray()); - - await Cleanup(); + tasks.AddRange(_engine.GetTasks(_cancelToken)); + tasks.Add(HeartbeatAsync(_cancelToken)); + await _taskManager.Start(tasks, _cancelTokenSource).ConfigureAwait(false); } protected override Task Cleanup() { @@ -141,12 +136,6 @@ namespace Discord.Net.WebSockets } ClearPCMFrames(); - if (!_wasDisconnectUnexpected) - { - _userId = null; - _sessionId = null; - _token = null; - } _udp = null; return base.Cleanup(); @@ -161,7 +150,7 @@ namespace Discord.Net.WebSockets int packetLength, resultOffset, resultLength; IPEndPoint endpoint = new IPEndPoint(IPAddress.Any, 0); - if ((_audioConfig.Mode & AudioMode.Incoming) != 0) + if ((_config.Mode & AudioMode.Incoming) != 0) { decodingBuffer = new byte[MaxOpusSize]; nonce = new byte[24]; @@ -188,7 +177,7 @@ namespace Discord.Net.WebSockets if (packetLength > 0 && endpoint.Equals(_endpoint)) { - if (_state != (int)WebSocketState.Connected) + if (_state != (int)ConnectionState.Connected) { if (packetLength != 70) return; @@ -197,8 +186,8 @@ namespace Discord.Net.WebSockets int port = packet[68] | packet[69] << 8; SendSelectProtocol(ip, port); - if ((_audioConfig.Mode & AudioMode.Incoming) == 0) - return; + if ((_config.Mode & AudioMode.Incoming) == 0) + return; //We dont need this thread anymore } else { @@ -258,7 +247,7 @@ namespace Discord.Net.WebSockets { try { - while (!cancelToken.IsCancellationRequested && _state != (int)WebSocketState.Connected) + while (!cancelToken.IsCancellationRequested && _state != (int)ConnectionState.Connected) Thread.Sleep(1); if (cancelToken.IsCancellationRequested) @@ -410,14 +399,15 @@ namespace Discord.Net.WebSockets { case VoiceOpCodes.Ready: { - if (_state != (int)WebSocketState.Connected) + if (_state != (int)ConnectionState.Connected) { var payload = (msg.Payload as JToken).ToObject(_serializer); _heartbeatInterval = payload.HeartbeatInterval; _ssrc = payload.SSRC; - _endpoint = new IPEndPoint((await Dns.GetHostAddressesAsync(Host.Replace("wss://", "")).ConfigureAwait(false)).FirstOrDefault(), payload.Port); + var address = (await Dns.GetHostAddressesAsync(Host.Replace("wss://", "")).ConfigureAwait(false)).FirstOrDefault(); + _endpoint = new IPEndPoint(address, payload.Port); - if (_audioConfig.EnableEncryption) + if (_config.EnableEncryption) { if (payload.Modes.Contains(EncryptedMode)) { @@ -458,7 +448,7 @@ namespace Discord.Net.WebSockets var payload = (msg.Payload as JToken).ToObject(_serializer); _secretKey = payload.SecretKey; SendIsTalking(true); - await EndConnect(); + EndConnect(); } break; case VoiceOpCodes.Speaking: @@ -507,9 +497,9 @@ namespace Discord.Net.WebSockets { var msg = new IdentifyCommand(); msg.Payload.ServerId = _serverId.Value; - msg.Payload.SessionId = _sessionId; - msg.Payload.Token = _token; - msg.Payload.UserId = _userId.Value; + msg.Payload.SessionId = _client.SessionId; + msg.Payload.Token = _audioClient.Token; + msg.Payload.UserId = _client.UserId.Value; QueueMessage(msg); } diff --git a/src/Discord.Net.Commands/Command.cs b/src/Discord.Net.Commands/Command.cs index e8d259408..28f414485 100644 --- a/src/Discord.Net.Commands/Command.cs +++ b/src/Discord.Net.Commands/Command.cs @@ -93,7 +93,7 @@ namespace Discord.Commands } internal void SetRunFunc(Action func) { - _runFunc = e => { func(e); return TaskHelper.CompletedTask; }; + _runFunc = TaskHelper.ToAsync(func); } internal Task Run(CommandEventArgs args) { diff --git a/src/Discord.Net.Net45/Discord.Net.csproj b/src/Discord.Net.Net45/Discord.Net.csproj index bba5dcb78..de69c937a 100644 --- a/src/Discord.Net.Net45/Discord.Net.csproj +++ b/src/Discord.Net.Net45/Discord.Net.csproj @@ -68,8 +68,8 @@ - - API\Converters\LongCollectionConverter.cs + + API\Converters\LongStringCollectionConverter.cs API\Converters\LongStringConverter.cs @@ -77,12 +77,12 @@ API\Endpoints.cs - - API\Enums\AvatarImageType.cs - API\Enums\ChannelType.cs + + API\Enums\ImageType.cs + API\Enums\PermissionTarget.cs @@ -176,6 +176,9 @@ Helpers\Reference.cs + + Helpers\TaskManager.cs + Models\Channel.cs diff --git a/src/Discord.Net.Shared/TaskHelper.cs b/src/Discord.Net.Shared/TaskHelper.cs index 4d0fb0ae5..82c70f500 100644 --- a/src/Discord.Net.Shared/TaskHelper.cs +++ b/src/Discord.Net.Shared/TaskHelper.cs @@ -1,4 +1,5 @@ -using System.Threading.Tasks; +using System; +using System.Threading.Tasks; namespace Discord { @@ -7,7 +8,26 @@ namespace Discord public static Task CompletedTask { get; } static TaskHelper() { +#if DOTNET54 + CompletedTask = Task.CompletedTask; +#else CompletedTask = Task.Delay(0); +#endif + } + + public static Func ToAsync(Action action) + { + return () => + { + action(); return CompletedTask; + }; + } + public static Func ToAsync(Action action) + { + return x => + { + action(x); return CompletedTask; + }; } } } diff --git a/src/Discord.Net/API/Converters/LongCollectionConverter.cs b/src/Discord.Net/API/Converters/LongStringCollectionConverter.cs similarity index 94% rename from src/Discord.Net/API/Converters/LongCollectionConverter.cs rename to src/Discord.Net/API/Converters/LongStringCollectionConverter.cs index 0d7d7bc83..253d964fe 100644 --- a/src/Discord.Net/API/Converters/LongCollectionConverter.cs +++ b/src/Discord.Net/API/Converters/LongStringCollectionConverter.cs @@ -4,7 +4,7 @@ using System.Collections.Generic; namespace Discord.API.Converters { - public class EnumerableLongStringConverter : JsonConverter + public class LongStringEnumerableConverter : JsonConverter { public override bool CanConvert(Type objectType) { @@ -38,7 +38,7 @@ namespace Discord.API.Converters } } - internal class LongArrayStringConverter : JsonConverter + internal class LongStringArrayConverter : JsonConverter { public override bool CanConvert(Type objectType) { diff --git a/src/Discord.Net/API/Enums/AvatarImageType.cs b/src/Discord.Net/API/Enums/ImageType.cs similarity index 100% rename from src/Discord.Net/API/Enums/AvatarImageType.cs rename to src/Discord.Net/API/Enums/ImageType.cs diff --git a/src/Discord.Net/API/Messages/Members.cs b/src/Discord.Net/API/Messages/Members.cs index d82591b71..75b9f5267 100644 --- a/src/Discord.Net/API/Messages/Members.cs +++ b/src/Discord.Net/API/Messages/Members.cs @@ -36,7 +36,7 @@ namespace Discord.API [JsonProperty("joined_at")] public DateTime? JoinedAt; [JsonProperty("roles")] - [JsonConverter(typeof(LongArrayStringConverter))] + [JsonConverter(typeof(LongStringArrayConverter))] public long[] Roles; } public class ExtendedMemberInfo : MemberInfo @@ -53,7 +53,7 @@ namespace Discord.API [JsonProperty("status")] public string Status; [JsonProperty("roles")] //TODO: Might be temporary - [JsonConverter(typeof(LongArrayStringConverter))] + [JsonConverter(typeof(LongStringArrayConverter))] public long[] Roles; } public class VoiceMemberInfo : MemberReference @@ -88,7 +88,7 @@ namespace Discord.API [JsonConverter(typeof(NullableLongStringConverter))] public long? ChannelId; [JsonProperty("roles", NullValueHandling = NullValueHandling.Ignore)] - [JsonConverter(typeof(EnumerableLongStringConverter))] + [JsonConverter(typeof(LongStringEnumerableConverter))] public IEnumerable Roles; } diff --git a/src/Discord.Net/API/Messages/Messages.cs b/src/Discord.Net/API/Messages/Messages.cs index 928847aa5..6dd4a54ef 100644 --- a/src/Discord.Net/API/Messages/Messages.cs +++ b/src/Discord.Net/API/Messages/Messages.cs @@ -108,7 +108,7 @@ namespace Discord.API [JsonProperty("content")] public string Content; [JsonProperty("mentions")] - [JsonConverter(typeof(EnumerableLongStringConverter))] + [JsonConverter(typeof(LongStringEnumerableConverter))] public IEnumerable Mentions; [JsonProperty("nonce", NullValueHandling = NullValueHandling.Ignore)] public string Nonce; @@ -123,7 +123,7 @@ namespace Discord.API [JsonProperty("content", NullValueHandling = NullValueHandling.Ignore)] public string Content; [JsonProperty("mentions", NullValueHandling = NullValueHandling.Ignore)] - [JsonConverter(typeof(EnumerableLongStringConverter))] + [JsonConverter(typeof(LongStringEnumerableConverter))] public IEnumerable Mentions; } public sealed class EditMessageResponse : MessageInfo { } diff --git a/src/Discord.Net/DiscordClient.Messages.cs b/src/Discord.Net/DiscordClient.Messages.cs index a21fcbddc..d109c8f17 100644 --- a/src/Discord.Net/DiscordClient.Messages.cs +++ b/src/Discord.Net/DiscordClient.Messages.cs @@ -380,10 +380,11 @@ namespace Discord else { await _api.EditMessage( - msg.Id, - msg.Channel.Id, - queuedMessage.Text, - queuedMessage.MentionedUsers); + msg.Id, + msg.Channel.Id, + queuedMessage.Text, + queuedMessage.MentionedUsers) + .ConfigureAwait(false); } } catch (WebException) { break; } diff --git a/src/Discord.Net/DiscordClient.Roles.cs b/src/Discord.Net/DiscordClient.Roles.cs index 7ecb66b9a..7554d288f 100644 --- a/src/Discord.Net/DiscordClient.Roles.cs +++ b/src/Discord.Net/DiscordClient.Roles.cs @@ -140,7 +140,7 @@ namespace Discord if (role == null) throw new ArgumentNullException(nameof(role)); CheckReady(); - try { await _api.DeleteRole(role.Server.Id, role.Id); } + try { await _api.DeleteRole(role.Server.Id, role.Id).ConfigureAwait(false); } catch (HttpException ex) when (ex.StatusCode == HttpStatusCode.NotFound) { } } diff --git a/src/Discord.Net/DiscordClient.Servers.cs b/src/Discord.Net/DiscordClient.Servers.cs index 9a0d34699..160cb3509 100644 --- a/src/Discord.Net/DiscordClient.Servers.cs +++ b/src/Discord.Net/DiscordClient.Servers.cs @@ -118,7 +118,8 @@ namespace Discord { CheckReady(); - return (await _api.GetVoiceRegions()).Select(x => new Region { Id = x.Id, Name = x.Name, Hostname = x.Hostname, Port = x.Port }); + var regions = await _api.GetVoiceRegions().ConfigureAwait(false); + return regions.Select(x => new Region { Id = x.Id, Name = x.Name, Hostname = x.Hostname, Port = x.Port }); } } } \ No newline at end of file diff --git a/src/Discord.Net/DiscordClient.Users.cs b/src/Discord.Net/DiscordClient.Users.cs index 4373ba91e..d221028a3 100644 --- a/src/Discord.Net/DiscordClient.Users.cs +++ b/src/Discord.Net/DiscordClient.Users.cs @@ -254,7 +254,7 @@ namespace Discord if (days <= 0) throw new ArgumentOutOfRangeException(nameof(days)); CheckReady(); - var response = await _api.PruneUsers(server.Id, days, simulate); + var response = await _api.PruneUsers(server.Id, days, simulate).ConfigureAwait(false); return response.Pruned ?? 0; } @@ -275,11 +275,11 @@ namespace Discord await _api.EditProfile(currentPassword: currentPassword, username: username ?? _privateUser?.Name, email: email ?? _privateUser?.Global.Email, password: password, - avatar: avatar, avatarType: avatarType, existingAvatar: _privateUser?.AvatarId); + avatar: avatar, avatarType: avatarType, existingAvatar: _privateUser?.AvatarId).ConfigureAwait(false); if (password != null) { - var loginResponse = await _api.Login(_privateUser.Global.Email, password); + var loginResponse = await _api.Login(_privateUser.Global.Email, password).ConfigureAwait(false); _api.Token = loginResponse.Token; } } diff --git a/src/Discord.Net/DiscordClient.cs b/src/Discord.Net/DiscordClient.cs index f7797dd57..5fb6450d4 100644 --- a/src/Discord.Net/DiscordClient.cs +++ b/src/Discord.Net/DiscordClient.cs @@ -11,15 +11,15 @@ using System.Threading.Tasks; namespace Discord { - public enum DiscordClientState : byte - { - Disconnected, - Connecting, - Connected, - Disconnecting - } + public enum ConnectionState : byte + { + Disconnected, + Connecting, + Connected, + Disconnecting + } - public class DisconnectedEventArgs : EventArgs + public class DisconnectedEventArgs : EventArgs { public readonly bool WasUnexpected; public readonly Exception Error; @@ -51,25 +51,24 @@ namespace Discord { public static readonly string Version = typeof(DiscordClient).GetTypeInfo().Assembly.GetName().Version.ToString(3); - private readonly ManualResetEvent _disconnectedEvent; + private readonly LogService _log; + private readonly Logger _logger, _restLogger, _cacheLogger; + private readonly Dictionary _singletons; + private readonly object _cacheLock; + private readonly Semaphore _lock; + private readonly ManualResetEvent _disconnectedEvent; private readonly ManualResetEventSlim _connectedEvent; - private readonly Dictionary _singletons; - private readonly LogService _log; - private readonly object _cacheLock; - private Logger _logger, _restLogger, _cacheLogger; + private readonly TaskManager _taskManager; private bool _sentInitialLog; private UserStatus _status; private int? _gameId; - private Task _runTask; - private ExceptionDispatchInfo _disconnectReason; - private bool _wasDisconnectUnexpected; /// Returns the configuration object used to make this client. Note that this object cannot be edited directly - to change the configuration of this client, use the DiscordClient(DiscordClientConfig config) constructor. public DiscordConfig Config => _config; private readonly DiscordConfig _config; /// Returns the current connection state of this client. - public DiscordClientState State => (DiscordClientState)_state; + public ConnectionState State => (ConnectionState)_state; private int _state; /// Gives direct access to the underlying DiscordAPIClient. This can be used to modify objects not in cache. @@ -82,12 +81,17 @@ namespace Discord public string GatewayUrl => _gateway; private string _gateway; - + public string Token => _token; private string _token; - /// Returns a cancellation token that triggers when the client is manually disconnected. - public CancellationToken CancelToken => _cancelToken; + public string SessionId => _sessionId; + private string _sessionId; + + public long? UserId => _privateUser?.Id; + + /// Returns a cancellation token that triggers when the client is manually disconnected. + public CancellationToken CancelToken => _cancelToken; private CancellationTokenSource _cancelTokenSource; private CancellationToken _cancelToken; @@ -111,35 +115,37 @@ namespace Discord _config.Lock(); _nonceRand = new Random(); - _state = (int)DiscordClientState.Disconnected; + _state = (int)ConnectionState.Disconnected; _status = UserStatus.Online; //Services _singletons = new Dictionary(); _log = AddService(new LogService()); - CreateMainLogger(); + _logger = CreateMainLogger(); - //Async + //Async + _lock = new Semaphore(1, 1); + _taskManager = new TaskManager(Cleanup); _cancelToken = new CancellationToken(true); _disconnectedEvent = new ManualResetEvent(true); _connectedEvent = new ManualResetEventSlim(false); - - //Cache - _cacheLock = new object(); + + //Cache + _cacheLock = new object(); _channels = new Channels(this, _cacheLock); _users = new Users(this, _cacheLock); _messages = new Messages(this, _cacheLock, Config.MessageCacheSize > 0); _roles = new Roles(this, _cacheLock); _servers = new Servers(this, _cacheLock); _globalUsers = new GlobalUsers(this, _cacheLock); - CreateCacheLogger(); + _cacheLogger = CreateCacheLogger(); - //Networking - _webSocket = new GatewaySocket(_config, _log.CreateLogger("WebSocket")); + //Networking + _webSocket = new GatewaySocket(this, _log.CreateLogger("WebSocket")); var settings = new JsonSerializerSettings(); _webSocket.Connected += (s, e) => { - if (_state == (int)DiscordClientState.Connecting) + if (_state == (int)ConnectionState.Connecting) EndConnect(); }; _webSocket.Disconnected += (s, e) => @@ -157,88 +163,94 @@ namespace Discord _api.CancelToken = _cancelToken; await SendStatus().ConfigureAwait(false); }; - CreateRestLogger(); + _restLogger = CreateRestLogger(); //Import/Export _messageImporter = new JsonSerializer(); _messageImporter.ContractResolver = new Message.ImportResolver(); } - private void CreateMainLogger() - { - _logger = _log.CreateLogger("Client"); - if (_logger.Level >= LogSeverity.Info) - { - JoinedServer += (s, e) => _logger.Info($"Server Created: {e.Server?.Name ?? "[Private]"}"); - LeftServer += (s, e) => _logger.Info($"Server Destroyed: {e.Server?.Name ?? "[Private]"}"); - ServerUpdated += (s, e) => _logger.Info($"Server Updated: {e.Server?.Name ?? "[Private]"}"); - ServerAvailable += (s, e) => _logger.Info($"Server Available: {e.Server?.Name ?? "[Private]"}"); - ServerUnavailable += (s, e) => _logger.Info($"Server Unavailable: {e.Server?.Name ?? "[Private]"}"); - ChannelCreated += (s, e) => _logger.Info($"Channel Created: {e.Server?.Name ?? "[Private]"}/{e.Channel?.Name}"); - ChannelDestroyed += (s, e) => _logger.Info($"Channel Destroyed: {e.Server?.Name ?? "[Private]"}/{e.Channel?.Name}"); - ChannelUpdated += (s, e) => _logger.Info($"Channel Updated: {e.Server?.Name ?? "[Private]"}/{e.Channel?.Name}"); - MessageReceived += (s, e) => _logger.Info($"Message Received: {e.Server?.Name ?? "[Private]"}/{e.Channel?.Name}/{e.Message?.Id}"); - MessageDeleted += (s, e) => _logger.Info($"Message Deleted: {e.Server?.Name ?? "[Private]"}/{e.Channel?.Name}/{e.Message?.Id}"); - MessageUpdated += (s, e) => _logger.Info($"Message Update: {e.Server?.Name ?? "[Private]"}/{e.Channel?.Name}/{e.Message?.Id}"); - RoleCreated += (s, e) => _logger.Info($"Role Created: {e.Server?.Name ?? "[Private]"}/{e.Role?.Name}"); - RoleUpdated += (s, e) => _logger.Info($"Role Updated: {e.Server?.Name ?? "[Private]"}/{e.Role?.Name}"); - RoleDeleted += (s, e) => _logger.Info($"Role Deleted: {e.Server?.Name ?? "[Private]"}/{e.Role?.Name}"); - UserBanned += (s, e) => _logger.Info($"Banned User: {e.Server?.Name ?? "[Private]" }/{e.UserId}"); - UserUnbanned += (s, e) => _logger.Info($"Unbanned User: {e.Server?.Name ?? "[Private]"}/{e.UserId}"); - UserJoined += (s, e) => _logger.Info($"User Joined: {e.Server?.Name ?? "[Private]"}/{e.User.Name}"); - UserLeft += (s, e) => _logger.Info($"User Left: {e.Server?.Name ?? "[Private]"}/{e.User.Name}"); - UserUpdated += (s, e) => _logger.Info($"User Updated: {e.Server?.Name ?? "[Private]"}/{e.User.Name}"); - UserVoiceStateUpdated += (s, e) => _logger.Info($"Voice Updated: {e.Server?.Name ?? "[Private]"}/{e.User.Name}"); - ProfileUpdated += (s, e) => _logger.Info("Profile Updated"); + private Logger CreateMainLogger() + { + Logger logger = null; + if (_log.Level >= LogSeverity.Info) + { + logger = _log.CreateLogger("Client"); + JoinedServer += (s, e) => logger.Info($"Server Created: {e.Server?.Name ?? "[Private]"}"); + LeftServer += (s, e) => logger.Info($"Server Destroyed: {e.Server?.Name ?? "[Private]"}"); + ServerUpdated += (s, e) => logger.Info($"Server Updated: {e.Server?.Name ?? "[Private]"}"); + ServerAvailable += (s, e) => logger.Info($"Server Available: {e.Server?.Name ?? "[Private]"}"); + ServerUnavailable += (s, e) => logger.Info($"Server Unavailable: {e.Server?.Name ?? "[Private]"}"); + ChannelCreated += (s, e) => logger.Info($"Channel Created: {e.Server?.Name ?? "[Private]"}/{e.Channel?.Name}"); + ChannelDestroyed += (s, e) => logger.Info($"Channel Destroyed: {e.Server?.Name ?? "[Private]"}/{e.Channel?.Name}"); + ChannelUpdated += (s, e) => logger.Info($"Channel Updated: {e.Server?.Name ?? "[Private]"}/{e.Channel?.Name}"); + MessageReceived += (s, e) => logger.Info($"Message Received: {e.Server?.Name ?? "[Private]"}/{e.Channel?.Name}/{e.Message?.Id}"); + MessageDeleted += (s, e) => logger.Info($"Message Deleted: {e.Server?.Name ?? "[Private]"}/{e.Channel?.Name}/{e.Message?.Id}"); + MessageUpdated += (s, e) => logger.Info($"Message Update: {e.Server?.Name ?? "[Private]"}/{e.Channel?.Name}/{e.Message?.Id}"); + RoleCreated += (s, e) => logger.Info($"Role Created: {e.Server?.Name ?? "[Private]"}/{e.Role?.Name}"); + RoleUpdated += (s, e) => logger.Info($"Role Updated: {e.Server?.Name ?? "[Private]"}/{e.Role?.Name}"); + RoleDeleted += (s, e) => logger.Info($"Role Deleted: {e.Server?.Name ?? "[Private]"}/{e.Role?.Name}"); + UserBanned += (s, e) => logger.Info($"Banned User: {e.Server?.Name ?? "[Private]" }/{e.UserId}"); + UserUnbanned += (s, e) => logger.Info($"Unbanned User: {e.Server?.Name ?? "[Private]"}/{e.UserId}"); + UserJoined += (s, e) => logger.Info($"User Joined: {e.Server?.Name ?? "[Private]"}/{e.User.Name}"); + UserLeft += (s, e) => logger.Info($"User Left: {e.Server?.Name ?? "[Private]"}/{e.User.Name}"); + UserUpdated += (s, e) => logger.Info($"User Updated: {e.Server?.Name ?? "[Private]"}/{e.User.Name}"); + UserVoiceStateUpdated += (s, e) => logger.Info($"Voice Updated: {e.Server?.Name ?? "[Private]"}/{e.User.Name}"); + ProfileUpdated += (s, e) => logger.Info("Profile Updated"); } if (_log.Level >= LogSeverity.Verbose) { - UserIsTypingUpdated += (s, e) => _logger.Verbose($"Is Typing: {e.Server?.Name ?? "[Private]"}/{e.Channel?.Name}/{e.User?.Name}"); - MessageAcknowledged += (s, e) => _logger.Verbose($"Ack Message: {e.Server?.Name ?? "[Private]"}/{e.Channel?.Name}/{e.Message?.Id}"); - MessageSent += (s, e) => _logger.Verbose($"Sent Message: {e.Server?.Name ?? "[Private]"}/{e.Channel?.Name}/{e.Message?.Id}"); - UserPresenceUpdated += (s, e) => _logger.Verbose($"Presence Updated: {e.Server?.Name ?? "[Private]"}/{e.User?.Name}"); + UserIsTypingUpdated += (s, e) => logger.Verbose($"Is Typing: {e.Server?.Name ?? "[Private]"}/{e.Channel?.Name}/{e.User?.Name}"); + MessageAcknowledged += (s, e) => logger.Verbose($"Ack Message: {e.Server?.Name ?? "[Private]"}/{e.Channel?.Name}/{e.Message?.Id}"); + MessageSent += (s, e) => logger.Verbose($"Sent Message: {e.Server?.Name ?? "[Private]"}/{e.Channel?.Name}/{e.Message?.Id}"); + UserPresenceUpdated += (s, e) => logger.Verbose($"Presence Updated: {e.Server?.Name ?? "[Private]"}/{e.User?.Name}"); } + return logger; } - private void CreateRestLogger() - { - _restLogger = _log.CreateLogger("Rest"); - if (_log.Level >= LogSeverity.Verbose) - { - _api.RestClient.OnRequest += (s, e) => + private Logger CreateRestLogger() + { + Logger logger = null; + if (_log.Level >= LogSeverity.Verbose) + { + logger = _log.CreateLogger("Rest"); + _api.RestClient.OnRequest += (s, e) => { if (e.Payload != null) - _restLogger.Verbose( $"{e.Method} {e.Path}: {Math.Round(e.ElapsedMilliseconds, 2)} ms ({e.Payload})"); + logger.Verbose( $"{e.Method} {e.Path}: {Math.Round(e.ElapsedMilliseconds, 2)} ms ({e.Payload})"); else - _restLogger.Verbose( $"{e.Method} {e.Path}: {Math.Round(e.ElapsedMilliseconds, 2)} ms"); + logger.Verbose( $"{e.Method} {e.Path}: {Math.Round(e.ElapsedMilliseconds, 2)} ms"); }; - } - } - private void CreateCacheLogger() - { - _cacheLogger = _log.CreateLogger("Cache"); - if (_log.Level >= LogSeverity.Debug) - { - _channels.ItemCreated += (s, e) => _cacheLogger.Debug( $"Created Channel {IdConvert.ToString(e.Item.Server?.Id) ?? "[Private]"}/{e.Item.Id}"); - _channels.ItemDestroyed += (s, e) => _cacheLogger.Debug( $"Destroyed Channel {IdConvert.ToString(e.Item.Server?.Id) ?? "[Private]"}/{e.Item.Id}"); - _channels.Cleared += (s, e) => _cacheLogger.Debug( $"Cleared Channels"); - _users.ItemCreated += (s, e) => _cacheLogger.Debug( $"Created User {IdConvert.ToString(e.Item.Server?.Id) ?? "[Private]"}/{e.Item.Id}"); - _users.ItemDestroyed += (s, e) => _cacheLogger.Debug( $"Destroyed User {IdConvert.ToString(e.Item.Server?.Id) ?? "[Private]"}/{e.Item.Id}"); - _users.Cleared += (s, e) => _cacheLogger.Debug( $"Cleared Users"); - _messages.ItemCreated += (s, e) => _cacheLogger.Debug( $"Created Message {IdConvert.ToString(e.Item.Server?.Id) ?? "[Private]"}/{e.Item.Channel.Id}/{e.Item.Id}"); - _messages.ItemDestroyed += (s, e) => _cacheLogger.Debug( $"Destroyed Message {IdConvert.ToString(e.Item.Server?.Id) ?? "[Private]"}/{e.Item.Channel.Id}/{e.Item.Id}"); - _messages.ItemRemapped += (s, e) => _cacheLogger.Debug( $"Remapped Message {IdConvert.ToString(e.Item.Server?.Id) ?? "[Private]"}/{e.Item.Channel.Id}/[{e.OldId} -> {e.NewId}]"); - _messages.Cleared += (s, e) => _cacheLogger.Debug( $"Cleared Messages"); - _roles.ItemCreated += (s, e) => _cacheLogger.Debug( $"Created Role {IdConvert.ToString(e.Item.Server?.Id) ?? "[Private]"}/{e.Item.Id}"); - _roles.ItemDestroyed += (s, e) => _cacheLogger.Debug( $"Destroyed Role {IdConvert.ToString(e.Item.Server?.Id) ?? "[Private]"}/{e.Item.Id}"); - _roles.Cleared += (s, e) => _cacheLogger.Debug( $"Cleared Roles"); - _servers.ItemCreated += (s, e) => _cacheLogger.Debug( $"Created Server {e.Item.Id}"); - _servers.ItemDestroyed += (s, e) => _cacheLogger.Debug( $"Destroyed Server {e.Item.Id}"); - _servers.Cleared += (s, e) => _cacheLogger.Debug( $"Cleared Servers"); - _globalUsers.ItemCreated += (s, e) => _cacheLogger.Debug( $"Created User {e.Item.Id}"); - _globalUsers.ItemDestroyed += (s, e) => _cacheLogger.Debug( $"Destroyed User {e.Item.Id}"); - _globalUsers.Cleared += (s, e) => _cacheLogger.Debug( $"Cleared Users"); - } - } + } + return logger; + } + private Logger CreateCacheLogger() + { + Logger logger = null; + if (_log.Level >= LogSeverity.Debug) + { + logger = _log.CreateLogger("Cache"); + _channels.ItemCreated += (s, e) => logger.Debug( $"Created Channel {IdConvert.ToString(e.Item.Server?.Id) ?? "[Private]"}/{e.Item.Id}"); + _channels.ItemDestroyed += (s, e) => logger.Debug( $"Destroyed Channel {IdConvert.ToString(e.Item.Server?.Id) ?? "[Private]"}/{e.Item.Id}"); + _channels.Cleared += (s, e) => logger.Debug( $"Cleared Channels"); + _users.ItemCreated += (s, e) => logger.Debug( $"Created User {IdConvert.ToString(e.Item.Server?.Id) ?? "[Private]"}/{e.Item.Id}"); + _users.ItemDestroyed += (s, e) => logger.Debug( $"Destroyed User {IdConvert.ToString(e.Item.Server?.Id) ?? "[Private]"}/{e.Item.Id}"); + _users.Cleared += (s, e) => logger.Debug( $"Cleared Users"); + _messages.ItemCreated += (s, e) => logger.Debug( $"Created Message {IdConvert.ToString(e.Item.Server?.Id) ?? "[Private]"}/{e.Item.Channel.Id}/{e.Item.Id}"); + _messages.ItemDestroyed += (s, e) => logger.Debug( $"Destroyed Message {IdConvert.ToString(e.Item.Server?.Id) ?? "[Private]"}/{e.Item.Channel.Id}/{e.Item.Id}"); + _messages.ItemRemapped += (s, e) => logger.Debug( $"Remapped Message {IdConvert.ToString(e.Item.Server?.Id) ?? "[Private]"}/{e.Item.Channel.Id}/[{e.OldId} -> {e.NewId}]"); + _messages.Cleared += (s, e) => logger.Debug( $"Cleared Messages"); + _roles.ItemCreated += (s, e) => logger.Debug( $"Created Role {IdConvert.ToString(e.Item.Server?.Id) ?? "[Private]"}/{e.Item.Id}"); + _roles.ItemDestroyed += (s, e) => logger.Debug( $"Destroyed Role {IdConvert.ToString(e.Item.Server?.Id) ?? "[Private]"}/{e.Item.Id}"); + _roles.Cleared += (s, e) => logger.Debug( $"Cleared Roles"); + _servers.ItemCreated += (s, e) => logger.Debug( $"Created Server {e.Item.Id}"); + _servers.ItemDestroyed += (s, e) => logger.Debug( $"Destroyed Server {e.Item.Id}"); + _servers.Cleared += (s, e) => logger.Debug( $"Cleared Servers"); + _globalUsers.ItemCreated += (s, e) => logger.Debug( $"Created User {e.Item.Id}"); + _globalUsers.ItemDestroyed += (s, e) => logger.Debug( $"Destroyed User {e.Item.Id}"); + _globalUsers.Cleared += (s, e) => logger.Debug( $"Cleared Users"); + } + return logger; + } /// Connects to the Discord server with the provided email and password. /// Returns a token for future connections. @@ -247,17 +259,16 @@ namespace Discord if (!_sentInitialLog) SendInitialLog(); - if (State != DiscordClientState.Disconnected) + if (State != ConnectionState.Disconnected) await Disconnect().ConfigureAwait(false); - var response = await _api.Login(email, password) - .ConfigureAwait(false); + var response = await _api.Login(email, password).ConfigureAwait(false); _token = response.Token; _api.Token = response.Token; if (_config.LogLevel >= LogSeverity.Verbose) _logger.Verbose( "Login successful, got token."); - await BeginConnect(); + await BeginConnect().ConfigureAwait(false); return response.Token; } /// Connects to the Discord server with the provided token. @@ -266,48 +277,62 @@ namespace Discord if (!_sentInitialLog) SendInitialLog(); - if (State != (int)DiscordClientState.Disconnected) + if (State != (int)ConnectionState.Disconnected) await Disconnect().ConfigureAwait(false); _token = token; _api.Token = token; - await BeginConnect(); + await BeginConnect().ConfigureAwait(false); } private async Task BeginConnect() { try { - _state = (int)DiscordClientState.Connecting; - - var gatewayResponse = await _api.Gateway().ConfigureAwait(false); - string gateway = gatewayResponse.Url; - if (_config.LogLevel >= LogSeverity.Verbose) - _logger.Verbose( $"Websocket endpoint: {gateway}"); - - _disconnectedEvent.Reset(); - - _gateway = gateway; - - _cancelTokenSource = new CancellationTokenSource(); - _cancelToken = _cancelTokenSource.Token; - - _webSocket.Host = gateway; - _webSocket.ParentCancelToken = _cancelToken; - await _webSocket.Connect(_token).ConfigureAwait(false); - - _runTask = RunTasks(); - + _lock.WaitOne(); try { - //Cancel if either Disconnect is called, data socket errors or timeout is reached - var cancelToken = CancellationTokenSource.CreateLinkedTokenSource(_cancelToken, _webSocket.CancelToken).Token; - _connectedEvent.Wait(cancelToken); + await _taskManager.Stop().ConfigureAwait(false); + _state = (int)ConnectionState.Connecting; + + var gatewayResponse = await _api.Gateway().ConfigureAwait(false); + string gateway = gatewayResponse.Url; + if (_config.LogLevel >= LogSeverity.Verbose) + _logger.Verbose( $"Websocket endpoint: {gateway}"); + + _disconnectedEvent.Reset(); + + _gateway = gateway; + + _cancelTokenSource = new CancellationTokenSource(); + _cancelToken = _cancelTokenSource.Token; + + _webSocket.Host = gateway; + _webSocket.ParentCancelToken = _cancelToken; + await _webSocket.Connect().ConfigureAwait(false); + + List tasks = new List(); + tasks.Add(_cancelToken.Wait()); + if (_config.UseMessageQueue) + tasks.Add(MessageQueueAsync()); + + await _taskManager.Start(tasks, _cancelTokenSource).ConfigureAwait(false); + + try + { + //Cancel if either Disconnect is called, data socket errors or timeout is reached + var cancelToken = CancellationTokenSource.CreateLinkedTokenSource(_cancelToken, _webSocket.CancelToken).Token; + _connectedEvent.Wait(cancelToken); + } + catch (OperationCanceledException) + { + _webSocket.ThrowError(); //Throws data socket's internal error if any occured + throw; + } } - catch (OperationCanceledException) + finally { - _webSocket.ThrowError(); //Throws data socket's internal error if any occured - throw; + _lock.Release(); } } catch @@ -318,88 +343,15 @@ namespace Discord } private void EndConnect() { - _state = (int)DiscordClientState.Connected; + _state = (int)ConnectionState.Connected; _connectedEvent.Set(); RaiseConnected(); } - /// Disconnects from the Discord server, canceling any pending requests. - public Task Disconnect() => SignalDisconnect(new Exception("Disconnect was requested by user."), isUnexpected: false); - private async Task SignalDisconnect(Exception ex = null, bool isUnexpected = true, bool wait = false) - { - int oldState; - bool hasWriterLock; - - //If in either connecting or connected state, get a lock by being the first to switch to disconnecting - oldState = Interlocked.CompareExchange(ref _state, (int)DiscordClientState.Disconnecting, (int)DiscordClientState.Connecting); - if (oldState == (int)DiscordClientState.Disconnected) return; //Already disconnected - hasWriterLock = oldState == (int)DiscordClientState.Connecting; //Caused state change - if (!hasWriterLock) - { - oldState = Interlocked.CompareExchange(ref _state, (int)DiscordClientState.Disconnecting, (int)DiscordClientState.Connected); - if (oldState == (int)DiscordClientState.Disconnected) return; //Already disconnected - hasWriterLock = oldState == (int)DiscordClientState.Connected; //Caused state change - } - - if (hasWriterLock) - { - _wasDisconnectUnexpected = isUnexpected; - _disconnectReason = ex != null ? ExceptionDispatchInfo.Capture(ex) : null; - - _cancelTokenSource.Cancel(); - /*if (_disconnectState == DiscordClientState.Connecting) //_runTask was never made - await Cleanup().ConfigureAwait(false);*/ - } - - if (wait) - { - Task task = _runTask; - if (_runTask != null) - await task.ConfigureAwait(false); - } - } - - private async Task RunTasks() - { - List tasks = new List(); - tasks.Add(_cancelToken.Wait()); - if (_config.UseMessageQueue) - tasks.Add(MessageQueueAsync()); - - Task[] tasksArray = tasks.ToArray(); - Task firstTask = Task.WhenAny(tasksArray); - Task allTasks = Task.WhenAll(tasksArray); - - //Wait until the first task ends/errors and capture the error - try { await firstTask.ConfigureAwait(false); } - catch (Exception ex) { await SignalDisconnect(ex: ex, wait: true).ConfigureAwait(false); } - - //Ensure all other tasks are signaled to end. - await SignalDisconnect(wait: true).ConfigureAwait(false); - - //Wait for the remaining tasks to complete - try { await allTasks.ConfigureAwait(false); } - catch { } - - //Start cleanup - var wasDisconnectUnexpected = _wasDisconnectUnexpected; - _wasDisconnectUnexpected = false; - - await _webSocket.SignalDisconnect().ConfigureAwait(false); - - _privateUser = null; - _gateway = null; - _token = null; - - if (!wasDisconnectUnexpected) - { - _state = (int)DiscordClientState.Disconnected; - _disconnectedEvent.Set(); - } - _connectedEvent.Reset(); - _runTask = null; - } - private async Task Stop() + /// Disconnects from the Discord server, canceling any pending requests. + public Task Disconnect() => _taskManager.Stop(); + + private async Task Cleanup() { if (Config.UseMessageQueue) { @@ -417,7 +369,13 @@ namespace Discord _globalUsers.Clear(); _privateUser = null; - } + _gateway = null; + _token = null; + + _state = (int)ConnectionState.Disconnected; + _disconnectedEvent.Set(); + _connectedEvent.Reset(); + } private void OnReceivedEvent(WebSocketEventEventArgs e) { @@ -429,7 +387,8 @@ namespace Discord case "READY": //Resync { var data = e.Payload.ToObject(_webSocket.Serializer); - _privateUser = _users.GetOrAdd(data.User.Id, null); + _sessionId = data.SessionId; + _privateUser = _users.GetOrAdd(data.User.Id, null); _privateUser.Update(data.User); _privateUser.Global.Update(data.User); foreach (var model in data.Guilds) @@ -863,11 +822,11 @@ namespace Discord { switch (_state) { - case (int)DiscordClientState.Disconnecting: + case (int)ConnectionState.Disconnecting: throw new InvalidOperationException("The client is disconnecting."); - case (int)DiscordClientState.Disconnected: + case (int)ConnectionState.Disconnected: throw new InvalidOperationException("The client is not connected to Discord"); - case (int)DiscordClientState.Connecting: + case (int)ConnectionState.Connecting: throw new InvalidOperationException("The client is connecting."); } } diff --git a/src/Discord.Net/Helpers/Mention.cs b/src/Discord.Net/Helpers/Mention.cs index 12037f4ce..a3bdfd5df 100644 --- a/src/Discord.Net/Helpers/Mention.cs +++ b/src/Discord.Net/Helpers/Mention.cs @@ -23,7 +23,7 @@ namespace Discord public static string Channel(Channel channel) => $"<#{channel.Id}>"; /// Returns the string used to create a mention to everyone in a channel. - [Obsolete("Use Role.Mention instead")] + [Obsolete("Use Server.EveryoneRole.Mention instead")] public static string Everyone() => $"@everyone"; diff --git a/src/Discord.Net/Helpers/TaskManager.cs b/src/Discord.Net/Helpers/TaskManager.cs new file mode 100644 index 000000000..d21c4e207 --- /dev/null +++ b/src/Discord.Net/Helpers/TaskManager.cs @@ -0,0 +1,148 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Runtime.ExceptionServices; +using System.Threading; +using System.Threading.Tasks; + +namespace Discord +{ + /// Helper class used to manage several tasks and keep them in sync. If any single task errors or stops, all other tasks will also be stopped. + public class TaskManager + { + private readonly object _lock; + private readonly Func _stopAction; + + private CancellationTokenSource _cancelSource; + private Task _task; + + public bool WasUnexpected => _wasUnexpected; + private bool _wasUnexpected; + + public Exception Exception => _stopReason.SourceException; + private ExceptionDispatchInfo _stopReason; + + public TaskManager() + { + _lock = new object(); + } + public TaskManager(Action stopAction) + : this() + { + _stopAction = TaskHelper.ToAsync(stopAction); + } + public TaskManager(Func stopAction) + : this() + { + _stopAction = stopAction; + } + + public async Task Start(IEnumerable tasks, CancellationTokenSource cancelSource) + { + while (true) + { + var task = _task; + if (task != null) + await Stop().ConfigureAwait(false); + + lock (_lock) + { + _cancelSource = new CancellationTokenSource(); + + if (_task != null) + continue; //Another thread sneaked in and started this manager before we got a lock, loop and try again + + _stopReason = null; + _wasUnexpected = false; + + Task[] tasksArray = tasks.ToArray(); + Task anyTask = Task.WhenAny(tasksArray); + Task allTasks = Task.WhenAll(tasksArray); + + _task = Task.Run(async () => + { + //Wait for the first task to stop or error + Task firstTask = await anyTask.ConfigureAwait(false); + + //Signal the rest of the tasks to stop + if (firstTask.Exception != null) + SignalError(firstTask.Exception.GetBaseException(), true); + else + SignalStop(); + + //Wait for the other tasks; + await allTasks.ConfigureAwait(false); + + //Run the cleanup function within our lock + await _stopAction().ConfigureAwait(false); + _task = null; + }); + return; + } + } + } + + public void SignalStop() + { + lock (_lock) + { + if (_task == null) return; //Are we running? + if (_cancelSource.IsCancellationRequested) return; + + _cancelSource.Cancel(); + } + } + public Task Stop() + { + Task task; + lock (_lock) + { + //Cache the task so we still have something to await if Cleanup is run really quickly + task = _task; + if (task == null) return TaskHelper.CompletedTask; //Are we running? + if (_cancelSource.IsCancellationRequested) return task; + + _cancelSource.Cancel(); + } + return task; + } + + public void SignalError(Exception ex, bool isUnexpected = true) + { + lock (_lock) + { + if (_task == null) return; //Are we running? + + _cancelSource.Cancel(); + _stopReason = ExceptionDispatchInfo.Capture(ex); + _wasUnexpected = isUnexpected; + } + } + public Task Error(Exception ex, bool isUnexpected = true) + { + Task task; + lock (_lock) + { + //Cache the task so we still have something to await if Cleanup is run really quickly + task = _task; + if (task == null) return TaskHelper.CompletedTask; //Are we running? + if (_cancelSource.IsCancellationRequested) return task; + + _cancelSource.Cancel(); + _stopReason = ExceptionDispatchInfo.Capture(ex); + _wasUnexpected = isUnexpected; + } + return task; + } + + /// Throws an exception if one was captured. + public void Throw() + { + lock (_lock) + { + if (_stopReason != null) + _stopReason.Throw(); + } + } + } +} diff --git a/src/Discord.Net/Net/Rest/SharpRestEngine.cs b/src/Discord.Net/Net/Rest/SharpRestEngine.cs index e5369dce7..af95f9fea 100644 --- a/src/Discord.Net/Net/Rest/SharpRestEngine.cs +++ b/src/Discord.Net/Net/Rest/SharpRestEngine.cs @@ -77,7 +77,7 @@ namespace Discord.Net.Rest .FirstOrDefault(x => x.Name.Equals("Retry-After", StringComparison.OrdinalIgnoreCase)); if (retryAfter != null) { - await Task.Delay((int)retryAfter.Value); + await Task.Delay((int)retryAfter.Value).ConfigureAwait(false); continue; } throw new HttpException(response.StatusCode); diff --git a/src/Discord.Net/Net/WebSockets/GatewaySocket.cs b/src/Discord.Net/Net/WebSockets/GatewaySocket.cs index d2a3ffe67..d9e7a64e1 100644 --- a/src/Discord.Net/Net/WebSockets/GatewaySocket.cs +++ b/src/Discord.Net/Net/WebSockets/GatewaySocket.cs @@ -2,6 +2,7 @@ using Newtonsoft.Json; using Newtonsoft.Json.Linq; using System; +using System.Collections.Generic; using System.Threading.Tasks; namespace Discord.Net.WebSockets @@ -11,14 +12,11 @@ namespace Discord.Net.WebSockets public int LastSequence => _lastSeq; private int _lastSeq; - public string Token => _token; - private string _token; - public string SessionId => _sessionId; private string _sessionId; - public GatewaySocket(DiscordConfig config, Logger logger) - : base(config, logger) + public GatewaySocket(DiscordClient client, Logger logger) + : base(client, logger) { Disconnected += async (s, e) => { @@ -27,18 +25,13 @@ namespace Discord.Net.WebSockets }; } - public async Task Connect(string token) + public async Task Connect() { - await SignalDisconnect(wait: true).ConfigureAwait(false); - - _token = token; await BeginConnect().ConfigureAwait(false); - SendIdentify(token); + SendIdentify(); } - private async Task Redirect(string server) + private async Task Redirect() { - await SignalDisconnect(wait: true).ConfigureAwait(false); - await BeginConnect().ConfigureAwait(false); SendResume(); } @@ -47,12 +40,12 @@ namespace Discord.Net.WebSockets try { var cancelToken = ParentCancelToken.Value; - await Task.Delay(_config.ReconnectDelay, cancelToken).ConfigureAwait(false); + await Task.Delay(_client.Config.ReconnectDelay, cancelToken).ConfigureAwait(false); while (!cancelToken.IsCancellationRequested) { try { - await Connect(_token).ConfigureAwait(false); + await Connect().ConfigureAwait(false); break; } catch (OperationCanceledException) { throw; } @@ -60,21 +53,21 @@ namespace Discord.Net.WebSockets { _logger.Log(LogSeverity.Error, $"Reconnect failed", ex); //Net is down? We can keep trying to reconnect until the user runs Disconnect() - await Task.Delay(_config.FailedReconnectDelay, cancelToken).ConfigureAwait(false); + await Task.Delay(_client.Config.FailedReconnectDelay, cancelToken).ConfigureAwait(false); } } } catch (OperationCanceledException) { } } - public Task Disconnect() - { - return SignalDisconnect(wait: true); - } + public Task Disconnect() => TaskManager.Stop(); protected override async Task Run() - { - await RunTasks(); - } + { + List tasks = new List(); + tasks.AddRange(_engine.GetTasks(_cancelToken)); + tasks.Add(HeartbeatAsync(_cancelToken)); + await _taskManager.Start(tasks, _cancelTokenSource).ConfigureAwait(false); + } protected override async Task ProcessMessage(string json) { @@ -102,7 +95,7 @@ namespace Discord.Net.WebSockets } RaiseReceivedDispatch(msg.Type, token); if (msg.Type == "READY" || msg.Type == "RESUMED") - await EndConnect(); //Complete the connect + EndConnect(); //Complete the connect } break; case GatewayOpCodes.Redirect: @@ -113,7 +106,7 @@ namespace Discord.Net.WebSockets Host = payload.Url; if (_logger.Level >= LogSeverity.Info) _logger.Info("Redirected to " + payload.Url); - await Redirect(payload.Url).ConfigureAwait(false); + await Redirect().ConfigureAwait(false); } } break; @@ -124,12 +117,12 @@ namespace Discord.Net.WebSockets } } - public void SendIdentify(string token) + public void SendIdentify() { var msg = new IdentifyCommand(); - msg.Payload.Token = token; + msg.Payload.Token = _client.Token; msg.Payload.Properties["$device"] = "Discord.Net"; - if (_config.UseLargeThreshold) + if (_client.Config.UseLargeThreshold) msg.Payload.LargeThreshold = 100; msg.Payload.Compress = true; QueueMessage(msg); diff --git a/src/Discord.Net/Net/WebSockets/WS4NetEngine.cs b/src/Discord.Net/Net/WebSockets/WS4NetEngine.cs index 41f3328e2..a8e625191 100644 --- a/src/Discord.Net/Net/WebSockets/WS4NetEngine.cs +++ b/src/Discord.Net/Net/WebSockets/WS4NetEngine.cs @@ -46,7 +46,7 @@ namespace Discord.Net.WebSockets _webSocket = new WS4NetWebSocket(host); _webSocket.EnableAutoSendPing = false; _webSocket.NoDelay = true; - _webSocket.Proxy = null; //Disable + _webSocket.Proxy = null; _webSocket.DataReceived += OnWebSocketBinary; _webSocket.MessageReceived += OnWebSocketText; @@ -81,15 +81,15 @@ namespace Discord.Net.WebSockets return TaskHelper.CompletedTask; } - private async void OnWebSocketError(object sender, ErrorEventArgs e) + private void OnWebSocketError(object sender, ErrorEventArgs e) { - await _parent.SignalDisconnect(e.Exception, isUnexpected: true).ConfigureAwait(false); + _parent.TaskManager.SignalError(e.Exception); _waitUntilConnect.Set(); } - private async void OnWebSocketClosed(object sender, EventArgs e) + private void OnWebSocketClosed(object sender, EventArgs e) { var ex = new Exception($"Connection lost or close message received."); - await _parent.SignalDisconnect(ex, isUnexpected: false/*true*/).ConfigureAwait(false); + _parent.TaskManager.SignalError(ex, isUnexpected: true); _waitUntilConnect.Set(); } private void OnWebSocketOpened(object sender, EventArgs e) diff --git a/src/Discord.Net/Net/WebSockets/WebSocket.cs b/src/Discord.Net/Net/WebSockets/WebSocket.cs index 38ddbdee2..f932e343b 100644 --- a/src/Discord.Net/Net/WebSockets/WebSocket.cs +++ b/src/Discord.Net/Net/WebSockets/WebSocket.cs @@ -1,52 +1,40 @@ using Newtonsoft.Json; using System; -using System.Collections.Generic; using System.IO; using System.IO.Compression; -using System.Linq; -using System.Runtime.ExceptionServices; using System.Threading; using System.Threading.Tasks; namespace Discord.Net.WebSockets { - public enum WebSocketState : byte - { - Disconnected, - Connecting, - Connected, - Disconnecting - } - public abstract partial class WebSocket - { - protected readonly IWebSocketEngine _engine; - protected readonly DiscordConfig _config; + { + private readonly Semaphore _lock; + protected readonly IWebSocketEngine _engine; + protected readonly DiscordClient _client; protected readonly ManualResetEventSlim _connectedEvent; - protected ExceptionDispatchInfo _disconnectReason; - protected bool _wasDisconnectUnexpected; - protected WebSocketState _disconnectState; - protected int _heartbeatInterval; private DateTime _lastHeartbeat; - private Task _runTask; public CancellationToken? ParentCancelToken { get; set; } public CancellationToken CancelToken => _cancelToken; - private CancellationTokenSource _cancelTokenSource; + protected CancellationTokenSource _cancelTokenSource; protected CancellationToken _cancelToken; - internal JsonSerializer Serializer => _serializer; + public JsonSerializer Serializer => _serializer; protected JsonSerializer _serializer; - public Logger Logger => _logger; + internal TaskManager TaskManager => _taskManager; + protected readonly TaskManager _taskManager; + + public Logger Logger => _logger; protected readonly Logger _logger; public string Host { get { return _host; } set { _host = value; } } private string _host; - public WebSocketState State => (WebSocketState)_state; + public ConnectionState State => (ConnectionState)_state; protected int _state; public event EventHandler Connected; @@ -66,20 +54,22 @@ namespace Discord.Net.WebSockets Disconnected(this, new DisconnectedEventArgs(wasUnexpected, error)); } - public WebSocket(DiscordConfig config, Logger logger) + public WebSocket(DiscordClient client, Logger logger) { - _config = config; + _client = client; _logger = logger; - - _cancelToken = new CancellationToken(true); + + _lock = new Semaphore(1, 1); + _taskManager = new TaskManager(Cleanup); + _cancelToken = new CancellationToken(true); _connectedEvent = new ManualResetEventSlim(false); #if !DOTNET5_4 - _engine = new WS4NetEngine(this, _config, _logger); + _engine = new WS4NetEngine(this, client.Config, _logger); #else - //_engine = new BuiltInWebSocketEngine(this, _config, _logger); + //_engine = new BuiltInWebSocketEngine(this, client.Config, _logger); #endif - _engine.BinaryMessage += (s, e) => + _engine.BinaryMessage += (s, e) => { using (var compressed = new MemoryStream(e.Data, 2, e.Data.Length - 2)) using (var decompressed = new MemoryStream()) @@ -91,10 +81,7 @@ namespace Discord.Net.WebSockets ProcessMessage(reader.ReadToEnd()).Wait(); } }; - _engine.TextMessage += (s, e) => - { - /*await*/ ProcessMessage(e.Message).Wait(); - }; + _engine.TextMessage += (s, e) => ProcessMessage(e.Message).Wait(); _serializer = new JsonSerializer(); _serializer.DateTimeZoneHandling = DateTimeZoneHandling.Utc; @@ -112,127 +99,59 @@ namespace Discord.Net.WebSockets protected async Task BeginConnect() { - try - { - _state = (int)WebSocketState.Connecting; - - if (ParentCancelToken == null) - throw new InvalidOperationException("Parent cancel token was never set."); - _cancelTokenSource = new CancellationTokenSource(); - _cancelToken = CancellationTokenSource.CreateLinkedTokenSource(_cancelTokenSource.Token, ParentCancelToken.Value).Token; - - if (_state != (int)WebSocketState.Connecting) - throw new InvalidOperationException("Socket is in the wrong state."); - - _lastHeartbeat = DateTime.UtcNow; - await _engine.Connect(Host, _cancelToken).ConfigureAwait(false); - - _runTask = Run(); + try + { + _lock.WaitOne(); + try + { + await _taskManager.Stop().ConfigureAwait(false); + _state = (int)ConnectionState.Connecting; + + _cancelTokenSource = new CancellationTokenSource(); + _cancelToken = CancellationTokenSource.CreateLinkedTokenSource(_cancelTokenSource.Token, ParentCancelToken.Value).Token; + _lastHeartbeat = DateTime.UtcNow; + + await _engine.Connect(Host, _cancelToken).ConfigureAwait(false); + await Run().ConfigureAwait(false); + } + finally + { + _lock.Release(); + } } catch (Exception ex) { - await SignalDisconnect(ex, isUnexpected: false).ConfigureAwait(false); - throw; + _taskManager.SignalError(ex, true); } } - protected async Task EndConnect() + protected void EndConnect() { try - { - _state = (int)WebSocketState.Connected; + { + _state = (int)ConnectionState.Connected; _connectedEvent.Set(); RaiseConnected(); } catch (Exception ex) - { - await SignalDisconnect(ex, isUnexpected: false).ConfigureAwait(false); - throw; - } - } - - protected internal async Task SignalDisconnect(Exception ex = null, bool isUnexpected = false, bool wait = false) - { - //If in either connecting or connected state, get a lock by being the first to switch to disconnecting - int oldState = Interlocked.CompareExchange(ref _state, (int)WebSocketState.Disconnecting, (int)WebSocketState.Connecting); - if (oldState == (int)WebSocketState.Disconnected) return; //Already disconnected - bool hasWriterLock = oldState == (int)WebSocketState.Connecting; //Caused state change - if (!hasWriterLock) - { - oldState = Interlocked.CompareExchange(ref _state, (int)WebSocketState.Disconnecting, (int)WebSocketState.Connected); - if (oldState == (int)WebSocketState.Disconnected) return; //Already disconnected - hasWriterLock = oldState == (int)WebSocketState.Connected; //Caused state change - } - - if (hasWriterLock) { - if (ex != null) - _logger.Log(LogSeverity.Error, "Error", ex); - CaptureError(ex ?? new Exception("Disconnect was requested."), isUnexpected); - _cancelTokenSource.Cancel(); - if (_disconnectState == WebSocketState.Connecting) //_runTask was never made - await Cleanup().ConfigureAwait(false); - } - - if (wait) - { - Task task = _runTask; - if (_runTask != null) - await task.ConfigureAwait(false); - } - } - private void CaptureError(Exception ex, bool isUnexpected) - { - _disconnectReason = ExceptionDispatchInfo.Capture(ex); - _wasDisconnectUnexpected = isUnexpected; + _taskManager.SignalError(ex, true); + } } protected abstract Task Run(); - protected async Task RunTasks(params Task[] tasks) - { - //Get all async tasks - tasks = tasks - .Concat(_engine.GetTasks(_cancelToken)) - .Concat(new Task[] { HeartbeatAsync(_cancelToken) }) - .ToArray(); - - //Create group tasks - Task firstTask = Task.WhenAny(tasks); - Task allTasks = Task.WhenAll(tasks); - - //Wait until the first task ends/errors and capture the error - Exception ex = null; - try { await firstTask.ConfigureAwait(false); } - catch (Exception ex2) { ex = ex2; } - - //Ensure all other tasks are signaled to end. - await SignalDisconnect(ex, ex != null, true).ConfigureAwait(false); - - //Wait for the remaining tasks to complete - try { await allTasks.ConfigureAwait(false); } - catch { } - - //Start cleanup - await Cleanup().ConfigureAwait(false); - } - protected virtual async Task Cleanup() { - var disconnectState = _disconnectState; - _disconnectState = WebSocketState.Disconnected; - var wasDisconnectUnexpected = _wasDisconnectUnexpected; - _wasDisconnectUnexpected = false; - //Dont reset disconnectReason, we may called ThrowError() later - await _engine.Disconnect().ConfigureAwait(false); _cancelTokenSource = null; var oldState = _state; - _state = (int)WebSocketState.Disconnected; - _runTask = null; _connectedEvent.Reset(); - if (disconnectState == WebSocketState.Connected) - RaiseDisconnected(wasDisconnectUnexpected, _disconnectReason?.SourceException); + if (oldState == (int)ConnectionState.Connected) + { + _state = (int)ConnectionState.Disconnected; + RaiseDisconnected(_taskManager.WasUnexpected, _taskManager.Exception); + } } protected virtual Task ProcessMessage(string json) @@ -240,8 +159,7 @@ namespace Discord.Net.WebSockets if (_logger.Level >= LogSeverity.Debug) _logger.Debug( $"In: {json}"); return TaskHelper.CompletedTask; - } - + } protected void QueueMessage(object message) { string json = JsonConvert.SerializeObject(message); @@ -250,7 +168,7 @@ namespace Discord.Net.WebSockets _engine.QueueMessage(json); } - private Task HeartbeatAsync(CancellationToken cancelToken) + protected Task HeartbeatAsync(CancellationToken cancelToken) { return Task.Run(async () => { @@ -258,7 +176,7 @@ namespace Discord.Net.WebSockets { while (!cancelToken.IsCancellationRequested) { - if (_state == (int)WebSocketState.Connected) + if (_state == (int)ConnectionState.Connected) { SendHeartbeat(); await Task.Delay(_heartbeatInterval, cancelToken).ConfigureAwait(false); @@ -269,18 +187,12 @@ namespace Discord.Net.WebSockets } catch (OperationCanceledException) { } }); - } + } + public abstract void SendHeartbeat(); - protected internal void ThrowError() + protected internal void ThrowError() { - if (_wasDisconnectUnexpected) - { - var reason = _disconnectReason; - _disconnectReason = null; - reason.Throw(); - } + _taskManager.Throw(); } - - public abstract void SendHeartbeat(); } } diff --git a/test/Discord.Net.Tests/Tests.cs b/test/Discord.Net.Tests/Tests.cs index 237810778..409d788ab 100644 --- a/test/Discord.Net.Tests/Tests.cs +++ b/test/Discord.Net.Tests/Tests.cs @@ -112,9 +112,9 @@ namespace Discord.Tests public static void Cleanup() { WaitMany( - _hostClient.State == DiscordClientState.Connected ? _hostClient.AllServers.Select(x => _hostClient.LeaveServer(x)) : null, - _targetBot.State == DiscordClientState.Connected ? _targetBot.AllServers.Select(x => _targetBot.LeaveServer(x)) : null, - _observerBot.State == DiscordClientState.Connected ? _observerBot.AllServers.Select(x => _observerBot.LeaveServer(x)) : null); + _hostClient.State == ConnectionState.Connected ? _hostClient.AllServers.Select(x => _hostClient.LeaveServer(x)) : null, + _targetBot.State == ConnectionState.Connected ? _targetBot.AllServers.Select(x => _targetBot.LeaveServer(x)) : null, + _observerBot.State == ConnectionState.Connected ? _observerBot.AllServers.Select(x => _observerBot.LeaveServer(x)) : null); WaitAll( _hostClient.Disconnect(),