diff --git a/src/Discord.Net.Rest/DiscordRestApiClient.cs b/src/Discord.Net.Rest/DiscordRestApiClient.cs index 8eec66ec2..065a25b55 100644 --- a/src/Discord.Net.Rest/DiscordRestApiClient.cs +++ b/src/Discord.Net.Rest/DiscordRestApiClient.cs @@ -35,7 +35,6 @@ namespace Discord.API protected bool _isDisposed; private CancellationTokenSource _loginCancelToken; - private bool _fetchCurrentUser; public RetryMode DefaultRetryMode { get; } public string UserAgent { get; } @@ -45,18 +44,15 @@ namespace Discord.API public TokenType AuthTokenType { get; private set; } internal string AuthToken { get; private set; } internal IRestClient RestClient { get; private set; } - internal User CurrentUser { get; private set; } - - public ulong? CurrentUserId => CurrentUser?.Id; + internal ulong? CurrentUserId { get; set;} public DiscordRestApiClient(RestClientProvider restClientProvider, string userAgent, RetryMode defaultRetryMode = RetryMode.AlwaysRetry, - JsonSerializer serializer = null, bool fetchCurrentUser = true) + JsonSerializer serializer = null) { RestClientProvider = restClientProvider; UserAgent = userAgent; DefaultRetryMode = defaultRetryMode; _serializer = serializer ?? new JsonSerializer { DateFormatString = "yyyy-MM-ddTHH:mm:ssZ", ContractResolver = new DiscordContractResolver() }; - _fetchCurrentUser = fetchCurrentUser; RequestQueue = new RequestQueue(); _stateLock = new SemaphoreSlim(1, 1); @@ -126,9 +122,6 @@ namespace Discord.API AuthToken = token; RestClient.SetHeader("authorization", GetPrefixedToken(AuthTokenType, AuthToken)); - if (_fetchCurrentUser) - CurrentUser = await GetMyUserAsync(new RequestOptions { IgnoreState = true, RetryMode = RetryMode.AlwaysRetry }).ConfigureAwait(false); - LoginState = LoginState.LoggedIn; } catch (Exception) @@ -162,7 +155,7 @@ namespace Discord.API await RequestQueue.SetCancelTokenAsync(CancellationToken.None).ConfigureAwait(false); RestClient.SetCancelToken(CancellationToken.None); - CurrentUser = null; + CurrentUserId = null; LoginState = LoginState.LoggedOut; } @@ -949,7 +942,7 @@ namespace Discord.API Preconditions.NotNull(args, nameof(args)); options = RequestOptions.CreateOrClone(options); - bool isCurrentUser = userId == CurrentUser.Id; + bool isCurrentUser = userId == CurrentUserId; if (isCurrentUser && args.Nickname.IsSpecified) { diff --git a/src/Discord.Net.Rest/DiscordRestClient.cs b/src/Discord.Net.Rest/DiscordRestClient.cs index 384e43821..f5f2cb8b0 100644 --- a/src/Discord.Net.Rest/DiscordRestClient.cs +++ b/src/Discord.Net.Rest/DiscordRestClient.cs @@ -17,10 +17,11 @@ namespace Discord.Rest private static API.DiscordRestApiClient CreateApiClient(DiscordRestConfig config) => new API.DiscordRestApiClient(config.RestClientProvider, DiscordRestConfig.UserAgent); - protected override Task OnLoginAsync(TokenType tokenType, string token) + protected override async Task OnLoginAsync(TokenType tokenType, string token) { - base.CurrentUser = RestSelfUser.Create(this, ApiClient.CurrentUser); - return Task.Delay(0); + var user = await ApiClient.GetMyUserAsync(new RequestOptions { RetryMode = RetryMode.AlwaysRetry }).ConfigureAwait(false); + ApiClient.CurrentUserId = user.Id; + base.CurrentUser = RestSelfUser.Create(this, user); } protected override Task OnLogoutAsync() { diff --git a/src/Discord.Net.Rpc/DiscordRpcApiClient.cs b/src/Discord.Net.Rpc/DiscordRpcApiClient.cs index b1ac7121a..8c83d24d6 100644 --- a/src/Discord.Net.Rpc/DiscordRpcApiClient.cs +++ b/src/Discord.Net.Rpc/DiscordRpcApiClient.cs @@ -69,7 +69,7 @@ namespace Discord.API public DiscordRpcApiClient(string clientId, string userAgent, string origin, RestClientProvider restClientProvider, WebSocketProvider webSocketProvider, RetryMode defaultRetryMode = RetryMode.AlwaysRetry, JsonSerializer serializer = null) - : base(restClientProvider, userAgent, defaultRetryMode, serializer, false) + : base(restClientProvider, userAgent, defaultRetryMode, serializer) { _connectionLock = new SemaphoreSlim(1, 1); _clientId = clientId; diff --git a/src/Discord.Net.Rpc/DiscordRpcClient.cs b/src/Discord.Net.Rpc/DiscordRpcClient.cs index e47cbf30c..845ba97c6 100644 --- a/src/Discord.Net.Rpc/DiscordRpcClient.cs +++ b/src/Discord.Net.Rpc/DiscordRpcClient.cs @@ -435,6 +435,7 @@ namespace Discord.Rpc { var response = await ApiClient.SendAuthenticateAsync(options).ConfigureAwait(false); CurrentUser = RestSelfUser.Create(this, response.User); + ApiClient.CurrentUserId = CurrentUser.Id; ApplicationInfo = RestApplication.Create(this, response.Application); Scopes = response.Scopes; TokenExpiresAt = response.Expires; diff --git a/src/Discord.Net.WebSocket/DiscordShardedClient.cs b/src/Discord.Net.WebSocket/DiscordShardedClient.cs index a32c46f10..3a8f90990 100644 --- a/src/Discord.Net.WebSocket/DiscordShardedClient.cs +++ b/src/Discord.Net.WebSocket/DiscordShardedClient.cs @@ -64,7 +64,7 @@ namespace Discord.WebSocket _shardIdsToIndex.Add(_shardIds[i], i); var newConfig = config.Clone(); newConfig.ShardId = _shardIds[i]; - _shards[i] = new DiscordSocketClient(newConfig, _connectionGroupLock); + _shards[i] = new DiscordSocketClient(newConfig, _connectionGroupLock, i != 0 ? _shards[0] : null); RegisterEvents(_shards[i]); } } @@ -86,7 +86,7 @@ namespace Discord.WebSocket var newConfig = _baseConfig.Clone(); newConfig.ShardId = _shardIds[i]; newConfig.TotalShards = _totalShards; - _shards[i] = new DiscordSocketClient(newConfig, _connectionGroupLock); + _shards[i] = new DiscordSocketClient(newConfig, _connectionGroupLock, i != 0 ? _shards[0] : null); RegisterEvents(_shards[i]); } } diff --git a/src/Discord.Net.WebSocket/DiscordSocketApiClient.cs b/src/Discord.Net.WebSocket/DiscordSocketApiClient.cs index 90f7f5c67..fcfa76653 100644 --- a/src/Discord.Net.WebSocket/DiscordSocketApiClient.cs +++ b/src/Discord.Net.WebSocket/DiscordSocketApiClient.cs @@ -35,7 +35,7 @@ namespace Discord.API public DiscordSocketApiClient(RestClientProvider restClientProvider, string userAgent, WebSocketProvider webSocketProvider, RetryMode defaultRetryMode = RetryMode.AlwaysRetry, JsonSerializer serializer = null) - : base(restClientProvider, userAgent, defaultRetryMode, serializer, true) + : base(restClientProvider, userAgent, defaultRetryMode, serializer) { WebSocketClient = webSocketProvider(); //WebSocketClient.SetHeader("user-agent", DiscordConfig.UserAgent); (Causes issues in .NET Framework 4.6+) diff --git a/src/Discord.Net.WebSocket/DiscordSocketClient.cs b/src/Discord.Net.WebSocket/DiscordSocketClient.cs index 7591717b6..5ecd48632 100644 --- a/src/Discord.Net.WebSocket/DiscordSocketClient.cs +++ b/src/Discord.Net.WebSocket/DiscordSocketClient.cs @@ -27,6 +27,7 @@ namespace Discord.WebSocket private readonly Logger _gatewayLogger; private readonly JsonSerializer _serializer; private readonly SemaphoreSlim _connectionGroupLock; + private readonly DiscordSocketClient _parentClient; private string _sessionId; private int _lastSeq; @@ -71,9 +72,9 @@ namespace Discord.WebSocket /// Creates a new REST/WebSocket discord client. public DiscordSocketClient() : this(new DiscordSocketConfig()) { } /// Creates a new REST/WebSocket discord client. - public DiscordSocketClient(DiscordSocketConfig config) : this(config, CreateApiClient(config), null) { } - internal DiscordSocketClient(DiscordSocketConfig config, SemaphoreSlim groupLock) : this(config, CreateApiClient(config), groupLock) { } - private DiscordSocketClient(DiscordSocketConfig config, API.DiscordSocketApiClient client, SemaphoreSlim groupLock) + public DiscordSocketClient(DiscordSocketConfig config) : this(config, CreateApiClient(config), null, null) { } + internal DiscordSocketClient(DiscordSocketConfig config, SemaphoreSlim groupLock, DiscordSocketClient parentClient) : this(config, CreateApiClient(config), groupLock, parentClient) { } + private DiscordSocketClient(DiscordSocketConfig config, API.DiscordSocketApiClient client, SemaphoreSlim groupLock, DiscordSocketClient parentClient) : base(config, client) { ShardId = config.ShardId ?? 0; @@ -90,6 +91,7 @@ namespace Discord.WebSocket _nextAudioId = 1; _gatewayLogger = LogManager.CreateLogger(ShardId == 0 && TotalShards == 1 ? "Gateway" : "Shard #" + ShardId); _connectionGroupLock = groupLock; + _parentClient = parentClient; _serializer = new JsonSerializer { ContractResolver = new DiscordContractResolver() }; _serializer.Error += (s, e) => @@ -134,8 +136,13 @@ namespace Discord.WebSocket protected override async Task OnLoginAsync(TokenType tokenType, string token) { - var voiceRegions = await ApiClient.GetVoiceRegionsAsync(new RequestOptions { IgnoreState = true, RetryMode = RetryMode.AlwaysRetry }).ConfigureAwait(false); - _voiceRegions = voiceRegions.Select(x => RestVoiceRegion.Create(this, x)).ToImmutableDictionary(x => x.Id); + if (_parentClient == null) + { + var voiceRegions = await ApiClient.GetVoiceRegionsAsync(new RequestOptions { IgnoreState = true, RetryMode = RetryMode.AlwaysRetry }).ConfigureAwait(false); + _voiceRegions = voiceRegions.Select(x => RestVoiceRegion.Create(this, x)).ToImmutableDictionary(x => x.Id); + } + else + _voiceRegions = _parentClient._voiceRegions; } protected override async Task OnLogoutAsync() { @@ -603,6 +610,7 @@ namespace Discord.WebSocket var state = new ClientState(data.Guilds.Length, data.PrivateChannels.Length); var currentUser = SocketSelfUser.Create(this, state, data.User); + ApiClient.CurrentUserId = currentUser.Id; int unavailableGuilds = 0; for (int i = 0; i < data.Guilds.Length; i++) {