diff --git a/src/Discord.Net/DiscordClient.cs b/src/Discord.Net/DiscordClient.cs index bd05fcecf..7a40ad9f3 100644 --- a/src/Discord.Net/DiscordClient.cs +++ b/src/Discord.Net/DiscordClient.cs @@ -1,11 +1,15 @@ using Discord.API; +using Discord.Net; using Discord.Net.WebSockets; using Newtonsoft.Json; using System; using System.Collections.Concurrent; using System.Collections.Generic; +using System.IO; using System.Reflection; using System.Runtime.ExceptionServices; +using System.Security.Cryptography; +using System.Text; using System.Threading; using System.Threading.Tasks; @@ -255,59 +259,43 @@ namespace Discord /// Returns a token for future connections. public async Task Connect(string email, string password) { - if (!_sentInitialLog) - SendInitialLog(); - - if (State != ConnectionState.Disconnected) - await Disconnect().ConfigureAwait(false); - - var response = await _api.Login(email, password).ConfigureAwait(false); - _token = response.Token; - _api.Token = response.Token; - if (_config.LogLevel >= LogSeverity.Verbose) - _logger.Verbose( "Login successful, got token."); - - await BeginConnect().ConfigureAwait(false); - return response.Token; + if (email == null) throw new ArgumentNullException(email); + if (password == null) throw new ArgumentNullException(password); + + await BeginConnect(email, password, null).ConfigureAwait(false); + return _token; } /// Connects to the Discord server with the provided token. public async Task Connect(string token) { - if (!_sentInitialLog) - SendInitialLog(); - - if (State != ConnectionState.Disconnected) - await Disconnect().ConfigureAwait(false); + if (token == null) throw new ArgumentNullException(token); - _token = token; - _api.Token = token; - await BeginConnect().ConfigureAwait(false); + await BeginConnect(null, null, token).ConfigureAwait(false); } - private async Task BeginConnect() + private async Task BeginConnect(string email, string password, string token = null) { try { _lock.WaitOne(); try { + if (!_sentInitialLog) + SendInitialLog(); + + if (State != ConnectionState.Disconnected) + await Disconnect().ConfigureAwait(false); await _taskManager.Stop().ConfigureAwait(false); _taskManager.ClearException(); _state = ConnectionState.Connecting; - - var gatewayResponse = await _api.Gateway().ConfigureAwait(false); - string gateway = gatewayResponse.Url; - if (_config.LogLevel >= LogSeverity.Verbose) - _logger.Verbose( $"Websocket endpoint: {gateway}"); - _disconnectedEvent.Reset(); - _gateway = gateway; - _cancelTokenSource = new CancellationTokenSource(); _cancelToken = _cancelTokenSource.Token; - _webSocket.Host = gateway; + await Login(email, password, token); + + _webSocket.Host = _gateway; _webSocket.ParentCancelToken = _cancelToken; await _webSocket.Connect().ConfigureAwait(false); @@ -341,7 +329,58 @@ namespace Discord throw; } } - private void EndConnect() + private async Task Login(string email, string password, string token) + { + bool useCache = _config.CacheToken; + while (true) + { + //Get Token + if (token == null) + { + if (useCache) + { + Rfc2898DeriveBytes deriveBytes = new Rfc2898DeriveBytes(password, + new byte[] { 0x5A, 0x2A, 0xF8, 0xCF, 0x78, 0xD3, 0x7D, 0x0D }); + byte[] key = deriveBytes.GetBytes(16); + + string tokenPath = GetTokenCachePath(email); + token = LoadToken(tokenPath, key); + if (token == null) + { + var response = await _api.Login(email, password).ConfigureAwait(false); + token = response.Token; + SaveToken(tokenPath, key, token); + useCache = false; + } + } + else + { + var response = await _api.Login(email, password).ConfigureAwait(false); + token = response.Token; + } + } + _token = token; + _api.Token = token; + + //Get gateway and check token + try + { + var gatewayResponse = await _api.Gateway().ConfigureAwait(false); + var gateway = gatewayResponse.Url; + _gateway = gateway; + if (_config.LogLevel >= LogSeverity.Verbose) + _logger.Verbose($"Login successful, gateway: {gateway}"); + } + catch (HttpException ex) when (ex.StatusCode == System.Net.HttpStatusCode.Unauthorized && useCache) + { + useCache = false; //Cached token is bad, retry without cache + token = null; + continue; + } + break; + } + } + private void EndConnect() { _state = ConnectionState.Connected; _connectedEvent.Set(); @@ -839,5 +878,71 @@ namespace Discord messageCount = _messages.Count; roleCount = _roles.Count; } + + private string GetTokenCachePath(string email) + { + using (var md5 = MD5.Create()) + { + byte[] data = md5.ComputeHash(Encoding.UTF8.GetBytes(email.ToLowerInvariant())); + StringBuilder filenameBuilder = new StringBuilder(); + for (int i = 0; i < data.Length; i++) + filenameBuilder.Append(data[i].ToString("x2")); + return Path.Combine(Path.GetTempPath(), _config.AppName ?? "Discord.Net", filenameBuilder.ToString()); + } + } + private string LoadToken(string path, byte[] key) + { + if (File.Exists(path)) + { + try + { + using (var fileStream = File.Open(path, FileMode.Open)) + using (var aes = Aes.Create()) + { + byte[] iv = new byte[aes.BlockSize / 8]; + fileStream.Read(iv, 0, iv.Length); + aes.IV = iv; + aes.Key = key; + using (var cryptoStream = new CryptoStream(fileStream, aes.CreateDecryptor(), CryptoStreamMode.Read)) + { + byte[] tokenBuffer = new byte[64]; + int length = cryptoStream.Read(tokenBuffer, 0, tokenBuffer.Length); + return Encoding.UTF8.GetString(tokenBuffer, 0, length); + } + } + } + catch (Exception ex) + { + _logger.Warning("Failed to load cached token. Wrong/changed password?", ex); + } + } + return null; + } + private void SaveToken(string path, byte[] key, string token) + { + byte[] tokenBytes = Encoding.UTF8.GetBytes(token); + try + { + string parentDir = Path.GetDirectoryName(path); + if (!Directory.Exists(parentDir)) + Directory.CreateDirectory(parentDir); + + using (var fileStream = File.Open(path, FileMode.Create)) + using (var aes = Aes.Create()) + { + aes.GenerateIV(); + aes.Key = key; + using (var cryptoStream = new CryptoStream(fileStream, aes.CreateEncryptor(), CryptoStreamMode.Write)) + { + fileStream.Write(aes.IV, 0, aes.IV.Length); + cryptoStream.Write(tokenBytes, 0, tokenBytes.Length); + } + } + } + catch (Exception ex) + { + _logger.Warning("Failed to cache token", ex); + } + } } } \ No newline at end of file diff --git a/src/Discord.Net/DiscordConfig.cs b/src/Discord.Net/DiscordConfig.cs index 620a3fab6..6da162c1c 100644 --- a/src/Discord.Net/DiscordConfig.cs +++ b/src/Discord.Net/DiscordConfig.cs @@ -47,27 +47,10 @@ namespace Discord /// Version of your application. public string AppVersion { get { return _appVersion; } set { SetValue(ref _appVersion, value); UpdateUserAgent(); } } private string _appVersion = null; - /// User Agent string to use when connecting to Discord. [JsonIgnore] public string UserAgent { get { return _userAgent; } } private string _userAgent; - private void UpdateUserAgent() - { - StringBuilder builder = new StringBuilder(); - if (!string.IsNullOrEmpty(_appName)) - { - builder.Append(_appName); - if (!string.IsNullOrEmpty(_appVersion)) - { - builder.Append('/'); - builder.Append(_appVersion); - } - builder.Append(' '); - } - builder.Append($"DiscordBot (https://github.com/RogueException/Discord.Net, v{DiscordClient.Version})"); - _userAgent = builder.ToString(); - } //Rest @@ -100,6 +83,9 @@ namespace Discord //Performance + /// Cache an encrypted login token to temp dir after success login. + public bool CacheToken { get { return _cacheToken; } set { SetValue(ref _cacheToken, value); } } + private bool _cacheToken = true; /// Instructs Discord to not send send information about offline users, for servers with more than 50 users. public bool UseLargeThreshold { get { return _useLargeThreshold; } set { SetValue(ref _useLargeThreshold, value); } } private bool _useLargeThreshold = false; @@ -114,5 +100,22 @@ namespace Discord { UpdateUserAgent(); } - } + + private void UpdateUserAgent() + { + StringBuilder builder = new StringBuilder(); + if (!string.IsNullOrEmpty(_appName)) + { + builder.Append(_appName); + if (!string.IsNullOrEmpty(_appVersion)) + { + builder.Append('/'); + builder.Append(_appVersion); + } + builder.Append(' '); + } + builder.Append($"DiscordBot (https://github.com/RogueException/Discord.Net, v{DiscordClient.Version})"); + _userAgent = builder.ToString(); + } + } } diff --git a/src/Discord.Net/project.json b/src/Discord.Net/project.json index 8be5cb8c9..5fedb7b7a 100644 --- a/src/Discord.Net/project.json +++ b/src/Discord.Net/project.json @@ -53,6 +53,7 @@ "System.Net.Requests": "4.0.11-beta-23516", "System.Net.WebSockets.Client": "4.0.0-beta-23516", "System.Runtime.InteropServices": "4.0.21-beta-23516", + "System.Security.Cryptography.Algorithms": "4.0.0-beta-23516", "System.Text.RegularExpressions": "4.0.11-beta-23516", "System.Threading": "4.0.11-beta-23516", "System.Threading.Thread": "4.0.0-beta-23516"